Skip to content

Conversation

@Aniketsy
Copy link
Contributor

#32647

This PR adds support for weighted quantile and percentile computations to jax.numpy.quantile, jax.numpy.nanquantile, jax.numpy.percentile, and jax.numpy.nanpercentile .

Please let me know if my approach or fix needs any improvements . I’m open to feedback and happy to make changes based on suggestions.
Thankyou !

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Aniketsy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the jax.numpy module by introducing support for weighted quantile and percentile calculations. By integrating an optional weights argument into the existing quantile, nanquantile, percentile, and nanpercentile functions, users can now perform more sophisticated statistical analyses where individual data points contribute differently to the overall distribution. The change involves modifying function signatures and implementing the core weighted logic within the internal _quantile function, ensuring accurate results across various interpolation methods and proper handling of NaN values.

Highlights

  • Weighted Quantile/Percentile Support: The jax.numpy.quantile, jax.numpy.nanquantile, jax.numpy.percentile, and jax.numpy.nanpercentile functions now accept an optional weights parameter, enabling weighted computations.
  • Core Weighted Logic Implementation: The internal _quantile function has been extended to handle weighted inputs, including sorting data and weights, calculating cumulative normalized weights, and determining quantile values based on various interpolation methods.
  • NaN Handling with Weights: The weighted quantile logic correctly handles NaN values by masking them and adjusting weights accordingly when squash_nans is enabled.
  • New Test Case: A new unit test, test_weighted_quantile_linear, has been added to tests/lax_numpy_reducers_test.py to validate the correctness of the weighted quantile calculation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for weighted quantiles and percentiles, which is a valuable addition. The overall approach is sound, but there are a couple of significant issues in the implementation within _quantile. Firstly, there's a structural issue where the new logic for weighted quantiles is added after the axis variable has been processed, leading to dead code and incorrect behavior for NaN handling when axis is None. Secondly, the implementation for the weighted case appears to be incorrect when q (the quantiles) is a vector. I've provided detailed comments and suggestions for fixes. I also recommend expanding the test suite to cover these edge cases to ensure the feature is robust.

@Aniketsy
Copy link
Contributor Author

@jakevdp please review these changes, when you get a chance. I'm not sure if i correctly placed -handle weights also all tests are not covered yet, should i add. Thanks!

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - thanks for working on this – I got a bit of the way through and then stopped reviewing becuse this implementation has fundamental flaws. I'd suggest starting over, and writing test cases first (including JIT-compilation) and make sure they pass before you request review.

# Handle weights
if weights is not None:
a, weights = promote_dtypes_inexact(a, weights)
if axis is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Axis will never be None here, because of the if axis is None check above.

If we need to change the shape of weights based on the value of axis, that needs to be done before this block.

@Aniketsy
Copy link
Contributor Author

Thanks for the earlier suggestions!
I’ve added additional tests and re-implemented the changes , all tests are now passing locally.
However, I’m unsure if I’ve fully addressed, as this part about -- a has already been promoted; we should promote only once. That requires making the previous promote_dtypes_inexact conditional on whether weights is supplied.
When I tried to apply that change, it led to some test failures. so I’ll look more into it.
Please feel free to guide me if I’m missing something or going in the wrong direction.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi - it looks like you resolved some merge conflicts in a way that led to your branch being invalid. Please take a look!

I'll review further once these issues are addressed.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're getting there! A number of comments below.

@Aniketsy
Copy link
Contributor Author

Thank you for the detailed feedback and for pointing out the corrections. I’ve updated the changes as per your suggestions. Please let me know if any further improvements are needed.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments!

Also, regarding the tests, you should be aware of this: numpy/numpy#30069

It looks like the NumPy release has some inconsistent behavior in some cases; if you see test failures in relevant cases, it could be due to this.

Array([1.5, 3. , 4.5], dtype=float32)
"""
a, q = ensure_arraylike("nanquantile", a, q)
check_arraylike("nanquantile", a, q)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

@Aniketsy
Copy link
Contributor Author

Thanks! I’ve updated the changes. I also have a small query, should we add local import inside functions that use jnp ?

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run tests locally to catch problems in your code: for a change like this relying on github CI will lead to the development cycle being very slow. There is information here: https://docs.jax.dev/en/latest/contributing.html#contributing-code-using-pull-requests

resh = [1] * a.ndim
resh[axis] = w_shape[0]
weights = lax.expand_dims(weights, axis)
weights = _broadcast_to(weights, a.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in the if block? If weights.shape is already equal to a.shape, this is not needed.

@Aniketsy
Copy link
Contributor Author

I've tested locally and getting test fails AttributeError: lax.count , interpolation deprecation error, please let me know, if I should update these.

Copy link

@cakedev0 cakedev0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a very quick look because I've been working on this subject those past days/weeks but I'm not familiar at all with the jax project.

I think this PR is not really aligned with numpy or scikit-learn implementations (sklearn.utils.stats._weighted_percentile), so I suggest to take a deep look at those before continuing this PR. The numpy's doc is pretty helpful too.
You can also take a look at my PR in array-api-extra: data-apis/array-api-extra#494 ^^ I believe it's the easiest to read, but I'm definitely biased here 😅

Comment on lines 2549 to 2557
if method == "linear":
denom = cw_next - cw_prev
denom = _where(denom == 0, 1, denom)
weight = (qi - cw_prev) / denom
out = val_prev * (1 - weight) + val * weight
elif method == "lower":
out = val_prev
elif method == "higher":
out = val
elif method == "nearest":
out = _where(lax.abs(qi - cw_prev) < lax.abs(qi - cw_next), val_prev, val)
elif method == "midpoint":
out = (val_prev + val) / 2
elif method == "inverted_cdf":
out = val

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you come up with those formula?

While they seem relatively sounded, it worth noting that NumPy only support method "inverted_cdf" with weights. And scikit-learn's internal _weighted_percentile implements "inverted_cdf" and "average_inverted_cdf" (through parameter average=True/False).

Typically, in your implementation I see that both methods "higher" and "inverted_cdf" are equivalent, but for non-weighted case in numpy, there aren't equivalent, try this for instance:

import numpy as np
a = np.random.rand(10)
q = np.linspace(0, 1, num=11)
np.allclose(np.quantile(a, q, method='higher'), np.quantile(a, q, method='inverted_cdf'))

So I don't think it's correct for both methods to be equivalent in the weighted case.

elif method == "midpoint":
result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5))
elif method == "inverted_cdf":
result = high_value

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: this is not consistent with numpy.

idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1)
val = jnp.take_along_axis(a_sorted, idx, axis)

idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that some weights might be null, doing idx - 1 might not have the expected effect in some cases (typically adding adding 0-weight samples shouldn't affect the output, but with this logic it will)

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 26, 2025

Using from jax._src.lax import lax causes all tests to fail.

You'll have to figure out why it causes failures and then fix them

Also, _const was not applied by me.

Some change you made caused this to fail (did you overwrite the local variable lax?). In order to proceed, you'll have to debug this and fix the issue.

These AttributeErrors are the easiest issues you're going to encounter: we haven't even gotten to the hard stuff yet, which is making sure outputs produced by your code match those produced by NumPy.

@Aniketsy
Copy link
Contributor Author

Sure, I’ll fix these soon.

@Aniketsy
Copy link
Contributor Author

I've fixed Attribute errors, still getting some errors, I'm working on them please, correct me if these changes needs improvement. Thanks!

====================== short test summary info =======================
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testMedian0 - AssertionError: Tuples differ: (1, 1) != (1,)
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testMedian1 - AssertionError: Tuples differ: (1, 1, 1) != (1, 1)
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testMedian2 - AssertionError: Tuples differ: (1, 1, 1) != (1, 1)
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testMedian4 - AssertionError: Tuples differ: (1, 1, 1) != (1, 1)
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testQuantile4 - AssertionError: Tuples differ: (4,) != (1, 4)
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testQuantile5 - AssertionError: Tuples differ: (1, 7) != (1, 1, 7)
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testQuantile7 - AssertionError: Tuples differ: (1, 4, 1, 1) != (4, 1, 1)       
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testQuantileDeprecatedArgs0 - TypeError: nanquantile() argument interpolation was removed in JAX...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testQuantileDeprecatedArgs2 - TypeError: quantile() argument interpolation was 
removed in JAX v0...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile0 - ValueError: Quantiles must be in the range [0, 1]      
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile1 - jax._src.dtypes.TypePromotionError: Input dtypes ('int8', 'float32...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile2 - jax._src.dtypes.TypePromotionError: Input dtypes ('int8', 'float32...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile3 - jax._src.dtypes.TypePromotionError: 

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 27, 2025

It seems like most of the errors are related to incorrect shapes when the leading dimension is 1. I don't totally follow the reasoning behind all the nested if statements at the end of the main function when you reshape the result array, but I'm pretty sure that's where the issue lies.

@Aniketsy
Copy link
Contributor Author

FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testQuantileDeprecatedArgs0 - TypeError: nanquantile() argument interpolation was removed in JAX...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testQuantileDeprecatedArgs2 - TypeError: quantile() argument interpolation was 
removed in JAX v0...

Just to confirm do i need to fix these as we already have this, in our code

if not isinstance(interpolation, DeprecatedArg):
    raise TypeError("nanquantile() argument interpolation was removed in JAX"
                    " v0.8.0. Use method instead.")

Also, could you please provide some insight on this TypePromotionError? Thanks!

FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile0 - ValueError: Quantiles must be in the range [0, 1]      
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile1 - jax._src.dtypes.TypePromotionError: Input dtypes ('int8', 'float32...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile2 - jax._src.dtypes.TypePromotionError: Input dtypes ('int8', 'float32...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::test_weighted_quantile_all_weights_zero - AssertionError: ValueError not raisedFAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::test_weighted_quantile_negative_weights - AssertionError: ValueError not raised

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 27, 2025

Just to confirm do i need to fix these as we already have this, in our code

Yes, if tests are failing you need to fix them. I'm not entirely sure how you're hitting that error case: perhaps you're passing an interpolation argument into the function somewhere?

Also, could you please provide some insight on this TypePromotionError? Thanks!

All JAX tests are run with jax_numpy_dtypes_promotion='strict', so implementations within JAX cannot rely on implicit dtype promotion. The reason for this is that if we used implicit promotion within JAX's core, then jax_numpy_dtype_promotion='strict' would be basically useless.

import numpy as np

import jax
from jax import lax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I commented on this previously, but it still applies: you cannot import top-level APIs in source files. For things like jax.vmap, use api.vmap here. And instead of jax.lax, use jax._src.lax and its contents.

weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3

def np_fun(a, q, weights):
return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are already numpy arrays; no need to pass them through np.array

@Aniketsy
Copy link
Contributor Author

I made a few additional modifications in this file apart from the parts where I was implementing weights, mainly to address some attribute errors. A few tests are still failing, but before making further updates, I wanted to check if these extra changes are acceptable or if I should revert them. Thank you!

@Aniketsy
Copy link
Contributor Author

These are some of the test failures I’m currently working on. Could you please provide some insights into them?
I also apologize for bothering you repeatedly with these test failures, I really appreciate your guidance! Thanks

========================== short test summary info ========================== 
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testQuantileDeprecatedArgs0 - TypeError: nanquantile() argument interpolation was removed in 
JAX v0.8.0...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testQuantileDeprecatedArgs2 - TypeError: quantile() argument interpolation was removed in JAX v0.8.0. U...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile0 - jax.errors.TracerBoolConversionError: Attempted boolean conversion of tra...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile1 - jax.errors.TracerBoolConversionError: Attempted boolean conversion of tra...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile2 - ValueError: Length of weights not compatible with specified axis.    
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile3 - UserWarning: Explicitly requested dtype float64 requested in astype is no...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile6 - ValueError: Length of weights not compatible with specified axis.    
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile7 - jax.errors.TracerBoolConversionError: Attempted boolean conversion of tra...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile8 - jax.errors.TracerBoolConversionError: Attempted boolean conversion of tra...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::testWeightedQuantile9 - UserWarning: Explicitly requested dtype float64 requested in astype is no...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::test_weighted_quantile_all_weights_zero - jax.errors.TracerBoolConversionError: Attempted boolean conversion of tra...
FAILED tests/lax_numpy_reducers_test.py::JaxNumpyReducerTests::test_weighted_quantile_negative_weights - jax.errors.TracerBoolConversionError: Attempted boolean conversion of tra...

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 29, 2025

These are some of the test failures I’m currently working on.

The interpolation failures are confusing to me. I don't see anywhere in the current changes where this would arise – do you have local changes that you haven't pushed?

@Aniketsy
Copy link
Contributor Author

I’ve pushed the latest changes, these failures are from the latest commit. I’ll go through everything again to see what I might be missing, and please feel free to guide me if you notice something.

return lax.transpose(a, perm)

def _upcast_f16(dtype: DTypeLike) -> DType:
def _upcast_f16(dtype: DTypeLike) -> DTypeLike:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should stay -> DType, because it returns a dtype.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a number of unrelated changes snuck in here (possibly due to a bad resolution of a merge conflict)

Please revert all unrelated changes, and then we can continue debugging your code.

@Aniketsy
Copy link
Contributor Author

I’ve reverted the changes.

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 31, 2025

I took a very quick look because I've been working on this subject those past days/weeks but I'm not familiar at all with the jax project.

I think this PR is not really aligned with numpy or scikit-learn implementations (sklearn.utils.stats._weighted_percentile), so I suggest to take a deep look at those before continuing this PR. The numpy's doc is pretty helpful too. You can also take a look at my PR in array-api-extra: data-apis/array-api-extra#494 ^^ I believe it's the easiest to read, but I'm definitely biased here 😅

@Aniketsy I wonder if you saw this comment – it's important. So far my reviews have not focused on the correctness of the implementation, as well-crafted tests will generally tell you whether you've gotten that right. But it's taken a lot of iteration to get to the point where the tests can even be run.

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 31, 2025

We've been at this for two weeks, and your responses to my reviews are generally two steps forward, one step back. Your commits keep referencing old APIs that no longer exist in the package (most recently, jax._src.numpy.util._complex_elem_type). This makes me think you're using an LLM and blindly pushing the commits without testing locally.

This wastes my time and yours – I think it would be best to not continue with this change.

@jakevdp jakevdp closed this Oct 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants