@@ -147,6 +147,7 @@ def _update_debug_special_thread_local(_):
147147float0 = dtypes .float0
148148
149149
150+ @partial (core .jax_boundary , api_name = "jax.jit" )
150151def jit (
151152 fun : Callable , / , * ,
152153 in_shardings : Any = sharding_impls .UNSPECIFIED ,
@@ -347,6 +348,7 @@ def disable_jit(disable: bool = True):
347348 yield
348349
349350
351+ @partial (core .jax_boundary , api_name = "jax.grad" )
350352def grad (fun : Callable , argnums : int | Sequence [int ] = 0 ,
351353 has_aux : bool = False , holomorphic : bool = False ,
352354 allow_int : bool = False ,
@@ -413,6 +415,7 @@ def grad_f_aux(*args, **kwargs):
413415
414416 return grad_f_aux if has_aux else grad_f
415417
418+ @partial (core .jax_boundary , api_name = "jax.value_and_grad" )
416419def value_and_grad (fun : Callable , argnums : int | Sequence [int ] = 0 ,
417420 has_aux : bool = False , holomorphic : bool = False ,
418421 allow_int : bool = False , reduce_axes : Sequence [AxisName ] = ()
@@ -545,6 +548,7 @@ def _check_output_dtype_revderiv(name, holomorphic, x):
545548 "jax.vjp directly." )
546549_check_output_dtype_grad = partial (_check_output_dtype_revderiv , "grad" )
547550
551+ @partial (core .jax_boundary , api_name = "jax.fwd_and_bwd" )
548552def fwd_and_bwd (
549553 fun : Callable , argnums : int | Sequence [int ], has_aux : bool = False ,
550554 jitted : bool = True ,
@@ -619,6 +623,7 @@ def bwd(f_vjp, outgrad):
619623 return fwd , bwd
620624
621625
626+ @partial (core .jax_boundary , api_name = "jaxfwd" )
622627def jacfwd (fun : Callable , argnums : int | Sequence [int ] = 0 ,
623628 has_aux : bool = False , holomorphic : bool = False ) -> Callable :
624629 """Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
@@ -709,6 +714,7 @@ def _check_output_dtype_jacfwd(holomorphic, x):
709714 raise TypeError ("jacfwd with holomorphic=True requires outputs with complex dtype, "
710715 f"but got { aval .dtype .name } ." )
711716
717+ @partial (core .jax_boundary , api_name = "jax.jacrev" )
712718def jacrev (fun : Callable , argnums : int | Sequence [int ] = 0 ,
713719 has_aux : bool = False , holomorphic : bool = False ,
714720 allow_int : bool = False ) -> Callable :
@@ -789,7 +795,7 @@ def jacobian(fun: Callable, argnums: int | Sequence[int] = 0,
789795_check_input_dtype_jacrev = partial (_check_input_dtype_revderiv , "jacrev" )
790796_check_output_dtype_jacrev = partial (_check_output_dtype_revderiv , "jacrev" )
791797
792-
798+ @ partial ( core . jax_boundary , api_name = "jax.hessian" )
793799def hessian (fun : Callable , argnums : int | Sequence [int ] = 0 ,
794800 has_aux : bool = False , holomorphic : bool = False ) -> Callable :
795801 """Hessian of ``fun`` as a dense array.
@@ -915,6 +921,7 @@ def _split(x, indices, axis):
915921 return x ._split (indices , axis )
916922
917923
924+ @partial (core .jax_boundary , api_name = "jax.vmap" )
918925def vmap (fun : F ,
919926 in_axes : int | None | Sequence [Any ] = 0 ,
920927 out_axes : Any = 0 ,
@@ -1234,7 +1241,7 @@ def _all_sizes_index(sz):
12341241 msg .append (f" * some axes ({ ct } of them) had size { sz } , e.g. axis { ax } of { ex } ;\n " )
12351242 raise ValueError ('' .join (msg )[:- 2 ]) # remove last semicolon and newline
12361243
1237-
1244+ @ partial ( core . jax_boundary , api_name = "jax.pmap" )
12381245def pmap (
12391246 fun : Callable ,
12401247 axis_name : AxisName | None = None ,
@@ -1789,6 +1796,7 @@ def _cpp_mapped_lower(pmap_f, *args, **kwargs):
17891796
17901797
17911798@api_boundary
1799+ @partial (core .jax_boundary , api_name = "jax.jvp" )
17921800def jvp (
17931801 fun : Callable , primals , tangents , has_aux : bool = False
17941802 ) -> tuple [Any , ...]:
@@ -1883,6 +1891,7 @@ def linearize(fun: Callable, *primals, has_aux: Literal[True]
18831891 ) -> tuple [Any , Callable , Any ]:
18841892 ...
18851893
1894+ @partial (core .jax_boundary , api_name = "jax.linearize" )
18861895def linearize (fun : Callable , * primals , has_aux : bool = False
18871896 ) -> tuple [Any , Callable ] | tuple [Any , Callable , Any ]:
18881897 """Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval.
@@ -2073,6 +2082,7 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
20732082 ...
20742083
20752084@api_boundary
2085+ @partial (core .jax_boundary , api_name = "jax.vjp" )
20762086def vjp (
20772087 fun : Callable , * primals , has_aux : bool = False , reduce_axes = ()
20782088 ) -> tuple [Any , Callable ] | tuple [Any , Callable , Any ]:
@@ -2146,7 +2156,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
21462156 else :
21472157 return out_primal_py , vjp_py , tree_unflatten (aux_tree , aux )
21482158
2149-
2159+ @ partial ( core . jax_boundary , api_name = "jax.saved_input_vjp" )
21502160def saved_input_vjp (f : Callable , which : Sequence [bool ], * primals ,
21512161 allow_unused : bool = True , allow_opaque : bool = True ):
21522162 if len (which ) != len (primals ):
@@ -2490,6 +2500,7 @@ def make_jaxpr(
24902500) -> Callable [..., tuple [core .ClosedJaxpr , Any ]]:
24912501 ...
24922502
2503+ @partial (core .jax_boundary , api_name = "jax.make_japr" )
24932504def make_jaxpr (
24942505 fun : Callable ,
24952506 static_argnums : int | Iterable [int ] = (),
@@ -3001,6 +3012,7 @@ def eval_shape(fun, *args, **kwargs):
30013012 return jit (fun ).trace (* args , ** kwargs ).out_info
30023013
30033014
3015+ @partial (core .jax_boundary , api_name = "jax.named_call" )
30043016def named_call (
30053017 fun : F ,
30063018 * ,
0 commit comments