-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[debug] Extract reproducers from JAX errors. #31867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
gnecula
wants to merge
5
commits into
jax-ml:main
Choose a base branch
from
gnecula:repro_2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1853564 to
7e01c3d
Compare
5b52d4f to
936d89c
Compare
29b1dd4 to
3b652d0
Compare
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Nov 14, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
8d0c90b to
f2a172d
Compare
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Nov 14, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Nov 14, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Nov 14, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Nov 14, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Nov 16, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Nov 16, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Nov 16, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Nov 25, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Dec 12, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Dec 13, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Dec 13, 2025
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
This makes the repro a standalone source file, depending only on the standard JAX.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The goal is to help debug JAX failures in large users examples, by trying to extract a smaller reproducing Python program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way.
WARNING: this code is experimental, expect changes or removal.
Documentation updates: https://jax--31867.org.readthedocs.build/en/31867/
Usage
If you get an uncaught exception from under a JAX API call, you can set
JAX_REPRO_DIRto a directory where JAX will attempt to save a Python source file that contains the JAX API calls that ought to reproduce the error. This mechanism can be enabled simply by setting theJAX_REPRO_DIRvariable (e.g., to "sponge" if using this in Google tests). JAX will track the sequence of nested JAX API calls, capturing all user-functions, their calls to JAX APIs, and then recursively the user functions that are called by JAX during tracing. If an uncaught exception arises, then we save repro that should result in the same call tree, and hopefully can reproduce the error. One can get the path and source code of the saved repro by callingrepro.last_repro().Alternatively, you can get repros by explicitly calling an API even in absence of errors:
repro.save_reprowill error ifJAX_REPRO_DIRis not set.In the usage above, all tracked JAX API invocations are saved in the repro. In the implicit use, successful JAX calls at the top-level are not retained.
One can think of this mechanism as a way to stage out (slice) a pure JAX program from a large JAX program.
This is somewhat similar to staging out a Jaxpr with
jax.jit(f).trace(*args).jaxpr(or the oldjax.make_jaxpr), except that:lax.scan_pprimitive with its low-level details, you will see a call tolax.scan. However, most first-order primitives are going to be represented in a similar way as in a Jaxpr.jax.jvp(jax.vmap(f)), even when the Jaxpr would reflect the code after these transformations.Limitations
So many ...
Not all errors will be reproduced by this mechanism:
numpywon't be captured,jax.numpylayer. E.g., the rank checking happens in thejax.numpylayer, so it won't be reproduced by this version of repros, but the shape checking happens after binding the JAX primitives, so it will be reproducible,np.ones.jax.named_callis not handled currentlyDuring call tracking we attempt to alter as little as possible the execution. There are a few known differences though:
JAX cache tracing is foiled, so you are going to see functions tracer multiple times. E.g., JAX will avoid tracing a function repeatedly with similar arguments, but when repros are enabled this cache is disabled. This will result in slower tracing, and for functions with side-effects it may even alter the execution of the program.
with
custom_vjp, when we do higher-order differentiation the custom fwd and bwd functions are being called more than once, and we assume that the subsequent calls will produce the same Jaxpr as earlier ones.we emit jax.Array as np.ndarray (e.g., loosing sharding)
if we don't recognize the jax.checkpoint policy param, we print a warning
and we use `dots_saveable".
Integration details
Since this code is experimental we want to separate it as much as possible from the core JAX source code. The
code is in
jax._src.repro, but the user-facing APIs are injax.experimental.repro. (I have tried to put the wholeimplementation in
jax.experimentalbut when importing it from there during internal usage, we end up importinga whole lot of other exprimental APIs and we run into circular import issues.) There are a few key integration points, added to
traceback_util.py:api_boundarywith a couple of kwargs (most notablyrepro_api_name), to mark those API boundaries that we want to intercept. These are primarily the higher-order JAX APIs, e.g.,jax.jit,jax.vmap, ..., and a few others as well.traceback_util.repro_enabled() -> bool, which is true if we are currently collecting repros.traceback_util.enable_repro(val: bool)for a context manager to enable and disable repros. This is for internal usage only.