How to save pretrained DQN agent and extract the weights inside the network?
4 views (last 30 days)
Show older comments
The following is part of the program. I want to know how to extract the weight values from the trained DQN network.
DQNnet = [
imageInputLayer([1 520 1],"Name","ImageFeatureInput","Normalization","none")
fullyConnectedLayer(1024,"Name","fc1")
reluLayer("Name","relu1")
% fullyConnectedLayer(512,"Name","fc2")
% reluLayer("Name","relu2")
fullyConnectedLayer(14,"Name","fc3")
softmaxLayer("Name","softmax")
classificationLayer("Name","ActionOutput")];
ObsInfo = getObservationInfo(env);
ActInfo = getActionInfo(env);
DQNOpts = rlRepresentationOptions('LearnRate',0.0001,'GradientThreshold',1,'UseDevice','gpu');
DQNagent = rlQValueRepresentation(DQNnet,ObsInfo,ActInfo,'Observation',{'ImageFeatureInput'},'ActionInputNames',{'BoundingBox Actions'},DQNOpts);
agentOpts = rlDQNAgentOptions(...
'UseDoubleDQN',true ...
,'MiniBatchSize',256);
agentOpts.EpsilonGreedyExploration.Epsilon = 1;
agent = rlDQNAgent(DQNagent,agentOpts);
%% Agent Training
% Training options
trainOpts = rlTrainingOptions(...
'MaxEpisodes', 100, ...
'MaxStepsPerEpisode', 100, ...
'Verbose', true, ...
'Plots','training-progress',...
'ScoreAveragingWindowLength',400,...
'StopTrainingCriteria','AverageSteps',...
'StopTrainingValue',1000000000,...
'SaveAgentDirectory', pwd + "\agents\");
% Agent training
trainingStats = train(agent,env,trainOpts);
0 Comments
Accepted Answer
praguna manvi
on 28 Aug 2024
Edited: praguna manvi
on 29 Aug 2024
For saving and loading pretrained “DQN” agent, you could use “load” and “save” functions as:
doTraining = false;
if doTraining
% Train the agent.
trainingStats = train(agent,env,trainOpts);
save myagent.mat agent
else
% Load the pretrained agent for the example.
load("MATLABCartpoleDQNMulti.mat","agent")
end
Refer to the example illustrated here:
And to extract weights from the saved agent you can use “getLearnableParameters“ function refer:
0 Comments
More Answers (0)
See Also
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!