@@ -770,5 +770,72 @@ def test_f16_mean(self, dtype):
770770 self .assertAllClose (expected , actual , atol = 0 )
771771
772772
773+
774+ def _is_canonical (dtype ):
775+ _dtype = dtypes .dtype (dtype )
776+ return _dtype == dtypes .canonicalize_dtype (_dtype )
777+
778+ @jtu .sample_product (
779+ [dict (shape = shape , axis = axis )
780+ for shape in all_shapes
781+ for axis in list (
782+ range (- len (shape ), len (shape ))
783+ ) + ([None ] if len (shape ) == 1 else [])],
784+ dtype = filter (_is_canonical , all_dtypes ),
785+ out_dtype = filter (_is_canonical , all_dtypes ),
786+ include_initial = [False , True ],
787+ )
788+ @jtu .ignore_warning (category = NumpyComplexWarning )
789+ @jax .numpy_dtype_promotion ('standard' ) # This test explicitly exercises mixed type promotion
790+ def testCumulativeSum (self , shape , axis , dtype , out_dtype , include_initial ):
791+ rng = jtu .rand_some_zero (self .rng ())
792+
793+ def np_mock_fun (x , axis = None , dtype = None , include_initial = False ):
794+ kind = x .dtype .kind
795+ if (dtype is None and kind in {'i' , 'u' }
796+ and x .dtype .itemsize * 8 < int (config .default_dtype_bits .value )):
797+ dtype = dtypes .canonicalize_dtype (dtypes ._default_types [kind ])
798+ axis = axis or 0
799+ x = x .astype (dtype = dtype or x .dtype )
800+ out = jnp .cumsum (x , axis = axis )
801+ if include_initial :
802+ zeros_shape = list (x .shape )
803+ zeros_shape [axis ] = 1
804+ out = jnp .concat ([jnp .zeros (zeros_shape , dtype = out .dtype ), out ], axis = axis )
805+ return out
806+
807+
808+ # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as
809+ # input because we rely on JAX-specific casting behavior
810+ args_maker = lambda : [jnp .array (rng (shape , dtype ))]
811+ np_op = getattr (np , "cumulative_sum" , np_mock_fun )
812+ kwargs = dict (axis = axis , dtype = out_dtype , include_initial = include_initial )
813+ np_fun = lambda x : np_op (x , ** kwargs )
814+ jnp_fun = lambda x : jnp .cumulative_sum (x , ** kwargs )
815+ self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker )
816+ self ._CompileAndCheck (jnp_fun , args_maker )
817+
818+ kwargs = dict (axis = axis , include_initial = include_initial )
819+ self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker )
820+ self ._CompileAndCheck (jnp_fun , args_maker )
821+
822+
823+ @jtu .sample_product (
824+ shape = all_shapes , dtype = all_dtypes ,
825+ include_initial = [False , True ])
826+ def testCumulativeSumErrors (self , shape , dtype , include_initial ):
827+ rng = jtu .rand_some_zero (self .rng ())
828+ x = rng (shape , dtype )
829+ if jnp .isscalar (x ) or x .ndim == 0 :
830+ msg = r"The input must be non-scalar to take"
831+ with self .assertRaisesRegex (ValueError , msg ):
832+ jnp .cumulative_sum (x , include_initial = include_initial )
833+ elif x .ndim > 1 :
834+ msg = r"The input array has rank \d*, however"
835+ with self .assertRaisesRegex (ValueError , msg ):
836+ jnp .cumulative_sum (x , include_initial = include_initial )
837+
838+
839+
773840if __name__ == "__main__" :
774841 absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments