2323
2424import numpy as np
2525
26- from jax ._src . lax import lax
26+ from jax ._src import config
2727from jax ._src import api
2828from jax ._src import core
2929from jax ._src import deprecations
@@ -2483,8 +2483,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24832483 axis = _canonicalize_axis (axis , a .ndim )
24842484
24852485 q , = promote_dtypes_inexact (q )
2486- q = lax_internal .asarray (q )
2487- q_was_scalar = getattr (q , "ndim" , 0 ) == 0
2486+ q_was_scalar = q .ndim == 0
24882487 if q_was_scalar :
24892488 q = lax .expand_dims (q , (0 ,))
24902489 q_shape = q .shape
@@ -2497,40 +2496,37 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24972496 if weights is None :
24982497 a , = promote_dtypes_inexact (a )
24992498 else :
2499+ if method != "inverted_cdf" :
2500+ raise ValueError ("Weighted quantiles are only supported for method='inverted_cdf'" )
2501+ if axis is None :
2502+ raise TypeError ("Axis must be specified when shapes of a and weights differ." )
2503+ axis_tuple = canonicalize_axis_tuple (axis , a .ndim )
2504+
25002505 a , q , weights = promote_dtypes_inexact (a , q , weights )
2501- #weights = lax.convert_element_type(weights, a.dtype)
25022506 a_shape = a .shape
25032507 w_shape = np .shape (weights )
25042508 if np .ndim (weights ) == 0 :
25052509 weights = lax .broadcast_in_dim (weights , a_shape , ())
25062510 w_shape = a_shape
2507- else :
2508- w_shape = np .shape (weights )
25092511 if w_shape != a_shape :
2510- if axis is None :
2511- raise TypeError ("Axis must be specified when shapes of a and weights differ." )
2512- if isinstance (axis , tuple ):
2513- if w_shape != tuple (a_shape [i ] for i in axis ):
2514- raise ValueError ("Shape of weights must match the shape of the axes being reduced." )
2515- weights = lax .broadcast_in_dim (
2516- weights ,
2517- shape = a_shape ,
2518- broadcast_dimensions = axis
2519- )
2520- w_shape = a_shape
2521- else :
2522- if len (w_shape ) != 1 or w_shape [0 ] != a_shape [axis ]:
2523- raise ValueError ("Length of weights not compatible with specified axis." )
2524- weights = lax .expand_dims (weights , (axis ,))
2525- weights = _broadcast_to (weights , a .shape )
2526- w_shape = a_shape
2512+ expected_shape = tuple (a_shape [i ] for i in axis_tuple )
2513+ if w_shape != expected_shape :
2514+ raise ValueError (f"Shape of weights must match the shape of the axes being reduced. "
2515+ f"Expected { expected_shape } , got { w_shape } " )
2516+ weights = lax .broadcast_in_dim (
2517+ weights ,
2518+ shape = a_shape ,
2519+ broadcast_dimensions = axis_tuple
2520+ )
25272521
25282522 if squash_nans :
25292523 nan_mask = ~ lax_internal ._isnan (a )
25302524 weights = _where (nan_mask , weights , 0 )
25312525 else :
2532- with jax .debug_nans (False ):
2533- a = _where (any (lax_internal ._isnan (a ), axis = axis , keepdims = True ), np .nan , a )
2526+ with config .debug_nans (False ):
2527+ has_nan_data = any (lax_internal ._isnan (a ), axis = axis , keepdims = True )
2528+ has_nan_weights = any (lax_internal ._isnan (weights ), axis = axis , keepdims = True )
2529+ a = _where (has_nan_data | has_nan_weights , np .nan , a )
25342530
25352531 total_weight = sum (weights , axis = axis , keepdims = True )
25362532 a_sorted , weights_sorted = lax_internal .sort_key_val (a , weights , dimension = axis )
@@ -2539,49 +2535,23 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25392535
25402536 def _weighted_quantile (qi ):
25412537 qi = lax .convert_element_type (qi , cum_weights_norm .dtype )
2542- index_dtype = dtypes .canonicalize_dtype ( dtypes . int_ )
2543- idx = sum (lax .lt (cum_weights_norm , qi ), axis = axis , dtype = index_dtype , keepdims = keepdims )
2538+ index_dtype = dtypes .default_int_dtype ( )
2539+ idx = _reduce_sum (lax .lt (cum_weights_norm , qi ), axis = axis , dtype = index_dtype , keepdims = keepdims )
25442540 idx = lax .clamp (_lax_const (idx , 0 ), idx , _lax_const (idx , a_sorted .shape [axis ] - 1 ))
2545- idx_prev = lax .clamp (idx - 1 , _lax_const (idx , 0 ), _lax_const (idx , a_sorted .shape [axis ] - 1 ))
2546-
2547- slice_sizes = list (a_shape )
2548- slice_sizes [axis ] = 1
2549- offset_start = q_ndim
2550- total_offset_dims = len (a_shape ) + q_ndim if keepdims else len (a_shape ) + q_ndim - 1
2551- dnums = lax .GatherDimensionNumbers (
2552- offset_dims = tuple (range (offset_start , total_offset_dims )),
2553- collapsed_slice_dims = (axis ,),
2554- start_index_map = (axis ,)
2555- )
2556- val = lax .gather (a_sorted , idx [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2557- val_prev = lax .gather (a_sorted , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2558- cw_prev = lax .gather (cum_weights_norm , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2559- cw_next = lax .gather (cum_weights_norm , idx [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2560- if method == "linear" :
2561- denom = cw_next - cw_prev
2562- denom = _where (denom == 0 , 1 , denom )
2563- weight = (qi - cw_prev ) / denom
2564- out = val_prev * (1 - weight ) + val * weight
2565- elif method == "lower" :
2566- out = val_prev
2567- elif method == "higher" :
2568- out = val
2569- elif method == "nearest" :
2570- out = _where (lax .abs (qi - cw_prev ) < lax .abs (qi - cw_next ), val_prev , val )
2571- elif method == "midpoint" :
2572- out = (val_prev + val ) / 2
2573- elif method == "inverted_cdf" :
2574- out = val
2575- else :
2576- raise ValueError (f"{ method = !r} not recognized" )
2577- return out
2541+
2542+ idx_expanded = lax .expand_dims (idx , (axis ,)) if not keepdims else idx
2543+ return jnp .take_along_axis (a_sorted , idx_expanded , axis = axis ).squeeze (axis = axis )
25782544 result = api .vmap (_weighted_quantile )(q )
2579- keepdim_out = list (keepdim )
2545+ shape_after = list (a_shape )
2546+ if keepdims :
2547+ shape_after [axis ] = 1
2548+ else :
2549+ del shape_after [axis ]
25802550 if not q_was_scalar :
2581- keepdim_out = [ q_shape [0 ], * keepdim_out ]
2582- result = result . reshape ( tuple ( keepdim_out ))
2583- elif q_was_scalar and result .ndim > 0 and result .shape [0 ] == 1 :
2584- result = result .squeeze ( axis = 0 )
2551+ result = result . reshape (( q_shape [0 ], * shape_after ))
2552+ else :
2553+ if result .ndim > 0 and result .shape [0 ] == 1 :
2554+ result = result .reshape ( tuple ( shape_after ) )
25852555 return result
25862556
25872557 if squash_nans :
0 commit comments