Skip to content

Multiple forward pass methods #2584

Answered by cgarciae
nico-bohlinger asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @nico-bohlinger, I don't think you should jit your Modules apply function directly, a better option is to create a helper function that is easier to jit e.g:

@jax.jit
def forward_torso(params, x):
  return mymodule.apply({'params': params}, x, method=mymodule.torso)

You can create a similar one for the normal call. You can even create a more general one that accepts the method as a static argument e.g:

from functools import partial

@partial(jax.jit, static_argnums=(2,))
def forward(params, x, method):
  return mymodule.apply({'params': params}, x, method=method)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by nico-bohlinger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants