Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quantile op #287

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open

Add quantile op #287

wants to merge 32 commits into from

Conversation

CelysPr
Copy link

@CelysPr CelysPr commented Nov 11, 2024

PR Category

Operator

Type of Change

New Feature

Description

Created the 'quantile' operator

Issue

Resolves #251

Progress

  • Compute for n-d input
  • dim, keepdim & interpolation
  • Autotune supported
  • Add a parameter for storing the output
  • Maybe use different way of sort

Performance

Add quantile op
@kiddyjinjin
Copy link
Collaborator

Please follow the instructions in section 2.1 of the CONTRIBUTING.md file and run the pre-commit.

@CelysPr
Copy link
Author

CelysPr commented Nov 19, 2024

11.19 Notes

Some test cases in test_general_reduction_ops.py failed:
image
One example (with the highest mismatch rate)
image
Notice all the failed cases apply either "higher" or "midpoint" as their interpolation method. I tried changing the way we calculate "q_upper" but it didn't work at all.

Besides, adding "CUDA_VISIBLE_DEVICES=7" before pytest command, or specifying "device='cuda:7'" in the codes will cause unexpected errors in both the operation test and the benchmark.
For example, by running CUDA_VISIBLE_DEVICES=7 pytest ../tests/test_general_reduction_ops.py -m quantile --cache-clear, errors indicating "the inputs are too large" occur:
image

I didn't test the benchmark for quantile op, because someone else is using cuda:0 while I cannot utilize other cuda devices anyway. I will refine the codes once I find the way to run them.

@CelysPr
Copy link
Author

CelysPr commented Dec 2, 2024

Problems: #341

@kiddyjinjin
Copy link
Collaborator

#341

I noticed that the test case failures occur only when the interpolation mode is set to either higher or midpoint, with q = 0.2, an input shape of 256 for the second dimension, and reduction applied along dim = 1. Below is a simple, reproducible example that highlights the issue:


Reproducible Example

Test Input:

# Sorted tensor with values from 0 to 255, shape [1, 256]
inp = torch.arange(256, dtype=torch.float32).reshape(1, 256)
q = torch.tensor([0.2], dtype=torch.float32)
dim = 1 
keep_dim = False  # or True

Expected result:
For higher, the result should be 51.0 (the smallest element greater than or equal to 0.2 * (256 - 1)).
For midpoint, the result should be the average of the lower (51) and upper (52) bounds: (51 + 52) / 2 = 51.5.

Actual Result:
Higher: Returns 52.0, which is incorrect.
Midpoint: Propagates the same bug and gives an incorrect result.

The issue likely stems from incorrect handling of inp_upper, specifically the ceil logic, which fails to handle cases where q_block is an exact integer. Since midpoint relies on both the lower and upper bounds, the bug in higher also affects midpoint.

Please fix the inp_upper logic and ensure the inp_lower logic accounts for its edge cases.
Also, you should check the inp_lower logic, as it may encounter some other edge cases.

@kiddyjinjin
Copy link
Collaborator

#341

I noticed that the test case failures occur only when the interpolation mode is set to either higher or midpoint, with q = 0.2, an input shape of 256 for the second dimension, and reduction applied along dim = 1. Below is a simple, reproducible example that highlights the issue:

Reproducible Example

Test Input:

# Sorted tensor with values from 0 to 255, shape [1, 256]
inp = torch.arange(256, dtype=torch.float32).reshape(1, 256)
q = torch.tensor([0.2], dtype=torch.float32)
dim = 1 
keep_dim = False  # or True

Expected result: For higher, the result should be 51.0 (the smallest element greater than or equal to 0.2 * (256 - 1)). For midpoint, the result should be the average of the lower (51) and upper (52) bounds: (51 + 52) / 2 = 51.5.

Actual Result: Higher: Returns 52.0, which is incorrect. Midpoint: Propagates the same bug and gives an incorrect result.

The issue likely stems from incorrect handling of inp_upper, specifically the ceil logic, which fails to handle cases where q_block is an exact integer. Since midpoint relies on both the lower and upper bounds, the bug in higher also affects midpoint.

Please fix the inp_upper logic and ensure the inp_lower logic accounts for its edge cases. Also, you should check the inp_lower logic, as it may encounter some other edge cases.

I’ve identified why the quantile kernel test fails in our case. The issue arises because the reference tensor's dtype is upcast to torch.float64 by calling to_reference(inp, upcast=True). This causes a mismatch between the results of torch.quantile() computed with torch.float64 and torch.float32.

To resolve this, update the test case as follows:

ref_inp = to_reference(inp)
ref_q = to_reference(q)

This ensures consistent dtype and resolves the discrepancy.

@pytest.mark.parametrize("interpolation", QUANTILE_INTERPOLATION)
def test_accuracy_quantile_dim(shape, dim, keepdim, dtype, q, interpolation):
inp = torch.randn(shape, dtype=dtype, device="cuda")
ref_inp = to_reference(inp, True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should change this to use to_reference(inp) for the float32 case to ensure consistency in the dtype. Additionally, consider adding more comprehensive test coverage for various dtypes to ensure the robustness of the implementation.



@libentry()
@triton.autotune(configs=cfggen(), key=["N", "M", "Q"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to replace autotune with heuristics to enhance runtime performance

@kiddyjinjin
Copy link
Collaborator

What's more, please provide the benchmark results of the operator.

assert Q > 0
assert torch.all(q >= 0.0) and torch.all(q <= 1.0)

inp, _ = inp.sort() # Sort the input with torch.sort()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better approach might involve sorting the data and calculating the quantile result within a single kernel (though I understand this is very challenging). The current implementation separates the process into two distinct kernels. So, please share the benchmark results first.

@CelysPr
Copy link
Author

CelysPr commented Dec 9, 2024

Thank you for your thorough and patient review. Below are the benchmark results (with autotune enabled):
image

It appears that the Triton operator struggles particularly with small inputs, though these results in the screenshot are among the best I’ve achieved for float32 inputs. Replacing autotune with heuristic typically yields minimal improvement.
@kiddyjinjin

@CelysPr
Copy link
Author

CelysPr commented Dec 9, 2024

Thank you for your thorough and patient review. Below are the benchmark results (with autotune enabled): image

It appears that the Triton operator struggles particularly with small inputs, though these results in the screenshot are among the best I’ve achieved for float32 inputs. Replacing autotune with heuristic typically yields minimal improvement. @kiddyjinjin

"0" denotes "dim = 0".

@CelysPr
Copy link
Author

CelysPr commented Dec 10, 2024

12.10 Update -- Benchmark

The performance improved significantly by removing torch.sort(), with Gems speedup mostly ranging between 0.8 and 0.9. In this case, the input tensors are generated using torch.arange().
float32
image
float64
image

Thus, I have to consider how to calculate quantile within a kernel or find a better way of sort first.

@CelysPr
Copy link
Author

CelysPr commented Dec 10, 2024

More results
image
image

@CelysPr
Copy link
Author

CelysPr commented Dec 10, 2024

@kiddyjinjin Apologies for reaching out again, but I’ve encountered a strange bug while running the operator unit tests in test_general_reduction_ops.py. Specifically, even when I intentionally made the operator output incorrect, the unit tests still passed.

For example, I set the quantile_dim to output 0 consistently, then ran the unit test command.
image
image
However, all the test cases still passed, which seems unusual.
image

It looks like the unit test might be comparing the reference outputs to torch.quantile instead of the triton quantile outputs.

I also intentionally introduced an error in the output of the "sum" operation and ran its unit tests, but the faults were detected correctly in that case.

Could you help me identify where the issue might be? Thank you.

@kiddyjinjin
Copy link
Collaborator

Have you updated the code in src/flag_gems/__init__.py to register your implementation? Based on the latest code in your PR, this part seems to be missing.

By the way, you can write a small script and enable debug mode by adding the following code to verify if your kernel is actually being invoked:

import logging
logging.basicConfig(level=logging.DEBUG)

@CelysPr
Copy link
Author

CelysPr commented Dec 12, 2024

Update 12.13

I’ve made significant updates to my code. The quantile operator now performs well in both unit tests and benchmark tests. However, two small issues remain:

1. Operator Registration in flaggems/init.py

On my device, I’m unable to implicitly import the quantile() interface like other operators. Instead, I had to explicitly import the quantile module as follows:

from .ops import quantile
# Inside the register
("quantile", quantile.quantile, Autograd.disable)

Despite this, I committed code that appears normal without the explicit import. I’m unsure if this discrepancy could cause any issues on other devices.

2. CUDA Error When BLOCK_N is Small

I encountered a CUDA crash under specific conditions. Here are the parameters that trigger the issue:

  • dim = 0
  • input_shape = (64, 64, 4096)
  • q = [0, 0.2, 0.4, 0.6, 0.8, 1.0]
  • Q = 6, N = 64 * 4096
  • BLOCK_Q = 4, BLOCK_N = 4
  • grid = (N / BLOCK_N, Q / BLOCK_Q)

The crash produces an error like this:
image

Here’s what I’ve observed so far:

  • If input_shape = (64, 64, 4095) or (4096**2), there’s no error, but (64, 64, 4097) causes a crash.
  • The error is unrelated to autotune.
  • BLOCK_Q values of 1, 2, 4, and 8 do not cause errors.
  • BLOCK_N = 8 also works without error.

Hypothesis:
It seems that N / BLOCK_N should not exceed 65,535. In the failing case, 64 * 4096 / 4 = 65,536, which crosses this threshold. I tested other shapes to verify this limit, but I still don’t fully understand why this happens. The memory usage for each CTA block should be small enough in my case.

Testing and Benchmark

image image image

@kiddyjinjin

@CelysPr
Copy link
Author

CelysPr commented Dec 12, 2024

reopen

@CelysPr CelysPr reopened this Dec 12, 2024
@kiddyjinjin
Copy link
Collaborator

  1. Operator Registration in flaggems/init.py

Updating src/flag_gems/ops/__init__.py like other ops should solve your problem.

@CelysPr
Copy link
Author

CelysPr commented Dec 13, 2024

  1. Operator Registration in flaggems/init.py

Updating src/flag_gems/ops/__init__.py like other ops should solve your problem.

Thank you very much for your assistance. Unfortunately, after updating ops/init.py, my operator began to fail, producing errors like the one below:

image

To make matters worse, I cannot reproduce these errors with my local code. (I’m unable to directly examine quantile.py since attempting to import from "..utils" causes an error.)

import logging

import torch
import triton
import triton.language as tl
from torch import Tensor

import os
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"


def dim_compress(inp, dims):
    if isinstance(dims, int):
        dims = [dims]
    dim = inp.ndim
    stride = inp.stride()
    batch_dim = [i for i in range(dim) if i not in dims]
    sorted_reduction_dim = sorted(dims, key=lambda x: stride[x], reverse=True)
    order = batch_dim + sorted_reduction_dim
    return inp.permute(order).contiguous()


INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]


def heur_block_q(args):
    return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16))


def heur_block_n(args):
    if args["N"] >= 65536:
        return triton.next_power_of_2(triton.cdiv(args["N"], 512))
    elif args["N"] >= 4096:
        return triton.next_power_of_2(triton.cdiv(args["N"], 128))
    elif args["N"] >= 64:
        return 32
    elif args["N"] >= 32:
        return 4
    else:
        return 1
    
#@libentry()
@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n})
@triton.jit
def quantile_kernel(
    inp,
    q,
    out,
    N,
    M,
    Q,
    BLOCK_Q: tl.constexpr,
    BLOCK_N: tl.constexpr,
    interpolation: tl.constexpr,
):
    pid_Q = tl.program_id(0)
    pid_N = tl.program_id(1)
    ctype = inp.dtype.element_ty

    offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q)
    mask_Q = offsets_Q < Q
    q_ptrs = q + offsets_Q

    offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N)
    mask_N = offsets_N < N

    out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :]
    mask_out = mask_N[:, None] & mask_Q[None, :]

    q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1)
    q_lower = tl.floor(q_block).to(tl.int32)
    q_upper = tl.ceil(q_block).to(tl.int32)

    inp_lower = tl.load(inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0)
    inp_upper = tl.load(inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0)

    if interpolation == "linear":
        q_frac = q_block - q_lower
        tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out)

    elif interpolation == "lower":
        tl.store(out_ptrs, inp_lower, mask_out)

    elif interpolation == "higher":
        tl.store(out_ptrs, inp_upper, mask_out)

    elif interpolation == "nearest":
        q_round = tl.extra.cuda.libdevice.rint(q_block)
        out_block = tl.where(q_round == q_upper, inp_upper, inp_lower)
        tl.store(out_ptrs, out_block, mask_out)

    elif interpolation == "midpoint":
        tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out)


def quantile(
    inp, q, dim=None, keepdim=False, interpolation="linear", out=None
) -> Tensor:
    logging.debug("GEMS QUANTILE DIM")
    assert torch.is_floating_point(inp)
    assert dim is None or isinstance(dim, int)
    assert isinstance(q, (float, torch.Tensor))
    assert interpolation in INTERPOLATION_METHOD

    M = inp.numel()
    if isinstance(q, float):
        q = torch.tensor(q, device=inp.device)
        Q = 1
    else:
        Q = 1 if q.numel() == 1 else len(q)

    assert M > 0
    assert Q > 0
    assert torch.all(q >= 0.0) and torch.all(q <= 1.0)

    if dim is None:
        inp = inp.ravel()
        dim = 0

    shape = list(inp.shape)

    dim %= inp.ndim
    inp = dim_compress(inp, dim)
    M = shape[dim]
    N = inp.numel() // M

    inp, _ = inp.sort()  # Sort the input with torch.sort()
    output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device)

    grid = lambda meta: (
        triton.cdiv(Q, meta["BLOCK_Q"]),
        triton.cdiv(N, meta["BLOCK_N"]),
    )

    with torch.cuda.device(inp.device):
        quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation)

    output = output.permute(
        (-1,) + tuple(range(0, inp.ndim - 1))
    )  # Same as torch.quantile()
    if keepdim:
        output = output.unsqueeze(dim + 1)
    if Q == 1:
        output = output.squeeze(0)

    if out is not None:
        out.copy_(output)
    return output

keep = True
dim = 0
inter = 'midpoint'
inputs = torch.randn((10, 64, 196), device='cuda', dtype=torch.float64)
q = torch.tensor(torch.arange(0.0, 1.0, 65536, dtype=torch.float64), device='cuda')
out = quantile(inputs, q, dim=dim, keepdim=keep, interpolation=inter)
ref = torch.quantile(inputs, q, dim=dim, keepdim=keep, interpolation=inter)

print(torch.allclose(out, ref))
print(torch.max(torch.abs(ref - out)))

I’ve tested various input shapes, including (200, 2560, 3) and (10, 64, 196), which trigger errors in the unit tests. However, when I run the same shapes locally, the calculation results are always correct. I am completely perplexed by this discrepancy.

For additional context, I replaced tl.program_id with tle.program_id, but this change did not seem to have any effect.

Once again, I find myself in need of your help. My apologies for the repeated disturbances. @kiddyjinjin

@pytest.mark.parametrize("q", QUANTILE_Q)
@pytest.mark.parametrize("interpolation", QUANTILE_INTERPOLATION)
def test_accuracy_quantile_without_dim(shape, dtype, q, interpolation):
inp = torch.randn(shape, dtype=dtype, device="cuda")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use ‘’ device=flag_gems.device ‘’ instead of ‘’ device="cuda" ‘’.

@@ -424,7 +424,7 @@ def set_more_shapes(self):


def generate_tensor_input(shape, dtype, device):
if dtype in FLOAT_DTYPES:
if dtype in FLOAT_DTYPES + [torch.float64]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don’t need to focus on torch.float64 accuracy for this case. In real-world model training and inference scenarios, float64 precision is rarely used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Code Contribution: 【Lv1】【Operator Development】quantile
3 participants