How to jit a function that has an instantiated module as an argument? #1153
-
Original question by @psc-g. Example code (simplified by me): class Foo(linen.Module):
@linen.compact
def __call__(self):
pass
@jax.jit
def test_inf(network_def):
return network_def.apply({})
test_inf(Foo()) This gives the error |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Answer by @levskaya: the underlying problem is that you can't hash np arrays - they're mutable like dicts in this case you have an instantiated module which can never be hashed because it can theoretically hold variables, but also have a numpy attrib which can't be hashed you can do things like: def get_inference_fn(network_def):
@jax.jit
def inference(params, x):
return network_def.apply(params, x)
return inference the only thing to be careful about here is that |
Beta Was this translation helpful? Give feedback.
Answer by @levskaya:
the underlying problem is that you can't hash np arrays - they're mutable like dicts
in this case you have an instantiated module which can never be hashed because it can theoretically hold variables, but also have a numpy attrib which can't be hashed
you can do things like:
the only thing to be careful about here is that
@jax.jit
and friends cache on function -identity-, so if you callget_inference_fn
multiple times, you will trigger recompiles each time. (but not if you reuse the returned fn)