Multiple forward pass methods #2584
-
Hello, I want to implement a simple Linen module that consists of two methods for the forward pass. A main Now I'm still confused how to call both methods correctly. Before training I normally call I googled a bit but didn't find a great example that shows this use case :( |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hey @nico-bohlinger, I don't think you should @jax.jit
def forward_torso(params, x):
return mymodule.apply({'params': params}, x, method=mymodule.torso) You can create a similar one for the normal call. You can even create a more general one that accepts the method as a static argument e.g: from functools import partial
@partial(jax.jit, static_argnums=(2,))
def forward(params, x, method):
return mymodule.apply({'params': params}, x, method=method) |
Beta Was this translation helpful? Give feedback.
Hey @nico-bohlinger, I don't think you should
jit
your Modulesapply
function directly, a better option is to create a helper function that is easier tojit
e.g:You can create a similar one for the normal call. You can even create a more general one that accepts the method as a static argument e.g: