2323
2424import numpy as np
2525
26- import jax
27- from jax import lax
28- import jax ._src .numpy as jnp
26+ from jax ._src .lax import lax
27+ from jax ._src .lax import parallel as lax_parallel
28+ from jax ._src .lax import slicing
29+ from jax ._src .lax .control_flow import loops
2930from jax ._src import api
31+ from jax ._src import config
3032from jax ._src import core
3133from jax ._src import deprecations
3234from jax ._src import dtypes
@@ -63,7 +65,7 @@ def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
6365 perm .insert (destination , source )
6466 return lax .transpose (a , perm )
6567
66- def _upcast_f16 (dtype : DTypeLike ) -> DType :
68+ def _upcast_f16 (dtype : DTypeLike ) -> DTypeLike :
6769 if np .dtype (dtype ) in [np .float16 , dtypes .bfloat16 ]:
6870 return np .dtype ('float32' )
6971 return np .dtype (dtype )
@@ -234,7 +236,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
234236 return _reduction (a , "sum" , lax .add , 0 , preproc = _cast_to_numeric ,
235237 bool_op = lax .bitwise_or , upcast_f16_for_computation = (dtype is None ),
236238 axis = axis , dtype = dtype , out = out , keepdims = keepdims ,
237- initial = initial , where_ = where , parallel_reduce = lax .psum ,
239+ initial = initial , where_ = where , parallel_reduce = lax_parallel .psum ,
238240 promote_integers = promote_integers )
239241
240242
@@ -407,7 +409,7 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
407409 where : ArrayLike | None = None ) -> Array :
408410 return _reduction (a , "max" , lax .max , - np .inf , has_identity = False ,
409411 axis = axis , out = out , keepdims = keepdims ,
410- initial = initial , where_ = where , parallel_reduce = lax .pmax )
412+ initial = initial , where_ = where , parallel_reduce = lax_parallel .pmax )
411413
412414
413415@export
@@ -490,7 +492,7 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
490492 where : ArrayLike | None = None ) -> Array :
491493 return _reduction (a , "min" , lax .min , np .inf , has_identity = False ,
492494 axis = axis , out = out , keepdims = keepdims ,
493- initial = initial , where_ = where , parallel_reduce = lax .pmin )
495+ initial = initial , where_ = where , parallel_reduce = lax_parallel .pmin )
494496
495497
496498@export
@@ -797,7 +799,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]):
797799 size = 1
798800 a_shape = np .shape (a )
799801 for a in axis_seq :
800- size *= maybe_named_axis (a , lambda i : a_shape [i ], lambda name : lax .psum (1 , name ))
802+ size *= maybe_named_axis (a , lambda i : a_shape [i ], lambda name : lax_parallel .psum (1 , name ))
801803 return size
802804
803805
@@ -1140,12 +1142,12 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
11401142 normalizer = lax .sub (normalizer , lax .convert_element_type (correction , computation_dtype ))
11411143 result = sum (centered , axis , dtype = computation_dtype , keepdims = keepdims , where = where )
11421144 result = lax .div (result , normalizer ).astype (dtype )
1143- with jax .debug_nans (False ):
1145+ with config .debug_nans (False ):
11441146 result = _where (normalizer > 0 , result , np .nan )
11451147 return result
11461148
11471149
1148- def _var_promote_types (a_dtype : DTypeLike , dtype : DTypeLike | None ) -> tuple [DType , DType ]:
1150+ def _var_promote_types (a_dtype : DTypeLike , dtype : DTypeLike | None ) -> tuple [Any , Any ]:
11491151 if dtype :
11501152 if (not dtypes .issubdtype (dtype , np .complexfloating ) and
11511153 dtypes .issubdtype (a_dtype , np .complexfloating )):
@@ -2010,8 +2012,8 @@ def _cumulative_reduction(
20102012 if fill_nan :
20112013 a = _where (lax_internal ._isnan (a ), _lax_const (a , fill_value ), a )
20122014
2013- a_type : DType = dtypes .dtype (a )
2014- result_type : DTypeLike = dtypes .dtype (dtype or a )
2015+ a_type = dtypes .dtype (a )
2016+ result_type = dtypes .dtype (dtype or a )
20152017 if dtype is None and promote_integers or dtypes .issubdtype (result_type , np .bool_ ):
20162018 result_type = _promote_integer_dtype (result_type )
20172019 result_type = dtypes .canonicalize_dtype (result_type )
@@ -2062,7 +2064,7 @@ def cumsum(a: ArrayLike, axis: int | None = None,
20622064 Array([[ 1, 3, 6],
20632065 [ 4, 9, 15]], dtype=int32)
20642066 """
2065- return _cumulative_reduction ("cumsum" , lax .cumsum , a , axis , dtype , out )
2067+ return _cumulative_reduction ("cumsum" , loops .cumsum , a , axis , dtype , out )
20662068
20672069
20682070@export
@@ -2098,7 +2100,7 @@ def cumprod(a: ArrayLike, axis: int | None = None,
20982100 Array([[ 1, 2, 6],
20992101 [ 4, 20, 120]], dtype=int32)
21002102 """
2101- return _cumulative_reduction ("cumprod" , lax .cumprod , a , axis , dtype , out )
2103+ return _cumulative_reduction ("cumprod" , loops .cumprod , a , axis , dtype , out )
21022104
21032105
21042106@export
@@ -2147,7 +2149,7 @@ def nancumsum(a: ArrayLike, axis: int | None = None,
21472149 Array([[ 1., 3., 3.],
21482150 [ 4., 4., 10.]], dtype=float32)
21492151 """
2150- return _cumulative_reduction ("nancumsum" , lax .cumsum , a , axis , dtype , out ,
2152+ return _cumulative_reduction ("nancumsum" , loops .cumsum , a , axis , dtype , out ,
21512153 fill_nan = True , fill_value = 0 )
21522154
21532155
@@ -2196,15 +2198,15 @@ def nancumprod(a: ArrayLike, axis: int | None = None,
21962198 Array([[ 1., 2., 2.],
21972199 [ 4., 4., 24.]], dtype=float32)
21982200 """
2199- return _cumulative_reduction ("nancumprod" , lax .cumprod , a , axis , dtype , out ,
2201+ return _cumulative_reduction ("nancumprod" , loops .cumprod , a , axis , dtype , out ,
22002202 fill_nan = True , fill_value = 1 )
22012203
22022204
22032205@partial (api .jit , static_argnames = ('axis' , 'dtype' ))
22042206def _cumsum_with_promotion (a : ArrayLike , axis : int | None = None ,
22052207 dtype : DTypeLike | None = None , out : None = None ) -> Array :
22062208 """Utility function to compute cumsum with integer promotion."""
2207- return _cumulative_reduction ("_cumsum_with_promotion" , lax .cumsum ,
2209+ return _cumulative_reduction ("_cumsum_with_promotion" , loops .cumsum ,
22082210 a , axis , dtype , out , promote_integers = True )
22092211
22102212
@@ -2322,7 +2324,7 @@ def cumulative_prod(
23222324
23232325 axis = _canonicalize_axis (axis , x .ndim )
23242326 dtypes .check_user_dtype_supported (dtype )
2325- out = _cumulative_reduction ("cumulative_prod" , lax .cumprod , x , axis , dtype )
2327+ out = _cumulative_reduction ("cumulative_prod" , loops .cumprod , x , axis , dtype )
23262328 if include_initial :
23272329 zeros_shape = list (x .shape )
23282330 zeros_shape [axis ] = 1
@@ -2486,21 +2488,24 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24862488
24872489 q , = promote_dtypes_inexact (q )
24882490 q = lax_internal .asarray (q )
2489- if getattr (q , "ndim" , 0 ) == 0 :
2491+ q_was_scalar = (getattr (q , "ndim" , 0 ) == 0 )
2492+ if q_was_scalar :
24902493 q = lax .expand_dims (q , (0 ,))
24912494 q_shape = q .shape
24922495 q_ndim = q .ndim
24932496 if q_ndim > 1 :
24942497 raise ValueError (f"q must be have rank <= 1, got shape { q .shape } " )
2495-
24962498 a_shape = a .shape
24972499 # Handle weights
24982500 if weights is None :
24992501 a , = promote_dtypes_inexact (a )
25002502 else :
2501- a , weights = promote_dtypes_inexact (a , weights )
2502- weights = lax .convert_element_type (weights , a .dtype )
2503- a_shape = a .shape
2503+ common_dtype = np .result_type (a , q , weights , np .float32 )
2504+ a = a .astype (common_dtype )
2505+ q = q .astype (common_dtype )
2506+ weights = weights .astype (common_dtype )
2507+ a ,q , weights = promote_dtypes_inexact (a , q , weights )
2508+ #weights = lax.convert_element_type(weights, a.dtype)
25042509 w_shape = np .shape (weights )
25052510 if np .ndim (weights ) == 0 :
25062511 weights = lax .broadcast_in_dim (weights , a_shape , ())
@@ -2511,8 +2516,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25112516 if axis is None :
25122517 raise TypeError ("Axis must be specified when shapes of a and weights differ." )
25132518 if isinstance (axis , tuple ):
2514- if w_shape != tuple (a_shape [i ] for i in axis ):
2515- raise ValueError ("Shape of weights must match the shape of the axes being reduced." )
2519+ expected_shape = tuple (a_shape [i ] for i in axis )
2520+ if w_shape != expected_shape :
2521+ raise ValueError ("Shape of weights must match the shape of the axes being reduced." )
25162522 weights = lax .broadcast_in_dim (
25172523 weights ,
25182524 shape = a_shape ,
@@ -2521,18 +2527,23 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25212527 w_shape = a_shape
25222528 else :
25232529 if len (w_shape ) != 1 or w_shape [0 ] != a_shape [axis ]:
2524- raise ValueError ("Length of weights not compatible with specified axis." )
2530+ raise ValueError ("Length of weights not compatible with specified axis." )
25252531 weights = lax .expand_dims (weights , (axis ,))
25262532 weights = _broadcast_to (weights , a .shape )
25272533 w_shape = a_shape
2528-
2534+
25292535 if squash_nans :
25302536 nan_mask = ~ lax_internal ._isnan (a )
25312537 weights = _where (nan_mask , weights , 0 )
25322538 else :
2533- with jax .debug_nans (False ):
2539+ with config .debug_nans (False ):
25342540 a = _where (any (lax_internal ._isnan (a ), axis = axis , keepdims = True ), np .nan , a )
25352541
2542+ if all (weights == 0 ):
2543+ raise ValueError ("Sum of weights must not be zero" )
2544+ if any (weights < 0 ):
2545+ raise ValueError ("Weights must be non-negative" )
2546+
25362547 total_weight = sum (weights , axis = axis , keepdims = True )
25372548 a_sorted , weights_sorted = lax .sort_key_val (a , weights , dimension = axis )
25382549 cum_weights = cumsum (weights_sorted , axis = axis )
@@ -2549,15 +2560,15 @@ def _weighted_quantile(qi):
25492560 slice_sizes [axis ] = 1
25502561 offset_start = q_ndim
25512562 total_offset_dims = len (a_shape ) + q_ndim if keepdims else len (a_shape ) + q_ndim - 1
2552- dnums = lax .GatherDimensionNumbers (
2563+ dnums = slicing .GatherDimensionNumbers (
25532564 offset_dims = tuple (range (offset_start , total_offset_dims )),
25542565 collapsed_slice_dims = (axis ,),
25552566 start_index_map = (axis ,)
25562567 )
2557- val = lax .gather (a_sorted , idx [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2558- val_prev = lax .gather (a_sorted , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2559- cw_prev = lax .gather (cum_weights_norm , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2560- cw_next = lax .gather (cum_weights_norm , idx [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2568+ val = slicing .gather (a_sorted , idx [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2569+ val_prev = slicing .gather (a_sorted , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2570+ cw_prev = slicing .gather (cum_weights_norm , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2571+ cw_next = slicing .gather (cum_weights_norm , idx [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
25612572 if method == "linear" :
25622573 denom = cw_next - cw_prev
25632574 denom = _where (denom == 0 , 1 , denom )
@@ -2577,16 +2588,15 @@ def _weighted_quantile(qi):
25772588 raise ValueError (f"{ method = !r} not recognized" )
25782589 return out
25792590
2580- result = jax .vmap (_weighted_quantile )(q )
2591+ result = api .vmap (_weighted_quantile )(q )
25812592 if keepdims and keepdim :
2582- if q_ndim > 0 :
2583- keepdim = [q_shape [0 ], * keepdim ]
2584- result = result .reshape (tuple (keepdim ))
2585- else :
2586- if q_ndim == 0 or (q_ndim == 1 and q_shape [0 ] == 1 ):
2587- if result .ndim > 0 and result .shape [0 ] == 1 :
2588- result = lax .squeeze (result , (0 ,))
2589- return lax .convert_element_type (result , a .dtype )
2593+ keepdim_out = list (keepdim )
2594+ if not q_was_scalar :
2595+ keepdim_out = [q_shape [0 ], * keepdim_out ]
2596+ result = result .reshape (tuple (keepdim_out ))
2597+ elif q_was_scalar and result .ndim > 0 and result .shape [0 ] == 1 :
2598+ result = result .squeeze (axis = 0 )
2599+ return result
25902600
25912601 if squash_nans :
25922602 a = _where (lax_internal ._isnan (a ), np .nan , a ) # Ensure nans are positive so they sort to the end.
@@ -2617,7 +2627,7 @@ def _weighted_quantile(qi):
26172627 index [axis ] = high
26182628 high_value = a [tuple (index )]
26192629 else :
2620- with jax .debug_nans (False ):
2630+ with config .debug_nans (False ):
26212631 a = _where (any (lax_internal ._isnan (a ), axis = axis , keepdims = True ), np .nan , a )
26222632 a = lax .sort (a , dimension = axis )
26232633 n = lax .convert_element_type (a_shape [axis ], lax_internal ._dtype (q ))
@@ -2634,15 +2644,15 @@ def _weighted_quantile(qi):
26342644
26352645 slice_sizes = list (a_shape )
26362646 slice_sizes [axis ] = 1
2637- dnums = lax .GatherDimensionNumbers (
2647+ dnums = slicing .GatherDimensionNumbers (
26382648 offset_dims = tuple (range (
26392649 q_ndim ,
26402650 len (a_shape ) + q_ndim if keepdims else len (a_shape ) + q_ndim - 1 )),
26412651 collapsed_slice_dims = () if keepdims else (axis ,),
26422652 start_index_map = (axis ,))
2643- low_value = lax .gather (a , low [..., None ], dimension_numbers = dnums ,
2653+ low_value = slicing .gather (a , low [..., None ], dimension_numbers = dnums ,
26442654 slice_sizes = slice_sizes )
2645- high_value = lax .gather (a , high [..., None ], dimension_numbers = dnums ,
2655+ high_value = slicing .gather (a , high [..., None ], dimension_numbers = dnums ,
26462656 slice_sizes = slice_sizes )
26472657 if q_ndim == 1 :
26482658 low_weight = lax .broadcast_in_dim (low_weight , low_value .shape ,
@@ -2667,14 +2677,13 @@ def _weighted_quantile(qi):
26672677 else :
26682678 raise ValueError (f"{ method = !r} not recognized" )
26692679 if keepdims and keepdim :
2670- if q_ndim > 0 :
2671- keepdim = [np .shape (q )[0 ], * keepdim ]
2672- result = result .reshape (keepdim )
2673- else :
2674- if q_ndim == 0 or (q_ndim == 1 and q_shape [0 ] == 1 ):
2675- if result .ndim > 0 and result .shape [0 ] == 1 :
2676- result = lax .squeeze (result , (0 ,))
2677- return lax .convert_element_type (result , a .dtype )
2680+ keepdim_out = list (keepdim )
2681+ if not q_was_scalar :
2682+ keepdim_out = [q_shape [0 ], * keepdim_out ]
2683+ result = result .reshape (tuple (keepdim_out ))
2684+ elif q_was_scalar and result .ndim > 0 and result .shape [0 ] == 1 :
2685+ result = result .squeeze (axis = 0 )
2686+ return result
26782687
26792688
26802689# TODO(jakevdp): interpolation argument deprecated 2024-05-16
0 commit comments