Skip to content

Commit 8e9c8d6

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent 7d50a32 commit 8e9c8d6

File tree

2 files changed

+64
-55
lines changed

2 files changed

+64
-55
lines changed

jax/_src/numpy/reductions.py

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323

2424
import numpy as np
2525

26-
import jax
27-
from jax import lax
28-
import jax._src.numpy as jnp
26+
from jax._src.lax import lax
27+
from jax._src.lax import parallel as lax_parallel
28+
from jax._src.lax import slicing
29+
from jax._src.lax.control_flow import loops
2930
from jax._src import api
31+
from jax._src import config
3032
from jax._src import core
3133
from jax._src import deprecations
3234
from jax._src import dtypes
@@ -63,7 +65,7 @@ def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
6365
perm.insert(destination, source)
6466
return lax.transpose(a, perm)
6567

66-
def _upcast_f16(dtype: DTypeLike) -> DType:
68+
def _upcast_f16(dtype: DTypeLike) -> DTypeLike:
6769
if np.dtype(dtype) in [np.float16, dtypes.bfloat16]:
6870
return np.dtype('float32')
6971
return np.dtype(dtype)
@@ -234,7 +236,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
234236
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
235237
bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None),
236238
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
237-
initial=initial, where_=where, parallel_reduce=lax.psum,
239+
initial=initial, where_=where, parallel_reduce=lax_parallel.psum,
238240
promote_integers=promote_integers)
239241

240242

@@ -407,7 +409,7 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
407409
where: ArrayLike | None = None) -> Array:
408410
return _reduction(a, "max", lax.max, -np.inf, has_identity=False,
409411
axis=axis, out=out, keepdims=keepdims,
410-
initial=initial, where_=where, parallel_reduce=lax.pmax)
412+
initial=initial, where_=where, parallel_reduce=lax_parallel.pmax)
411413

412414

413415
@export
@@ -490,7 +492,7 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
490492
where: ArrayLike | None = None) -> Array:
491493
return _reduction(a, "min", lax.min, np.inf, has_identity=False,
492494
axis=axis, out=out, keepdims=keepdims,
493-
initial=initial, where_=where, parallel_reduce=lax.pmin)
495+
initial=initial, where_=where, parallel_reduce=lax_parallel.pmin)
494496

495497

496498
@export
@@ -797,7 +799,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]):
797799
size = 1
798800
a_shape = np.shape(a)
799801
for a in axis_seq:
800-
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
802+
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax_parallel.psum(1, name))
801803
return size
802804

803805

@@ -1140,12 +1142,12 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
11401142
normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype))
11411143
result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where)
11421144
result = lax.div(result, normalizer).astype(dtype)
1143-
with jax.debug_nans(False):
1145+
with config.debug_nans(False):
11441146
result = _where(normalizer > 0, result, np.nan)
11451147
return result
11461148

11471149

1148-
def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DType, DType]:
1150+
def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[Any, Any]:
11491151
if dtype:
11501152
if (not dtypes.issubdtype(dtype, np.complexfloating) and
11511153
dtypes.issubdtype(a_dtype, np.complexfloating)):
@@ -2010,8 +2012,8 @@ def _cumulative_reduction(
20102012
if fill_nan:
20112013
a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)
20122014

2013-
a_type: DType = dtypes.dtype(a)
2014-
result_type: DTypeLike = dtypes.dtype(dtype or a)
2015+
a_type = dtypes.dtype(a)
2016+
result_type = dtypes.dtype(dtype or a)
20152017
if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_):
20162018
result_type = _promote_integer_dtype(result_type)
20172019
result_type = dtypes.canonicalize_dtype(result_type)
@@ -2062,7 +2064,7 @@ def cumsum(a: ArrayLike, axis: int | None = None,
20622064
Array([[ 1, 3, 6],
20632065
[ 4, 9, 15]], dtype=int32)
20642066
"""
2065-
return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out)
2067+
return _cumulative_reduction("cumsum", loops.cumsum, a, axis, dtype, out)
20662068

20672069

20682070
@export
@@ -2098,7 +2100,7 @@ def cumprod(a: ArrayLike, axis: int | None = None,
20982100
Array([[ 1, 2, 6],
20992101
[ 4, 20, 120]], dtype=int32)
21002102
"""
2101-
return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out)
2103+
return _cumulative_reduction("cumprod", loops.cumprod, a, axis, dtype, out)
21022104

21032105

21042106
@export
@@ -2147,7 +2149,7 @@ def nancumsum(a: ArrayLike, axis: int | None = None,
21472149
Array([[ 1., 3., 3.],
21482150
[ 4., 4., 10.]], dtype=float32)
21492151
"""
2150-
return _cumulative_reduction("nancumsum", lax.cumsum, a, axis, dtype, out,
2152+
return _cumulative_reduction("nancumsum", loops.cumsum, a, axis, dtype, out,
21512153
fill_nan=True, fill_value=0)
21522154

21532155

@@ -2196,15 +2198,15 @@ def nancumprod(a: ArrayLike, axis: int | None = None,
21962198
Array([[ 1., 2., 2.],
21972199
[ 4., 4., 24.]], dtype=float32)
21982200
"""
2199-
return _cumulative_reduction("nancumprod", lax.cumprod, a, axis, dtype, out,
2201+
return _cumulative_reduction("nancumprod", loops.cumprod, a, axis, dtype, out,
22002202
fill_nan=True, fill_value=1)
22012203

22022204

22032205
@partial(api.jit, static_argnames=('axis', 'dtype'))
22042206
def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None,
22052207
dtype: DTypeLike | None = None, out: None = None) -> Array:
22062208
"""Utility function to compute cumsum with integer promotion."""
2207-
return _cumulative_reduction("_cumsum_with_promotion", lax.cumsum,
2209+
return _cumulative_reduction("_cumsum_with_promotion", loops.cumsum,
22082210
a, axis, dtype, out, promote_integers=True)
22092211

22102212

@@ -2322,7 +2324,7 @@ def cumulative_prod(
23222324

23232325
axis = _canonicalize_axis(axis, x.ndim)
23242326
dtypes.check_user_dtype_supported(dtype)
2325-
out = _cumulative_reduction("cumulative_prod", lax.cumprod, x, axis, dtype)
2327+
out = _cumulative_reduction("cumulative_prod", loops.cumprod, x, axis, dtype)
23262328
if include_initial:
23272329
zeros_shape = list(x.shape)
23282330
zeros_shape[axis] = 1
@@ -2486,21 +2488,24 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24862488

24872489
q, = promote_dtypes_inexact(q)
24882490
q = lax_internal.asarray(q)
2489-
if getattr(q, "ndim", 0) == 0:
2491+
q_was_scalar = (getattr(q, "ndim", 0) == 0)
2492+
if q_was_scalar:
24902493
q = lax.expand_dims(q, (0,))
24912494
q_shape = q.shape
24922495
q_ndim = q.ndim
24932496
if q_ndim > 1:
24942497
raise ValueError(f"q must be have rank <= 1, got shape {q.shape}")
2495-
24962498
a_shape = a.shape
24972499
# Handle weights
24982500
if weights is None:
24992501
a, = promote_dtypes_inexact(a)
25002502
else:
2501-
a, weights = promote_dtypes_inexact(a, weights)
2502-
weights = lax.convert_element_type(weights, a.dtype)
2503-
a_shape = a.shape
2503+
common_dtype = np.result_type(a, q, weights, np.float32)
2504+
a = a.astype(common_dtype)
2505+
q = q.astype(common_dtype)
2506+
weights = weights.astype(common_dtype)
2507+
a,q, weights = promote_dtypes_inexact(a, q, weights)
2508+
#weights = lax.convert_element_type(weights, a.dtype)
25042509
w_shape = np.shape(weights)
25052510
if np.ndim(weights) == 0:
25062511
weights = lax.broadcast_in_dim(weights, a_shape, ())
@@ -2511,8 +2516,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25112516
if axis is None:
25122517
raise TypeError("Axis must be specified when shapes of a and weights differ.")
25132518
if isinstance(axis, tuple):
2514-
if w_shape != tuple(a_shape[i] for i in axis):
2515-
raise ValueError("Shape of weights must match the shape of the axes being reduced.")
2519+
expected_shape = tuple(a_shape[i] for i in axis)
2520+
if w_shape != expected_shape:
2521+
raise ValueError("Shape of weights must match the shape of the axes being reduced.")
25162522
weights = lax.broadcast_in_dim(
25172523
weights,
25182524
shape=a_shape,
@@ -2521,18 +2527,23 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25212527
w_shape = a_shape
25222528
else:
25232529
if len(w_shape) != 1 or w_shape[0] != a_shape[axis]:
2524-
raise ValueError("Length of weights not compatible with specified axis.")
2530+
raise ValueError("Length of weights not compatible with specified axis.")
25252531
weights = lax.expand_dims(weights, (axis,))
25262532
weights = _broadcast_to(weights, a.shape)
25272533
w_shape = a_shape
2528-
2534+
25292535
if squash_nans:
25302536
nan_mask = ~lax_internal._isnan(a)
25312537
weights = _where(nan_mask, weights, 0)
25322538
else:
2533-
with jax.debug_nans(False):
2539+
with config.debug_nans(False):
25342540
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
25352541

2542+
if all(weights == 0):
2543+
raise ValueError("Sum of weights must not be zero")
2544+
if any(weights < 0):
2545+
raise ValueError("Weights must be non-negative")
2546+
25362547
total_weight = sum(weights, axis=axis, keepdims=True)
25372548
a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis)
25382549
cum_weights = cumsum(weights_sorted, axis=axis)
@@ -2549,15 +2560,15 @@ def _weighted_quantile(qi):
25492560
slice_sizes[axis] = 1
25502561
offset_start = q_ndim
25512562
total_offset_dims = len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1
2552-
dnums = lax.GatherDimensionNumbers(
2563+
dnums = slicing.GatherDimensionNumbers(
25532564
offset_dims=tuple(range(offset_start, total_offset_dims)),
25542565
collapsed_slice_dims=(axis,),
25552566
start_index_map=(axis,)
25562567
)
2557-
val = lax.gather(a_sorted, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2558-
val_prev = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2559-
cw_prev = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2560-
cw_next = lax.gather(cum_weights_norm, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2568+
val = slicing.gather(a_sorted, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2569+
val_prev = slicing.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2570+
cw_prev = slicing.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2571+
cw_next = slicing.gather(cum_weights_norm, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
25612572
if method == "linear":
25622573
denom = cw_next - cw_prev
25632574
denom = _where(denom == 0, 1, denom)
@@ -2577,16 +2588,15 @@ def _weighted_quantile(qi):
25772588
raise ValueError(f"{method=!r} not recognized")
25782589
return out
25792590

2580-
result = jax.vmap(_weighted_quantile)(q)
2591+
result = api.vmap(_weighted_quantile)(q)
25812592
if keepdims and keepdim:
2582-
if q_ndim > 0:
2583-
keepdim = [q_shape[0], *keepdim]
2584-
result = result.reshape(tuple(keepdim))
2585-
else:
2586-
if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1):
2587-
if result.ndim > 0 and result.shape[0] == 1:
2588-
result = lax.squeeze(result, (0,))
2589-
return lax.convert_element_type(result, a.dtype)
2593+
keepdim_out = list(keepdim)
2594+
if not q_was_scalar:
2595+
keepdim_out = [q_shape[0], *keepdim_out]
2596+
result = result.reshape(tuple(keepdim_out))
2597+
elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1:
2598+
result = result.squeeze(axis=0)
2599+
return result
25902600

25912601
if squash_nans:
25922602
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
@@ -2617,7 +2627,7 @@ def _weighted_quantile(qi):
26172627
index[axis] = high
26182628
high_value = a[tuple(index)]
26192629
else:
2620-
with jax.debug_nans(False):
2630+
with config.debug_nans(False):
26212631
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
26222632
a = lax.sort(a, dimension=axis)
26232633
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
@@ -2634,15 +2644,15 @@ def _weighted_quantile(qi):
26342644

26352645
slice_sizes = list(a_shape)
26362646
slice_sizes[axis] = 1
2637-
dnums = lax.GatherDimensionNumbers(
2647+
dnums = slicing.GatherDimensionNumbers(
26382648
offset_dims=tuple(range(
26392649
q_ndim,
26402650
len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
26412651
collapsed_slice_dims=() if keepdims else (axis,),
26422652
start_index_map=(axis,))
2643-
low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
2653+
low_value = slicing.gather(a, low[..., None], dimension_numbers=dnums,
26442654
slice_sizes=slice_sizes)
2645-
high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
2655+
high_value = slicing.gather(a, high[..., None], dimension_numbers=dnums,
26462656
slice_sizes=slice_sizes)
26472657
if q_ndim == 1:
26482658
low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
@@ -2667,14 +2677,13 @@ def _weighted_quantile(qi):
26672677
else:
26682678
raise ValueError(f"{method=!r} not recognized")
26692679
if keepdims and keepdim:
2670-
if q_ndim > 0:
2671-
keepdim = [np.shape(q)[0], *keepdim]
2672-
result = result.reshape(keepdim)
2673-
else:
2674-
if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1):
2675-
if result.ndim > 0 and result.shape[0] == 1:
2676-
result = lax.squeeze(result, (0,))
2677-
return lax.convert_element_type(result, a.dtype)
2680+
keepdim_out = list(keepdim)
2681+
if not q_was_scalar:
2682+
keepdim_out = [q_shape[0], *keepdim_out]
2683+
result = result.reshape(tuple(keepdim_out))
2684+
elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1:
2685+
result = result.squeeze(axis=0)
2686+
return result
26782687

26792688

26802689
# TODO(jakevdp): interpolation argument deprecated 2024-05-16

tests/lax_numpy_reducers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdim
795795
weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3
796796

797797
def np_fun(a, q, weights):
798-
return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims)
798+
return np.quantile(np.array(a), np.array(q), axis=axis, weights=weights, method=method, keepdims=keepdims)
799799
def jnp_fun(a, q, weights):
800800
return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims)
801801
args_maker = lambda: [

0 commit comments

Comments
 (0)