-
A bit of confusion about the Flax semantics of I have a module that has a complex parameterization. During training, I need setup to be called repeatedly to transform my parameters to the right form.
During test, I would like to avoid rerunning this transformation since the parameters don't change. Is there a way to do this? I looked at |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
This is an interesting problem. You provide a (potentially fresh) dict of params on each apply. But the jax arrays don't have identity so there's no super fast way to validate the the param didn't change. I don't think you would want to use bind because in a setup like this:
you can only bind inside eval_step which means you would still recompute it every step in the end (you can't pass abound module across transform boundaries). So I would do this by using a collection to cache the value def setup(self):
self.t = self.variable("constants", "t", slow_function, param) then in training you do: # don't pass old constants back in
y, variables = model.apply({"params": params}, mutable="constants")
# discard constants after optimizer update... and during eval you use y = model.apply({"params": params, "constants": constants}) Before eval you could pre-compute the constants or pass an empty constants dict and make it mutable so it's automatically computed on the first eval_step (in this case jax would compile 2 versions of eval_step one with slow_function and one where it uses the cache) |
Beta Was this translation helpful? Give feedback.
-
Thanks! This is a very clear explanation and doesn't seem to make my code much messier. I'm using caches for RNN state anyway, so this is not too much of a change. It's nice that I can have some mutable and some non-mutable variables. Just to clarify two points.
One random comment:
|
Beta Was this translation helpful? Give feedback.
This is an interesting problem. You provide a (potentially fresh) dict of params on each apply. But the jax arrays don't have identity so there's no super fast way to validate the the param didn't change. I don't think you would want to use bind because in a setup like this:
you can only bind inside eval_step which means you would still recompute it every step in the end (you can't pass abound module across transform boundaries).
So I would do this by using a collection to cache the value
then in training you do:
…