Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[linen] Linesearch (and lbfgs) support for TrainState #4471

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

emiresenov
Copy link

@emiresenov emiresenov commented Jan 6, 2025

What does this PR do?

This PR modifies .update() to support the optional additional arguments value and value_fn for the GradientTransformationExtraArgs class.

These changes are in the same vein as the merged PR #4351, addressing #4144 for the TrainState class.

Checklist

Copy link

google-cla bot commented Jan 6, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@emiresenov emiresenov changed the title Update train_state.py [nnx] Linesearch (and lbfgs) support for TrainState Jan 6, 2025
updates, new_opt_state = self.tx.update(
grads_with_opt, self.opt_state, params_with_opt
)
if value is None or value_fn is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if either of them is not None we should pass them.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I'll update the PR

Copy link
Author

@emiresenov emiresenov Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figure the same logic should hold for grad since e.g. scale_by_polyak also usesGradientTransformationExtraArgs and only takes value as an additional keyword argument in .update().

I moved all three optional args to an update_kwargs dictionary, adding each to .update() only if they are not None. Additionally, my former commit should not have sent the **kwargs to .update() sinceTrainState uses these for something else – I mixed up the usage from the merged PR I was referring to.

See the changes in my latest commit and let me know if there are any issues.

@cgarciae
Copy link
Collaborator

cgarciae commented Jan 8, 2025

This version looks better. Can you fix pre-commit issue? Run

pip install pre-commit
pre-commit run --all-files

@cgarciae cgarciae changed the title [nnx] Linesearch (and lbfgs) support for TrainState [linen] Linesearch (and lbfgs) support for TrainState Jan 8, 2025
@cgarciae
Copy link
Collaborator

cgarciae commented Jan 8, 2025

Tangent: curious if nnx.Optimizer works with scale_by_polyak? I've never used it but we recently added some support for Linesearch in Optimizer as well.

@emiresenov
Copy link
Author

This version looks better. Can you fix pre-commit issue? Run

pip install pre-commit
pre-commit run --all-files

Fixed, see my latest commit!

@emiresenov
Copy link
Author

emiresenov commented Jan 8, 2025

Tangent: curious if nnx.Optimizer works with scale_by_polyak? I've never used it but we recently added some support for Linesearch in Optimizer as well.

I haven't used it either, but I see no reason why it wouldn't work since the last commit of #4351 uses the same logic. I think the first commit of that PR (the one I mistakenly replicated in my first commit) suggested a solution that wouldn't work for scale_by_polyak used by optax.polyak_sgd since it didn't incorporate grad as an optional if only value was passed, presumably because the author only had Linesearch in mind.

However, this potential issue was fixed in the last commit of the PR when the author moved everything to their added kwargs in .update(). With this change, each additional value to the GradientTransformationExtraArgs update function is passed only if not None – this makes the solution more generalized and also enables scale_by_polyak which only takes value as an additional keyword argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants