Skip to content

Commit 1d03e17

Browse files
committed
improve readthedocs behavior for jax.remat / jax.checkpoint
1 parent fcfb0b7 commit 1d03e17

3 files changed

Lines changed: 15 additions & 64 deletions

File tree

docs/jax.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Automatic differentiation
102102
custom_gradient
103103
closure_convert
104104
checkpoint
105+
remat
105106

106107
``custom_jvp``
107108
~~~~~~~~~~~~~~

jax/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@
8181

8282
from jax._src.api import effects_barrier as effects_barrier
8383
from jax._src.api import block_until_ready as block_until_ready
84-
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
84+
from jax._src.ad_checkpoint import checkpoint as checkpoint
85+
from jax._src.ad_checkpoint import remat as remat
8586
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
8687
from jax._src.api import clear_backends as _deprecated_clear_backends
8788
from jax._src.api import clear_caches as clear_caches
@@ -122,7 +123,6 @@
122123
from jax._src.xla_bridge import process_index as process_index
123124
from jax._src.xla_bridge import process_indices as process_indices
124125
from jax._src.callback import pure_callback as pure_callback
125-
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
126126
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
127127
from jax._src.api import value_and_grad as value_and_grad
128128
from jax._src.api import vjp as vjp

jax/_src/ad_checkpoint.py

Lines changed: 12 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
from collections.abc import Callable, Sequence
18-
import functools
1918
from functools import partial
2019
import logging
2120
from typing import Any
@@ -168,6 +167,7 @@ def policy(prim, *args, **params):
168167
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
169168
policy: Callable[..., bool] | None = None,
170169
static_argnums: int | tuple[int, ...] = (),
170+
concrete: bool = False,
171171
) -> Callable:
172172
"""Make ``fun`` recompute internal linearization points when differentiated.
173173
@@ -222,6 +222,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
222222
returns a boolean indicating whether the corresponding output value(s) can
223223
be saved as residuals (or instead must be recomputed in the (co)tangent
224224
computation if needed).
225+
concrete: Ignored vestigial argument. It does nothing.
225226
226227
Returns:
227228
A function (callable) with the same input/output behavior as ``fun`` but
@@ -309,6 +310,8 @@ def foo(x, y):
309310
``jax.ensure_compile_time_eval``), it may be easier to compute some values
310311
outside the :func:`jax.checkpoint`-decorated function and then close over them.
311312
"""
313+
del concrete # Ignored.
314+
312315
@wraps(fun)
313316
@api_boundary
314317
def fun_remat(*args, **kwargs):
@@ -322,7 +325,14 @@ def fun_remat(*args, **kwargs):
322325
return tree_unflatten(out_tree, out_flat)
323326
return fun_remat
324327

325-
remat = checkpoint # alias
328+
def remat(fun: Callable, *, prevent_cse: bool = True,
329+
policy: Callable[..., bool] | None = None,
330+
static_argnums: int | tuple[int, ...] = (),
331+
) -> Callable:
332+
"""Alias of :func:`~jax.checkpoint`."""
333+
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
334+
static_argnums=static_argnums)
335+
326336

327337
# This function is similar to api_util.argnums_partial, except the error
328338
# messages are specific to jax.remat (and thus more actionable), the
@@ -855,65 +865,5 @@ def name_batcher(args, dims, *, name):
855865
return name_p.bind(x, name=name), d
856866
batching.primitive_batchers[name_p] = name_batcher
857867

858-
859-
@functools.wraps(checkpoint)
860-
def checkpoint_wrapper(
861-
fun: Callable,
862-
*,
863-
concrete: bool = False,
864-
prevent_cse: bool = True,
865-
static_argnums: int | tuple[int, ...] = (),
866-
policy: Callable[..., bool] | None = None,
867-
) -> Callable:
868-
if concrete:
869-
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
870-
"in its place, you can use its `static_argnums` option, and if "
871-
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
872-
"\n"
873-
"For example, if using `concrete=True` for an `is_training` flag:\n"
874-
"\n"
875-
" from functools import partial\n"
876-
"\n"
877-
" @partial(jax.checkpoint, concrete=True)\n"
878-
" def foo(x, is_training):\n"
879-
" if is_training:\n"
880-
" return f(x)\n"
881-
" else:\n"
882-
" return g(x)\n"
883-
"\n"
884-
"replace it with a use of `static_argnums`:\n"
885-
"\n"
886-
" @partial(jax.checkpoint, static_argnums=(1,))\n"
887-
" def foo(x, is_training):\n"
888-
" ...\n"
889-
"\n"
890-
"If jax.numpy operations need to be performed on static arguments, "
891-
"we can use the `jax.ensure_compile_time_eval()` context manager. "
892-
"For example, we can replace this use of `concrete=True`\n:"
893-
"\n"
894-
" @partial(jax.checkpoint, concrete=True)\n"
895-
" def foo(x, y):\n"
896-
" if y > 0:\n"
897-
" return f(x)\n"
898-
" else:\n"
899-
" return g(x)\n"
900-
"\n"
901-
"with this combination of `static_argnums` and "
902-
"`jax.ensure_compile_time_eval()`:\n"
903-
"\n"
904-
" @partial(jax.checkpoint, static_argnums=(1,))\n"
905-
" def foo(x, y):\n"
906-
" with jax.ensure_compile_time_eval():\n"
907-
" y_pos = y > 0\n"
908-
" if y_pos:\n"
909-
" return f(x)\n"
910-
" else:\n"
911-
" return g(x)\n"
912-
"\n"
913-
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
914-
raise NotImplementedError(msg)
915-
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
916-
static_argnums=static_argnums)
917-
918868
# TODO(phawkins): update users to refer to the public name.
919869
_optimization_barrier = lax_internal.optimization_barrier

0 commit comments

Comments
 (0)