How to jit when my flax module which has two method #2550
Unanswered
wisdomGEsLA
asked this question in
Q&A
Replies: 1 comment
-
Hey @wisdomGEsLA, one way to pass static metadata (like a method) is to extend from flax.training import train_state
from typing import Callable
from flax import struct
class TrainState(train_state.TrainState):
apply_method: Callable = struct.field(pytree_node=False)
...
state = TrainState.create(
apply_fn=module.apply,
apply_method=module.calculate_loss,
...
) And then use that inside @jax.jit
def apply_model(state, users, pos_items, neg_items):
def loss_fn(params):
loss = state.apply_fn({'params': params},
users, pos_items, neg_items, method=state.apply_method)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss_val, grads = grad_fn(state.params)
return grads, loss_val |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Since my module needs two method to its job call() and calculate_loss() , I think when I apply calculate_loss(), I need to pass method parameter to apply_fn, then I met a problem when using jit to accelerate my code like this:
it will raise error:
May I ask is there a solution to this?
Beta Was this translation helpful? Give feedback.
All reactions