Main Content

setCritic

Set critic of reinforcement learning agent

Since R2019a

Description

example

agent = setCritic(agent,critic) updates the reinforcement learning agent, agent, to use the specified critic object, critic.

Examples

collapse all

Assume that you have an existing trained reinforcement learning agent. For this example, load the trained agent from Compare DDPG Agent to LQR Controller.

load("DoubleIntegDDPG.mat","agent") 

Obtain the critic function approximator from the agent.

critic = getCritic(agent);

Obtain the learnable parameters from the critic.

params = getLearnableParameters(critic)
params=2×1 cell array
    {[-4.9869 -1.5577 -0.3351 -0.1097 -0.0444 7.5240e-04]}
    {[                                                 0]}

Modify the parameter values. For this example, simply multiply all of the parameters by 2.

modifiedParams = cellfun(@(x) x*2,params,"UniformOutput",false);

Set the parameter values of the critic to the new modified values.

critic = setLearnableParameters(critic,modifiedParams);

Set the critic in the agent to the new modified critic.

setCritic(agent,critic);

Display the new parameter values.

getLearnableParameters(getCritic(agent))
ans=2×1 cell array
    {[-9.9737 -3.1153 -0.6702 -0.2194 -0.0888 0.0015]}
    {[                                             0]}

Create an environment with a continuous action space and obtain its observation and action specifications. For this example, load the environment used in the example Compare DDPG Agent to LQR Controller.

Load the predefined environment.

env = rlPredefinedEnv("DoubleIntegrator-Continuous");

Obtain observation and action specifications.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Create a PPO agent from the environment observation and action specifications. This agent uses default deep neural networks for its actor and critic.

agent = rlPPOAgent(obsInfo,actInfo);

To modify the deep neural networks within a reinforcement learning agent, you must first extract the actor and critic function approximators.

actor = getActor(agent);
critic = getCritic(agent);

Extract the deep neural networks from both the actor and critic function approximators.

actorNet = getModel(actor);
criticNet = getModel(critic);

The networks are dlnetwork objects. To view them using the plot function, you must convert them to layerGraph objects.

For example, view the actor network.

plot(layerGraph(actorNet))

Figure contains an axes object. The axes object contains an object of type graphplot.

To validate a network, use analyzeNetwork. For example, validate the critic network.

analyzeNetwork(criticNet)

You can modify the actor and critic networks and save them back to the agent. To modify the networks, you can use the Deep Network Designer app. To open the app for each network, use the following commands.

deepNetworkDesigner(criticNet)
deepNetworkDesigner(actorNet)

In Deep Network Designer, modify the networks. For example, you can add additional layers to your network. When you modify the networks, do not change the input and output layers of the networks returned by getModel. For more information on building networks, see Build Networks with Deep Network Designer.

To validate the modified network in Deep Network Designer, you must click on Analyze, under the Analysis section. To export the modified network structures to the MATLAB® workspace, generate code for creating the new networks and run this code from the command line. Do not use the exporting option in Deep Network Designer. For an example that shows how to generate and run code, see Create DQN Agent Using Deep Network Designer and Train Using Image Observations.

For this example, the code for creating the modified actor and critic networks is in the createModifiedNetworks helper script.

createModifiedNetworks

Each of the modified networks includes an additional fullyConnectedLayer and reluLayer in their main common path. View the modified actor network.

plot(layerGraph(modifiedActorNet))

Figure contains an axes object. The axes object contains an object of type graphplot.

After exporting the networks, insert the networks into the actor and critic function approximators.

actor = setModel(actor,modifiedActorNet);
critic = setModel(critic,modifiedCriticNet);

Finally, insert the modified actor and critic function approximators into the actor and critic objects.

agent = setActor(agent,actor);
agent = setCritic(agent,critic);

Input Arguments

collapse all

Reinforcement learning agent that contains a critic, specified as one of the following:

Note

agent is an handle object. Therefore is updated by setCritic whether agent is returned as an output argument or not. For more information about handle objects, see Handle Object Behavior.

Critic object, specified as one of the following:

  • rlValueFunction object — Returned when agent is an rlACAgent, rlPGAgent, or rlPPOAgent object.

  • rlQValueFunction object — Returned when agent is an rlQAgent, rlSARSAAgent, rlDQNAgent, rlDDPGAgent, or rlTD3Agent object with a single critic.

  • rlVectorQValueFunction object — Returned when agent is an rlQAgent, rlSARSAAgent, rlDQNAgent, object with a discrete action space, vector Q-value function critic.

  • Two-element row vector of rlQValueFunction objects — Returned when agent is an rlTD3Agent or rlSACAgent object with two critics.

Output Arguments

collapse all

Updated agent, returned as an agent object. Note that agent is an handle object. Therefore its actor is updated by setCritic whether agent is returned as an output argument or not. For more information about handle objects, see Handle Object Behavior.

Version History

Introduced in R2019a