Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 114 additions & 7 deletions flash_mla/flash_mla_interface.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from typing import Optional, Tuple

import torch

import warnings
import flash_mla.cuda as flash_mla_cuda

def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None
*args,
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
Expand All @@ -25,8 +22,118 @@ def get_mla_metadata(
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata(cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk)

if "num_heads_per_head_k" in kwargs:
"""
Handle Deprecation calls (using num_heads_per_head_k)
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
"""
warnings.warn(
"Parameter 'num_heads_per_head_k' is deprecated. Please use 'num_q_tokens_per_head_k' instead.",
DeprecationWarning,
stacklevel=2,
)
num_heads_per_head_k = kwargs.pop("num_heads_per_head_k")

if "num_heads_k" in kwargs:
num_heads_k = kwargs.pop("num_heads_k")
elif len(args) >= 1:
num_heads_k = args[0]
args = args[1:]
else:
raise TypeError(
"Legacy call missing required 'num_heads_k' (position 2 or keyword argument)"
)

if "num_heads_q" in kwargs or "is_fp8_kvcache" in kwargs or "topk" in kwargs:
raise TypeError(
"The legacy call does not support the parameters: (num_heads_q, is_fp8_kvcache, topk). \
If you want to use them, please replace the parameter name 'num_heads_per_head_k' with 'num_q_tokens_per_head_k'"
)

if len(args) > 0 or len(kwargs) > 0:
extra_args = list(args) + list(kwargs.keys())
raise TypeError(
f"Legacy calls do not support extra position parameters: {extra_args}. Legacy parameters only support (cache_seqlens, num_heads_per_head_k, num_heads_k) \
\nIf you want to use (num_heads_q, is_fp8_kvcache, topk), please replace the parameter name 'num_heads_per_head_k' with 'num_q_tokens_per_head_k'"
)

return flash_mla_cuda.get_mla_decoding_metadata(
cache_seqlens, num_heads_per_head_k, num_heads_k, None, False, None
)
else:
"""
Handle v32 calls (using num_q_tokens_per_head_k)
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
num_heads_k: The number of k heads.
num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled
is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to.
"""
if len(args) > 5:
raise TypeError(
f"get_mla_metadata() takes 6 positional arguments but {len(args)+1} were given"
)

if "num_q_tokens_per_head_k" in kwargs:
num_q_tokens_per_head_k = kwargs.pop("num_q_tokens_per_head_k")
elif len(args) >= 1:
num_q_tokens_per_head_k = args[0]
args = args[1:]
else:
raise TypeError(
"get_mla_metadata() missing required 'num_q_tokens_per_head_k' (position 1 or keyword argument)"
)

if len(args) >= 1 and "num_heads_k" not in kwargs:
num_heads_k = args[0]
args = args[1:]
elif "num_heads_k" in kwargs:
num_heads_k = kwargs.pop("num_heads_k")
elif len(args) <= 2:
raise TypeError(
"get_mla_metadata() missing required 'num_heads_k' (position 2 or keyword argument)"
)

num_heads_q: Optional[int] = None
is_fp8_kvcache: bool = False
topk: Optional[int] = None
if len(args) >= 1 and "num_heads_q" not in kwargs:
num_heads_q = args[0]
args = args[1:]
elif "num_heads_q" in kwargs:
num_heads_q = kwargs.pop("num_heads_q")

if len(args) >= 1 and "is_fp8_kvcache" not in kwargs:
is_fp8_kvcache = args[0]
args = args[1:]
elif "is_fp8_kvcache" in kwargs:
is_fp8_kvcache = kwargs.pop("is_fp8_kvcache")

if len(args) >= 1 and "topk" not in kwargs:
topk = args[0]
args = args[1:]
elif "topk" in kwargs:
topk = kwargs.pop("topk")

if len(kwargs) > 0:
raise TypeError(
f"Unrecognized keyword arguments: {list(kwargs.keys())}. Supported parameters: cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk"
)

return flash_mla_cuda.get_mla_decoding_metadata(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
num_heads_q,
is_fp8_kvcache,
topk,
)

def flash_mla_with_kvcache(
q: torch.Tensor,
Expand Down