Cannot get tracing to work on complex custom deep learning layer
8 views (last 30 days)
Show older comments
I'm trying to get a Matlab version of https://github.com/jfcrenshaw/pzflow to work because I need something like it buried deep in a Matlab workflow. My working code is at https://github.com/jeremylea/DLextras/tree/main/mzflow. No matter what I try, I cannot get it to train. I keep getting this error: 'dlgradient' inputs must be traced dlarray objects or cell arrays, structures, or tables containing traced dlarray objects. To enable tracing, use 'dlfeval'. This is despite dfeval being in the call stack... The main custom layer is quite complex and only uses the learnables tangentially as inputs to the knot locations in a spline, and then these impact the loss through the jacobian, not through the main output. The code flow is test_flow->Flow.train->dlfeval->loss_fun->dlgradient. It follows a standard training path compared to the examples and the custom training setups I have done and that work.
Hours of debugging tell me that the tape is recording all of the operations on the complete set of layers, that the weight matrices are on the tape, and that the weight matrices are in a recording state when they are called in predict. However, somehow the value that is found by the bijector.Learnables call for dlgradient is returning a table of values that are not in a recording state (which it should be within the dlfeval call?). I can't figure out how and when the matrices get switched out of a recording state or if I have two copies. Can anyone help?
I've tried replacing the embedded dlnetwork with a custom set of weights and biases - that doesn't help. I've also tried using the state to capture the log determinant values (which is much cleaner and would be my prefered design). My next choice is to figure out how to make this a direct call to Python and abandon the deep learning toolkit... Unfortunately, the pzflow library uses JAX, which is hard on Windows, so that probably also means moving the entire flow to Linux.
0 Comments
Accepted Answer
Richard
on 19 May 2023
You are passing the bijector network into dlfeval as data that is copied inside the closure of an anonymous function. dlfeval cannot see these pieces of data because they are private to the closure, and thus cannot convert them to inputs in the trace. You need to change the call to dlfeval so that at least the network is an input to dlfeval:
[~,gradients] = dlfeval(@loss_fun,this.bijector,this.latent,Xbat,Cbat);
It also looks like your implementations of log_prob, and possibly also sample, in the distributions.Uniform class may cause an issue: from inspection it looks to me as if they do not derive their output in a traceable chain from the input, which will ultimtely result in the loss value not being traced.
In log_prob, the problem line is:
log_prob = repmat(-inf,size(mask));
log_prob is not a tracing variable and thus it will not record the mask application on the next line. I think a better implementation is:
function inputs = log_prob(this, inputs)
mask = all((inputs >= 0) & (inputs <= 1),finddim(inputs,"C"));
inputs(:) = -inf;
inputs(mask) = -this.input_dim*log(1.0);
end
I was also a bit suspicious about some of the code in the custom layers' forwards - it isn't clear that they are all correctly tracing everything they do. For example the use of extractdata in the NeuralSplineCoupling class is a flag that often indicates that something will be lost from the trace. You may need to write custom backward implementations for some of these cases.
More Answers (0)
See Also
Categories
Find more on Operations in Help Center and File Exchange
Products
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!