Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor change so backtracking init and update return the same state types #1175

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

bafflingbits
Copy link

@bafflingbits bafflingbits commented Jan 14, 2025

Currently if update is jitted, then first call (using state from init) will see slightly different types than the second call (using state from update) leading to an unnecessary recompilation. The difference is just in weak_type on some values, which can be controlled by explicitly setting the dtype when initializing.

(Discovered in discussion of issue #1171 )

Copy link

google-cla bot commented Jan 14, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

…te types, this avoids recompilation when calling update
@bafflingbits bafflingbits force-pushed the minorfix/backtracking_equalize_init_and_update_state_type branch from c59a153 to 7d090b7 Compare January 15, 2025 23:07
@bafflingbits
Copy link
Author

It appears optimization with complex values may have been using that type ambiguity?

=========================== short test summary info ============================
FAILED test_venv/lib/python3.11/site-packages/optax/_src/alias_test.py::LBFGSTest::test_lbfgs_complex0 - TypeError: while_loop body function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:

  * the input carry component carry[1][2].value has type complex64[] but the corresponding output carry component has type float32[], so the dtypes do not match;
  * the input carry component carry[1][2].info.decrease_error has type complex64[] but the corresponding output carry component has type float32[], so the dtypes do not match.

Fundamentally, it appears the issue is that init needs to know the output type of the function to be optimized, but it isn't given this information. So it guesses with a value + type that can be promoted to the actual type later? Is that what is going on?

But then on the second call to update, the type has 'narrowed', and so this triggers a recompilation.

What is the "best" way to fix this?
Add a new optional argument to init, that allows specifying the function output type?

@rdyro rdyro self-requested a review January 16, 2025 00:19
@rdyro
Copy link
Collaborator

rdyro commented Jan 16, 2025

Hey, it's a very interesting contribution! Also, your analysis is correct I think

This test combines L-BFGS and the backtracing linesearch and so one of their update_fn is not respecting the dtype, I haven't manage to test which one (you could try digging deeper into this failing test to find a proper fix).

Overall it might be hard to ensure that all optax gradient transforms output the same type as input (it'd need a more comprehensive review). And because init_fn and update_fn are generally generated via closures it might be difficult to ensure that JAX gets cache hits on these functions every time for all transforms.

I'd be happy to merge your PR if you can find the root issue in that failing test, but overall jitting the top level might be always necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants