Skip to content

Commit d4a3ff4

Browse files
authored
fix: correct PDL parameter handling in RopeQuantize kernel (#1982)
<!-- .github/pull_request_template.md --> ## 📌 Description ### 1. Fixed Parameter Alignment - **Issue**: The `stream` parameter was being passed to the wrong position in the `RopeQuantize` function call due to missing `enable_pdl` parameter. SGLang will hang before this pr. - **Fix**: Added the `enable_pdl` parameter to the function signature and properly aligned all parameters ### 2. Fixed PDL Launch Configuration - **Issue**: When `enable_pdl=true`, the kernel would throw CUDA errors due to incorrect PDL attribute handling - **Fix**: Aligned the implementation with `csrc/fmhaReduction.cu`. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added PDL (Programmatic Dynamic Launch) benchmarking capability for rope quantization operations. * Extended configuration options to enable or disable PDL functionality. * **Tests** * Updated test suite to validate PDL enabled and disabled scenarios in rope quantization workflows. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent d225332 commit d4a3ff4

File tree

6 files changed

+49
-20
lines changed

6 files changed

+49
-20
lines changed

benchmarks/bench_rope_quantize_fp8.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _apply_rotary_emb(
8888
return torch.stack((o1, o2), dim=-1).flatten(-2)
8989

9090

91-
def benchmark_config(config_name, num_tokens, provider):
91+
def benchmark_config(config_name, num_tokens, provider, enable_pdl=False):
9292
"""Benchmark a specific attention configuration."""
9393
input_dtype = torch.bfloat16
9494
device = "cuda"
@@ -177,6 +177,7 @@ def execute():
177177
k_nope_out=k_nope_out,
178178
quant_scale_q=1.0,
179179
quant_scale_kv=1.0,
180+
enable_pdl=enable_pdl,
180181
)
181182

182183
if mode_ncu and run_idx == 20:
@@ -278,6 +279,23 @@ def benchmark_mha(provider, num_tokens):
278279
return benchmark_config("mha", num_tokens, provider)
279280

280281

282+
@triton.testing.perf_report(
283+
triton.testing.Benchmark(
284+
x_names=["num_tokens"],
285+
x_vals=[768] if mode_ncu else [1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 768],
286+
line_arg="enable_pdl",
287+
line_vals=[False, True],
288+
line_names=["enable_pdl=False", "enable_pdl=True"],
289+
styles=[("blue", "-"), ("red", "-")],
290+
ylabel="Latency (ms)",
291+
plot_name="rope-pdl-benchmark",
292+
args={},
293+
)
294+
)
295+
def benchmark_pdl(enable_pdl, num_tokens):
296+
return benchmark_config("mla", num_tokens, "flashinfer", enable_pdl=enable_pdl)
297+
298+
281299
if __name__ == "__main__":
282300
# Run all benchmarks and generate individual plots
283301
print("Running MLA benchmark...")
@@ -289,6 +307,9 @@ def benchmark_mha(provider, num_tokens):
289307
print("Running MHA benchmark...")
290308
benchmark_mha.run(print_data=False, show_plots=True, save_path=".")
291309

310+
print("Running PDL benchmark...")
311+
benchmark_pdl.run(print_data=False, show_plots=True, save_path=".")
312+
292313
# Collect results for summary table
293314
token_counts = (
294315
[1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 768] if not mode_ncu else [768]
@@ -319,3 +340,4 @@ def benchmark_mha(provider, num_tokens):
319340
print(" mla-rope-benchmark.png (FlashInfer vs PyTorch)")
320341
print(" gqa-rope-benchmark.png (FlashInfer vs PyTorch)")
321342
print(" mha-rope-benchmark.png (FlashInfer vs PyTorch)")
343+
print(" rope-pdl-benchmark.png (enable_pdl=False vs enable_pdl=True)")

csrc/flashinfer_rope_binding.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ void apply_rope_pos_ids_cos_sin_cache(TensorView q, TensorView k, TensorView q_r
4242
void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in,
4343
TensorView k_nope_in, TensorView q_rope_out, TensorView k_rope_out,
4444
TensorView q_nope_out, TensorView k_nope_out, TensorView cos_sin_cache,
45-
TensorView pos_ids, double quant_scale_q, double quant_scale_kv,
46-
bool interleave);
45+
TensorView pos_ids, double quant_scale_q, double quant_scale_kv, bool interleave,
46+
bool enable_pdl);
4747

4848
TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope);
4949
TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope);

csrc/rope.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ void apply_llama31_rope_pos_ids(TensorView q, TensorView k, TensorView q_rope, T
271271
void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in,
272272
TensorView k_nope_in, TensorView q_rope_out, TensorView k_rope_out,
273273
TensorView q_nope_out, TensorView k_nope_out, TensorView cos_sin_cache,
274-
TensorView pos_ids, double quant_scale_q, double quant_scale_kv,
275-
bool interleave) {
274+
TensorView pos_ids, double quant_scale_q, double quant_scale_kv, bool interleave,
275+
bool enable_pdl) {
276276
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_in);
277277
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_rope_in);
278278
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_in);
@@ -411,7 +411,7 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope
411411
q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n,
412412
q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride,
413413
k_nope_in_stride_h, k_rope_out_stride, k_rope_out_stride_h, k_nope_out_stride,
414-
k_nope_out_stride_h, quant_scale_q, quant_scale_kv, interleave, stream);
414+
k_nope_out_stride_h, quant_scale_q, quant_scale_kv, interleave, enable_pdl, stream);
415415

416416
TVM_FFI_ICHECK(status == cudaSuccess)
417417
<< "RopeQuantize failed with error code " << cudaGetErrorString(status);

flashinfer/rope.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def _rope_quantize(
181181
quant_scale_q: float,
182182
quant_scale_kv: float,
183183
interleave: bool,
184+
enable_pdl: bool,
184185
) -> None:
185186
r"""Custom operator that routes to the CUDA kernel implementation.
186187
@@ -201,6 +202,7 @@ def _rope_quantize(
201202
quant_scale_q,
202203
quant_scale_kv,
203204
interleave,
205+
enable_pdl,
204206
)
205207

206208

@@ -219,6 +221,7 @@ def _fake_rope_quantize(
219221
quant_scale_q: float,
220222
quant_scale_kv: float,
221223
interleave: bool,
224+
enable_pdl: bool,
222225
) -> None:
223226
pass
224227

@@ -1159,6 +1162,7 @@ def mla_rope_quantize_fp8(
11591162
k_rope_out: Optional[torch.Tensor] = None,
11601163
q_nope_out: Optional[torch.Tensor] = None,
11611164
k_nope_out: Optional[torch.Tensor] = None,
1165+
enable_pdl: bool = False,
11621166
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
11631167
return rope_quantize_fp8(
11641168
q_rope,
@@ -1175,6 +1179,7 @@ def mla_rope_quantize_fp8(
11751179
k_rope_out,
11761180
q_nope_out,
11771181
k_nope_out,
1182+
enable_pdl,
11781183
)
11791184

11801185

@@ -1193,6 +1198,7 @@ def rope_quantize_fp8(
11931198
k_rope_out: Optional[torch.Tensor] = None,
11941199
q_nope_out: Optional[torch.Tensor] = None,
11951200
k_nope_out: Optional[torch.Tensor] = None,
1201+
enable_pdl: bool = False,
11961202
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
11971203
r"""Apply RoPE (Rotary Positional Embeddings) and quantize to FP8 format.
11981204
@@ -1237,6 +1243,8 @@ def rope_quantize_fp8(
12371243
Pre-allocated output tensor for quantized query (non-rotary). If ``None``, allocated automatically.
12381244
k_nope_out : Optional[torch.Tensor]
12391245
Pre-allocated output tensor for quantized key (non-rotary). If ``None``, allocated automatically.
1246+
enable_pdl : bool
1247+
Whether to enable PDL (Programmatic Dependent Launch). Default: ``False``.
12401248
12411249
Returns
12421250
-------
@@ -1291,6 +1299,7 @@ def rope_quantize_fp8(
12911299
quant_scale_q,
12921300
quant_scale_kv,
12931301
not is_neox, # interleave
1302+
enable_pdl,
12941303
)
12951304

12961305
return q_rope_out, k_rope_out, q_nope_out, k_nope_out

include/flashinfer/pos_enc.cuh

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -813,24 +813,16 @@ cudaError_t RopeQuantize(
813813
dim3 nblks(nblks_x, total_blocks_y);
814814
dim3 nthrs(bdx, bdy);
815815

816+
cudaLaunchAttribute attribute[1];
817+
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
818+
attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0;
816819
cudaLaunchConfig_t config;
817820
config.gridDim = nblks;
818821
config.blockDim = nthrs;
819-
config.dynamicSmemBytes = 0;
820822
config.stream = stream;
821-
822-
if (enable_pdl) {
823-
// PDL launch config
824-
cudaLaunchAttribute attribute[1];
825-
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
826-
attribute[0].val.programmaticStreamSerializationAllowed = 1;
827-
config.attrs = attribute;
828-
config.numAttrs = 1;
829-
} else {
830-
// Regular launch config
831-
config.attrs = nullptr;
832-
config.numAttrs = 0;
833-
}
823+
config.dynamicSmemBytes = 0;
824+
config.attrs = attribute;
825+
config.numAttrs = 1;
834826

835827
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(
836828
&config, kernel, q_rope_in, k_rope_in, q_nope_in, k_nope_in, q_rope_out, k_rope_out,

tests/attention/test_rope.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def test_rope_cos_sin_cache(
380380
@pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047])
381381
@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16])
382382
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
383+
@pytest.mark.parametrize("enable_pdl", [True, False])
383384
def test_generalized_rope_quantize(
384385
attention_type,
385386
num_qo_heads,
@@ -389,6 +390,7 @@ def test_generalized_rope_quantize(
389390
num_tokens,
390391
input_dtype,
391392
quant_dtype,
393+
enable_pdl,
392394
):
393395
"""Test generalized rope + quantization for MLA, GQA, and MHA architectures."""
394396
device = "cuda:0"
@@ -459,6 +461,7 @@ def test_generalized_rope_quantize(
459461
k_nope_out=k_nope_out,
460462
quant_scale_q=1.0,
461463
quant_scale_kv=1.0,
464+
enable_pdl=enable_pdl,
462465
)
463466

464467
# Verify results
@@ -481,10 +484,12 @@ def test_generalized_rope_quantize(
481484
@pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047])
482485
@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16])
483486
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
487+
@pytest.mark.parametrize("enable_pdl", [True, False])
484488
def test_mla_rope_quantize(
485489
num_tokens,
486490
input_dtype,
487491
quant_dtype,
492+
enable_pdl,
488493
):
489494
device = "cuda:0"
490495
num_qo_heads = 128
@@ -525,6 +530,7 @@ def test_mla_rope_quantize(
525530
k_nope_out=k_out[..., 64:],
526531
quant_scale_q=1.0,
527532
quant_scale_kv=1.0,
533+
enable_pdl=enable_pdl,
528534
)
529535

530536
torch.testing.assert_close(

0 commit comments

Comments
 (0)