Skip to content

Commit 6a60ed4

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent e7ec418 commit 6a60ed4

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

jax/_src/numpy/reductions.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,7 +2337,8 @@ def cumulative_prod(
23372337
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
23382338
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
23392339
out: None = None, overwrite_input: bool = False, method: str = "linear",
2340-
keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
2340+
keepdims: bool = False, weights: ArrayLike | None = None, *,
2341+
interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
23412342
"""Compute the quantile of the data along the specified axis.
23422343
23432344
JAX implementation of :func:`numpy.quantile`.
@@ -2387,7 +2388,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23872388
("The interpolation= argument to 'quantile' is deprecated. "
23882389
"Use 'method=' instead."), stacklevel=2)
23892390
method = interpolation
2390-
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False)
2391+
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False, weights)
23912392

23922393
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
23932394
@export
@@ -2449,7 +2450,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24492450
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True)
24502451

24512452
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
2452-
method: str, keepdims: bool, squash_nans: bool) -> Array:
2453+
method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array:
24532454
if method not in ["linear", "lower", "higher", "midpoint", "nearest"]:
24542455
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'")
24552456
a, = promote_dtypes_inexact(a)
@@ -2488,6 +2489,66 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24882489
raise ValueError(f"q must be have rank <= 1, got shape {q.shape}")
24892490

24902491
a_shape = a.shape
2492+
# Handle weights
2493+
if weights is not None:
2494+
a, weights = promote_dtypes_inexact(a, weights)
2495+
if axis is None:
2496+
a = a.ravel()
2497+
weights = weights.ravel()
2498+
axis = 0
2499+
else:
2500+
weights = _broadcast_to(weights, a.shape)
2501+
if squash_nans:
2502+
nan_mask = ~lax_internal._isnan(a)
2503+
if axis is None:
2504+
a = a[nan_mask]
2505+
weights = weights[nan_mask]
2506+
else:
2507+
weights = _where(nan_mask, weights, 0)
2508+
a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis)
2509+
2510+
cum_weights = lax.cumsum(weights_sorted, axis=axis)
2511+
total_weight = lax.sum(weights_sorted, axis=axis, keepdims=True)
2512+
if lax_internal._all(total_weight == 0):
2513+
raise ValueError("Sum of weights must not be zero.")
2514+
cum_weights_norm = cum_weights / total_weight
2515+
quantile_pos = q
2516+
mask = cum_weights_norm >= quantile_pos[..., None]
2517+
idx = lax.argmin(mask.astype(int), axis=axis)
2518+
idx_prev = lax.max(idx - 1, _lax_const(idx, 0))
2519+
idx_next = idx
2520+
gather_shape = list(a_sorted.shape)
2521+
gather_shape[axis] = 1
2522+
dnums = lax.GatherDimensionNumbers(
2523+
offset_dims=tuple(range(len(a_sorted.shape))),
2524+
collapsed_slice_dims=(axis,),
2525+
start_index_map=(axis,))
2526+
prev_value = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
2527+
next_value = lax.gather(a_sorted, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
2528+
prev_cumw = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
2529+
next_cumw = lax.gather(cum_weights_norm, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
2530+
2531+
if method == "linear":
2532+
denom = next_cumw - prev_cumw
2533+
denom = lax.select(denom == 0, _lax_const(denom, 1), denom)
2534+
weight = (quantile_pos - prev_cumw) / denom
2535+
result = prev_value * (1 - weight) + next_value * weight
2536+
elif method == "lower":
2537+
result = prev_value
2538+
elif method == "higher":
2539+
result = next_value
2540+
elif method == "nearest":
2541+
use_prev = (quantile_pos - prev_cumw) < (next_cumw - quantile_pos)
2542+
result = lax.select(use_prev, prev_value, next_value)
2543+
elif method == "midpoint":
2544+
result = (prev_value + next_value) / 2
2545+
else:
2546+
raise ValueError(f"{method=!r} not recognized")
2547+
2548+
if not keepdims:
2549+
result = lax.squeeze(result, axis)
2550+
return lax.convert_element_type(result, a.dtype)
2551+
24912552

24922553
if squash_nans:
24932554
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
@@ -2578,7 +2639,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25782639
def percentile(a: ArrayLike, q: ArrayLike,
25792640
axis: int | tuple[int, ...] | None = None,
25802641
out: None = None, overwrite_input: bool = False, method: str = "linear",
2581-
keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
2642+
keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
25822643
"""Compute the percentile of the data along the specified axis.
25832644
25842645
JAX implementation of :func:`numpy.percentile`.
@@ -2627,7 +2688,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
26272688
"Use 'method=' instead."), stacklevel=2)
26282689
method = interpolation
26292690
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
2630-
method=method, keepdims=keepdims)
2691+
method=method, keepdims=keepdims, weights=weights)
26312692

26322693

26332694
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@@ -2636,7 +2697,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
26362697
def nanpercentile(a: ArrayLike, q: ArrayLike,
26372698
axis: int | tuple[int, ...] | None = None,
26382699
out: None = None, overwrite_input: bool = False, method: str = "linear",
2639-
keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
2700+
keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
26402701
"""Compute the percentile of the data along the specified axis, ignoring NaN values.
26412702
26422703
JAX implementation of :func:`numpy.nanpercentile`.
@@ -2688,7 +2749,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
26882749
"Use 'method=' instead."), stacklevel=2)
26892750
method = interpolation
26902751
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
2691-
method=method, keepdims=keepdims)
2752+
method=method, keepdims=keepdims, weights=weights)
26922753

26932754

26942755
@export

tests/lax_numpy_reducers_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from jax._src import config
3030
from jax._src import dtypes
31+
from jax._src.numpy.reductions import quantile
3132
from jax._src import test_util as jtu
3233
from jax._src.util import NumpyComplexWarning
3334

@@ -763,6 +764,14 @@ def testPercentilePrecision(self):
763764
x = jnp.float64([1, 2, 3, 4, 7, 10])
764765
self.assertEqual(jnp.percentile(x, 50), 3.5)
765766

767+
def test_weighted_quantile_linear(self):
768+
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
769+
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
770+
q = jnp.array([0.5])
771+
expected = np.quantile(a, q, weights=weights)
772+
result = quantile(a, q, weights=weights, method="linear")
773+
np.testing.assert_allclose(result, expected, rtol=1e-6)
774+
766775
@jtu.sample_product(
767776
[dict(a_shape=a_shape, axis=axis)
768777
for a_shape, axis in (

0 commit comments

Comments
 (0)