-
Notifications
You must be signed in to change notification settings - Fork 225
Description
Problem Description
aiter/csrc/cpp_itfs/pa/pa_kernels.cuh
Lines 124 to 125 in d2f5f27
| const int warp_mtp_idx = warpid / (4 / MTP_PARALLEL_THREADS); | |
| const int warp_row_idx = warpid % (4 / MTP_PARALLEL_THREADS); |
The value of MTP_PARALLEL_THREADS is computed as:
aiter/csrc/cpp_itfs/pa/pa_kernels.cuh
Lines 57 to 60 in d2f5f27
| constexpr int MAX_ELEMENTS_PER_QUERY = DIVIDE_ROUND_UP(16, GQA_RATIO); | |
| constexpr int MTP_PER_THREAD = DIVIDE_ROUND_UP(MTP, MAX_ELEMENTS_PER_QUERY); | |
| constexpr int MTP_PARALLEL_THREADS = MTP / MTP_PER_THREAD; |
If MTP_PARALLEL_THREADS > 4 (which can happen with large MTP values and certain GQA ratios),
then 4 / MTP_PARALLEL_THREADS = 0 due to integer division, causing division by zero and undefined behavior or a kernel crash.
Operating System
Ubuntu 22.04.5 LTS (Jammy Jellyfish)
CPU
AMD EPYC 9575F 64-Core Processor
GPU
8 x AMD Instinct MI355X
ROCm Version
ROCm version: 7.0.51831-a3e329ad8
ROCm Component
No response
Steps to Reproduce
Replicator script:
paged_attention_v1.issue2.mi355.py
Logs from running replicator script:
paged_attention_v1.issue2.mi355.log
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
My goal is to enable CUDA graph support for speculative decoding in the ROCM_AITER_FA attention backend by adding spec_as_decode support and changing _cudagraph_support from UNIFORM_SINGLE_TOKEN_DECODE to UNIFORM_BATCH. This is expected to speed up speculative decoding: removes graph breaks and allows routing speculative tokens through the more efficient decode path instead of the extend path. The AITER PA API appears to support MTP>1, but the implementation seems to be buggy.