1515from __future__ import annotations
1616
1717from collections .abc import Callable , Sequence
18- import functools
1918from functools import partial
2019import logging
2120from typing import Any
@@ -168,6 +167,7 @@ def policy(prim, *args, **params):
168167def checkpoint (fun : Callable , * , prevent_cse : bool = True ,
169168 policy : Callable [..., bool ] | None = None ,
170169 static_argnums : int | tuple [int , ...] = (),
170+ concrete : bool = False ,
171171 ) -> Callable :
172172 """Make ``fun`` recompute internal linearization points when differentiated.
173173
@@ -222,6 +222,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
222222 returns a boolean indicating whether the corresponding output value(s) can
223223 be saved as residuals (or instead must be recomputed in the (co)tangent
224224 computation if needed).
225+ concrete: Ignored vestigial argument. It does nothing.
225226
226227 Returns:
227228 A function (callable) with the same input/output behavior as ``fun`` but
@@ -309,6 +310,8 @@ def foo(x, y):
309310 ``jax.ensure_compile_time_eval``), it may be easier to compute some values
310311 outside the :func:`jax.checkpoint`-decorated function and then close over them.
311312 """
313+ del concrete # Ignored.
314+
312315 @wraps (fun )
313316 @api_boundary
314317 def fun_remat (* args , ** kwargs ):
@@ -322,7 +325,14 @@ def fun_remat(*args, **kwargs):
322325 return tree_unflatten (out_tree , out_flat )
323326 return fun_remat
324327
325- remat = checkpoint # alias
328+ def remat (fun : Callable , * , prevent_cse : bool = True ,
329+ policy : Callable [..., bool ] | None = None ,
330+ static_argnums : int | tuple [int , ...] = (),
331+ ) -> Callable :
332+ """Alias of :func:`~jax.checkpoint`."""
333+ return checkpoint (fun , prevent_cse = prevent_cse , policy = policy ,
334+ static_argnums = static_argnums )
335+
326336
327337# This function is similar to api_util.argnums_partial, except the error
328338# messages are specific to jax.remat (and thus more actionable), the
@@ -855,65 +865,5 @@ def name_batcher(args, dims, *, name):
855865 return name_p .bind (x , name = name ), d
856866batching .primitive_batchers [name_p ] = name_batcher
857867
858-
859- @functools .wraps (checkpoint )
860- def checkpoint_wrapper (
861- fun : Callable ,
862- * ,
863- concrete : bool = False ,
864- prevent_cse : bool = True ,
865- static_argnums : int | tuple [int , ...] = (),
866- policy : Callable [..., bool ] | None = None ,
867- ) -> Callable :
868- if concrete :
869- msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
870- "in its place, you can use its `static_argnums` option, and if "
871- "necessary the `jax.ensure_compile_time_eval()` context manager.\n "
872- "\n "
873- "For example, if using `concrete=True` for an `is_training` flag:\n "
874- "\n "
875- " from functools import partial\n "
876- "\n "
877- " @partial(jax.checkpoint, concrete=True)\n "
878- " def foo(x, is_training):\n "
879- " if is_training:\n "
880- " return f(x)\n "
881- " else:\n "
882- " return g(x)\n "
883- "\n "
884- "replace it with a use of `static_argnums`:\n "
885- "\n "
886- " @partial(jax.checkpoint, static_argnums=(1,))\n "
887- " def foo(x, is_training):\n "
888- " ...\n "
889- "\n "
890- "If jax.numpy operations need to be performed on static arguments, "
891- "we can use the `jax.ensure_compile_time_eval()` context manager. "
892- "For example, we can replace this use of `concrete=True`\n :"
893- "\n "
894- " @partial(jax.checkpoint, concrete=True)\n "
895- " def foo(x, y):\n "
896- " if y > 0:\n "
897- " return f(x)\n "
898- " else:\n "
899- " return g(x)\n "
900- "\n "
901- "with this combination of `static_argnums` and "
902- "`jax.ensure_compile_time_eval()`:\n "
903- "\n "
904- " @partial(jax.checkpoint, static_argnums=(1,))\n "
905- " def foo(x, y):\n "
906- " with jax.ensure_compile_time_eval():\n "
907- " y_pos = y > 0\n "
908- " if y_pos:\n "
909- " return f(x)\n "
910- " else:\n "
911- " return g(x)\n "
912- "\n "
913- "See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n " )
914- raise NotImplementedError (msg )
915- return checkpoint (fun , prevent_cse = prevent_cse , policy = policy ,
916- static_argnums = static_argnums )
917-
918868# TODO(phawkins): update users to refer to the public name.
919869_optimization_barrier = lax_internal .optimization_barrier
0 commit comments