Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[linen] Linesearch (and lbfgs) support for TrainState #4471

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions flax/training/train_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,20 @@ class TrainState(struct.PyTreeNode):
tx: optax.GradientTransformation = struct.field(pytree_node=False)
opt_state: optax.OptState = struct.field(pytree_node=True)

def apply_gradients(self, *, grads, **kwargs):
def apply_gradients(self, *, grads, grad=None, value=None, value_fn=None, **kwargs):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.

Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.

For ``optax.GradientTransformationExtraArgs``, the optional ``grad``,
``value`` and ``value_fn`` are passed to ``.tx.update()``.

Args:
grads: Gradients that have the same pytree structure as ``.params``.
grad (optional): gradient of the function at the current params.
value (optional): value of the objective associated with the current grads update.
value_fn (optional): function to evaluate the objective given the model.
**kwargs: Additional dataclass attributes that should be ``.replace()``-ed.

Returns:
Expand All @@ -100,9 +106,21 @@ def apply_gradients(self, *, grads, **kwargs):
grads_with_opt = grads
params_with_opt = self.params

update_kwargs = {
"grad": grad,
"value": value,
"value_fn": value_fn
}

update_kwargs = {k: v for k, v in update_kwargs.items() if v is not None}

updates, new_opt_state = self.tx.update(
grads_with_opt, self.opt_state, params_with_opt
grads_with_opt,
self.opt_state,
params_with_opt,
**update_kwargs
)

new_params_with_opt = optax.apply_updates(params_with_opt, updates)

# As implied by the OWG name, the gradients are used directly to update the
Expand Down
Loading