-
Notifications
You must be signed in to change notification settings - Fork 660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[linen] Linesearch (and lbfgs) support for TrainState #4471
base: main
Are you sure you want to change the base?
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
flax/training/train_state.py
Outdated
updates, new_opt_state = self.tx.update( | ||
grads_with_opt, self.opt_state, params_with_opt | ||
) | ||
if value is None or value_fn is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if either of them is not None
we should pass them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I'll update the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figure the same logic should hold for grad
since e.g. scale_by_polyak also usesGradientTransformationExtraArgs
and only takes value
as an additional keyword argument in .update()
.
I moved all three optional args to an update_kwargs
dictionary, adding each to .update()
only if they are not None
. Additionally, my former commit should not have sent the **kwargs
to .update()
sinceTrainState
uses these for something else – I mixed up the usage from the merged PR I was referring to.
See the changes in my latest commit and let me know if there are any issues.
52c4bc4
to
9e0def5
Compare
This version looks better. Can you fix pre-commit issue? Run
|
Tangent: curious if |
Fixed, see my latest commit! |
I haven't used it either, but I see no reason why it wouldn't work since the last commit of #4351 uses the same logic. I think the first commit of that PR (the one I mistakenly replicated in my first commit) suggested a solution that wouldn't work for However, this potential issue was fixed in the last commit of the PR when the author moved everything to their added |
What does this PR do?
This PR modifies
.update()
to support the optional additional argumentsvalue
andvalue_fn
for theGradientTransformationExtraArgs
class.These changes are in the same vein as the merged PR #4351, addressing #4144 for the TrainState class.
Checklist
checks if that's the case).
Linesearch (and lbfgs) support #4351 , Support for optax lbfgs and related optimizers with NNX #4144
documentation guidelines.
Linesearch (and lbfgs) support #4351 includes a test with these modifications. Additionally, the modifications follow the spec for
GradientTransformationExtraArgs
. If you want me to write a file with test cases for theTrainState
class, let me know. Still, I cannot find any tests for this class in the repository.