@@ -2337,7 +2337,8 @@ def cumulative_prod(
23372337@partial (api .jit , static_argnames = ('axis' , 'overwrite_input' , 'interpolation' , 'keepdims' , 'method' ))
23382338def quantile (a : ArrayLike , q : ArrayLike , axis : int | tuple [int , ...] | None = None ,
23392339 out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2340- keepdims : bool = False , * , interpolation : DeprecatedArg | str = DeprecatedArg ()) -> Array :
2340+ keepdims : bool = False , weights : ArrayLike | None = None , * ,
2341+ interpolation : DeprecatedArg | str = DeprecatedArg ()) -> Array :
23412342 """Compute the quantile of the data along the specified axis.
23422343
23432344 JAX implementation of :func:`numpy.quantile`.
@@ -2387,7 +2388,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23872388 ("The interpolation= argument to 'quantile' is deprecated. "
23882389 "Use 'method=' instead." ), stacklevel = 2 )
23892390 method = interpolation
2390- return _quantile (lax_internal .asarray (a ), lax_internal .asarray (q ), axis , method , keepdims , False )
2391+ return _quantile (lax_internal .asarray (a ), lax_internal .asarray (q ), axis , method , keepdims , False , weights )
23912392
23922393# TODO(jakevdp): interpolation argument deprecated 2024-05-16
23932394@export
@@ -2449,7 +2450,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24492450 return _quantile (lax_internal .asarray (a ), lax_internal .asarray (q ), axis , method , keepdims , True )
24502451
24512452def _quantile (a : Array , q : Array , axis : int | tuple [int , ...] | None ,
2452- method : str , keepdims : bool , squash_nans : bool ) -> Array :
2453+ method : str , keepdims : bool , squash_nans : bool , weights : ArrayLike | None = None ) -> Array :
24532454 if method not in ["linear" , "lower" , "higher" , "midpoint" , "nearest" ]:
24542455 raise ValueError ("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'" )
24552456 a , = promote_dtypes_inexact (a )
@@ -2488,6 +2489,66 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24882489 raise ValueError (f"q must be have rank <= 1, got shape { q .shape } " )
24892490
24902491 a_shape = a .shape
2492+ # Handle weights
2493+ if weights is not None :
2494+ a , weights = promote_dtypes_inexact (a , weights )
2495+ if axis is None :
2496+ a = a .ravel ()
2497+ weights = weights .ravel ()
2498+ axis = 0
2499+ else :
2500+ weights = _broadcast_to (weights , a .shape )
2501+ if squash_nans :
2502+ nan_mask = ~ lax_internal ._isnan (a )
2503+ if axis is None :
2504+ a = a [nan_mask ]
2505+ weights = weights [nan_mask ]
2506+ else :
2507+ weights = _where (nan_mask , weights , 0 )
2508+ a_sorted , weights_sorted = lax .sort_key_val (a , weights , dimension = axis )
2509+
2510+ cum_weights = lax .cumsum (weights_sorted , axis = axis )
2511+ total_weight = lax .sum (weights_sorted , axis = axis , keepdims = True )
2512+ if lax_internal ._all (total_weight == 0 ):
2513+ raise ValueError ("Sum of weights must not be zero." )
2514+ cum_weights_norm = cum_weights / total_weight
2515+ quantile_pos = q
2516+ mask = cum_weights_norm >= quantile_pos [..., None ]
2517+ idx = lax .argmin (mask .astype (int ), axis = axis )
2518+ idx_prev = lax .max (idx - 1 , _lax_const (idx , 0 ))
2519+ idx_next = idx
2520+ gather_shape = list (a_sorted .shape )
2521+ gather_shape [axis ] = 1
2522+ dnums = lax .GatherDimensionNumbers (
2523+ offset_dims = tuple (range (len (a_sorted .shape ))),
2524+ collapsed_slice_dims = (axis ,),
2525+ start_index_map = (axis ,))
2526+ prev_value = lax .gather (a_sorted , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = gather_shape )
2527+ next_value = lax .gather (a_sorted , idx_next [..., None ], dimension_numbers = dnums , slice_sizes = gather_shape )
2528+ prev_cumw = lax .gather (cum_weights_norm , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = gather_shape )
2529+ next_cumw = lax .gather (cum_weights_norm , idx_next [..., None ], dimension_numbers = dnums , slice_sizes = gather_shape )
2530+
2531+ if method == "linear" :
2532+ denom = next_cumw - prev_cumw
2533+ denom = lax .select (denom == 0 , _lax_const (denom , 1 ), denom )
2534+ weight = (quantile_pos - prev_cumw ) / denom
2535+ result = prev_value * (1 - weight ) + next_value * weight
2536+ elif method == "lower" :
2537+ result = prev_value
2538+ elif method == "higher" :
2539+ result = next_value
2540+ elif method == "nearest" :
2541+ use_prev = (quantile_pos - prev_cumw ) < (next_cumw - quantile_pos )
2542+ result = lax .select (use_prev , prev_value , next_value )
2543+ elif method == "midpoint" :
2544+ result = (prev_value + next_value ) / 2
2545+ else :
2546+ raise ValueError (f"{ method = !r} not recognized" )
2547+
2548+ if not keepdims :
2549+ result = lax .squeeze (result , axis )
2550+ return lax .convert_element_type (result , a .dtype )
2551+
24912552
24922553 if squash_nans :
24932554 a = _where (lax_internal ._isnan (a ), np .nan , a ) # Ensure nans are positive so they sort to the end.
@@ -2578,7 +2639,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25782639def percentile (a : ArrayLike , q : ArrayLike ,
25792640 axis : int | tuple [int , ...] | None = None ,
25802641 out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2581- keepdims : bool = False , * , interpolation : str | DeprecatedArg = DeprecatedArg ()) -> Array :
2642+ keepdims : bool = False , weights : ArrayLike | None = None , * , interpolation : str | DeprecatedArg = DeprecatedArg ()) -> Array :
25822643 """Compute the percentile of the data along the specified axis.
25832644
25842645 JAX implementation of :func:`numpy.percentile`.
@@ -2627,7 +2688,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
26272688 "Use 'method=' instead." ), stacklevel = 2 )
26282689 method = interpolation
26292690 return quantile (a , q / 100 , axis = axis , out = out , overwrite_input = overwrite_input ,
2630- method = method , keepdims = keepdims )
2691+ method = method , keepdims = keepdims , weights = weights )
26312692
26322693
26332694# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@@ -2636,7 +2697,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
26362697def nanpercentile (a : ArrayLike , q : ArrayLike ,
26372698 axis : int | tuple [int , ...] | None = None ,
26382699 out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2639- keepdims : bool = False , * , interpolation : str | DeprecatedArg = DeprecatedArg ()) -> Array :
2700+ keepdims : bool = False , weights : ArrayLike | None = None , * , interpolation : str | DeprecatedArg = DeprecatedArg ()) -> Array :
26402701 """Compute the percentile of the data along the specified axis, ignoring NaN values.
26412702
26422703 JAX implementation of :func:`numpy.nanpercentile`.
@@ -2688,7 +2749,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
26882749 "Use 'method=' instead." ), stacklevel = 2 )
26892750 method = interpolation
26902751 return nanquantile (a , q , axis = axis , out = out , overwrite_input = overwrite_input ,
2691- method = method , keepdims = keepdims )
2752+ method = method , keepdims = keepdims , weights = weights )
26922753
26932754
26942755@export
0 commit comments