Skip to content

Conversation

@gnecula
Copy link
Collaborator

@gnecula gnecula commented Sep 17, 2025

The goal is to help debug JAX failures in large users examples, by trying to extract a smaller reproducing Python program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way.

WARNING: this code is experimental, expect changes or removal.

Documentation updates: https://jax--31867.org.readthedocs.build/en/31867/

Usage

If you get an uncaught exception from under a JAX API call, you can set JAX_REPRO_DIR to a directory where JAX will attempt to save a Python source file that contains the JAX API calls that ought to reproduce the error. This mechanism can be enabled simply by setting the JAX_REPRO_DIR variable (e.g., to "sponge" if using this in Google tests). JAX will track the sequence of nested JAX API calls, capturing all user-functions, their calls to JAX APIs, and then recursively the user functions that are called by JAX during tracing. If an uncaught exception arises, then we save repro that should result in the same call tree, and hopefully can reproduce the error. One can get the path and source code of the saved repro by calling repro.last_repro().

Alternatively, you can get repros by explicitly calling an API even in absence of errors:

   from jax.experimental import repro
   try:
      result = repro.save_repro(fun, *args, **kwargs)
   finally:
     repro_path, repro_source = repro.last_repro()

repro.save_repro will error if JAX_REPRO_DIR is not set.
In the usage above, all tracked JAX API invocations are saved in the repro. In the implicit use, successful JAX calls at the top-level are not retained.

One can think of this mechanism as a way to stage out (slice) a pure JAX program from a large JAX program.
This is somewhat similar to staging out a Jaxpr withjax.jit(f).trace(*args).jaxpr (or the old jax.make_jaxpr), except that:

  • it produces Python source code rather that a dump of a Jaxpr. This means that the result can be more readable, more editable, and can be executed directly, e.g., in a debugger.
  • it works even if there are errors in the user program or in JAX. Then the produced output may reproduce the error.
  • the repro is higher-level than the Jaxpr: the higher-order Jaxpr primitives are replaced by calls to high-level JAX API. E.g., instead of seeing the lax.scan_p primitive with its low-level details, you will see a call to lax.scan. However, most first-order primitives are going to be represented in a similar way as in a Jaxpr.
  • the repro code will contain the sequence of JAX API calls as they appear in the user code, e.g., jax.jvp(jax.vmap(f)), even when the Jaxpr would reflect the code after these transformations.

Limitations

So many ...

Not all errors will be reproduced by this mechanism:

  • errors from numpy won't be captured,
  • errors from the jax.numpy layer. E.g., the rank checking happens in the jax.numpy layer, so it won't be reproduced by this version of repros, but the shape checking happens after binding the JAX primitives, so it will be reproducible,
  • we currently do not try to preserve the exact data arrays, and for arrays larger than a threshold we will use np.ones.
  • jax.named_call is not handled currently

During call tracking we attempt to alter as little as possible the execution. There are a few known differences though:

  • JAX cache tracing is foiled, so you are going to see functions tracer multiple times. E.g., JAX will avoid tracing a function repeatedly with similar arguments, but when repros are enabled this cache is disabled. This will result in slower tracing, and for functions with side-effects it may even alter the execution of the program.

  • with custom_vjp, when we do higher-order differentiation the custom fwd and bwd functions are being called more than once, and we assume that the subsequent calls will produce the same Jaxpr as earlier ones.

  • we emit jax.Array as np.ndarray (e.g., loosing sharding)

  • if we don't recognize the jax.checkpoint policy param, we print a warning
    and we use `dots_saveable".

Integration details

Since this code is experimental we want to separate it as much as possible from the core JAX source code. The
code is in jax._src.repro, but the user-facing APIs are in jax.experimental.repro. (I have tried to put the whole
implementation in jax.experimental but when importing it from there during internal usage, we end up importing
a whole lot of other exprimental APIs and we run into circular import issues.) There are a few key integration points, added to traceback_util.py:

  • we augment api_boundary with a couple of kwargs (most notably repro_api_name), to mark those API boundaries that we want to intercept. These are primarily the higher-order JAX APIs, e.g., jax.jit, jax.vmap, ..., and a few others as well.
  • we add a new entry point traceback_util.repro_enabled() -> bool, which is true if we are currently collecting repros.
  • we add a new entry point traceback_util.enable_repro(val: bool) for a context manager to enable and disable repros. This is for internal usage only.

@gnecula gnecula self-assigned this Sep 17, 2025
@gnecula gnecula force-pushed the repro_2 branch 2 times, most recently from 1853564 to 7e01c3d Compare September 17, 2025 08:50
@gnecula gnecula added the pull ready Ready for copybara import and testing label Sep 17, 2025
@gnecula gnecula force-pushed the repro_2 branch 20 times, most recently from 5b52d4f to 936d89c Compare September 22, 2025 11:57
@gnecula gnecula force-pushed the repro_2 branch 5 times, most recently from 29b1dd4 to 3b652d0 Compare September 26, 2025 09:33
gnecula added a commit to gnecula/jax that referenced this pull request Nov 14, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
@gnecula gnecula force-pushed the repro_2 branch 2 times, most recently from 8d0c90b to f2a172d Compare November 14, 2025 03:31
gnecula added a commit to gnecula/jax that referenced this pull request Nov 14, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Nov 14, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Nov 14, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Nov 14, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Nov 16, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Nov 16, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Nov 16, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Nov 25, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Dec 12, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Dec 13, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
gnecula added a commit to gnecula/jax that referenced this pull request Dec 13, 2025
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
This is just a draft of an experiment.

The purpose is to help debug JAX failures in large users examples,
by trying to extract a smaller program that contains only JAX
API calls. E.g., in a large program using Flax, this would get
Flax out of the way.

To use set the environment variable `JAX_REPRO_DIR` to a directory
where the repro files should be saved. You can use the value "sponge"
to save test artifacts in Google internal tests.

See jax-ml#31867 for more details.
This makes the repro a standalone source file, depending
only on the standard JAX.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant