Skip to content

Commit b3bc1f8

Browse files
vadiklyutiyjuliendenize
authored andcommitted
[PERF] Decouple projections from GDN custom op (vllm-project#27512)
Signed-off-by: Vadim Gimpelson <[email protected]>
1 parent 3dbea51 commit b3bc1f8

File tree

3 files changed

+204
-53
lines changed

3 files changed

+204
-53
lines changed

vllm/config/compilation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ class CompilationConfig:
462462
"vllm::short_conv",
463463
"vllm::linear_attention",
464464
"vllm::plamo2_mamba_mixer",
465-
"vllm::gdn_attention",
465+
"vllm::gdn_attention_core",
466466
"vllm::kda_attention",
467467
"vllm::sparse_attn_indexer",
468468
]

vllm/model_executor/layers/layernorm.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
rms_norm_batch_invariant,
1313
vllm_is_batch_invariant,
1414
)
15+
from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
1516
from vllm.platforms import current_platform
1617
from vllm.utils.torch_utils import direct_register_custom_op
1718

@@ -369,6 +370,107 @@ def forward_cuda(
369370
return self.forward_native(x, residual)
370371

371372

373+
@CustomOp.register("rms_norm_gated")
374+
class RMSNormGated(CustomOp):
375+
"""RMS Normalization with optional gating.
376+
377+
This is a native PyTorch implementation that supports:
378+
- Standard RMS normalization
379+
- Group RMS normalization
380+
- Optional gating with SiLU activation
381+
"""
382+
383+
def __init__(
384+
self,
385+
hidden_size: int,
386+
eps: float = 1e-5,
387+
group_size: int | None = None,
388+
norm_before_gate: bool = False,
389+
device: torch.device | None = None,
390+
dtype: torch.dtype | None = None,
391+
):
392+
"""Initialize RMSNormGated.
393+
394+
Args:
395+
hidden_size: Size of the hidden dimension
396+
eps: Epsilon for numerical stability
397+
group_size: If not None, do GroupNorm with each group
398+
having group_size elements.
399+
group_size=None is equivalent to group_size=hidden_size
400+
(i.e. there's only 1 group).
401+
norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
402+
If False and z is provided: out = norm(x * silu(z))
403+
device: Device to create parameters on
404+
dtype: Data type for parameters
405+
"""
406+
factory_kwargs = {"device": device, "dtype": dtype}
407+
super().__init__()
408+
self.eps = eps
409+
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
410+
self.register_parameter("bias", None)
411+
self.group_size = group_size
412+
self.norm_before_gate = norm_before_gate
413+
self.reset_parameters()
414+
415+
def reset_parameters(self):
416+
torch.nn.init.ones_(self.weight)
417+
418+
def forward_native(
419+
self, x: torch.Tensor, z: torch.Tensor | None = None
420+
) -> torch.Tensor:
421+
"""
422+
Native PyTorch implementation of RMS normalization with gating.
423+
424+
Args:
425+
x: Input tensor
426+
z: Optional gating tensor
427+
428+
Returns:
429+
Normalized (and optionally gated) tensor
430+
431+
If z is not None:
432+
- norm_before_gate=True: out = norm(x) * silu(z)
433+
- norm_before_gate=False: out = norm(x * silu(z))
434+
"""
435+
# Apply gating before normalization if needed
436+
if z is not None and not self.norm_before_gate:
437+
x = x * F.silu(z)
438+
439+
# RMS Normalization
440+
if self.group_size is None:
441+
# Standard RMS norm across the last dimension
442+
variance = x.pow(2).mean(dim=-1, keepdim=True)
443+
x_normed = x * torch.rsqrt(variance + self.eps)
444+
out = x_normed * self.weight
445+
else:
446+
# Group RMS norm
447+
from einops import rearrange
448+
449+
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
450+
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
451+
x_normed = x_group * torch.rsqrt(variance + self.eps)
452+
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
453+
454+
# Apply gating after normalization if needed
455+
if z is not None and self.norm_before_gate:
456+
out = out * F.silu(z)
457+
458+
return out
459+
460+
def forward_cuda(
461+
self, x: torch.Tensor, z: torch.Tensor | None = None
462+
) -> torch.Tensor:
463+
return rmsnorm_fn(
464+
x,
465+
self.weight,
466+
self.bias,
467+
z=z,
468+
eps=self.eps,
469+
group_size=self.group_size,
470+
norm_before_gate=self.norm_before_gate,
471+
)
472+
473+
372474
class LayerNorm(nn.Module):
373475
"""
374476
Layer Normalization.

vllm/model_executor/models/qwen3_next.py

Lines changed: 101 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@
3030
from vllm.forward_context import ForwardContext, get_forward_context
3131
from vllm.logger import init_logger
3232
from vllm.model_executor.layers.fla.ops import (
33-
RMSNormGated,
3433
chunk_gated_delta_rule,
3534
fused_recurrent_gated_delta_rule,
3635
)
3736
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
38-
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm
37+
from vllm.model_executor.layers.layernorm import (
38+
GemmaRMSNorm as Qwen3NextRMSNorm,
39+
)
40+
from vllm.model_executor.layers.layernorm import RMSNormGated
3941
from vllm.model_executor.layers.linear import (
4042
ColumnParallelLinear,
4143
QKVParallelLinear,
@@ -436,17 +438,66 @@ def forward(
436438
hidden_states: torch.Tensor,
437439
output: torch.Tensor,
438440
):
439-
return torch.ops.vllm.gdn_attention(
440-
hidden_states,
441-
output,
441+
"""
442+
Forward pass with three parts:
443+
1. Input projection
444+
2. Core attention (custom op)
445+
3. Output projection
446+
"""
447+
num_tokens = hidden_states.size(0)
448+
449+
# ============================================================
450+
# Part 1: Input Projection
451+
# ============================================================
452+
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
453+
projected_states_ba, _ = self.in_proj_ba(hidden_states)
454+
query, key, value, z, b, a = self.fix_query_key_value_ordering(
455+
projected_states_qkvz, projected_states_ba
456+
)
457+
query, key, value = map(
458+
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
459+
)
460+
mixed_qkv = torch.cat((query, key, value), dim=-1)
461+
462+
# ============================================================
463+
# Part 2: Core Attention (Custom Op)
464+
# ============================================================
465+
core_attn_out = torch.zeros(
466+
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
467+
dtype=hidden_states.dtype,
468+
device=hidden_states.device,
469+
)
470+
471+
torch.ops.vllm.gdn_attention_core(
472+
mixed_qkv,
473+
b,
474+
a,
475+
core_attn_out,
442476
self.prefix,
443477
)
444478

445-
def _forward(
479+
# ============================================================
480+
# Part 3: Output Projection
481+
# ============================================================
482+
z_shape_og = z.shape
483+
# Reshape input data into 2D tensor
484+
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
485+
z = z.reshape(-1, z.shape[-1])
486+
core_attn_out = self.norm(core_attn_out, z)
487+
core_attn_out = core_attn_out.reshape(z_shape_og)
488+
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
489+
output[:num_tokens], _ = self.out_proj(core_attn_out)
490+
491+
def _forward_core(
446492
self,
447-
hidden_states: torch.Tensor,
448-
output: torch.Tensor,
493+
mixed_qkv: torch.Tensor,
494+
b: torch.Tensor,
495+
a: torch.Tensor,
496+
core_attn_out: torch.Tensor,
449497
):
498+
"""
499+
Core attention computation (called by custom op).
500+
"""
450501
forward_context = get_forward_context()
451502
attn_metadata: AttentionMetadata = forward_context.attn_metadata
452503

@@ -471,18 +522,11 @@ def _forward(
471522
num_actual_tokens = attn_metadata.num_actual_tokens
472523
num_accepted_tokens = attn_metadata.num_accepted_tokens
473524

474-
# 1. Set up dimensions for reshapes later
475-
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens])
476-
projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens])
477-
query, key, value, z, b, a = self.fix_query_key_value_ordering(
478-
projected_states_qkvz, projected_states_ba
479-
)
480-
query, key, value = map(
481-
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
482-
)
483-
mixed_qkv = torch.cat((query, key, value), dim=-1)
525+
mixed_qkv = mixed_qkv[:num_actual_tokens]
526+
b = b[:num_actual_tokens]
527+
a = a[:num_actual_tokens]
484528

485-
# 2. Convolution sequence transformation
529+
# 1. Convolution sequence transformation
486530
conv_weights = self.conv1d.weight.view(
487531
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
488532
)
@@ -498,7 +542,7 @@ def _forward(
498542
mixed_qkv_spec = None
499543
mixed_qkv_non_spec = mixed_qkv
500544

501-
# 2.1: process the mutli-query part
545+
# 1.1: Process the multi-query part
502546
if spec_sequence_masks is not None:
503547
mixed_qkv_spec = causal_conv1d_update(
504548
mixed_qkv_spec,
@@ -515,7 +559,7 @@ def _forward(
515559
validate_data=False,
516560
)
517561

518-
# 2.2: process the remaining part
562+
# 1.2: Process the remaining part
519563
if attn_metadata.num_prefills > 0:
520564
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
521565
# - "cache_indices" updates the conv_state cache in positions
@@ -573,9 +617,9 @@ def _forward(
573617
g_non_spec = g
574618
beta_non_spec = beta
575619

576-
# 3. Recurrent attention
620+
# 2. Recurrent attention
577621

578-
# 3.1: process the mutlti-query part
622+
# 2.1: Process the multi-query part
579623
if spec_sequence_masks is not None:
580624
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
581625
q=query_spec,
@@ -593,7 +637,7 @@ def _forward(
593637
else:
594638
core_attn_out_spec, last_recurrent_state = None, None
595639

596-
# 3.2: process the remaining part
640+
# 2.2: Process the remaining part
597641
if attn_metadata.num_prefills > 0:
598642
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
599643
initial_state[~has_initial_state, ...] = 0
@@ -636,30 +680,20 @@ def _forward(
636680
else:
637681
core_attn_out_non_spec, last_recurrent_state = None, None
638682

639-
# Merge core attention output
683+
# 3. Merge core attention output
640684
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
641-
core_attn_out = torch.empty(
685+
merged_out = torch.empty(
642686
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
643687
dtype=core_attn_out_non_spec.dtype,
644688
device=core_attn_out_non_spec.device,
645689
)
646-
core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
647-
core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
648-
690+
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
691+
merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
692+
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
649693
elif spec_sequence_masks is not None:
650-
core_attn_out = core_attn_out_spec
694+
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
651695
else:
652-
core_attn_out = core_attn_out_non_spec
653-
654-
z_shape_og = z.shape
655-
# reshape input data into 2D tensor
656-
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
657-
z = z.reshape(-1, z.shape[-1])
658-
core_attn_out = self.norm(core_attn_out, z)
659-
core_attn_out = core_attn_out.reshape(z_shape_og)
660-
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
661-
662-
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
696+
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
663697

664698

665699
class Qwen3NextAttention(nn.Module):
@@ -1270,29 +1304,44 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
12701304
return self.model.get_expert_mapping()
12711305

12721306

1273-
def gdn_attention(
1274-
hidden_states: torch.Tensor,
1275-
output: torch.Tensor,
1307+
def gdn_attention_core(
1308+
mixed_qkv: torch.Tensor,
1309+
b: torch.Tensor,
1310+
a: torch.Tensor,
1311+
core_attn_out: torch.Tensor,
12761312
layer_name: str,
12771313
) -> None:
1314+
"""
1315+
Custom op for the core attention computation.
1316+
Only handles the convolution + recurrent attention part.
1317+
Input/output projections are handled outside this op.
1318+
"""
12781319
forward_context: ForwardContext = get_forward_context()
12791320
self = forward_context.no_compile_layers[layer_name]
1280-
self._forward(hidden_states=hidden_states, output=output)
1321+
self._forward_core(
1322+
mixed_qkv=mixed_qkv,
1323+
b=b,
1324+
a=a,
1325+
core_attn_out=core_attn_out,
1326+
)
12811327

12821328

1283-
def gdn_attention_fake(
1284-
hidden_states: torch.Tensor,
1285-
output: torch.Tensor,
1329+
def gdn_attention_core_fake(
1330+
mixed_qkv: torch.Tensor,
1331+
b: torch.Tensor,
1332+
a: torch.Tensor,
1333+
core_attn_out: torch.Tensor,
12861334
layer_name: str,
12871335
) -> None:
1336+
"""Fake implementation for torch.compile."""
12881337
return
12891338

12901339

12911340
direct_register_custom_op(
1292-
op_name="gdn_attention",
1293-
op_func=gdn_attention,
1294-
mutates_args=["output"],
1295-
fake_impl=gdn_attention_fake,
1341+
op_name="gdn_attention_core",
1342+
op_func=gdn_attention_core,
1343+
mutates_args=["core_attn_out"],
1344+
fake_impl=gdn_attention_core_fake,
12961345
)
12971346

12981347

0 commit comments

Comments
 (0)