Skip to content

Commit 10d2ede

Browse files
committed
Add new cumulative_sum function to numpy and array_api namespaces
1 parent cd9dcd2 commit 10d2ede

10 files changed

Lines changed: 116 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ Remember to align the itemized text with the first line of an item within a list
1111
* New Functionality
1212
* Added {func}`jax.numpy.unstack`, following the addition of this function in
1313
the array API 2023 standard, soon to be adopted by NumPy.
14+
* Added {func}`jax.numpy.cumulative_sum`, following the addition of this
15+
function in the array API 2023 standard, soon to be adopted by NumPy.
1416

1517
* Changes
1618
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`

docs/jax.numpy.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ namespace; they are listed below.
138138
csingle
139139
cumprod
140140
cumsum
141+
cumulative_sum
141142
deg2rad
142143
degrees
143144
delete

jax/_src/numpy/reductions.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from jax import lax
2828
from jax._src import api
29-
from jax._src import core
29+
from jax._src import core, config
3030
from jax._src import dtypes
3131
from jax._src.numpy import ufuncs
3232
from jax._src.numpy.util import (
@@ -708,6 +708,40 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
708708
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
709709
fill_nan=True, fill_value=1)
710710

711+
@implements(getattr(np, 'cumulative_sum', None))
712+
def cumulative_sum(
713+
x: ArrayLike, /, *, axis: int | None = None,
714+
dtype: DTypeLike | None = None,
715+
include_initial: bool = False) -> Array:
716+
check_arraylike("cumulative_sum", x)
717+
x = lax_internal.asarray(x)
718+
if x.ndim == 0:
719+
raise ValueError(
720+
"The input must be non-scalar to take a cumulative sum, however a "
721+
"scalar value or scalar array was given."
722+
)
723+
if axis is None and x.ndim > 1:
724+
raise ValueError(
725+
f"The input array has rank {x.ndim}, however axis was not set to an "
726+
"explicit value. The axis argument is only optional for one-dimensional "
727+
"arrays.")
728+
axis = axis or 0
729+
axis = _canonicalize_axis(axis, x.ndim)
730+
dtypes.check_user_dtype_supported(dtype)
731+
kind = x.dtype.kind
732+
if (dtype is None and kind in {'i', 'u'}
733+
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
734+
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
735+
x = x.astype(dtype=dtype or x.dtype)
736+
out = cumsum(x, axis=axis)
737+
if include_initial:
738+
zeros_shape = list(x.shape)
739+
zeros_shape[axis] = 1
740+
out = lax_internal.concatenate(
741+
[lax_internal.full(zeros_shape, 0, dtype=out.dtype), out],
742+
dimension=axis)
743+
return out
744+
711745
# Quantiles
712746
@implements(np.quantile, skip_params=['out', 'overwrite_input'])
713747
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',

jax/experimental/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@
204204
)
205205

206206
from jax.experimental.array_api._statistical_functions import (
207+
cumulative_sum as cumulative_sum,
207208
max as max,
208209
mean as mean,
209210
min as min,

jax/experimental/array_api/_statistical_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
)
1919

2020

21+
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
22+
"""Calculates the cumulative sum of elements in the input array x."""
23+
return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial)
24+
2125
def max(x, /, *, axis=None, keepdims=False):
2226
"""Calculates the maximum value of the input array x."""
2327
return jax.numpy.max(x, axis=axis, keepdims=keepdims)

jax/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@
296296
count_nonzero as count_nonzero,
297297
cumsum as cumsum,
298298
cumprod as cumprod,
299+
cumulative_sum as cumulative_sum,
299300
max as max,
300301
mean as mean,
301302
median as median,

jax/numpy/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
241241
cumproduct = cumprod
242242
def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
243243
out: None = ...) -> Array: ...
244+
def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ...,
245+
dtype: DTypeLike | None = ...,
246+
include_initial: bool = ...) -> Array: ...
244247

245248
def deg2rad(x: ArrayLike, /) -> Array: ...
246249
degrees = rad2deg

tests/array_api_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
'copysign',
6969
'cos',
7070
'cosh',
71+
'cumulative_sum',
7172
'divide',
7273
'e',
7374
'empty',

tests/lax_numpy_reducers_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,5 +770,72 @@ def test_f16_mean(self, dtype):
770770
self.assertAllClose(expected, actual, atol=0)
771771

772772

773+
774+
def _is_canonical(dtype):
775+
_dtype = dtypes.dtype(dtype)
776+
return _dtype == dtypes.canonicalize_dtype(_dtype)
777+
778+
@jtu.sample_product(
779+
[dict(shape=shape, axis=axis)
780+
for shape in all_shapes
781+
for axis in list(
782+
range(-len(shape), len(shape))
783+
) + ([None] if len(shape) == 1 else [])],
784+
dtype=filter(_is_canonical, all_dtypes),
785+
out_dtype=filter(_is_canonical, all_dtypes),
786+
include_initial=[False, True],
787+
)
788+
@jtu.ignore_warning(category=NumpyComplexWarning)
789+
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
790+
def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial):
791+
rng = jtu.rand_some_zero(self.rng())
792+
793+
def np_mock_fun(x, axis=None, dtype=None, include_initial=False):
794+
kind = x.dtype.kind
795+
if (dtype is None and kind in {'i', 'u'}
796+
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
797+
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
798+
axis = axis or 0
799+
x = x.astype(dtype=dtype or x.dtype)
800+
out = jnp.cumsum(x, axis=axis)
801+
if include_initial:
802+
zeros_shape = list(x.shape)
803+
zeros_shape[axis] = 1
804+
out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis)
805+
return out
806+
807+
808+
# We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as
809+
# input because we rely on JAX-specific casting behavior
810+
args_maker = lambda: [jnp.array(rng(shape, dtype))]
811+
np_op = getattr(np, "cumulative_sum", np_mock_fun)
812+
kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial)
813+
np_fun = lambda x: np_op(x, **kwargs)
814+
jnp_fun = lambda x: jnp.cumulative_sum(x, **kwargs)
815+
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
816+
self._CompileAndCheck(jnp_fun, args_maker)
817+
818+
kwargs = dict(axis=axis, include_initial=include_initial)
819+
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
820+
self._CompileAndCheck(jnp_fun, args_maker)
821+
822+
823+
@jtu.sample_product(
824+
shape=all_shapes, dtype=all_dtypes,
825+
include_initial=[False, True])
826+
def testCumulativeSumErrors(self, shape, dtype, include_initial):
827+
rng = jtu.rand_some_zero(self.rng())
828+
x = rng(shape, dtype)
829+
if jnp.isscalar(x) or x.ndim == 0:
830+
msg = r"The input must be non-scalar to take"
831+
with self.assertRaisesRegex(ValueError, msg):
832+
jnp.cumulative_sum(x, include_initial=include_initial)
833+
elif x.ndim > 1:
834+
msg = r"The input array has rank \d*, however"
835+
with self.assertRaisesRegex(ValueError, msg):
836+
jnp.cumulative_sum(x, include_initial=include_initial)
837+
838+
839+
773840
if __name__ == "__main__":
774841
absltest.main(testLoader=jtu.JaxTestLoader())

tests/lax_numpy_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def np_fun(x):
307307
atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2})
308308
self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1})
309309

310+
310311
@jtu.sample_product(
311312
[dict(shape=shape, axis=axis)
312313
for shape in all_shapes

0 commit comments

Comments
 (0)