Skip to content

Commit 8da86ea

Browse files
Jake VanderPlasGoogle-ML-Automation
Jake VanderPlas
authored andcommitted
Move _src/interpreters/ad.py to its own BUILD rule.
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This required moving some internal utilities out of dispatch.py, which is part of the main JAX build rule. I chose api_util.py because they seem to fit there. PiperOrigin-RevId: 761722054
1 parent 61a9bd2 commit 8da86ea

File tree

9 files changed

+81
-61
lines changed

9 files changed

+81
-61
lines changed

jax/BUILD

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ py_library_providing_imports_info(
316316
"_src/ffi.py",
317317
"_src/flatten_util.py",
318318
"_src/interpreters/__init__.py",
319-
"_src/interpreters/ad.py",
320319
"_src/interpreters/batching.py",
321320
"_src/interpreters/pxla.py",
322321
"_src/pjit.py",
@@ -381,6 +380,7 @@ py_library_providing_imports_info(
381380
visibility = ["//visibility:public"],
382381
deps = [
383382
":abstract_arrays",
383+
":ad",
384384
":ad_util",
385385
":api_util",
386386
":basearray",
@@ -671,6 +671,23 @@ pytype_strict_library(
671671
] + py_deps("numpy"),
672672
)
673673

674+
pytype_strict_library(
675+
name = "ad",
676+
srcs = ["_src/interpreters/ad.py"],
677+
deps = [
678+
":ad_util",
679+
":api_util",
680+
":config",
681+
":core",
682+
":dtypes",
683+
":mesh",
684+
":partial_eval",
685+
":source_info_util",
686+
":tree_util",
687+
":util",
688+
],
689+
)
690+
674691
pytype_strict_library(
675692
name = "mlir",
676693
srcs = ["_src/interpreters/mlir.py"],

jax/_src/api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import numpy as np
3838
from contextlib import contextmanager
3939

40+
from jax._src import api_util
4041
from jax._src import deprecations
4142
from jax._src import linear_util as lu
4243
from jax._src import stages
@@ -113,14 +114,14 @@ def _nan_check_posthook(fun, args, kwargs, output):
113114

114115
try:
115116
dispatch.check_special(pjit.pjit_p.name, buffers)
116-
except dispatch.InternalFloatingPointError as e:
117+
except api_util.InternalFloatingPointError as e:
117118
assert config.debug_nans.value or config.debug_infs.value
118119
if hasattr(fun, '_fun'):
119120
f = fun._fun
120121
if getattr(f, '_apply_primitive', False):
121122
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
122123
# compiled_fun can only raise in this case
123-
dispatch.maybe_recursive_nan_check(e, f, args, kwargs)
124+
api_util.maybe_recursive_nan_check(e, f, args, kwargs)
124125
raise AssertionError("Unreachable") from e
125126
else:
126127
# TODO(emilyaf): Shouldn't need this fallback.
@@ -1707,7 +1708,7 @@ def cache_miss(*args, **kwargs):
17071708
out = execute(*p.flat_args)
17081709
else:
17091710
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
1710-
except dispatch.InternalFloatingPointError as e:
1711+
except api_util.InternalFloatingPointError as e:
17111712
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.')
17121713

17131714
out_tree, out_flat = p.out_tree, out

jax/_src/api_util.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,3 +767,41 @@ def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> Non
767767
f"array reference of type {a.str_short()} was both closed over and "
768768
f"passed as the argument "
769769
f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}")
770+
771+
class InternalFloatingPointError(Exception):
772+
name: str
773+
ty: str
774+
775+
def __init__(self, name: str, ty: str):
776+
self.name = name
777+
self.ty = ty
778+
779+
def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs,
780+
) -> None: # always raises an exception
781+
print("Invalid nan value encountered in the output of a jax.jit "
782+
"function. Calling the de-optimized version.")
783+
try:
784+
_ = fun(*args, **kwargs)
785+
except (FloatingPointError, ZeroDivisionError) as e2:
786+
raise e2 from None
787+
else:
788+
_raise_no_nan_in_deoptimized(e)
789+
790+
791+
def _raise_no_nan_in_deoptimized(e) -> None:
792+
msg = (f"{str(e)}. Because "
793+
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
794+
"de-optimized function (i.e., the function as if the `jit` "
795+
"decorator were removed) was called in an attempt to get a more "
796+
"precise error message. However, the de-optimized function did not "
797+
"produce invalid values during its execution. This behavior can "
798+
"result from `jit` optimizations causing the invalid value to be "
799+
"produced. It may also arise from having nan/inf literals as "
800+
"inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
801+
"\n\n"
802+
"It may be possible to avoid the invalid value by removing the "
803+
"`jit` decorator, at the cost of losing optimizations. "
804+
"\n\n"
805+
"If you see this error, consider opening a bug report at "
806+
"https://github.com/jax-ml/jax.")
807+
raise FloatingPointError(msg) from None

jax/_src/dispatch.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import logging
2525
import threading
2626
import time
27-
from typing import Any, Callable
27+
from typing import Any
2828

2929
import jax
3030
from jax._src import api
@@ -42,6 +42,7 @@
4242
from jax._src.interpreters import mlir
4343
from jax._src.interpreters import pxla
4444
from jax._src.interpreters import xla
45+
from jax._src.api_util import InternalFloatingPointError
4546
from jax._src.layout import DeviceLocalLayout, Layout
4647
from jax._src.lib import xla_client as xc
4748
from jax._src.mesh import AbstractMesh, Mesh
@@ -341,43 +342,6 @@ class CopySemantics(enum.Enum):
341342
COPY = enum.auto()
342343
DONATE = enum.auto()
343344

344-
class InternalFloatingPointError(Exception):
345-
name: str
346-
ty: str
347-
348-
def __init__(self, name: str, ty: str):
349-
self.name = name
350-
self.ty = ty
351-
352-
def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs,
353-
) -> None: # always raises an exception
354-
print("Invalid nan value encountered in the output of a jax.jit "
355-
"function. Calling the de-optimized version.")
356-
try:
357-
_ = fun(*args, **kwargs)
358-
except (FloatingPointError, ZeroDivisionError) as e2:
359-
raise e2 from None
360-
else:
361-
_raise_no_nan_in_deoptimized(e)
362-
363-
def _raise_no_nan_in_deoptimized(e) -> None:
364-
msg = (f"{str(e)}. Because "
365-
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
366-
"de-optimized function (i.e., the function as if the `jit` "
367-
"decorator were removed) was called in an attempt to get a more "
368-
"precise error message. However, the de-optimized function did not "
369-
"produce invalid values during its execution. This behavior can "
370-
"result from `jit` optimizations causing the invalid value to be "
371-
"produced. It may also arise from having nan/inf literals as "
372-
"inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
373-
"\n\n"
374-
"It may be possible to avoid the invalid value by removing the "
375-
"`jit` decorator, at the cost of losing optimizations. "
376-
"\n\n"
377-
"If you see this error, consider opening a bug report at "
378-
"https://github.com/jax-ml/jax.")
379-
raise FloatingPointError(msg) from None
380-
381345
def _identity_fn(x):
382346
return x
383347

jax/_src/interpreters/ad.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
from functools import partial
2222
from typing import Any
2323

24+
from jax._src import api_util
2425
from jax._src import config
25-
from jax._src import dispatch
2626
from jax._src import linear_util as lu
2727
from jax._src.interpreters import partial_eval as pe
28-
from jax.tree_util import (tree_flatten, tree_unflatten,
29-
register_pytree_node, Partial, PyTreeDef)
28+
from jax._src.tree_util import (tree_flatten, tree_unflatten,
29+
register_pytree_node, Partial, PyTreeDef)
3030
from jax._src import mesh as mesh_lib
3131
from jax._src import core
3232
from jax._src import source_info_util
@@ -1125,7 +1125,7 @@ def out_axes_thunk():
11251125

11261126
try:
11271127
out_flat = primitive.bind(fun, *all_args, **new_params)
1128-
except dispatch.InternalFloatingPointError as e:
1128+
except api_util.InternalFloatingPointError as e:
11291129
print("Invalid nan value encountered in the backward pass of a jax.jit "
11301130
"function. Calling the de-optimized backward pass.")
11311131
try:
@@ -1135,7 +1135,7 @@ def out_axes_thunk():
11351135
else:
11361136
# If control reaches this line, we got a NaN on the output of `compiled`
11371137
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
1138-
dispatch._raise_no_nan_in_deoptimized(e)
1138+
api_util._raise_no_nan_in_deoptimized(e)
11391139
arg_cts = tree_unflatten(out_tree(), out_flat)
11401140

11411141
# The freevars are being fanned out (not mapped). During transpose the
@@ -1266,11 +1266,3 @@ def __init__(self):
12661266

12671267
# TODO(mattjj): remove this vestigial dict
12681268
reducing_transposes: dict[core.Primitive, Callable] = {}
1269-
1270-
########################### pvary ##################################
1271-
1272-
def _pvary_transpose_rule(cts, *_, axes, axis_index_groups):
1273-
from jax._src.lax import parallel as lax_parallel
1274-
return lax_parallel.psum_invariant_p.bind(
1275-
*cts, axes=axes, axis_index_groups=axis_index_groups)
1276-
deflinear2(core.pvary_p, _pvary_transpose_rule)

jax/_src/lax/parallel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,3 +2059,10 @@ def _psum_invariant_transpose_rule(cts, *args, axes, axis_index_groups):
20592059
del args
20602060
return core.pvary_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
20612061
ad.deflinear2(psum_invariant_p, _psum_invariant_transpose_rule)
2062+
2063+
########################### pvary ##################################
2064+
2065+
def _pvary_transpose_rule(cts, *_, axes, axis_index_groups):
2066+
return psum_invariant_p.bind(
2067+
*cts, axes=axes, axis_index_groups=axis_index_groups)
2068+
ad.deflinear2(core.pvary_p, _pvary_transpose_rule)

jax/_src/pjit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,10 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
176176
f"Argument '{name}' of shape {aval.str_short()} of type"
177177
f' {type(arg)} is not a valid JAX type.') from e
178178
raise AssertionError("Unreachable") from e
179-
except dispatch.InternalFloatingPointError as e:
179+
except api_util.InternalFloatingPointError as e:
180180
if getattr(fun, '_apply_primitive', False):
181181
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None
182-
dispatch.maybe_recursive_nan_check(e, fun, args, kwargs)
182+
api_util.maybe_recursive_nan_check(e, fun, args, kwargs)
183183

184184
if p.box_data:
185185
box_treedef, out_tree = p.out_tree.children()
@@ -2562,7 +2562,7 @@ def prune_type(ty, xs, maybe_zeros):
25622562
keep_unused=keep_unused,
25632563
inline=inline,
25642564
compiler_options_kvs=compiler_options_kvs)
2565-
except dispatch.InternalFloatingPointError as e:
2565+
except api_util.InternalFloatingPointError as e:
25662566
print("Invalid nan value encountered in the backward pass of a jax.jit "
25672567
"function. Calling the de-optimized backward pass.")
25682568
try:
@@ -2572,7 +2572,7 @@ def prune_type(ty, xs, maybe_zeros):
25722572
else:
25732573
# If control reaches this line, we got a NaN on the output of `compiled`
25742574
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
2575-
dispatch._raise_no_nan_in_deoptimized(e)
2575+
api_util._raise_no_nan_in_deoptimized(e)
25762576

25772577
if attrs_tracked:
25782578
final_states, nz_cts_out = split_list(nz_cts_out, [num_attr_outs])

jax/_src/shard_map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ def _maybe_check_special(outs):
10541054
for s in getattr(leaf, 'addressable_shards', [])]
10551055
try:
10561056
dispatch.check_special('shard_map', bufs)
1057-
except dispatch.InternalFloatingPointError as e:
1057+
except api_util.InternalFloatingPointError as e:
10581058
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None
10591059

10601060
class ShardMapTrace(core.Trace):
@@ -1562,7 +1562,7 @@ def new_out_specs_thunk():
15621562
except (FloatingPointError, ZeroDivisionError) as e2:
15631563
raise e2 from None
15641564
else:
1565-
dispatch._raise_no_nan_in_deoptimized(e)
1565+
api_util._raise_no_nan_in_deoptimized(e)
15661566
return tree_unflatten(out_tree(), out_flat)
15671567
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
15681568

jax/extend/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ py_library_providing_imports_info(
4343
deps = [
4444
"//jax",
4545
"//jax:abstract_arrays",
46+
"//jax:ad",
4647
"//jax:ad_util",
4748
"//jax:core",
4849
],

0 commit comments

Comments
 (0)