-
Notifications
You must be signed in to change notification settings - Fork 16
ENH: add quantile function with weights support
#494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 23 commits
a6f6c93
d30bcbf
dc236da
f92fc4b
06e370a
dc7a1e5
98fe39f
034c064
05ffb7b
89d8410
19fa6ea
fa789fc
1d8fef7
26804fe
3611708
7160bae
0b2cb9b
3226659
c395b84
07f7007
e319529
8ab7d62
1b48267
c71351f
ce55335
0353767
7c18a82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,8 @@ | |
| from types import ModuleType | ||
| from typing import Literal | ||
|
|
||
| from ._lib import _funcs | ||
| from ._lib import _funcs, _quantile | ||
| from ._lib._backends import NUMPY_VERSION | ||
| from ._lib._utils._compat import ( | ||
| array_namespace, | ||
| is_cupy_namespace, | ||
|
|
@@ -768,7 +769,7 @@ def argpartition( | |
| Axis along which to partition. The default is ``-1`` (the last axis). | ||
| If ``None``, the flattened array is used. | ||
| xp : array_namespace, optional | ||
| The standard-compatible namespace for `x`. Default: infer. | ||
| The standard-compatible namespace for `a`. Default: infer. | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -895,3 +896,273 @@ def isin( | |
| return xp.isin(a, b, assume_unique=assume_unique, invert=invert) | ||
|
|
||
| return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp) | ||
|
|
||
|
|
||
| def quantile( | ||
| a: Array, | ||
| q: float | Array, | ||
| /, | ||
| axis: int | None = None, | ||
| method: str = "linear", | ||
| keepdims: bool = False, | ||
| nan_policy: str = "propagate", | ||
| *, | ||
| weights: Array | None = None, | ||
| xp: ModuleType | None = None, | ||
| ) -> Array: | ||
| """ | ||
| Compute the q-th quantile of the data along the specified axis. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| a : array_like of real numbers | ||
| Input array or object that can be converted to an array. | ||
| q : array_like of float | ||
| Probability or sequence of probabilities of the quantiles to compute. | ||
| Values must be between 0 and 1 inclusive. | ||
| axis : {int, tuple of int, None}, optional | ||
| Axis or axes along which the quantiles are computed. The default is | ||
| to compute the quantile(s) along a flattened version of the array. | ||
| method : str, optional | ||
| This parameter specifies the method to use for estimating the | ||
| quantile. There are many different methods. | ||
| The recommended options, numbered as they appear in [1]_, are: | ||
|
|
||
| 1. 'inverted_cdf' | ||
| 2. 'averaged_inverted_cdf' | ||
| 3. 'closest_observation' | ||
| 4. 'interpolated_inverted_cdf' | ||
| 5. 'hazen' | ||
| 6. 'weibull' | ||
| 7. 'linear' (default) | ||
| 8. 'median_unbiased' | ||
| 9. 'normal_unbiased' | ||
|
|
||
| The first three methods are discontinuous. | ||
| Only 'linear', 'inverted_cdf' and 'averaged_inverted_cdf' are implemented. | ||
|
|
||
| keepdims : bool, optional | ||
| If this is set to True, the axes which are reduced are left in | ||
| the result as dimensions with size one. With this option, the | ||
| result will broadcast correctly against the original array `a`. | ||
|
|
||
| nan_policy : str, optional | ||
| 'propagate' (default) or 'omit'. | ||
| 'omit' is support only when `weights` are provided. | ||
|
|
||
| weights : array_like, optional | ||
| An array of weights associated with the values in `a`. Each value in | ||
| `a` contributes to the quantile according to its associated weight. | ||
| The weights array can either be 1-D (in which case its length must be | ||
| the size of `a` along the given axis) or of the same shape as `a`. | ||
| If `weights=None`, then all data in `a` are assumed to have a | ||
| weight equal to one. | ||
| Only `method="inverted_cdf"` or `method="averaged_inverted_cdf"` | ||
| support weights. See the notes for more details. | ||
|
|
||
| xp : array_namespace, optional | ||
| The standard-compatible namespace for `a` and `q`. Default: infer. | ||
|
|
||
| Returns | ||
| ------- | ||
| scalar or ndarray | ||
| If `q` is a single probability and `axis=None`, then the result | ||
| is a scalar. If multiple probability levels are given, first axis | ||
| of the result corresponds to the quantiles. The other axes are | ||
| the axes that remain after the reduction of `a`. If the input | ||
| contains integers or floats smaller than ``float64``, the output | ||
| data-type is ``float64``. Otherwise, the output data-type is the | ||
| same as that of the input. If `out` is specified, that array is | ||
| returned instead. | ||
|
|
||
| Notes | ||
| ----- | ||
| Given a sample `a` from an underlying distribution, `quantile` provides a | ||
| nonparametric estimate of the inverse cumulative distribution function. | ||
|
|
||
| By default, this is done by interpolating between adjacent elements in | ||
| ``y``, a sorted copy of `a`:: | ||
|
|
||
| (1-g)*y[j] + g*y[j+1] | ||
|
|
||
| where the index ``j`` and coefficient ``g`` are the integral and | ||
| fractional components of ``q * (n-1)``, and ``n`` is the number of | ||
| elements in the sample. | ||
|
|
||
| This is a special case of Equation 1 of H&F [1]_. More generally, | ||
|
|
||
| - ``j = (q*n + m - 1) // 1``, and | ||
| - ``g = (q*n + m - 1) % 1``, | ||
|
|
||
| where ``m`` may be defined according to several different conventions. | ||
| The preferred convention may be selected using the ``method`` parameter: | ||
|
|
||
| =============================== =============== =============== | ||
| ``method`` number in H&F ``m`` | ||
| =============================== =============== =============== | ||
| ``interpolated_inverted_cdf`` 4 ``0`` | ||
| ``hazen`` 5 ``1/2`` | ||
| ``weibull`` 6 ``q`` | ||
| ``linear`` (default) 7 ``1 - q`` | ||
| ``median_unbiased`` 8 ``q/3 + 1/3`` | ||
| ``normal_unbiased`` 9 ``q/4 + 3/8`` | ||
| =============================== =============== =============== | ||
|
|
||
| Note that indices ``j`` and ``j + 1`` are clipped to the range ``0`` to | ||
| ``n - 1`` when the results of the formula would be outside the allowed | ||
| range of non-negative indices. The ``- 1`` in the formulas for ``j`` and | ||
| ``g`` accounts for Python's 0-based indexing. | ||
|
|
||
| The table above includes only the estimators from H&F that are continuous | ||
| functions of probability `q` (estimators 4-9). NumPy also provides the | ||
| three discontinuous estimators from H&F (estimators 1-3), where ``j`` is | ||
| defined as above, ``m`` is defined as follows, and ``g`` is a function | ||
| of the real-valued ``index = q*n + m - 1`` and ``j``. | ||
|
|
||
| 1. ``inverted_cdf``: ``m = 0`` and ``g = int(index - j > 0)`` | ||
| 2. ``averaged_inverted_cdf``: ``m = 0`` and | ||
| ``g = (1 + int(index - j > 0)) / 2`` | ||
| 3. ``closest_observation``: ``m = -1/2`` and | ||
| ``g = 1 - int((index == j) & (j%2 == 1))`` | ||
|
|
||
| **Weighted quantiles:** | ||
| More formally, the quantile at probability level :math:`q` of a cumulative | ||
| distribution function :math:`F(y)=P(Y \\leq y)` with probability measure | ||
| :math:`P` is defined as any number :math:`x` that fulfills the | ||
| *coverage conditions* | ||
|
|
||
| .. math:: P(Y < x) \\leq q \\quad\\text{and}\\quad P(Y \\leq x) \\geq q | ||
|
|
||
| with random variable :math:`Y\\sim P`. | ||
| Sample quantiles, the result of `quantile`, provide nonparametric | ||
| estimation of the underlying population counterparts, represented by the | ||
| unknown :math:`F`, given a data vector `a` of length ``n``. | ||
|
|
||
| Some of the estimators above arise when one considers :math:`F` as the | ||
| empirical distribution function of the data, i.e. | ||
| :math:`F(y) = \\frac{1}{n} \\sum_i 1_{a_i \\leq y}`. | ||
| Then, different methods correspond to different choices of :math:`x` that | ||
| fulfill the above coverage conditions. Methods that follow this approach | ||
| are ``inverted_cdf`` and ``averaged_inverted_cdf``. | ||
|
|
||
| For weighted quantiles, the coverage conditions still hold. The | ||
| empirical cumulative distribution is simply replaced by its weighted | ||
| version, i.e. | ||
| :math:`P(Y \\leq t) = \\frac{1}{\\sum_i w_i} \\sum_i w_i 1_{x_i \\leq t}`. | ||
|
|
||
| References | ||
| ---------- | ||
| .. [1] R. J. Hyndman and Y. Fan, | ||
| "Sample quantiles in statistical packages," | ||
| The American Statistician, 50(4), pp. 361-365, 1996 | ||
| """ | ||
| if xp is None: | ||
| xp = array_namespace(a) | ||
| if is_pydata_sparse_namespace(xp): | ||
| msg = "Sparse backend not supported" | ||
| raise ValueError(msg) | ||
|
|
||
| methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"} | ||
| if method not in methods: | ||
| msg = f"`method` must be one of {methods}" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sort methods to get a deterministic output? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's deterministic already. But do you mean declaring methods in the sorted order? Like this: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought this was not deterministic - but maybe this was in older python versions or between different OS, or maybe I just misremembered? In any case - sorry about the noise. |
||
| raise ValueError(msg) | ||
| nan_policies = {"propagate", "omit"} | ||
| if nan_policy not in nan_policies: | ||
| msg = f"`nan_policy` must be one of {nan_policies}" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| raise ValueError(msg) | ||
|
|
||
| a = xp.asarray(a) | ||
| if not xp.isdtype(a.dtype, ("integral", "real floating")): | ||
| msg = "`a` must have real dtype." | ||
| raise ValueError(msg) | ||
| if not xp.isdtype(xp.asarray(q).dtype, "real floating"): | ||
| msg = "`q` must have real floating dtype." | ||
| raise ValueError(msg) | ||
| weights = None if weights is None else xp.asarray(weights) | ||
|
|
||
| ndim = a.ndim | ||
| if ndim < 1: | ||
| msg = "`a` must be at least 1-dimensional." | ||
| raise TypeError(msg) | ||
| if axis is not None and ((axis >= ndim) or (axis < -ndim)): | ||
| msg = "`axis` is not compatible with the dimension of `a`." | ||
| raise ValueError(msg) | ||
| if weights is None: | ||
| if nan_policy != "propagate": | ||
| msg = "When `weights` aren't provided, `nan_policy` must be 'propagate'." | ||
| raise ValueError(msg) | ||
| else: | ||
| if method not in {"inverted_cdf", "averaged_inverted_cdf"}: | ||
| msg = f"`method` '{method}' not supported with weights." | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| raise ValueError(msg) | ||
| if not xp.isdtype(weights.dtype, ("integral", "real floating")): | ||
| msg = "`weights` must have real dtype." | ||
| raise ValueError(msg) | ||
| if ndim > 2: | ||
| msg = "When weights are provided, dimension of `a` must be 1 or 2." | ||
| raise ValueError(msg) | ||
| if a.shape != weights.shape: | ||
| if axis is None: | ||
| msg = "Axis must be specified when shapes of `a` and ̀ weights` differ." | ||
| raise TypeError(msg) | ||
| if weights.shape != eager_shape(a, axis): | ||
| msg = ( | ||
| "Shape of weights must be consistent with shape" | ||
| " of a along specified axis." | ||
| ) | ||
| raise ValueError(msg) | ||
| if axis is None and ndim == 2: | ||
| msg = "Axis must be specified when `a` and ̀ weights` are 2d." | ||
| raise ValueError(msg) | ||
|
|
||
| # Align result dtype with what numpy does: | ||
| dtype = xp.result_type( | ||
| xp.float64 if xp.isdtype(a.dtype, "integral") else a, | ||
| xp.asarray(q), | ||
| xp.float64, # at least float64 | ||
| ) | ||
| device = get_device(a) | ||
| a = xp.asarray(a, dtype=dtype, device=device) | ||
| q_arr = xp.asarray(q, dtype=dtype, device=device) | ||
| # TODO: cast weights here? Assert weights are on the same device as `a`? | ||
|
|
||
| if xp.any((q_arr > 1) | (q_arr < 0) | xp.isnan(q_arr)): | ||
| msg = "`q` values must be in the range [0, 1]" | ||
| raise ValueError(msg) | ||
|
|
||
| # Delegate when possible. | ||
| # Note: No delegation for dask: I couldn't make it work. | ||
| basic_case = method == "linear" and weights is None | ||
|
|
||
| np_2 = NUMPY_VERSION >= (2, 0) | ||
| np_handles_weights = np_2 and nan_policy == "propagate" and method == "inverted_cdf" | ||
| if weights is None: | ||
| if is_numpy_namespace(xp) and (basic_case or np_2): | ||
| quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile | ||
| return quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims) | ||
| elif is_numpy_namespace(xp) and np_handles_weights: | ||
| # TODO: call nanquantile for nan_policy == "omit" once | ||
| # https://github.com/numpy/numpy/issues/29709 is fixed | ||
| return xp.quantile( | ||
| a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights | ||
| ) | ||
|
|
||
| jax_or_cupy = is_jax_namespace(xp) or is_cupy_namespace(xp) | ||
| if jax_or_cupy and basic_case and nan_policy == "propagate": | ||
| return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims) | ||
| if is_torch_namespace(xp) and basic_case: | ||
| quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile | ||
| return quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims) | ||
|
|
||
| # Otherwise call our implementation (will sort data) | ||
| return _quantile.quantile( | ||
| # XXX: I'm not sure we want to support dask, it seems uterly slow... | ||
| a, | ||
| q_arr, | ||
| axis=axis, | ||
| method=method, | ||
| keepdims=keepdims, | ||
| nan_policy=nan_policy, | ||
| weights=weights, | ||
| xp=xp, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.