|
| 1 | +# Generating reproducers for JAX errors |
| 2 | + |
| 3 | +<!--* freshness: { reviewed: '2025-10-15' } *--> |
| 4 | + |
| 5 | +WARNING: this code is experimental, expect changes or deletion. |
| 6 | + |
| 7 | +Have you encountered a hard-to-debug JAX error in a large user program, |
| 8 | +perhaps using several other libraries on top of JAX? |
| 9 | +Do you believe that there is a small and pure JAX program, without additional |
| 10 | +layers of libraries, that reproduces the same error? |
| 11 | + |
| 12 | + |
| 13 | +## Usage |
| 14 | + |
| 15 | +If you get an uncaught exception from under a JAX API call, |
| 16 | +you can set `JAX_REPRO_DIR` to a directory where JAX should attempt to save a Python source |
| 17 | +file that contains the JAX API calls that ought to reproduce the error. |
| 18 | +This mechanism can be enabled simply by setting the `JAX_REPRO_DIR` variable. |
| 19 | + |
| 20 | +JAX will track the sequence of nested JAX API calls, capturing all user-functions, |
| 21 | +their calls to JAX APIs, and then recursively the user functions that are |
| 22 | +called by JAX during tracing. |
| 23 | +If an uncaught exception arises, then we save a repro that should result in |
| 24 | +the same call tree, and hopefully can reproduce the error. |
| 25 | +One can get the path and source code of the saved repro by |
| 26 | +calling `repro.last_saved_repro()`. |
| 27 | + |
| 28 | +The above use case is the "implicit" repro generation. You can also |
| 29 | +generate repros "explicitly", even in absence of errors: |
| 30 | + |
| 31 | +``` |
| 32 | + from jax._src import repro # TODO: find final location |
| 33 | + col = repro.collector(fun) # fun should be a nullary Callable |
| 34 | + try: |
| 35 | + result = col() # Executes `fun` and returns its result |
| 36 | + finally: |
| 37 | + repro_source = col.to_source() |
| 38 | + repro_path = repro.save() |
| 39 | +``` |
| 40 | + |
| 41 | +`repro.collector` will error if `JAX_REPRO_DIR` is not set. |
| 42 | +In the usage above, all tracked JAX API invocations are saved in the repro. |
| 43 | +In the implicit use (save on uncaught exception), |
| 44 | +successful JAX calls at the top-level are not retained. |
| 45 | + |
| 46 | +One can think of this mechanism as a way to stage out a pure JAX program |
| 47 | +from a large JAX program. |
| 48 | +This is somewhat similar to staging out a Jaxpr with `jax.jit(f).trace(*args).jaxpr` |
| 49 | +(or the old `jax.make_jaxpr`), except that: |
| 50 | + |
| 51 | + * it produces Python source code rather that a dump of a Jaxpr, which |
| 52 | + should be more readable, more editable, and can be executed directly, |
| 53 | + e.g., in a debugger. |
| 54 | + * it works even if there are errors in the user program or in JAX before |
| 55 | + a Jaxprs is produced; the repro source may reproduce the error. |
| 56 | + * the repro is higher-level than the Jaxpr. E.g., instead of seeing the |
| 57 | + `lax.scan_p` primitive with its low-level details, |
| 58 | + you will see a call to `lax.scan`. The higher-order primitives in JAX often have |
| 59 | + complicated parameters, and sometimes even references to Python callables. |
| 60 | + Furthermore, some JAX transformations, e.g., `vmap` or `jvp`, |
| 61 | + do not stage a Jaxpr, and the first Jaxpr produced will reflect the |
| 62 | + result of the transformations. In contrast, the repro source will |
| 63 | + contains calls to `jax.vmap` and `jax.jvp`. |
| 64 | + |
| 65 | + |
| 66 | +## Configuration options |
| 67 | + |
| 68 | +This section is very likely to change. |
| 69 | + |
| 70 | +There are two configuration options: |
| 71 | + |
| 72 | + * `JAX_REPRO_DIR` denotes the directory where reproducers are saved. A non-empty |
| 73 | + value also triggers the tracking of the call tree, so that a reproducer is saved |
| 74 | + on error. It can be `sponge` for use in internal Google tests. |
| 75 | + * `JAX_REPRO_FLAGS` contains comma-separated flags that configure how repro generation works. |
| 76 | + You can specify a flag without a value, in which case it takes a default value, e.g., `True`, |
| 77 | + or you can specify a value using `=value`. For example, `log_calls,log_traceback_frames=10`. |
| 78 | + * `log_calls` (default 0). An integer value that controls the repro tracking logging (for debugging |
| 79 | + the repro module). The recognized values are: 0 (no logging, default), 1 (log all calls except the |
| 80 | + JAX primitive.bind), 2 (log all calls). |
| 81 | + * `log_call_details` (default ""). A sequence of call ids for which to log more details ( |
| 82 | + for debugging the repro module). E.g., `log_call_details=3+5+6`. |
| 83 | + * `error_mode` (default "defer"). Configures the handling of repro collection and generation |
| 84 | + errors. The possible values are: |
| 85 | + * "ignore" |
| 86 | + * "log" -- the errors are logged as `logging.error`. Each error message contains |
| 87 | + `log_traceback_frames` stack frames. |
| 88 | + * "defer" -- the errors are logged and at the end of the explicit |
| 89 | + repro collection a `repro.ReproError` will be generated. |
| 90 | + * "raise" -- a `repro.ReproError` is raised when the first error appears. |
| 91 | + |
| 92 | + * `log_traceback_frames` (default 40) how many frames from the traceback to show. |
| 93 | + * `fake_array_threshold` (default 128) arrays with `.size()` larger than this value are replaced |
| 94 | + with `np.ones` with the right shape and dtype. Smaller arrays are emitted as `np` array literals. |
| 95 | + |
| 96 | +## Limitations |
| 97 | + |
| 98 | +So many ... |
| 99 | + |
| 100 | +Not all errors will be reproduced by this mechanism: |
| 101 | + |
| 102 | + * Errors from numpy won't be captured, |
| 103 | + * Errors from the jax.numpy layer. E.g., the rank checking happens in the jax.numpy layer, |
| 104 | + so it won't be reproduced by this version of repros, but the shape checking |
| 105 | + happens after binding the JAX primitives, so it will be reproducible, |
| 106 | + * The repro contains some argument preprocessing, e.g., the `static_argnums` |
| 107 | + are removed. Errors involving the handling of `static_argnums` won't be |
| 108 | + reproduced. |
| 109 | + * we currently do not try to preserve the exact data arrays, and for |
| 110 | + arrays larger than a threshold we will use np.ones. |
| 111 | + * `jax.named_call` is not handled currently (treated as a noop) |
| 112 | + |
| 113 | +During call tracking we attempt to alter as little as possible the execution. |
| 114 | +There are a few known differences though: |
| 115 | + |
| 116 | + * JAX cache tracing is foiled, so you are going to see functions traced multiple times. |
| 117 | + E.g., JAX will normally avoid tracing a function repeatedly with similar arguments, |
| 118 | + but when repros are enabled this cache is disabled. |
| 119 | + This will result in slower tracing, and for functions with side-effects it may even |
| 120 | + alter the execution of the program. |
| 121 | + |
| 122 | + * when using `jax.custom_vjp`, when we do higher-order differentiation the custom |
| 123 | + fwd and bwd functions are being called more than once, and we assume that the |
| 124 | + subsequent calls will produce the same Jaxpr as earlier ones. |
| 125 | + |
| 126 | + * we emit jax.Array as np.ndarray (e.g., loosing sharding) |
| 127 | + |
| 128 | + * if we don't recognize the jax.checkpoint policy param, we print a warning |
| 129 | + and we use `dots_saveable` as a replacement. |
| 130 | + |
| 131 | + |
| 132 | +## Design |
| 133 | + |
| 134 | +There are two phases: repro collection, when we follow certain function calls |
| 135 | +and construct a call tree, and repro generation when we produce Python |
| 136 | +source code that should reproduce the same call tree. |
| 137 | + |
| 138 | +### Repro collection |
| 139 | + |
| 140 | +### Tracking JAX and USER functions |
| 141 | + |
| 142 | +We label top-level higher-order JAX APIs with |
| 143 | +`traceback_util.api_boundary(repro_api_name="jax.jit")`. These |
| 144 | +APIs take user functions as arguments, which we must |
| 145 | +also wrap and track. In an earlier design, I tried |
| 146 | +to identify these user functions by looking for callables |
| 147 | +among the positional arguments. This is not enough, because |
| 148 | +some custom PyTrees, e.g., `flax.module`, are callable |
| 149 | +yet we sometimes must treat them as containers with arrays. |
| 150 | + |
| 151 | +Instead, for each of these calls we also pass a `wrap_user_func_args` |
| 152 | +argument which takes in the args and kwargs and returns |
| 153 | +the args with the user-functions wrapped. By default, |
| 154 | +`wrap_user_func` wraps the first positional argument, if |
| 155 | +it is a callable. |
| 156 | + |
| 157 | +We keep a call stack of USER and JAX functions as the |
| 158 | +program executes. The bottom of the stack (first call) |
| 159 | +is always a JAX API function. Some of the JAX functions |
| 160 | +will call into USER functions that were passed as arguments, |
| 161 | +in order to trace them. Then the user function may |
| 162 | +call again into JAX functions. |
| 163 | + |
| 164 | +Additionally, we intercept the `core.Primitive.bind` and |
| 165 | +we consider those as first-order JAX calls. |
| 166 | + |
| 167 | +Note that we ignore all calls to JAX functions (primitives |
| 168 | +or JAX APIs) if they are made from a JAX function. This |
| 169 | +can happen because we use, e.g., `jax.jit` or `lax.scan` |
| 170 | +internally in JAX, and we don't need to track those. |
| 171 | +The same mechanism will ignore the calls to the |
| 172 | +higher-order primitives, e.g., `lax.scan_p`, because |
| 173 | +those are bound always from a JAX API call (since we |
| 174 | +annotate all higher-order JAX APIs). |
| 175 | + |
| 176 | +Finally, when we are inside a USER call, we collect |
| 177 | +in a list all the JAX calls made. The end result is |
| 178 | +a top-evel call node of a JAX API function with |
| 179 | +some USER functions passed as arguments. Each USER |
| 180 | +function contains a list of JAX calls, each with |
| 181 | +its own USER function among arguments. This is the |
| 182 | +date structure that results from repro collection. |
| 183 | + |
| 184 | +Note that a USER function cannot call another USER |
| 185 | +function, because USER functions are called during |
| 186 | +tracing and are passed only arrays. |
| 187 | + |
| 188 | +We should see at most one call to each USER function |
| 189 | +object, because JAX should not trace a function |
| 190 | +twice (an exception happens for `fuser.fusible`, |
| 191 | +see below). It is possible that we don't see a |
| 192 | +call for some USER functions, if JAX do not trace |
| 193 | +them. E.g., in `jax.custom_vjp` we pass a user function |
| 194 | +for the forward pass and one for the backwards, but |
| 195 | +the latter will never be traced if we do not differentiate |
| 196 | +the function. |
| 197 | + |
| 198 | +#### Handling caches |
| 199 | + |
| 200 | +One of the trickiest issues was to |
| 201 | +collect repros in presence of JAX caches. JAX will try |
| 202 | +to memoize the tracing of user function, e.g., based |
| 203 | +on the shapes and types of the arguments. In reality, |
| 204 | +the cache keys are quite a bit more complicated. |
| 205 | + |
| 206 | +If we just wrap the user functions passed to the JAX |
| 207 | +APIs and track their calls, we may see a function called |
| 208 | +multiple times, and some calls we won't even see if |
| 209 | +they hit the cache. E.g., for the program: |
| 210 | + |
| 211 | +``` |
| 212 | +def fun(x): ... |
| 213 | +j_fun = jax.jit(fun) |
| 214 | +y0 = j_fun(0) |
| 215 | +y1 = j_fun(0.) |
| 216 | +y2 = j_fun(1) |
| 217 | +``` |
| 218 | + |
| 219 | +we will see calls to `fun` coming from `j_fun(0)` and `j_fun(0.)`, |
| 220 | +but we won't see corresponding to `j_fun(1)` because it will hit |
| 221 | +the cache for the first call. |
| 222 | + |
| 223 | +The solution that I ended up using, is to set up a set of |
| 224 | +predefined trampolines for the JAX API calls, indexed by |
| 225 | +the `api_name` (e.g., "jax.jit"). These trampolines will |
| 226 | +behave as if the code above had been: |
| 227 | + |
| 228 | +``` |
| 229 | +def func(x): |
| 230 | +y0 = jax_jit_call(fun, 0) |
| 231 | +y1 = jax_jit_call(fun, 0.) |
| 232 | +y2 = jax_jit_call(fun, 1) |
| 233 | +``` |
| 234 | + |
| 235 | +Furthermore, because `jax_jit_call` will wrap the firts argument |
| 236 | +as a fresh object on each call, the undelying JAX caches will miss. |
| 237 | +These new JAX APIs will appear in the generated repro. They |
| 238 | +are defined in `repro_api.py`. |
| 239 | + |
| 240 | +With this system of trampolines (defined in `tracker.py`) we turn the |
| 241 | +JAX program into one that uses modified JAX APIs that take user functions |
| 242 | +as arguments but do not return functions. The trampolines end up being |
| 243 | +quite useful to turn all the various forms of higher-order JAX APIs, |
| 244 | +e.g., `jax.custom_vjp` with multiple user-defined functions, |
| 245 | +into a uniform system of APIs that take all user-functions as |
| 246 | +positional arguments, along with other non-callable arguments. |
| 247 | + |
| 248 | +In some very rare cases, we had to retain functions that |
| 249 | +return other function. E.g., |
| 250 | + |
| 251 | +``` |
| 252 | +def fun(x): ... |
| 253 | +y, f_vjp = jax.vjp(f, x) |
| 254 | +x_ct = f_vjp(y_ct) |
| 255 | +``` |
| 256 | + |
| 257 | +We must mark the returned `f_vjp` function as a JAX function (because |
| 258 | +it calls back into JAX internals). We do this in the trampoline |
| 259 | +for `jax.vjp`. |
| 260 | + |
| 261 | +### Miscellaneous collection issues |
| 262 | + |
| 263 | +For each call, we store the arguments and results. These would |
| 264 | +be tracers for the USER functions, but may be arrays or other |
| 265 | +objects. If we store the arguments literally, we would run into |
| 266 | +issues when they are mutable; user functions may mutate them. |
| 267 | +We would also leak tracers and run into internal JAX error checks. |
| 268 | +There is also no point in storing the custom PyTrees, because |
| 269 | +they can never be part of the generated repro (to keep it pure JAX). |
| 270 | + |
| 271 | +So, we first "normalize" the arguments before storing them into |
| 272 | +the call nodes. Normalization turns custom PyTrees into tuples, |
| 273 | +except for some hardcoded custom PyTrees for which we know |
| 274 | +how to emit source code (see below in [repro generation](#repro-generation)). |
| 275 | + |
| 276 | + |
| 277 | +### Repro generation |
| 278 | + |
| 279 | + |
0 commit comments