Skip to content

Commit 7e01c3d

Browse files
committed
[debug] Extract reproducers from JAX errors.
This is just a draft of an experiment. The purpose is to help debug JAX failures in large users examples, by trying to extract a smaller program that contains only JAX API calls. E.g., in a large program using Flax, this would get Flax out of the way.
1 parent b5b6e1f commit 7e01c3d

13 files changed

Lines changed: 1197 additions & 10 deletions

File tree

jax/_src/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,17 @@ pytype_strict_library(
13581358
],
13591359
)
13601360

1361+
pytype_strict_library(
1362+
name = "repro",
1363+
srcs = ["repro.py", "repro_runtime.py"],
1364+
visibility = ["//jax:internal"] + jax_visibility("repro"),
1365+
deps = [
1366+
":source_info_util",
1367+
":traceback_util",
1368+
"//jax/_src/lib",
1369+
],
1370+
)
1371+
13611372
pytype_strict_library(
13621373
name = "state_types",
13631374
srcs = [

jax/_src/ad_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def policy(prim, *args, **params):
200200
### Main API
201201

202202
@api_boundary
203+
@partial(core.jax_boundary, api_name="jax.checkpoint")
203204
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
204205
policy: Callable[..., bool] | None = None,
205206
static_argnums: int | tuple[int, ...] = (),

jax/_src/api.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _update_debug_special_thread_local(_):
147147
float0 = dtypes.float0
148148

149149

150+
@partial(core.jax_boundary, api_name="jax.jit")
150151
def jit(
151152
fun: Callable, /, *,
152153
in_shardings: Any = sharding_impls.UNSPECIFIED,
@@ -347,6 +348,7 @@ def disable_jit(disable: bool = True):
347348
yield
348349

349350

351+
@partial(core.jax_boundary, api_name="jax.grad")
350352
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
351353
has_aux: bool = False, holomorphic: bool = False,
352354
allow_int: bool = False,
@@ -413,6 +415,7 @@ def grad_f_aux(*args, **kwargs):
413415

414416
return grad_f_aux if has_aux else grad_f
415417

418+
@partial(core.jax_boundary, api_name="jax.value_and_grad")
416419
def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0,
417420
has_aux: bool = False, holomorphic: bool = False,
418421
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
@@ -545,6 +548,7 @@ def _check_output_dtype_revderiv(name, holomorphic, x):
545548
"jax.vjp directly.")
546549
_check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad")
547550

551+
@partial(core.jax_boundary, api_name="jax.fwd_and_bwd")
548552
def fwd_and_bwd(
549553
fun: Callable, argnums: int | Sequence[int], has_aux: bool = False,
550554
jitted: bool = True,
@@ -619,6 +623,7 @@ def bwd(f_vjp, outgrad):
619623
return fwd, bwd
620624

621625

626+
@partial(core.jax_boundary, api_name="jaxfwd")
622627
def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0,
623628
has_aux: bool = False, holomorphic: bool = False) -> Callable:
624629
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
@@ -709,6 +714,7 @@ def _check_output_dtype_jacfwd(holomorphic, x):
709714
raise TypeError("jacfwd with holomorphic=True requires outputs with complex dtype, "
710715
f"but got {aval.dtype.name}.")
711716

717+
@partial(core.jax_boundary, api_name="jax.jacrev")
712718
def jacrev(fun: Callable, argnums: int | Sequence[int] = 0,
713719
has_aux: bool = False, holomorphic: bool = False,
714720
allow_int: bool = False) -> Callable:
@@ -789,7 +795,7 @@ def jacobian(fun: Callable, argnums: int | Sequence[int] = 0,
789795
_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev")
790796
_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev")
791797

792-
798+
@partial(core.jax_boundary, api_name="jax.hessian")
793799
def hessian(fun: Callable, argnums: int | Sequence[int] = 0,
794800
has_aux: bool = False, holomorphic: bool = False) -> Callable:
795801
"""Hessian of ``fun`` as a dense array.
@@ -915,6 +921,7 @@ def _split(x, indices, axis):
915921
return x._split(indices, axis)
916922

917923

924+
@partial(core.jax_boundary, api_name="jax.vmap")
918925
def vmap(fun: F,
919926
in_axes: int | None | Sequence[Any] = 0,
920927
out_axes: Any = 0,
@@ -1234,7 +1241,7 @@ def _all_sizes_index(sz):
12341241
msg.append(f" * some axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n")
12351242
raise ValueError(''.join(msg)[:-2]) # remove last semicolon and newline
12361243

1237-
1244+
@partial(core.jax_boundary, api_name="jax.pmap")
12381245
def pmap(
12391246
fun: Callable,
12401247
axis_name: AxisName | None = None,
@@ -1789,6 +1796,7 @@ def _cpp_mapped_lower(pmap_f, *args, **kwargs):
17891796

17901797

17911798
@api_boundary
1799+
@partial(core.jax_boundary, api_name="jax.jvp")
17921800
def jvp(
17931801
fun: Callable, primals, tangents, has_aux: bool = False
17941802
) -> tuple[Any, ...]:
@@ -1883,6 +1891,7 @@ def linearize(fun: Callable, *primals, has_aux: Literal[True]
18831891
) -> tuple[Any, Callable, Any]:
18841892
...
18851893

1894+
@partial(core.jax_boundary, api_name="jax.linearize")
18861895
def linearize(fun: Callable, *primals, has_aux: bool = False
18871896
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
18881897
"""Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval.
@@ -2073,6 +2082,7 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
20732082
...
20742083

20752084
@api_boundary
2085+
@partial(core.jax_boundary, api_name="jax.vjp")
20762086
def vjp(
20772087
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
20782088
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
@@ -2146,7 +2156,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
21462156
else:
21472157
return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)
21482158

2149-
2159+
@partial(core.jax_boundary, api_name="jax.saved_input_vjp")
21502160
def saved_input_vjp(f: Callable, which: Sequence[bool], *primals,
21512161
allow_unused: bool = True, allow_opaque: bool = True):
21522162
if len(which) != len(primals):
@@ -2490,6 +2500,7 @@ def make_jaxpr(
24902500
) -> Callable[..., tuple[core.ClosedJaxpr, Any]]:
24912501
...
24922502

2503+
@partial(core.jax_boundary, api_name="jax.make_japr")
24932504
def make_jaxpr(
24942505
fun: Callable,
24952506
static_argnums: int | Iterable[int] = (),
@@ -3001,6 +3012,7 @@ def eval_shape(fun, *args, **kwargs):
30013012
return jit(fun).trace(*args, **kwargs).out_info
30023013

30033014

3015+
@partial(core.jax_boundary, api_name="jax.named_call")
30043016
def named_call(
30053017
fun: F,
30063018
*,

jax/_src/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,3 +2030,19 @@ def _default_pmap_no_rank_reduction(new_val):
20302030
' ragged_dot_general_p.'
20312031
),
20322032
)
2033+
2034+
repro_dir = string_flag(
2035+
name='jax_output_dir',
2036+
default=os.getenv("JAX_REPRO_FLAGS", ""),
2037+
help=(
2038+
'Turn on saving of repros. EXPERIMENTAL.'
2039+
),
2040+
)
2041+
2042+
repro_flags = string_flag(
2043+
name='jax_repro_flags',
2044+
default=os.getenv("JAX_REPRO_FLAGS", ""),
2045+
help=(
2046+
'Comma-separated flags for repros. EXPERIMENTAL.'
2047+
),
2048+
)

jax/_src/core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
TracerIntegerConversionError, UnexpectedTracerError)
4848
from jax._src import linear_util as lu
4949
from jax._src.tree_util import tree_flatten, tree_unflatten
50+
from jax._src import repro
5051
from jax._src import source_info_util
5152
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
5253
tuple_delete, cache,
@@ -83,6 +84,8 @@
8384

8485
def identity(x): return x
8586

87+
jax_boundary = repro.jax_boundary
88+
8689
# -------------------- jaxprs --------------------
8790

8891
Effect = effects.Effect
@@ -631,7 +634,10 @@ def __repr__(self):
631634

632635
def bind(self, *args, **params):
633636
args = args if self.skip_canonicalization else map(canonicalize_value, args)
634-
return self._true_bind(*args, **params)
637+
if config.repro_dir.value:
638+
return repro._true_bind_primitive(self, args, params)
639+
else:
640+
return self._true_bind(*args, **params)
635641

636642
def _true_bind(self, *args, **params):
637643
for arg in args:

jax/_src/custom_derivatives.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def f_jvp(primals, tangents):
138138
jvp: Callable[..., tuple[ReturnValue, ReturnValue]] | None = None
139139
symbolic_zeros: bool = False
140140

141+
@partial(core.jax_boundary, api_name="jax.custom_jvp", api_constructor=True)
141142
def __init__(self,
142143
fun: Callable[..., ReturnValue],
143144
nondiff_argnums: Sequence[int] = (),
@@ -162,6 +163,7 @@ def __init__(self,
162163

163164
__getattr__ = custom_api_util.forward_attr
164165

166+
@partial(core.jax_boundary, api_name="jax_defjvp")
165167
def defjvp(self,
166168
jvp: Callable[..., tuple[ReturnValue, ReturnValue]],
167169
symbolic_zeros: bool = False,
@@ -214,6 +216,7 @@ def defjvp(self,
214216
self.symbolic_zeros = symbolic_zeros
215217
return jvp
216218

219+
@partial(core.jax_boundary, api_name="jax_defjvps")
217220
def defjvps(self, *jvps: Callable[..., ReturnValue] | None) -> None:
218221
"""Convenience wrapper for defining JVPs for each argument separately.
219222
@@ -257,6 +260,7 @@ def jvp(primals, tangents):
257260
self.defjvp(jvp)
258261

259262
@traceback_util.api_boundary
263+
@partial(core.jax_boundary, api_name="jax.custom_jvp_call")
260264
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
261265
debug = debug_info("custom_jvp fun", self.fun, args, kwargs,
262266
static_argnums=self.nondiff_argnums)
@@ -559,7 +563,7 @@ def f_bwd(res, g):
559563
560564
.. _tutorial: https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
561565
"""
562-
566+
@partial(core.jax_boundary, api_name="jax.custom_vjp", api_constructor=True)
563567
def __init__(self,
564568
fun: Callable[..., ReturnValue],
565569
nondiff_argnums: Sequence[int] = (),
@@ -587,6 +591,7 @@ def __init__(self,
587591

588592
__getattr__ = custom_api_util.forward_attr
589593

594+
@partial(core.jax_boundary, api_name="jax_defvjp")
590595
def defvjp(self,
591596
fwd: Callable[..., tuple[ReturnValue, Any]],
592597
bwd: Callable[..., tuple[Any, ...]],
@@ -686,6 +691,7 @@ def defvjp(self,
686691
"remat optimization for custom_vjp does not support symbolic zeros")
687692

688693
@traceback_util.api_boundary
694+
@partial(core.jax_boundary, api_name="jax.custom_vjp_call")
689695
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
690696
debug_fun = debug_info("custom_vjp fun", self.fun, args, kwargs,
691697
static_argnums=self.nondiff_argnums)

jax/_src/interpreters/partial_eval.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2382,7 +2382,6 @@ def trace_to_jaxpr_dynamic(
23822382
in_tracers = _input_type_to_tracers(
23832383
partial(trace.new_arg, source_info=source_info), in_avals)
23842384
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
2385-
23862385
with core.set_current_trace(trace):
23872386
ans = fun.call_wrapped(*in_tracers)
23882387
_check_returned_jaxtypes(fun.debug_info, ans)

jax/_src/lax/control_flow/conditionals.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def _capitalize(s):
400400
return s[0].capitalize() + s[1:]
401401

402402
@api_boundary
403+
@partial(core.jax_boundary, api_name="jax.lax.cond")
403404
@functools.wraps(_cond)
404405
def cond(*args, **kwargs):
405406
# detect an attempt to call the former, deprecated cond

jax/_src/lax/control_flow/loops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
110110
Y = TypeVar('Y')
111111

112112
@api_boundary
113+
@partial(core.jax_boundary, api_name="jax.lax.scan")
113114
def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
114115
init: Carry,
115116
xs: X | None = None,
@@ -1603,6 +1604,7 @@ def _move_right(lst, to_move):
16031604
### while_loop
16041605

16051606
@api_boundary
1607+
@partial(core.jax_boundary, api_name="jax.lax.while_loop")
16061608
def while_loop(cond_fun: Callable[[T], BooleanNumeric],
16071609
body_fun: Callable[[T], T],
16081610
init_val: T) -> T:

0 commit comments

Comments
 (0)