Skip to content

Commit 9ce1af7

Browse files
qsang-nvyzh119
andauthored
minor fix for xqa (#1994)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> 1 change xqa_mla comments to be consistent with mla instead of mha. 2 put cudaMemcpyFromSymbol/cudaFuncSetAttribute outside of launch function to avoid breaking cuda graph capture 3 use int32 as pagetable index ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added MLA variant documentation clarifying SM120 GPU requirement and fixed head group ratio configuration. * **Documentation** * Updated data type specifications for XQA operations; page table now requires int32 instead of uint32. * Added max sequence length derivation notes for page-table-based configurations. * Clarified MLA variant input/output data types (float8_e4m3fn and bfloat16). * **Bug Fixes** * Corrected data type handling in page table processing to ensure compatibility. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Qidi Sang <[email protected]> Co-authored-by: yzh119 <[email protected]>
1 parent 7d9d7af commit 9ce1af7

File tree

5 files changed

+42
-36
lines changed

5 files changed

+42
-36
lines changed

csrc/xqa/mha.cu

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2655,6 +2655,15 @@ void launchMHA(
26552655
}
26562656
#endif
26572657

2658+
static uint32_t configureKernel() {
2659+
uint32_t size;
2660+
cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize));
2661+
cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size);
2662+
return size;
2663+
}
2664+
2665+
static uint32_t const hostSmemSize = configureKernel();
2666+
26582667
void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize,
26592668
float qScale, OutputHead* output,
26602669
#if LOW_PREC_OUTPUT
@@ -2673,13 +2682,6 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
26732682
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
26742683
#endif
26752684
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
2676-
static uint32_t const hostSmemSize = [&]() {
2677-
uint32_t size;
2678-
checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize)));
2679-
checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size));
2680-
return size;
2681-
}();
2682-
26832685
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
26842686
if (!allowMultiBlockMode) {
26852687
return 1;

csrc/xqa/mha_sm90.cu

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3165,6 +3165,15 @@ void launchHopperF8MHA(
31653165
}
31663166
#endif
31673167

3168+
static uint32_t configureKernel() {
3169+
uint32_t size;
3170+
cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize));
3171+
cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size);
3172+
return size;
3173+
}
3174+
3175+
static uint32_t const hostSmemSize = configureKernel();
3176+
31683177
void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads,
31693178
uint32_t slidingWinSize, float qScale, OutputHead* output,
31703179
#if LOW_PREC_OUTPUT
@@ -3183,12 +3192,6 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
31833192
uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask,
31843193
#endif
31853194
uint32_t* semaphores, void* scratch, cudaStream_t stream) {
3186-
static uint32_t const hostSmemSize = [&]() {
3187-
uint32_t size;
3188-
checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize)));
3189-
checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size));
3190-
return size;
3191-
}();
31923195
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
31933196
float const factor = 0.25f;
31943197
return mha::min<uint32_t>(

csrc/xqa/mla_sm120.cu

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,6 +1835,15 @@ void launchMLA(
18351835
#endif
18361836
}
18371837

1838+
static uint32_t configureKernel() {
1839+
uint32_t size;
1840+
cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize));
1841+
cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size);
1842+
return size;
1843+
}
1844+
1845+
static uint32_t const hostSmemSize = configureKernel();
1846+
18381847
void launchMLAFlashInfer(
18391848
uint32_t multiProcessorCount,
18401849
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
@@ -1860,13 +1869,6 @@ void launchMLAFlashInfer(
18601869
if (beamWidth != 1) {
18611870
throw std::runtime_error("not implemented");
18621871
}
1863-
static uint32_t const hostSmemSize = [&]() {
1864-
// printf("smemSize = %u\n", smemSize);
1865-
uint32_t size;
1866-
checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize)));
1867-
checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size));
1868-
return size;
1869-
}();
18701872
uint32_t const nbKHeads = 1;
18711873
uint32_t const nbVHeads = nbKHeads;
18721874
uint32_t const nbQHeads = nbKHeads * headGrpSize;

flashinfer/xqa.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def xqa(
153153
Should be the same data type as k_cache.
154154
page_table : torch.Tensor
155155
Page table tensor with shape ``batch_size, nb_pages_per_seq``.
156-
Data type should be torch.uint32.
156+
Data type should be torch.int32.
157157
K and V share the same table.
158158
seq_lens : torch.Tensor
159159
Sequence lengths tensor with shape ``[batch_size, beam_width]``.
@@ -195,6 +195,7 @@ def xqa(
195195
- input_dtype from q.dtype
196196
- kv_cache_dtype from k.dtype
197197
- head_group_ratio from num_q_heads // num_kv_heads
198+
- max_seq_len from page_table.shape[-1] * page_size
198199
"""
199200
# Handle optional parameters
200201
if sm_count is None:
@@ -352,31 +353,29 @@ def xqa_mla(
352353
kv_scale: Optional[torch.Tensor] = None,
353354
sm_count: Optional[int] = None,
354355
) -> None:
355-
r"""Apply attention with paged KV cache using XQA kernel.
356+
r"""Apply attention with paged KV cache using XQA MLA (Multi-Head Latent Attention) kernel.
356357
Parameters
357358
----------
358359
q : torch.Tensor
359360
Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
360-
Data type should be torch.float16 or torch.bfloat16.
361+
Data type should be torch.float8_e4m3fn.
361362
Now only beam_width 1 is supported.
362363
k_cache: torch.Tensor
363364
Paged K cache tensor with shape ``[total_num_cache_heads, head_dim]``.
364-
Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation.
365-
Should be the same data type as v_cache.
365+
Data type should be torch.float8_e4m3fn
366366
v_cache: torch.Tensor
367367
Paged V cache tensor with shape ``[total_num_cache_heads, head_dim]``.
368-
Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation.
369-
Should be the same data type as k_cache.
368+
Data type should be torch.float8_e4m3fn
370369
page_table : torch.Tensor
371370
Page table tensor with shape ``batch_size, nb_pages_per_seq``.
372-
Data type should be torch.uint32.
371+
Data type should be torch.int32.
373372
K and V share the same table.
374373
seq_lens : torch.Tensor
375374
Sequence lengths tensor with shape ``[batch_size, beam_width]``.
376375
Data type should be torch.uint32.
377376
output : torch.Tensor
378377
Output tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
379-
Data type should match query tensor. This tensor will be modified in-place.
378+
Data type should be torch.bfloat16. This tensor will be modified in-place.
380379
workspace_buffer : torch.Tensor
381380
Workspace buffer for temporary computations.
382381
Data type should be torch.uint8.
@@ -399,8 +398,8 @@ def xqa_mla(
399398
The function automatically infers several parameters from tensor shapes:
400399
- batch_size from q.shape[0]
401400
- head_dim from q.shape[-1]
402-
- input_dtype from q.dtype
403-
- kv_cache_dtype from k.dtype
401+
- head_group_ratio is fixed to 128 for MLA
402+
- max_seq_len from page_table.shape[-1] * page_size
404403
"""
405404
# Handle optional parameters
406405
if sm_count is None:
@@ -423,7 +422,7 @@ def xqa_mla(
423422
assert k_cache.dtype == v_cache.dtype, "K and V cache must have the same dtype"
424423

425424
if get_compute_capability(torch.device(device="cuda"))[0] not in [12]:
426-
raise RuntimeError("XQA is only supported on SM120 GPUs")
425+
raise RuntimeError("XQA MLA is only supported on SM120 GPUs")
427426

428427
xqa_module = get_xqa_module_mla(
429428
q.dtype,

tests/attention/test_xqa.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def test_xqa(
253253
cache_k_heads /= 4.0
254254
cache_v_heads /= 4.0
255255
page_list_arg = torch.zeros(
256-
batch_size, nb_pages_per_seq, dtype=torch.uint32, device="cuda"
256+
batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda"
257257
)
258258

259259
# Initialize page list sequentially
@@ -265,7 +265,7 @@ def test_xqa(
265265

266266
flattened = page_list_arg.flatten()
267267
indices = torch.randperm(flattened.numel())
268-
shuffled_flat = flattened.to(torch.int32)[indices].to(torch.uint32)
268+
shuffled_flat = flattened[indices]
269269
page_list_arg = shuffled_flat.view(page_list_arg.shape)
270270

271271
def cache_head_at(
@@ -470,7 +470,7 @@ def test_xqa_mla(
470470
cache_v_heads /= 4.0
471471

472472
page_list_arg = torch.zeros(
473-
batch_size, nb_pages_per_seq, dtype=torch.uint32, device="cuda"
473+
batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda"
474474
)
475475

476476
# Initialize page list sequentially
@@ -482,7 +482,7 @@ def test_xqa_mla(
482482

483483
flattened = page_list_arg.flatten()
484484
indices = torch.randperm(flattened.numel())
485-
shuffled_flat = flattened.to(torch.int32)[indices].to(torch.uint32)
485+
shuffled_flat = flattened[indices]
486486
page_list_arg = shuffled_flat.view(page_list_arg.shape)
487487

488488
def cache_head_at(

0 commit comments

Comments
 (0)