Skip to content

Add another apply_fn in TrainState #2913

Answered by cgarciae
vcharraut asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @valentin-cnt, some notes:

  • If you are planning on using TrainState.apply_gradients they would only be applied to params=actor_params, not critic_params.
  • You need to mark apply_critic as a "static" field since its not a valid JAX type:
from flax.struct import field

class TrainState(TrainState):
    critic_params: flax.core.FrozenDict
    apply_critic: Callable = field(pytree_node=False)

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by vcharraut
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants