-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add quantile op #287
base: master
Are you sure you want to change the base?
Add quantile op #287
Conversation
Add quantile op
Please follow the instructions in section 2.1 of the CONTRIBUTING.md file and run the pre-commit. |
Problems: #341 |
I noticed that the test case failures occur only when the interpolation mode is set to either Reproducible ExampleTest Input: # Sorted tensor with values from 0 to 255, shape [1, 256]
inp = torch.arange(256, dtype=torch.float32).reshape(1, 256)
q = torch.tensor([0.2], dtype=torch.float32)
dim = 1
keep_dim = False # or True Expected result: Actual Result: The issue likely stems from incorrect handling of Please fix the |
I’ve identified why the quantile kernel test fails in our case. The issue arises because the reference tensor's dtype is upcast to To resolve this, update the test case as follows: ref_inp = to_reference(inp)
ref_q = to_reference(q) This ensures consistent dtype and resolves the discrepancy. |
tests/test_general_reduction_ops.py
Outdated
@pytest.mark.parametrize("interpolation", QUANTILE_INTERPOLATION) | ||
def test_accuracy_quantile_dim(shape, dim, keepdim, dtype, q, interpolation): | ||
inp = torch.randn(shape, dtype=dtype, device="cuda") | ||
ref_inp = to_reference(inp, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should change this to use to_reference(inp) for the float32 case to ensure consistency in the dtype. Additionally, consider adding more comprehensive test coverage for various dtypes to ensure the robustness of the implementation.
src/flag_gems/ops/quantile.py
Outdated
|
||
|
||
@libentry() | ||
@triton.autotune(configs=cfggen(), key=["N", "M", "Q"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to replace autotune with heuristics to enhance runtime performance
What's more, please provide the benchmark results of the operator. |
src/flag_gems/ops/quantile.py
Outdated
assert Q > 0 | ||
assert torch.all(q >= 0.0) and torch.all(q <= 1.0) | ||
|
||
inp, _ = inp.sort() # Sort the input with torch.sort() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A better approach might involve sorting the data and calculating the quantile result within a single kernel (though I understand this is very challenging). The current implementation separates the process into two distinct kernels. So, please share the benchmark results first.
Thank you for your thorough and patient review. Below are the benchmark results (with autotune enabled): It appears that the Triton operator struggles particularly with small inputs, though these results in the screenshot are among the best I’ve achieved for float32 inputs. Replacing autotune with heuristic typically yields minimal improvement. |
"0" denotes "dim = 0". |
@kiddyjinjin Apologies for reaching out again, but I’ve encountered a strange bug while running the operator unit tests in test_general_reduction_ops.py. Specifically, even when I intentionally made the operator output incorrect, the unit tests still passed. For example, I set the quantile_dim to output 0 consistently, then ran the unit test command. It looks like the unit test might be comparing the reference outputs to torch.quantile instead of the triton quantile outputs. I also intentionally introduced an error in the output of the "sum" operation and ran its unit tests, but the faults were detected correctly in that case. Could you help me identify where the issue might be? Thank you. |
Have you updated the code in By the way, you can write a small script and enable debug mode by adding the following code to verify if your kernel is actually being invoked: import logging
logging.basicConfig(level=logging.DEBUG) |
reopen |
Updating |
Thank you very much for your assistance. Unfortunately, after updating ops/init.py, my operator began to fail, producing errors like the one below: To make matters worse, I cannot reproduce these errors with my local code. (I’m unable to directly examine quantile.py since attempting to import from "..utils" causes an error.)
I’ve tested various input shapes, including (200, 2560, 3) and (10, 64, 196), which trigger errors in the unit tests. However, when I run the same shapes locally, the calculation results are always correct. I am completely perplexed by this discrepancy. For additional context, I replaced tl.program_id with tle.program_id, but this change did not seem to have any effect. Once again, I find myself in need of your help. My apologies for the repeated disturbances. @kiddyjinjin |
tests/test_general_reduction_ops.py
Outdated
@pytest.mark.parametrize("q", QUANTILE_Q) | ||
@pytest.mark.parametrize("interpolation", QUANTILE_INTERPOLATION) | ||
def test_accuracy_quantile_without_dim(shape, dtype, q, interpolation): | ||
inp = torch.randn(shape, dtype=dtype, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use ‘’ device=flag_gems.device ‘’ instead of ‘’ device="cuda" ‘’.
benchmark/performance_utils.py
Outdated
@@ -424,7 +424,7 @@ def set_more_shapes(self): | |||
|
|||
|
|||
def generate_tensor_input(shape, dtype, device): | |||
if dtype in FLOAT_DTYPES: | |||
if dtype in FLOAT_DTYPES + [torch.float64]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don’t need to focus on torch.float64
accuracy for this case. In real-world model training and inference scenarios, float64
precision is rarely used.
PR Category
Operator
Type of Change
New Feature
Description
Created the 'quantile' operator
Issue
Resolves #251
Progress
Performance