@@ -909,8 +909,7 @@ def _common(
909909 self ,
910910 at_op : str ,
911911 y = _undef ,
912- copy : bool | None = True ,
913- mode : str = "promise_in_bounds" ,
912+ copy : bool | None = True ,
914913 ** kwargs ,
915914 ):
916915 """Validate kwargs and perform common prepocessing.
@@ -931,19 +930,14 @@ def _common(
931930 " at(x)[idx].set(value)\n "
932931 "(same for all other methods)."
933932 )
934- if mode != "promise_in_bounds" and not is_jax_array (self .x ):
935- xp = array_namespace (self .x )
936- raise NotImplementedError (
937- f"mode='{ mode !r} ' is not supported for backend { xp .__name__ } "
938- )
939933
940934 copy = _parse_copy_param (self .x , copy )
941935
942936 if copy and is_jax_array (self .x ):
943937 # Use JAX's at[]
944938 at_ = self .x .at [self .idx ]
945939 args = (y , ) if y is not _undef else ()
946- return getattr (at_ , at_op )(* args , mode = mode , ** kwargs ), None
940+ return getattr (at_ , at_op )(* args , ** kwargs ), None
947941
948942 # Emulate at[] behaviour for non-JAX arrays
949943 x = self .x .copy () if copy else self .x
0 commit comments