Main Content

Import Neural Network Models Using ONNX

To create function approximators for reinforcement learning, you can import pre-trained deep neural networks or deep neural network layer architectures using the Deep Learning Toolbox™ network import functionality. You can import:

  • Open Neural Network Exchange (ONNX™) models, which require the Deep Learning Toolbox Converter for ONNX Model Format support package software. For more information, importONNXLayers.

  • TensorFlow™-Keras networks, which require Deep Learning Toolbox Converter for TensorFlow Models support package software. For more information, see importKerasLayers.

  • Caffe convolutional networks, which require Deep Learning Toolbox Importer for Caffe Models support package software. For more information, see importCaffeLayers.

After you import a deep neural network, you can create an actor or critic object, such as rlQValueFunction or rlDiscreteCategoricalActor.

When you import deep neural network architectures, consider the following.

  • The dimensions of the imported network architecture input and output layers must match the dimensions of the corresponding action, observation, or reward dimensions for your environment.

  • After importing the network architecture, you must set the names of the input and output layers to match the names of the corresponding action and observation specifications.

For more information on the deep neural network architectures supported for reinforcement learning, see Create Policies and Value Functions.

Import Actor and Critic for Image Observation Application

As an example, assume that you have an environment with a 50-by-50 grayscale image observation signal and a continuous action space. To train a policy gradient agent, you require the following function approximators, both of which must have a single 50-by-50 image input observation layer and a single scalar output value.

  • Actor — Selects an action value based on the current observation

  • Critic — Estimates the expected long-term reward based on the current observation

Also, assume that you have the following network architectures to import:

  • A deep neural network architecture for the actor with a 50-by-50 image input layer and a scalar output layer, which is saved in the ONNX format (criticNetwork.onnx).

  • A deep neural network architecture for the critic with a 50-by-50 image input layer and a scalar output layer, which is saved in the ONNX format (actorNetwork.onnx).

To import the critic and actor networks, use the importONNXLayers function without specifying an output layer.

criticNetwork = importONNXLayers("criticNetwork.onnx");
actorNetwork = importONNXLayers("actorNetwork.onnx");

These commands generate a warning, which states that the network is trainable until an output layer is added. When you use an imported network to create an actor or critic, Reinforcement Learning Toolbox™ software automatically adds an output layer for you.

After you import the networks, create the actor and critic function approximators. To do so, first obtain the observation and action specifications from the environment.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Create the critic, specifying the name of the input layer of the critic network as the observation name. Since the critic network has a single observation input and a single action output, use a value-function.

critic = rlValueFunction(criticNetwork,obsInfo,...
             ObservationInputNames={criticNetwork.Layers(1).Name});

Create the actor, specifying the name of the input layer of the actor network as the observation name and the output layer of the actor network as the observation name. Since the actor network has a single scalar output, use a continuous deterministic actor.

actor = rlContinuousDeterministicActor(actorNetwork,obsInfo,actInfo,...
             ObservationInputNames={actorNetwork.Layers(1).Name});

You can then:

Related Topics