Skip to content

How to jit a function that has an instantiated module as an argument? #1153

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Answer by @levskaya:

the underlying problem is that you can't hash np arrays - they're mutable like dicts

in this case you have an instantiated module which can never be hashed because it can theoretically hold variables, but also have a numpy attrib which can't be hashed

you can do things like:

def get_inference_fn(network_def):
  @jax.jit
  def inference(params, x):
    return network_def.apply(params, x)
  return inference

the only thing to be careful about here is that @jax.jit and friends cache on function -identity-, so if you call get_inference_fn multiple times, you will trigger recompiles each time. (but not if you reuse the returned fn)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant