Import PyTorch LSTM Model into Matlab
10 views (last 30 days)
Show older comments
Hey Guys,
I am currently trying to use my Pytorch LSTM in Matlab (Trained with Pytorch Lightning) but I have no idea how to use the importNetworkFromPyTorch function with an LSTM. The Structure of the model is the following:
LSTM -> Linear -> Sigmoid
The LSTM properties (https://docs.pytorch.org/docs/stable/generated/torch.nn.LSTM.html) are (num_inputs=3, nhid=5, nlayers=5) which causes the Linear layer to be (in=5, out=1).
The Training Data has the shape [BS, 600, 3] with BS being batch_size, 600 being the time series and 3 being the individual input at one timestep. The shape of the hidden state is [5, BS, 5].
So my problem is that I do not understand what input sizes I have to put into the importNetworkFromPyTorch function.
I expect it so be something like this:
net = importNetworkFromPyTorch("example/path/model.pt",PyTorchInputSizes={[NaN,3], [2, 5, NaN, 5]})
I exported the traced model by:
traced_model = torch.jit.trace(model.model.forward, (input, hidden_input))
torch.jit.save(traced_model, "model.pt")
The shape of input is [3] and of hidden_input is ([5, 1, 5], [5, 1, 5]) (one for hidden state and one for context)
Can you please tell me how to use this importNetworkFromPyTorch function.
0 Comments
Answers (1)
Gayathri
on 15 May 2025
Can you please confirm on which MATLAB function you are using? And are you facing any errors when running the "importNetworkFromPyTorch" command in MATLAB?
I can see in the MATLAB documentation that importing LSTM layers is only supported from MATLAB R2025a. Please upgrade to MATLAB R2025a to import the LSTM model.
Hope this helps!
0 Comments
See Also
Categories
Find more on Image Data Workflows in Help Center and File Exchange
Products
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!