Autodifferentiating through a model that requires the conversion of a traced array #3903
Unanswered
Peter-Vincent
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
This is a copy of the question asked in the jax github repo, but I figured this might be a more appropriate place to ask it!
I have a model I am trying to optimise. A component of that model is a neural network. That neural network is trained beforehand (since it can be used by a variety of other models), saved using the following code
And then loaded in when I need it with the following code
loaded_model = tf.saved_model.load(model_path)
This is following the documentation and examples in the Orbax repository
The problem is arising now when I try and optimise this second model, for which the neural network is a small component.
If I just run the model without wrapping it in
grad
(orvalue_and_grad
) it works fine. However, if instead the function is wrapped byvalue_and_grad
, then I get these errorswhen I try and call the model with
model(inputs)["outputs"]
I understand at a high level why this is happening, but I don't see a way around this. This will also be a problem if this line were to succeed, since the outputs of the model are
tf.float32
arrays, whichjax.numpy
functions like (for example)jnp.mod
can't process, so the arrays need to be cast back into recognisable types.Any recommendations on how to resolve this would be greatly appreciated!
Many thanks
Beta Was this translation helpful? Give feedback.
All reactions