Skip to content

Commit

Permalink
Update train_state.py
Browse files Browse the repository at this point in the history
  • Loading branch information
emiresenov committed Jan 6, 2025
1 parent 53bde74 commit 9e0def5
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions flax/training/train_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,19 @@ class TrainState(struct.PyTreeNode):
tx: optax.GradientTransformation = struct.field(pytree_node=False)
opt_state: optax.OptState = struct.field(pytree_node=True)

def apply_gradients(self, *, grads, **kwargs):
def apply_gradients(self, *, grads, value=None, value_fn=None, **kwargs):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
For ``optax.GradientTransformationExtraArgs``, the optional ``value`` and
``value_fn`` are passed to ``.tx.update()``.
Args:
grads: Gradients that have the same pytree structure as ``.params``.
value (optional): value of the objective associated with the current grads update.
value_fn (optional): function to evaluate the objective given the model.
**kwargs: Additional dataclass attributes that should be ``.replace()``-ed.
Returns:
Expand All @@ -100,9 +105,20 @@ def apply_gradients(self, *, grads, **kwargs):
grads_with_opt = grads
params_with_opt = self.params

updates, new_opt_state = self.tx.update(
grads_with_opt, self.opt_state, params_with_opt
)
if value is None or value_fn is None:
updates, new_opt_state = self.tx.update(
grads_with_opt, self.opt_state, params_with_opt
)
else:
updates, new_opt_state = self.tx.update(
grads, self.opt_state,
params_with_opt,
grad=grads,
value=value,
value_fn=value_fn,
**kwargs,
)

new_params_with_opt = optax.apply_updates(params_with_opt, updates)

# As implied by the OWG name, the gradients are used directly to update the
Expand Down

0 comments on commit 9e0def5

Please sign in to comment.