-
Notifications
You must be signed in to change notification settings - Fork 718
Description
Thanks for dsv32 great work!
By analysis the fp8_mqa_logits
and fp8_paged_mqa_logits
function, looks like after the q@k, we don't consider causal masking before topk?
I know the q/k feed to mqa_logits kernel is different from the q/k feed to MLA attention kernel, but we use the output of the mqa_logits kernel (and topk 2048) as indexer into the real MLA's kvcache, hence during the real attention computation we need consider causal, in prefill or decode(MTP) case.
vLLM prefill dispatch using torch.ops._C.top_k_per_row
seems not considering causal, decode dispatch looks like considered causal in MTP case after the logits kernel, before topk.
not sure if it is suppose to let the framework side to do causal before topk, or actually causal is not important during the indexer kernel?