Skip to content

Commit 8732fb7

Browse files
committed
[Debug] Extract reproducers from JAX errors.
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way. To use set the environment variable `JAX_REPRO_DIR` to a directory where the repro files should be saved. You can use the value "sponge" to save test artifacts in Google internal tests. See jax-ml#31867 for more details.
1 parent d21998e commit 8732fb7

27 files changed

Lines changed: 6663 additions & 18 deletions

docs/debugging/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Table of contents:
1010
* [Functional error checks with jax.experimental.checkify](checkify_guide)
1111
* [Throwing Python errors with JAX’s debug flags](flags)
1212
* [Attaching XLA metadata with `set_xla_metadata`](xla_metadata)
13+
* [Generating reproducers for JAX errors](repro)
1314

1415
## Interactive inspection with `jax.debug`
1516

@@ -142,5 +143,5 @@ print_breakpoint
142143
checkify_guide
143144
flags
144145
xla_metadata
146+
repro
145147
```
146-

docs/debugging/repro.md

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# Generating reproducers for JAX errors
2+
3+
<!--* freshness: { reviewed: '2025-10-15' } *-->
4+
5+
WARNING: this code is experimental, expect changes or deletion.
6+
7+
Have you encountered a hard-to-debug JAX error in a large user program,
8+
perhaps using several other libraries on top of JAX?
9+
Do you believe that there is a small and pure JAX program, without additional
10+
layers of libraries, that reproduces the same error?
11+
12+
13+
## Usage
14+
15+
If you get an uncaught exception from under a JAX API call,
16+
you can set `JAX_REPRO_DIR` to a directory where JAX should attempt to save a Python source
17+
file that contains the JAX API calls that ought to reproduce the error.
18+
This mechanism can be enabled simply by setting the `JAX_REPRO_DIR` variable.
19+
20+
JAX will track the sequence of nested JAX API calls, capturing all user-functions,
21+
their calls to JAX APIs, and then recursively the user functions that are
22+
called by JAX during tracing.
23+
If an uncaught exception arises, then we save a repro that should result in
24+
the same call tree, and hopefully can reproduce the error.
25+
One can get the path and source code of the saved repro by
26+
calling `repro.last_saved_repro()`.
27+
28+
The above use case is the "implicit" repro generation. You can also
29+
generate repros "explicitly", even in absence of errors:
30+
31+
```
32+
from jax._src import repro # TODO: find final location
33+
col = repro.collector(fun) # fun should be a nullary Callable
34+
try:
35+
result = col() # Executes `fun` and returns its result
36+
finally:
37+
repro_source = col.to_source()
38+
repro_path = repro.save()
39+
```
40+
41+
`repro.collector` will error if `JAX_REPRO_DIR` is not set.
42+
In the usage above, all tracked JAX API invocations are saved in the repro.
43+
In the implicit use (save on uncaught exception),
44+
successful JAX calls at the top-level are not retained.
45+
46+
One can think of this mechanism as a way to stage out a pure JAX program
47+
from a large JAX program.
48+
This is somewhat similar to staging out a Jaxpr with `jax.jit(f).trace(*args).jaxpr`
49+
(or the old `jax.make_jaxpr`), except that:
50+
51+
* it produces Python source code rather that a dump of a Jaxpr, which
52+
should be more readable, more editable, and can be executed directly,
53+
e.g., in a debugger.
54+
* it works even if there are errors in the user program or in JAX before
55+
a Jaxprs is produced; the repro source may reproduce the error.
56+
* the repro is higher-level than the Jaxpr. E.g., instead of seeing the
57+
`lax.scan_p` primitive with its low-level details,
58+
you will see a call to `lax.scan`. The higher-order primitives in JAX often have
59+
complicated parameters, and sometimes even references to Python callables.
60+
Furthermore, some JAX transformations, e.g., `vmap` or `jvp`,
61+
do not stage a Jaxpr, and the first Jaxpr produced will reflect the
62+
result of the transformations. In contrast, the repro source will
63+
contains calls to `jax.vmap` and `jax.jvp`.
64+
65+
66+
## Configuration options
67+
68+
This section is very likely to change.
69+
70+
There are two configuration options:
71+
72+
* `JAX_REPRO_DIR` denotes the directory where reproducers are saved. A non-empty
73+
value also triggers the tracking of the call tree, so that a reproducer is saved
74+
on error. It can be `sponge` for use in internal Google tests.
75+
* `JAX_REPRO_FLAGS` contains comma-separated flags that configure how repro generation works.
76+
You can specify a flag without a value, in which case it takes a default value, e.g., `True`,
77+
or you can specify a value using `=value`. For example, `log_calls,log_traceback_frames=10`.
78+
* `log_calls` (default 0). An integer value that controls the repro tracking logging (for debugging
79+
the repro module). The recognized values are: 0 (no logging, default), 1 (log all calls except the
80+
JAX primitive.bind), 2 (log all calls).
81+
* `log_call_details` (default ""). A sequence of call ids for which to log more details (
82+
for debugging the repro module). E.g., `log_call_details=3+5+6`.
83+
* `error_mode` (default "defer"). Configures the handling of repro collection and generation
84+
errors. The possible values are:
85+
* "ignore"
86+
* "log" -- the errors are logged as `logging.error`. Each error message contains
87+
`log_traceback_frames` stack frames.
88+
* "defer" -- the errors are logged and at the end of the explicit
89+
repro collection a `repro.ReproError` will be generated.
90+
* "raise" -- a `repro.ReproError` is raised when the first error appears.
91+
92+
* `log_traceback_frames` (default 40) how many frames from the traceback to show.
93+
* `fake_array_threshold` (default 128) arrays with `.size()` larger than this value are replaced
94+
with `np.ones` with the right shape and dtype. Smaller arrays are emitted as `np` array literals.
95+
96+
## Limitations
97+
98+
So many ...
99+
100+
Not all errors will be reproduced by this mechanism:
101+
102+
* Errors from numpy won't be captured,
103+
* Errors from the jax.numpy layer. E.g., the rank checking happens in the jax.numpy layer,
104+
so it won't be reproduced by this version of repros, but the shape checking
105+
happens after binding the JAX primitives, so it will be reproducible,
106+
* The repro contains some argument preprocessing, e.g., the `static_argnums`
107+
are removed. Errors involving the handling of `static_argnums` won't be
108+
reproduced.
109+
* we currently do not try to preserve the exact data arrays, and for
110+
arrays larger than a threshold we will use np.ones.
111+
* `jax.named_call` is not handled currently (treated as a noop)
112+
113+
During call tracking we attempt to alter as little as possible the execution.
114+
There are a few known differences though:
115+
116+
* JAX cache tracing is foiled, so you are going to see functions traced multiple times.
117+
E.g., JAX will normally avoid tracing a function repeatedly with similar arguments,
118+
but when repros are enabled this cache is disabled.
119+
This will result in slower tracing, and for functions with side-effects it may even
120+
alter the execution of the program.
121+
122+
* when using `jax.custom_vjp`, when we do higher-order differentiation the custom
123+
fwd and bwd functions are being called more than once, and we assume that the
124+
subsequent calls will produce the same Jaxpr as earlier ones.
125+
126+
* we emit jax.Array as np.ndarray (e.g., loosing sharding)
127+
128+
* if we don't recognize the jax.checkpoint policy param, we print a warning
129+
and we use `dots_saveable` as a replacement.
130+
131+
132+
## Design
133+
134+
There are two phases: repro collection, when we follow certain function calls
135+
and construct a call tree, and repro generation when we produce Python
136+
source code that should reproduce the same call tree.
137+
138+
### Repro collection
139+
140+
### Tracking JAX and USER functions
141+
142+
We label top-level higher-order JAX APIs with
143+
`traceback_util.api_boundary(repro_api_name="jax.jit")`. These
144+
APIs take user functions as arguments, which we must
145+
also wrap and track. In an earlier design, I tried
146+
to identify these user functions by looking for callables
147+
among the positional arguments. This is not enough, because
148+
some custom PyTrees, e.g., `flax.module`, are callable
149+
yet we sometimes must treat them as containers with arrays.
150+
151+
Instead, for each of these calls we also pass a `wrap_user_func_args`
152+
argument which takes in the args and kwargs and returns
153+
the args with the user-functions wrapped. By default,
154+
`wrap_user_func` wraps the first positional argument, if
155+
it is a callable.
156+
157+
We keep a call stack of USER and JAX functions as the
158+
program executes. The bottom of the stack (first call)
159+
is always a JAX API function. Some of the JAX functions
160+
will call into USER functions that were passed as arguments,
161+
in order to trace them. Then the user function may
162+
call again into JAX functions.
163+
164+
Additionally, we intercept the `core.Primitive.bind` and
165+
we consider those as first-order JAX calls.
166+
167+
Note that we ignore all calls to JAX functions (primitives
168+
or JAX APIs) if they are made from a JAX function. This
169+
can happen because we use, e.g., `jax.jit` or `lax.scan`
170+
internally in JAX, and we don't need to track those.
171+
The same mechanism will ignore the calls to the
172+
higher-order primitives, e.g., `lax.scan_p`, because
173+
those are bound always from a JAX API call (since we
174+
annotate all higher-order JAX APIs).
175+
176+
Finally, when we are inside a USER call, we collect
177+
in a list all the JAX calls made. The end result is
178+
a top-evel call node of a JAX API function with
179+
some USER functions passed as arguments. Each USER
180+
function contains a list of JAX calls, each with
181+
its own USER function among arguments. This is the
182+
date structure that results from repro collection.
183+
184+
Note that a USER function cannot call another USER
185+
function, because USER functions are called during
186+
tracing and are passed only arrays.
187+
188+
We should see at most one call to each USER function
189+
object, because JAX should not trace a function
190+
twice (an exception happens for `fuser.fusible`,
191+
see below). It is possible that we don't see a
192+
call for some USER functions, if JAX do not trace
193+
them. E.g., in `jax.custom_vjp` we pass a user function
194+
for the forward pass and one for the backwards, but
195+
the latter will never be traced if we do not differentiate
196+
the function.
197+
198+
#### Handling caches
199+
200+
One of the trickiest issues was to
201+
collect repros in presence of JAX caches. JAX will try
202+
to memoize the tracing of user function, e.g., based
203+
on the shapes and types of the arguments. In reality,
204+
the cache keys are quite a bit more complicated.
205+
206+
If we just wrap the user functions passed to the JAX
207+
APIs and track their calls, we may see a function called
208+
multiple times, and some calls we won't even see if
209+
they hit the cache. E.g., for the program:
210+
211+
```
212+
def fun(x): ...
213+
j_fun = jax.jit(fun)
214+
y0 = j_fun(0)
215+
y1 = j_fun(0.)
216+
y2 = j_fun(1)
217+
```
218+
219+
we will see calls to `fun` coming from `j_fun(0)` and `j_fun(0.)`,
220+
but we won't see corresponding to `j_fun(1)` because it will hit
221+
the cache for the first call.
222+
223+
The solution that I ended up using, is to set up a set of
224+
predefined trampolines for the JAX API calls, indexed by
225+
the `api_name` (e.g., "jax.jit"). These trampolines will
226+
behave as if the code above had been:
227+
228+
```
229+
def func(x):
230+
y0 = jax_jit_call(fun, 0)
231+
y1 = jax_jit_call(fun, 0.)
232+
y2 = jax_jit_call(fun, 1)
233+
```
234+
235+
Furthermore, because `jax_jit_call` will wrap the firts argument
236+
as a fresh object on each call, the undelying JAX caches will miss.
237+
These new JAX APIs will appear in the generated repro. They
238+
are defined in `repro_api.py`.
239+
240+
With this system of trampolines (defined in `tracker.py`) we turn the
241+
JAX program into one that uses modified JAX APIs that take user functions
242+
as arguments but do not return functions. The trampolines end up being
243+
quite useful to turn all the various forms of higher-order JAX APIs,
244+
e.g., `jax.custom_vjp` with multiple user-defined functions,
245+
into a uniform system of APIs that take all user-functions as
246+
positional arguments, along with other non-callable arguments.
247+
248+
In some very rare cases, we had to retain functions that
249+
return other function. E.g.,
250+
251+
```
252+
def fun(x): ...
253+
y, f_vjp = jax.vjp(f, x)
254+
x_ct = f_vjp(y_ct)
255+
```
256+
257+
We must mark the returned `f_vjp` function as a JAX function (because
258+
it calls back into JAX internals). We do this in the trampoline
259+
for `jax.vjp`.
260+
261+
### Miscellaneous collection issues
262+
263+
For each call, we store the arguments and results. These would
264+
be tracers for the USER functions, but may be arrays or other
265+
objects. If we store the arguments literally, we would run into
266+
issues when they are mutable; user functions may mutate them.
267+
We would also leak tracers and run into internal JAX error checks.
268+
There is also no point in storing the custom PyTrees, because
269+
they can never be part of the generated repro (to keep it pure JAX).
270+
271+
So, we first "normalize" the arguments before storing them into
272+
the call nodes. Normalization turns custom PyTrees into tuples,
273+
except for some hardcoded custom PyTrees for which we know
274+
how to emit source code (see below in [repro generation](#repro-generation)).
275+
276+
277+
### Repro generation
278+
279+

docs/internals/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ is prone to become stale.
1111
:maxdepth: 1
1212

1313
Handling of closed-over constants <constants>
14+
Implementation of repro generation <repro_internals>

docs/internals/repro_internals.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
(repro-note)=
2+
3+
# Repro implementation details
4+
5+
The challenge for repro extraction for JAX, compared to a regular compiler,
6+
is that JAX does not get the input as a data structure that we can save.
7+
Instead, we have to augment the JAX tracing mechanism to **track** which
8+
JAX API calls are being made by the user program, and what user functions
9+
JAX calls while tracing the program. The repro tracker (in `tracker.py`)
10+
constructs a representation of the call tree. Then the repro emitter
11+
(in `emitter.py`) outputs a pure JAX program that would result in the
12+
same call tree.
13+
14+
## How do we track?
15+
16+
First, we wrap the JAX API functions that take user functions as arguments,
17+
e.g., `jax.jit`, `jax.vmap`. We do this by adding a `repro_api_name` to the
18+
existing `traceback_util.api_boundary` annotation. This annotation was already
19+
present in most places we needed it, but we had to add it in a few places
20+
where it was missing, e.g., in `lax.loops.while_loop`.
21+
Whenever we call one of these annotated APIs, we scan the arguments looking
22+
for callables, and we wrap those as well. One goal would be to emit repro
23+
code for these callables.
24+
25+
We use the class `tracker.Func` to wrap callables of interest. They are of several
26+
kinds:
27+
28+
* JAX API functions. These are constructed for the JAX API entry points annotated
29+
with `repro_api_name`.
30+
* USER functions. These are constructed for callables passed to JAX API functions.
31+
* JAX non-API functions. These are constructed for callables returned by JAX API
32+
functions, e.g., the returned value from `jax.jit`. Note: this kind of functions
33+
will go away, see below.
34+
35+
When one of the tracker functions is called, we construct a `tracker.Call` object
36+
that has references to the `Func` that was called, the actual arguments and results
37+
of the call (these would be actual tracers, or constants, or even non-JAX values
38+
for the static arguments). The call objects for user functions have a body, which
39+
is a list of calls to JAX functions that the user function makes.
40+
41+
Furthermore, we modified the `core.Primitive._true_bind` method to call into
42+
the repro source code with the primitive and its arguments. If this call happens
43+
while we are currently in a user call, we record the primitive.
44+
45+
Thus, the call objects for a user function will contain a list of calls to
46+
JAX functions and to primitives.
47+
48+
### Dealing with JAX caches
49+
50+
51+
## How do we emit?
52+
53+
TO EXPLAIN ...
54+
55+
## How do we reduce?
56+
57+
TO DO ...

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ py_library_providing_imports_info(
225225
"//jax/_src:profiler",
226226
"//jax/_src:public_test_util",
227227
"//jax/_src:random",
228+
"//jax/_src:repro",
228229
"//jax/_src:ref",
229230
"//jax/_src:scipy",
230231
"//jax/_src:shard_alike",

0 commit comments

Comments
 (0)