Skip to content

Initializing Parameterized Modules #1921

Answered by jheek
srush asked this question in General
Discussion options

You must be logged in to vote

This is an interesting problem. You provide a (potentially fresh) dict of params on each apply. But the jax arrays don't have identity so there's no super fast way to validate the the param didn't change. I don't think you would want to use bind because in a setup like this:

@jit
def eval_step(variables, batch)
  ....
for batch in batches:
  eval_step(variables, batch)

you can only bind inside eval_step which means you would still recompute it every step in the end (you can't pass abound module across transform boundaries).

So I would do this by using a collection to cache the value

def setup(self):
   self.t = self.variable("constants", "t", slow_function, param)

then in training you do:

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Answer selected by srush
Comment options

You must be logged in to vote
1 reply
@jheek
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants