Skip to content

[QUESTION] dsv32 mqa_logits kernel not considering causal masking?Β #209

@carlushuang

Description

@carlushuang

Thanks for dsv32 great work!

By analysis the fp8_mqa_logits and fp8_paged_mqa_logits function, looks like after the q@k, we don't consider causal masking before topk?

I know the q/k feed to mqa_logits kernel is different from the q/k feed to MLA attention kernel, but we use the output of the mqa_logits kernel (and topk 2048) as indexer into the real MLA's kvcache, hence during the real attention computation we need consider causal, in prefill or decode(MTP) case.

vLLM prefill dispatch using torch.ops._C.top_k_per_row seems not considering causal, decode dispatch looks like considered causal in MTP case after the logits kernel, before topk.

not sure if it is suppose to let the framework side to do causal before topk, or actually causal is not important during the indexer kernel?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions