- 
                Notifications
    You must be signed in to change notification settings 
- Fork 3.2k
Add Weighted Quantile and Percentile Support to jax.numpy #32737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| Summary of ChangesHello @Aniketsy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the  Highlights
 Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either  
 Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a  Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for weighted quantiles and percentiles, which is a valuable addition. The overall approach is sound, but there are a couple of significant issues in the implementation within _quantile. Firstly, there's a structural issue where the new logic for weighted quantiles is added after the axis variable has been processed, leading to dead code and incorrect behavior for NaN handling when axis is None. Secondly, the implementation for the weighted case appears to be incorrect when q (the quantiles) is a vector. I've provided detailed comments and suggestions for fixes. I also recommend expanding the test suite to cover these edge cases to ensure the feature is robust.
| @jakevdp please review these changes, when you get a chance. I'm not sure if i correctly placed -handle weights also all tests are not covered yet, should i add. Thanks! | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - thanks for working on this – I got a bit of the way through and then stopped reviewing becuse this implementation has fundamental flaws. I'd suggest starting over, and writing test cases first (including JIT-compilation) and make sure they pass before you request review.
| # Handle weights | ||
| if weights is not None: | ||
| a, weights = promote_dtypes_inexact(a, weights) | ||
| if axis is None: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Axis will never be None here, because of the if axis is None check above.
If we need to change the shape of weights based on the value of axis, that needs to be done before this block.
dafc4c1    to
    f5d2177      
    Compare
  
    | Thanks for the earlier suggestions! | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi - it looks like you resolved some merge conflicts in a way that led to your branch being invalid. Please take a look!
I'll review further once these issues are addressed.
85873ce    to
    f7ab683      
    Compare
  
    4593d40    to
    5f881bf      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're getting there! A number of comments below.
fb246ba    to
    7b967cb      
    Compare
  
    | Thank you for the detailed feedback and for pointing out the corrections. I’ve updated the changes as per your suggestions. Please let me know if any further improvements are needed. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments!
Also, regarding the tests, you should be aware of this: numpy/numpy#30069
It looks like the NumPy release has some inconsistent behavior in some cases; if you see test failures in relevant cases, it could be due to this.
        
          
                jax/_src/numpy/reductions.py
              
                Outdated
          
        
      | Array([1.5, 3. , 4.5], dtype=float32) | ||
| """ | ||
| a, q = ensure_arraylike("nanquantile", a, q) | ||
| check_arraylike("nanquantile", a, q) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
2ab7977    to
    4f522e6      
    Compare
  
    | Thanks! I’ve updated the changes. I also have a small query, should we add local import inside functions that use  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please run tests locally to catch problems in your code: for a change like this relying on github CI will lead to the development cycle being very slow. There is information here: https://docs.jax.dev/en/latest/contributing.html#contributing-code-using-pull-requests
        
          
                jax/_src/numpy/reductions.py
              
                Outdated
          
        
      | resh = [1] * a.ndim | ||
| resh[axis] = w_shape[0] | ||
| weights = lax.expand_dims(weights, axis) | ||
| weights = _broadcast_to(weights, a.shape) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be in the if block? If weights.shape is already equal to a.shape, this is not needed.
f3c407d    to
    cef1731      
    Compare
  
    | I've tested locally and getting test fails   | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a very quick look because I've been working on this subject those past days/weeks but I'm not familiar at all with the jax project.
I think this PR is not really aligned with numpy or scikit-learn implementations (sklearn.utils.stats._weighted_percentile), so I suggest to take a deep look at those before continuing this PR. The numpy's doc is pretty helpful too.
You can also take a look at my PR in array-api-extra: data-apis/array-api-extra#494 ^^ I believe it's the easiest to read, but I'm definitely biased here 😅
        
          
                jax/_src/numpy/reductions.py
              
                Outdated
          
        
      | if method == "linear": | ||
| denom = cw_next - cw_prev | ||
| denom = _where(denom == 0, 1, denom) | ||
| weight = (qi - cw_prev) / denom | ||
| out = val_prev * (1 - weight) + val * weight | ||
| elif method == "lower": | ||
| out = val_prev | ||
| elif method == "higher": | ||
| out = val | ||
| elif method == "nearest": | ||
| out = _where(lax.abs(qi - cw_prev) < lax.abs(qi - cw_next), val_prev, val) | ||
| elif method == "midpoint": | ||
| out = (val_prev + val) / 2 | ||
| elif method == "inverted_cdf": | ||
| out = val | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How did you come up with those formula?
While they seem relatively sounded, it worth noting that NumPy only support method "inverted_cdf" with weights. And scikit-learn's internal _weighted_percentile implements "inverted_cdf" and "average_inverted_cdf" (through parameter average=True/False).
Typically, in your implementation I see that both methods "higher" and "inverted_cdf" are equivalent, but for non-weighted case in numpy, there aren't equivalent, try this for instance:
import numpy as np
a = np.random.rand(10)
q = np.linspace(0, 1, num=11)
np.allclose(np.quantile(a, q, method='higher'), np.quantile(a, q, method='inverted_cdf'))So I don't think it's correct for both methods to be equivalent in the weighted case.
| elif method == "midpoint": | ||
| result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5)) | ||
| elif method == "inverted_cdf": | ||
| result = high_value | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here: this is not consistent with numpy.
        
          
                jax/_src/numpy/reductions.py
              
                Outdated
          
        
      | idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1) | ||
| val = jnp.take_along_axis(a_sorted, idx, axis) | ||
|  | ||
| idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that some weights might be null, doing idx - 1 might not have the expected effect in some cases (typically adding adding 0-weight samples shouldn't affect the output, but with this logic it will)
| 
 You'll have to figure out why it causes failures and then fix them 
 Some change you made caused this to fail (did you overwrite the local variable  These  | 
| Sure, I’ll fix these soon. | 
dcbc485    to
    7d50a32      
    Compare
  
    | I've fixed Attribute errors, still getting some errors, I'm working on them please, correct me if these changes needs improvement. Thanks!  | 
| It seems like most of the errors are related to incorrect shapes when the leading dimension is  | 
| Just to confirm do i need to fix these as we already have this, in our code Also, could you please provide some insight on this   | 
| 
 Yes, if tests are failing you need to fix them. I'm not entirely sure how you're hitting that error case: perhaps you're passing an  
 All JAX tests are run with  | 
        
          
                jax/_src/numpy/reductions.py
              
                Outdated
          
        
      | import numpy as np | ||
|  | ||
| import jax | ||
| from jax import lax | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I commented on this previously, but it still applies: you cannot import top-level APIs in source files. For things like jax.vmap, use api.vmap here. And instead of jax.lax, use jax._src.lax and its contents.
| weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3 | ||
|  | ||
| def np_fun(a, q, weights): | ||
| return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are already numpy arrays; no need to pass them through np.array
0a9de57    to
    8e9c8d6      
    Compare
  
    | I made a few additional modifications in this file apart from the parts where I was implementing weights, mainly to address some attribute errors. A few tests are still failing, but before making further updates, I wanted to check if these extra changes are acceptable or if I should revert them. Thank you! | 
| These are some of the test failures I’m currently working on. Could you please provide some insights into them?  | 
| 
 The  | 
| I’ve pushed the latest changes, these failures are from the latest commit. I’ll go through everything again to see what I might be missing, and please feel free to guide me if you notice something. | 
        
          
                jax/_src/numpy/reductions.py
              
                Outdated
          
        
      | return lax.transpose(a, perm) | ||
|  | ||
| def _upcast_f16(dtype: DTypeLike) -> DType: | ||
| def _upcast_f16(dtype: DTypeLike) -> DTypeLike: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should stay -> DType, because it returns a dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like a number of unrelated changes snuck in here (possibly due to a bad resolution of a merge conflict)
Please revert all unrelated changes, and then we can continue debugging your code.
4d0dfd4    to
    ca1d95b      
    Compare
  
    | I’ve reverted the changes. | 
4b318a4    to
    4ebbf21      
    Compare
  
    19ad7be    to
    43cdb5c      
    Compare
  
    | 
 @Aniketsy I wonder if you saw this comment – it's important. So far my reviews have not focused on the correctness of the implementation, as well-crafted tests will generally tell you whether you've gotten that right. But it's taken a lot of iteration to get to the point where the tests can even be run. | 
| We've been at this for two weeks, and your responses to my reviews are generally two steps forward, one step back. Your commits keep referencing old APIs that no longer exist in the package (most recently,  This wastes my time and yours – I think it would be best to not continue with this change. | 
#32647
This PR adds support for weighted quantile and percentile computations to jax.numpy.quantile, jax.numpy.nanquantile, jax.numpy.percentile, and jax.numpy.nanpercentile .
Please let me know if my approach or fix needs any improvements . I’m open to feedback and happy to make changes based on suggestions.
Thankyou !