From 13816eaee65f84027861476642e2ec1d40490b36 Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Fri, 9 May 2025 10:02:45 +0000 Subject: [PATCH 1/4] mla problem --- problems/amd/mla/README.md | 46 ++++ problems/amd/mla/reference.py | 98 +++++++++ problems/amd/mla/rotary_embedding.py | 307 +++++++++++++++++++++++++++ problems/amd/mla/submission.py | 13 ++ problems/amd/mla/task.py | 17 ++ problems/amd/mla/task.yml | 29 +++ 6 files changed, 510 insertions(+) create mode 100644 problems/amd/mla/README.md create mode 100644 problems/amd/mla/reference.py create mode 100644 problems/amd/mla/rotary_embedding.py create mode 100644 problems/amd/mla/submission.py create mode 100644 problems/amd/mla/task.py create mode 100644 problems/amd/mla/task.yml diff --git a/problems/amd/mla/README.md b/problems/amd/mla/README.md new file mode 100644 index 0000000..bf75a36 --- /dev/null +++ b/problems/amd/mla/README.md @@ -0,0 +1,46 @@ +# Description + +You will implement a custom mla decode kernel optimized for MI300, a few things simplified here: + +1. Q, K, V data type as bfloat16 + +2. provide Q, K, V hidden states directly, no Q, K, V up/down projections + +3. decode only with pre-allocated non-paged latent kv cache + +4. no need to update kv cache + +5. no need to implement RoPE in mla kernel, we only show its implementation in ref kernel + +The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1, and split number of heads to fit in one GPU. +To be explicit, you will be given a tuple to tensors: + +```yml +q_input [B, q_seqlen, h_q, kv_lora_rank + qk_rope_hd] +k_input [B, kv_seqlen, 1, kv_lora_rank + qk_rope_hd] +v_input [B, kv_seqlen, 1, kv_lora_rank] +attn_output [B, q_seqlen, h_q, kv_lora_rank] +``` + +where + +0. B::128 # batch size +1. kv_seqlen [1024, 6144] +2. q_seqlen:: 1 # as only consider decoding +3. qk_nope_head_dim:: 512 +4. qk_rope_hd:: 64 +5. kv_lora_rank(v_head_dim):: 512 +6. h_q:: 128 # num of q heads +7. h_kv:: 1 # as it's mla, kv head is 1 + + +The ranking criteria is the geometric mean of the benchmark results. + +For the grand price, your kernel will be evaluated against the speed of light analysis +and the solution closest to the speed of light will be awarded the grand price. + +aiter performance for different kv_seqlen is below: +| batch | kv_seqlen | q_seqlen | dtype | aiter time(us) | +|---|---|---|---|---| +| 128 | 1024 | 1 | bf16 | 152.52 | +| 128 | 6144 | 1 | bf16 | 640.57 | \ No newline at end of file diff --git a/problems/amd/mla/reference.py b/problems/amd/mla/reference.py new file mode 100644 index 0000000..3ec9954 --- /dev/null +++ b/problems/amd/mla/reference.py @@ -0,0 +1,98 @@ +import torch +import math +import random +from task import input_t, output_t +from rotary_embedding import DeepseekScalingRotaryEmbedding +from utils import make_match_reference + +def generate_input(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, seed): + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + ) + + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = math.ceil(max_seqlen/256) * 256 + + gen = torch.Generator() + gen.manual_seed(seed) + + q = torch.randn((b, s_q, h_q, d), dtype=torch.bfloat16, generator=gen) + k = torch.randn((b, max_seqlen_pad, h_kv, d), dtype=torch.bfloat16, generator=gen) + v = torch.randn((b, max_seqlen_pad, h_kv, dv), dtype=torch.bfloat16, generator=gen) + positions = ( + torch.tensor([s_q], device=q.device).unsqueeze(0).repeat(b, 1) + ) # only gen 1 token per req + return q, k, v, cache_seqlens, max_seqlen_pad, positions, causal + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value + +def ref_kernel(data: input_t, use_rope=False) -> output_t: + """ + q shape: batch_size, q_seqlen, h_q, d + k shape: batch_size, max_seqlen_pad, h_kv, d + v shape: batch_size, max_seqlen_pad, h_kv, d_v + """ + q, k, v, cache_seqlens, max_seqlen_pad, positions, causal = data + b, s_q, h_q, d = q.shape + _, _, h_kv, dv = v.shape + rope_head_dim = d - dv + rotary_dim = rope_head_dim + rope_max_seq_len=16324 + rope_base=1000 + rope_scaling=1.0 + is_neox_style=True + rotary_emb = DeepseekScalingRotaryEmbedding( + rope_head_dim, + rotary_dim, + rope_max_seq_len, + rope_base, + is_neox_style, + rope_scaling, + q.dtype, + device=q.device) + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + ik = k.view(-1, h_kv, d)[begin:end] + iv = v.view(-1, h_kv, dv)[begin:end] + iq = q[i] + if use_rope: + q_nope, q_pe = iq.split([dv, rotary_dim], dim=-1) # [s_q, h_q, d] + k_nope, k_pe = ik.split([dv, rotary_dim], dim=-1) # [s_k, h_kv, d] + q_pe, k_pe = rotary_emb(positions[i], q_pe, k_pe) + iq[..., dv:]=q_pe + ik[..., dv:]=k_pe + O = scaled_dot_product_attention( + iq.transpose(0, 1), + ik.transpose(0, 1), + iv.transpose(0, 1), + h_q=h_q, + h_kv=h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + return out + + +check_implementation = make_match_reference(ref_kernel) \ No newline at end of file diff --git a/problems/amd/mla/rotary_embedding.py b/problems/amd/mla/rotary_embedding.py new file mode 100644 index 0000000..c235157 --- /dev/null +++ b/problems/amd/mla/rotary_embedding.py @@ -0,0 +1,307 @@ +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py +"""Rotary Positional Embeddings.""" +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype, device +) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + device: Optional[str] = "cuda", + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) + / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) + * attn_factor + ) + self.device = device + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) + / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device + ) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=self.device, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] + # (max_seq, 64). 32 sin, 32 cos + cos, sin = cos_sin.chunk(2, dim=-1) + + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key \ No newline at end of file diff --git a/problems/amd/mla/submission.py b/problems/amd/mla/submission.py new file mode 100644 index 0000000..f7bfcce --- /dev/null +++ b/problems/amd/mla/submission.py @@ -0,0 +1,13 @@ +import torch +from task import input_t, output_t +from reference import ref_kernel + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of rope + mla + Args: + data: q, k_cache, v_cache, block_table, cached_seqlens, max_seqlen_pad, rope_positions + Returns: + mla output + """ + return ref_kernel(data) \ No newline at end of file diff --git a/problems/amd/mla/task.py b/problems/amd/mla/task.py new file mode 100644 index 0000000..8db1a57 --- /dev/null +++ b/problems/amd/mla/task.py @@ -0,0 +1,17 @@ +import torch +from typing import TypeVar, TypedDict + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]) +output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) + +class TestSpec(TypedDict): + b: int + s_q: int + mean_sk: int + h_q: int + h_kv: int + d: int + dv: int + causal: bool + var_len: bool + seed: int \ No newline at end of file diff --git a/problems/amd/mla/task.yml b/problems/amd/mla/task.yml new file mode 100644 index 0000000..2b198da --- /dev/null +++ b/problems/amd/mla/task.yml @@ -0,0 +1,29 @@ +# name: mla-py + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + Implement a custom mla decode that matches the reference implementation. + The function should handle a tuple of input tensors and apply fused attention decode calculation + The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1 + +config: + main: "eval.py" + +templates: + Python: "../template.py" + +tests: + - {"b": 128, "s_q": 1, "mean_sk": 1024, "h_q": 128, "h_kv": 1, "d": 576, "dv": 512, "causal": True, "var_len": False, "seed": 97} + - {"b": 128, "s_q": 1, "mean_sk": 6144, "h_q": 128, "h_kv": 1, "d": 576, "dv": 512, "causal": True, "var_len": False, "seed": 97} + +benchmarks: + - {"b": 128, "s_q": 1, "mean_sk": 1024, "h_q": 128, "h_kv": 1, "d": 576, "dv": 512, "causal": True, "var_len": False, "seed": 97} + - {"b": 128, "s_q": 1, "mean_sk": 6144, "h_q": 128, "h_kv": 1, "d": 576, "dv": 512, "causal": True, "var_len": False, "seed": 97} \ No newline at end of file From 3275a45e94decabe5eb2183b1940f7ea4a9b50ce Mon Sep 17 00:00:00 2001 From: danielhua23 Date: Fri, 9 May 2025 10:44:07 +0000 Subject: [PATCH 2/4] fix typo --- problems/amd/mla/submission.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/problems/amd/mla/submission.py b/problems/amd/mla/submission.py index f7bfcce..b2b2c1d 100644 --- a/problems/amd/mla/submission.py +++ b/problems/amd/mla/submission.py @@ -4,9 +4,9 @@ def custom_kernel(data: input_t) -> output_t: """ - Reference implementation of rope + mla + Reference implementation of mla without RoPE Args: - data: q, k_cache, v_cache, block_table, cached_seqlens, max_seqlen_pad, rope_positions + data: b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, seed Returns: mla output """ From c7d2183205c239590ef58bd94e315ac24343be86 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 9 May 2025 16:35:12 -0700 Subject: [PATCH 3/4] Update MLA implementation --- problems/amd/mla/README.md | 46 --- problems/amd/mla/reference.py | 421 +++++++++++++++++++++++---- problems/amd/mla/rotary_embedding.py | 307 ------------------- problems/amd/mla/submission.py | 386 +++++++++++++++++++++++- problems/amd/mla/task.py | 23 +- problems/amd/mla/task.yml | 58 +++- problems/amd/mla/template.py | 14 + 7 files changed, 826 insertions(+), 429 deletions(-) delete mode 100644 problems/amd/mla/README.md delete mode 100644 problems/amd/mla/rotary_embedding.py create mode 100644 problems/amd/mla/template.py diff --git a/problems/amd/mla/README.md b/problems/amd/mla/README.md deleted file mode 100644 index bf75a36..0000000 --- a/problems/amd/mla/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# Description - -You will implement a custom mla decode kernel optimized for MI300, a few things simplified here: - -1. Q, K, V data type as bfloat16 - -2. provide Q, K, V hidden states directly, no Q, K, V up/down projections - -3. decode only with pre-allocated non-paged latent kv cache - -4. no need to update kv cache - -5. no need to implement RoPE in mla kernel, we only show its implementation in ref kernel - -The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1, and split number of heads to fit in one GPU. -To be explicit, you will be given a tuple to tensors: - -```yml -q_input [B, q_seqlen, h_q, kv_lora_rank + qk_rope_hd] -k_input [B, kv_seqlen, 1, kv_lora_rank + qk_rope_hd] -v_input [B, kv_seqlen, 1, kv_lora_rank] -attn_output [B, q_seqlen, h_q, kv_lora_rank] -``` - -where - -0. B::128 # batch size -1. kv_seqlen [1024, 6144] -2. q_seqlen:: 1 # as only consider decoding -3. qk_nope_head_dim:: 512 -4. qk_rope_hd:: 64 -5. kv_lora_rank(v_head_dim):: 512 -6. h_q:: 128 # num of q heads -7. h_kv:: 1 # as it's mla, kv head is 1 - - -The ranking criteria is the geometric mean of the benchmark results. - -For the grand price, your kernel will be evaluated against the speed of light analysis -and the solution closest to the speed of light will be awarded the grand price. - -aiter performance for different kv_seqlen is below: -| batch | kv_seqlen | q_seqlen | dtype | aiter time(us) | -|---|---|---|---|---| -| 128 | 1024 | 1 | bf16 | 152.52 | -| 128 | 6144 | 1 | bf16 | 640.57 | \ No newline at end of file diff --git a/problems/amd/mla/reference.py b/problems/amd/mla/reference.py index 3ec9954..a7662ca 100644 --- a/problems/amd/mla/reference.py +++ b/problems/amd/mla/reference.py @@ -1,98 +1,415 @@ import torch +import torch.nn as nn import math import random from task import input_t, output_t -from rotary_embedding import DeepseekScalingRotaryEmbedding from utils import make_match_reference +from typing import Optional, Tuple, Union -def generate_input(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, seed): - print( - f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py +## We provide the implementation of the rotary embedding here, you do not need to modify this section +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) ) - cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) - if varlen: - for i in range(b): - cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype, device +) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + device: Optional[str] = "cuda", + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) + / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) + * attn_factor + ) + self.device = device + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) + / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device + ) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=self.device, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] + # (max_seq, 64). 32 sin, 32 cos + cos, sin = cos_sin.chunk(2, dim=-1) + + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key +## End of the implementation of the rotary embedding + + +def generate_input(b, d, dv, hq, sq, hkv, meansk, seed): + + cache_seqlens = torch.full((b,), meansk, dtype=torch.int32) max_seqlen = cache_seqlens.max().item() max_seqlen_pad = math.ceil(max_seqlen/256) * 256 gen = torch.Generator() gen.manual_seed(seed) - - q = torch.randn((b, s_q, h_q, d), dtype=torch.bfloat16, generator=gen) - k = torch.randn((b, max_seqlen_pad, h_kv, d), dtype=torch.bfloat16, generator=gen) - v = torch.randn((b, max_seqlen_pad, h_kv, dv), dtype=torch.bfloat16, generator=gen) - positions = ( - torch.tensor([s_q], device=q.device).unsqueeze(0).repeat(b, 1) - ) # only gen 1 token per req - return q, k, v, cache_seqlens, max_seqlen_pad, positions, causal - -def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): - query = query.float() - key = key.float() - value = value.float() - key = key.repeat_interleave(h_q // h_kv, dim=0) - value = value.repeat_interleave(h_q // h_kv, dim=0) - attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + + q = torch.randn((b, sq, hq, d), dtype=torch.bfloat16, generator=gen) + k = torch.randn((b, max_seqlen_pad, hkv, d), dtype=torch.bfloat16, generator=gen) + v = torch.randn((b, max_seqlen_pad, hkv, dv), dtype=torch.bfloat16, generator=gen) + positions = torch.tensor([sq], device=q.device).unsqueeze(0).repeat(b, 1) # only gen 1 token per req + return q, k, v, cache_seqlens, max_seqlen_pad, positions + +def scaled_dot_product_attention(query, key, value, hq, hkv, is_causal=False): + # Convert to higher precision for numerical stability + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + + # Multi-query attention pattern: repeat the keys and values for each query head + key = key.repeat_interleave(hq // hkv, dim=0) + value = value.repeat_interleave(hq // hkv, dim=0) + + # Scale dot product attention + scale = 1.0 / math.sqrt(query.size(-1)) + attn_weight = torch.matmul(query, key.transpose(-2, -1)) * scale + + # Apply causal mask if needed if is_causal: - s_q = query.shape[-2] - s_k = key.shape[-2] - attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + sq = query.shape[-2] + sk = key.shape[-2] + attn_bias = torch.zeros(sq, sk, dtype=torch.float32, device=query.device) + temp_mask = torch.ones(sq, sk, dtype=torch.bool, device=query.device).tril(diagonal=sk - sq) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) - return attn_weight @ value + + # Apply softmax for attention weights + attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) + + # Apply attention to values + out = torch.matmul(attn_weight, value) + + return out def ref_kernel(data: input_t, use_rope=False) -> output_t: """ - q shape: batch_size, q_seqlen, h_q, d - k shape: batch_size, max_seqlen_pad, h_kv, d - v shape: batch_size, max_seqlen_pad, h_kv, d_v + q shape: batch_size, q_seqlen, hq, d + k shape: batch_size, max_seqlen_pad, hkv, d + v shape: batch_size, max_seqlen_pad, hkv, d_v + cache_seqlens: tensor containing the actual sequence lengths + max_seqlen_pad: the padded sequence length + positions: tensor containing position information for RoPE """ - q, k, v, cache_seqlens, max_seqlen_pad, positions, causal = data - b, s_q, h_q, d = q.shape - _, _, h_kv, dv = v.shape - rope_head_dim = d - dv - rotary_dim = rope_head_dim + q, k, v, cache_seqlens, max_seqlen_pad, positions = data + causal = False # Default value as it's not provided in the input + b, sq, hq, d = q.shape + _, _, hkv, dv = v.shape + rope_head_dim = d - dv + rotary_dim = rope_head_dim rope_max_seq_len=16324 rope_base=1000 rope_scaling=1.0 is_neox_style=True rotary_emb = DeepseekScalingRotaryEmbedding( - rope_head_dim, - rotary_dim, - rope_max_seq_len, - rope_base, + rope_head_dim, + rotary_dim, + rope_max_seq_len, + rope_base, is_neox_style, rope_scaling, - q.dtype, + q.dtype, device=q.device) - out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + out = torch.empty(b, sq, hq, dv, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] - ik = k.view(-1, h_kv, d)[begin:end] - iv = v.view(-1, h_kv, dv)[begin:end] + ik = k.view(-1, hkv, d)[begin:end] + iv = v.view(-1, hkv, dv)[begin:end] iq = q[i] if use_rope: - q_nope, q_pe = iq.split([dv, rotary_dim], dim=-1) # [s_q, h_q, d] - k_nope, k_pe = ik.split([dv, rotary_dim], dim=-1) # [s_k, h_kv, d] + q_nope, q_pe = iq.split([dv, rotary_dim], dim=-1) # [sq, hq, d] + k_nope, k_pe = ik.split([dv, rotary_dim], dim=-1) # [sk, hkv, d] q_pe, k_pe = rotary_emb(positions[i], q_pe, k_pe) iq[..., dv:]=q_pe - ik[..., dv:]=k_pe + ik[..., dv:]=k_pe O = scaled_dot_product_attention( iq.transpose(0, 1), ik.transpose(0, 1), iv.transpose(0, 1), - h_q=h_q, - h_kv=h_kv, + hq=hq, + hkv=hkv, is_causal=causal, ) out[i] = O.transpose(0, 1) return out + + check_implementation = make_match_reference(ref_kernel) \ No newline at end of file diff --git a/problems/amd/mla/rotary_embedding.py b/problems/amd/mla/rotary_embedding.py deleted file mode 100644 index c235157..0000000 --- a/problems/amd/mla/rotary_embedding.py +++ /dev/null @@ -1,307 +0,0 @@ -# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py -"""Rotary Positional Embeddings.""" -import math -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -def _rotate_neox(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - - return torch.cat((-x2, x1), dim=-1) - - -def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., ::2] - x2 = x[..., 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) - - -def _apply_rotary_emb( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -class RotaryEmbedding(nn.Module): - """Original rotary positional embedding.""" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.is_neox_style = is_neox_style - self.dtype = dtype - - cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) - self.cos_sin_cache: torch.Tensor - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - """Compute the inverse frequency.""" - # NOTE(woosuk): To exactly match the HF implementation, we need to - # use CPU to compute the cache and then move it to GPU. However, we - # create the cache on GPU for faster initialization. This may cause - # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim - ) - ) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward_native( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - - def extra_repr(self) -> str: - s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" - s += f", max_position_embeddings={self.max_position_embeddings}" - s += f", base={self.base}, is_neox_style={self.is_neox_style}" - return s - - -# Inverse dim formula to find dim based on number of rotations -def _yarn_find_correction_dim( - num_rotations: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048, -) -> float: - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) - - -# Find dim range bounds based on rotations -def _yarn_find_correction_range( - low_rot: int, - high_rot: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048, -) -> Tuple[int, int]: - low = math.floor( - _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def _yarn_linear_ramp_mask( - low: float, high: float, dim: int, dtype: torch.dtype, device -) -> torch.Tensor: - if low == high: - high += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -class DeepseekScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with YaRN method. - - Credits to Peng et al. github.com/jquesnelle/yarn - """ - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - *, - extrapolation_factor: float = 1, - attn_factor: float = 1, - beta_fast: int = 32, - beta_slow: int = 1, - mscale: float = 1, - mscale_all_dim: float = 0, - device: Optional[str] = "cuda", - ) -> None: - self.scaling_factor = scaling_factor - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - # Get n-d magnitude scaling corrected for interpolation. - self.mscale = float( - yarn_get_mscale(self.scaling_factor, float(mscale)) - / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) - * attn_factor - ) - self.device = device - super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype - ) - - def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) - / self.rotary_dim - ) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - - low, high = _yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - self.rotary_dim, - self.base, - self.max_position_embeddings, - ) - # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - - _yarn_linear_ramp_mask( - low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device - ) - ) * self.extrapolation_factor - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_mask) - + inv_freq_extrapolation * inv_freq_mask - ) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange( - self.max_position_embeddings * self.scaling_factor, - device=self.device, - dtype=torch.float32, - ) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * self.mscale - sin = freqs.sin() * self.mscale - cache = torch.cat((cos, sin), dim=-1) - - return cache - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """PyTorch-native implementation equivalent to forward().""" - query_rot = query[..., : self.rotary_dim] - key_rot = key[..., : self.rotary_dim] - if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim :] - key_pass = key[..., self.rotary_dim :] - - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) - cos_sin = self.cos_sin_cache[ - torch.add(positions, offsets) if offsets is not None else positions - ] - # (max_seq, 64). 32 sin, 32 cos - cos, sin = cos_sin.chunk(2, dim=-1) - - if self.is_neox_style: - # NOTE(woosuk): Here we assume that the positions tensor has the - # shape [batch_size, seq_len]. - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - - rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj - query_rot = query_rot * cos + rotate_fn(query_rot) * sin - key_rot = key_rot * cos + rotate_fn(key_rot) * sin - - if self.rotary_dim < self.head_size: - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - else: - query = query_rot - key = key_rot - return query, key \ No newline at end of file diff --git a/problems/amd/mla/submission.py b/problems/amd/mla/submission.py index b2b2c1d..58247b7 100644 --- a/problems/amd/mla/submission.py +++ b/problems/amd/mla/submission.py @@ -1,13 +1,391 @@ import torch +import torch.nn as nn from task import input_t, output_t -from reference import ref_kernel +import math +from typing import Optional, Tuple, Union -def custom_kernel(data: input_t) -> output_t: +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py +## We provide the implementation of the rotary embedding here, you do not need to modify this section +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype, device +) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + device: Optional[str] = "cuda", + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) + / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) + * attn_factor + ) + self.device = device + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) + / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device + ) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=self.device, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] + # (max_seq, 64). 32 sin, 32 cos + cos, sin = cos_sin.chunk(2, dim=-1) + + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key +## End of the implementation of the rotary embedding + +def scaled_dot_product_attention(query, key, value, hq, hkv, is_causal=False): + # Convert to higher precision for numerical stability + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + + # Multi-query attention pattern: repeat the keys and values for each query head + key = key.repeat_interleave(hq // hkv, dim=0) + value = value.repeat_interleave(hq // hkv, dim=0) + + # Scale dot product attention + scale = 1.0 / math.sqrt(query.size(-1)) + attn_weight = torch.matmul(query, key.transpose(-2, -1)) * scale + + # Apply causal mask if needed + if is_causal: + sq = query.shape[-2] + sk = key.shape[-2] + attn_bias = torch.zeros(sq, sk, dtype=torch.float32, device=query.device) + temp_mask = torch.ones(sq, sk, dtype=torch.bool, device=query.device).tril(diagonal=sk - sq) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_weight += attn_bias + + # Apply softmax for attention weights + attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) + + # Apply attention to values + out = torch.matmul(attn_weight, value) + + return out + +def custom_kernel(data: input_t, use_rope=False) -> output_t: """ Reference implementation of mla without RoPE Args: - data: b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, seed + data: q, k, v, cache_seqlens, max_seqlen_pad, positions Returns: mla output """ - return ref_kernel(data) \ No newline at end of file + + q, k, v, cache_seqlens, max_seqlen_pad, positions = data + causal = False # Default value as it's not provided in the input + b, sq, hq, d = q.shape + _, _, hkv, dv = v.shape + rope_head_dim = d - dv + rotary_dim = rope_head_dim + rope_max_seq_len=16324 + rope_base=1000 + rope_scaling=1.0 + is_neox_style=True + rotary_emb = DeepseekScalingRotaryEmbedding( + rope_head_dim, + rotary_dim, + rope_max_seq_len, + rope_base, + is_neox_style, + rope_scaling, + q.dtype, + device=q.device) + out = torch.empty(b, sq, hq, dv, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + ik = k.view(-1, hkv, d)[begin:end] + iv = v.view(-1, hkv, dv)[begin:end] + iq = q[i] + if use_rope: + q_nope, q_pe = iq.split([dv, rotary_dim], dim=-1) # [sq, hq, d] + k_nope, k_pe = ik.split([dv, rotary_dim], dim=-1) # [sk, hkv, d] + q_pe, k_pe = rotary_emb(positions[i], q_pe, k_pe) + iq[..., dv:]=q_pe + ik[..., dv:]=k_pe + O = scaled_dot_product_attention( + iq.transpose(0, 1), + ik.transpose(0, 1), + iv.transpose(0, 1), + hq=hq, + hkv=hkv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + return out \ No newline at end of file diff --git a/problems/amd/mla/task.py b/problems/amd/mla/task.py index 8db1a57..e83654f 100644 --- a/problems/amd/mla/task.py +++ b/problems/amd/mla/task.py @@ -1,17 +1,16 @@ import torch from typing import TypeVar, TypedDict -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]) -output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) +# Define test spec with parameters in the same order as in task.yml class TestSpec(TypedDict): - b: int - s_q: int - mean_sk: int - h_q: int - h_kv: int - d: int - dv: int - causal: bool - var_len: bool - seed: int \ No newline at end of file + b: int # batch size + d: int # dimension + dv: int # value dimension + hq: int # number of query heads + sq: int # query sequence length + hkv: int # number of key/value heads + meansk: int # mean kv sequence length + seed: int # random seed \ No newline at end of file diff --git a/problems/amd/mla/task.yml b/problems/amd/mla/task.yml index 2b198da..49b93e4 100644 --- a/problems/amd/mla/task.yml +++ b/problems/amd/mla/task.yml @@ -10,20 +10,62 @@ files: lang: "py" description: | - Implement a custom mla decode that matches the reference implementation. - The function should handle a tuple of input tensors and apply fused attention decode calculation - The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1 + You will implement a custom mla decode kernel optimized for MI300, a few things simplified here: + + 1. Q, K, V data type as bfloat16 + 2. provide Q, K, V hidden states directly, no Q, K, V up/down projections + 3. decode only with pre-allocated non-paged latent kv cache + 4. no need to update kv cache + 5. no need to implement RoPE in mla kernel, we only show its implementation in ref kernel + + The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1, and split number of heads to fit in one GPU. + To be explicit, you will be given a tuple to tensors: + + ```yaml + q_input [B, q_seqlen, h_q, kv_lora_rank + qk_rope_hd] + k_input [B, kv_seqlen, 1, kv_lora_rank + qk_rope_hd] + v_input [B, kv_seqlen, 1, kv_lora_rank] + attn_output [B, q_seqlen, h_q, kv_lora_rank] + ``` + + where + + 0. B::128 # batch size + 1. kv_seqlen [1024, 6144] + 2. q_seqlen:: 1 # as only consider decoding + 3. qk_nope_head_dim:: 512 + 4. qk_rope_hd:: 64 + 5. kv_lora_rank(v_head_dim):: 512 + 6. h_q:: 128 # num of q heads + 7. h_kv:: 1 # as it's mla, kv head is 1 + + The ranking criteria is the geometric mean of the benchmark results. + + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + + aiter performance for different kv_seqlen is below: + | batch | kv_seqlen | q_seqlen | dtype | aiter time(us) | + |---|---|---|---|---| + | 128 | 1024 | 1 | bf16 | 152.52 | + | 128 | 6144 | 1 | bf16 | 640.57 | config: main: "eval.py" templates: - Python: "../template.py" + Python: "template.py" + +test_timeout: 900 +benchmark_timeout: 900 +ranked_timeout: 1200 tests: - - {"b": 128, "s_q": 1, "mean_sk": 1024, "h_q": 128, "h_kv": 1, "d": 576, "dv": 512, "causal": True, "var_len": False, "seed": 97} - - {"b": 128, "s_q": 1, "mean_sk": 6144, "h_q": 128, "h_kv": 1, "d": 576, "dv": 512, "causal": True, "var_len": False, "seed": 97} + - {"b": 128, "d": 576, "dv": 512, "hq": 128, "sq": 1, "hkv": 1, "meansk": 1024, "seed": 97} + - {"b": 128, "d": 576, "dv": 512, "hq": 128, "sq": 1, "hkv": 1, "meansk": 6144, "seed": 97} benchmarks: - - {"b": 128, "s_q": 1, "mean_sk": 1024, "h_q": 128, "h_kv": 1, "d": 576, "dv": 512, "causal": True, "var_len": False, "seed": 97} - - {"b": 128, "s_q": 1, "mean_sk": 6144, "h_q": 128, "h_kv": 1, "d": 576, "dv": 512, "causal": True, "var_len": False, "seed": 97} \ No newline at end of file + - {"b": 128, "d": 576, "dv": 512, "hq": 128, "sq": 1, "hkv": 1, "meansk": 1024, "seed": 97} + - {"b": 128, "d": 576, "dv": 512, "hq": 128, "sq": 1, "hkv": 1, "meansk": 6144, "seed": 97} + +ranking_by: "geom" \ No newline at end of file diff --git a/problems/amd/mla/template.py b/problems/amd/mla/template.py new file mode 100644 index 0000000..e5abde3 --- /dev/null +++ b/problems/amd/mla/template.py @@ -0,0 +1,14 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Copies the contents of `input` into `output` + Args: + data: tuple of (input, output) tensors + + Returns: output tensor + """ + input, output = data + # implement processing + return output From b97035fca58db743dd74a9ffe7d179251f6d3f8e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 9 May 2025 16:53:36 -0700 Subject: [PATCH 4/4] minor fixes --- problems/amd/identity/template.py | 14 -------------- problems/amd/mla/reference.py | 14 ++------------ problems/amd/mla/submission.py | 14 ++------------ 3 files changed, 4 insertions(+), 38 deletions(-) delete mode 100644 problems/amd/identity/template.py diff --git a/problems/amd/identity/template.py b/problems/amd/identity/template.py deleted file mode 100644 index e5abde3..0000000 --- a/problems/amd/identity/template.py +++ /dev/null @@ -1,14 +0,0 @@ -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - """ - Copies the contents of `input` into `output` - Args: - data: tuple of (input, output) tensors - - Returns: output tensor - """ - input, output = data - # implement processing - return output diff --git a/problems/amd/mla/reference.py b/problems/amd/mla/reference.py index a7662ca..d67b4b0 100644 --- a/problems/amd/mla/reference.py +++ b/problems/amd/mla/reference.py @@ -327,16 +327,9 @@ def generate_input(b, d, dv, hq, sq, hkv, meansk, seed): return q, k, v, cache_seqlens, max_seqlen_pad, positions def scaled_dot_product_attention(query, key, value, hq, hkv, is_causal=False): - # Convert to higher precision for numerical stability - query = query.to(torch.float32) - key = key.to(torch.float32) - value = value.to(torch.float32) - - # Multi-query attention pattern: repeat the keys and values for each query head key = key.repeat_interleave(hq // hkv, dim=0) value = value.repeat_interleave(hq // hkv, dim=0) - # Scale dot product attention scale = 1.0 / math.sqrt(query.size(-1)) attn_weight = torch.matmul(query, key.transpose(-2, -1)) * scale @@ -349,10 +342,7 @@ def scaled_dot_product_attention(query, key, value, hq, hkv, is_causal=False): attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_weight += attn_bias - # Apply softmax for attention weights attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) - - # Apply attention to values out = torch.matmul(attn_weight, value) return out @@ -367,7 +357,7 @@ def ref_kernel(data: input_t, use_rope=False) -> output_t: positions: tensor containing position information for RoPE """ q, k, v, cache_seqlens, max_seqlen_pad, positions = data - causal = False # Default value as it's not provided in the input + causal = False b, sq, hq, d = q.shape _, _, hkv, dv = v.shape rope_head_dim = d - dv @@ -385,7 +375,7 @@ def ref_kernel(data: input_t, use_rope=False) -> output_t: rope_scaling, q.dtype, device=q.device) - out = torch.empty(b, sq, hq, dv, dtype=torch.float32) + out = torch.empty(b, sq, hq, dv, dtype=q.dtype) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] diff --git a/problems/amd/mla/submission.py b/problems/amd/mla/submission.py index 58247b7..6f280bc 100644 --- a/problems/amd/mla/submission.py +++ b/problems/amd/mla/submission.py @@ -308,16 +308,9 @@ def forward( ## End of the implementation of the rotary embedding def scaled_dot_product_attention(query, key, value, hq, hkv, is_causal=False): - # Convert to higher precision for numerical stability - query = query.to(torch.float32) - key = key.to(torch.float32) - value = value.to(torch.float32) - - # Multi-query attention pattern: repeat the keys and values for each query head key = key.repeat_interleave(hq // hkv, dim=0) value = value.repeat_interleave(hq // hkv, dim=0) - # Scale dot product attention scale = 1.0 / math.sqrt(query.size(-1)) attn_weight = torch.matmul(query, key.transpose(-2, -1)) * scale @@ -330,10 +323,7 @@ def scaled_dot_product_attention(query, key, value, hq, hkv, is_causal=False): attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_weight += attn_bias - # Apply softmax for attention weights attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) - - # Apply attention to values out = torch.matmul(attn_weight, value) return out @@ -348,7 +338,7 @@ def custom_kernel(data: input_t, use_rope=False) -> output_t: """ q, k, v, cache_seqlens, max_seqlen_pad, positions = data - causal = False # Default value as it's not provided in the input + causal = False b, sq, hq, d = q.shape _, _, hkv, dv = v.shape rope_head_dim = d - dv @@ -366,7 +356,7 @@ def custom_kernel(data: input_t, use_rope=False) -> output_t: rope_scaling, q.dtype, device=q.device) - out = torch.empty(b, sq, hq, dv, dtype=torch.float32) + out = torch.empty(b, sq, hq, dv, dtype=q.dtype) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i]