-
Hello! I have first implemented an A2C algorithm in Flax here but for the ease of the implementation I use a shared network for the Actor Critic model. What I would like to do is define two model as follow:
but only use one I guess the right way of doing so is to create a custom TrainState class but I'm not sure how to define two distinct
I know this code is not working but I would be happy if someone could guide me towards the right solution Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hey @valentin-cnt, some notes:
from flax.struct import field
class TrainState(TrainState):
critic_params: flax.core.FrozenDict
apply_critic: Callable = field(pytree_node=False) |
Beta Was this translation helpful? Give feedback.
-
Ok, actually I didn't think about this step. I guess it would become a bad pratice to redefine Thank you for your answer! |
Beta Was this translation helpful? Give feedback.
Hey @valentin-cnt, some notes:
TrainState.apply_gradients
they would only be applied toparams=actor_params
, notcritic_params
.apply_critic
as a "static" field since its not a valid JAX type: