Skip to content

Put the computation of Q and K norm (in attn) into a single CUDA stream, and get a 5% - 8% throughput improvement on Qwen3 4B and Qwen3 - moe 30B - A3B. #4005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
Qwen2ForRewardModel)
from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
from .modeling_qwen_moe import Qwen2MoeForCausalLM
from .modeling_qwen3 import Qwen3ForCausalLM
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
from .modeling_utils import get_model_architecture
from .modeling_vila import VilaModel

Expand All @@ -35,6 +37,8 @@
"VilaModel",
"Qwen2VLModel",
"Qwen2_5_VLModel",
"Qwen3ForCausalLM",
"Qwen3MoeForCausalLM",
]

if transformers.__version__ >= "4.45.1":
Expand Down
241 changes: 241 additions & 0 deletions tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from typing import Optional, Tuple, Union

import torch
from torch import nn
from transformers import Qwen3Config

from tensorrt_llm.functional import PositionEmbeddingType

from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.rms_norm import RMSNorm
from ..pipeline_interface import PipelineInterface
from .modeling_utils import DecoderModel, DecoderModelForCausalLM, register_auto_model

class Qwen3Attention(Attention):

def __init__(
self,
model_config: ModelConfig[Qwen3Config],
layer_idx: Optional[int] = None,
):
config = model_config.pretrained_config
if getattr(config, "rope_scaling", None) is not None:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.from_string(config.rope_scaling["type"]),
rope=RopeParams.from_config(config),
)
else:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
)

super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
config=model_config,
)

self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=True)
self.k_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=True)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]

class Qwen3DecoderLayer(DecoderLayer):

def __init__(
self,
model_config: ModelConfig[Qwen3Config],
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
super().__init__()
self.layer_idx = layer_idx
config = model_config.pretrained_config
self.self_attn = Qwen3Attention(
model_config,
layer_idx=layer_idx,
)

self.mlp = GatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=config.mlp_bias if hasattr(config, "mlp_bias") else False,
dtype=config.torch_dtype,
config=model_config,
)
self.input_layernorm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)
self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype
)

def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)

# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
mrope_config=mrope_config,
**kwargs,
)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)

return hidden_states, residual


class Qwen3Model(DecoderModel):

def __init__(self, model_config: ModelConfig[Qwen3Config]):
super().__init__(model_config)
config = self.model_config
self.padding_idx = config.pretrained_config.pad_token_id

self.embed_tokens = Embedding(
config.pretrained_config.vocab_size,
config.pretrained_config.hidden_size,
dtype=config.pretrained_config.torch_dtype,
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList(
[
Qwen3DecoderLayer(
model_config,
layer_idx,
)
for layer_idx in range(config.pretrained_config.num_hidden_layers)
]
)
self.norm = RMSNorm(
hidden_size=config.pretrained_config.hidden_size,
eps=config.pretrained_config.rms_norm_eps,
dtype=config.pretrained_config.torch_dtype,
)

def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

hidden_states = inputs_embeds

residual = None
for decoder_layer in self.layers:
hidden_states, residual = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
mrope_config=mrope_config,
)

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


@register_auto_model("Qwen3ForCausalLM")
class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]):

def __init__(
self,
model_config: ModelConfig[Qwen3Config],
):
super().__init__(
Qwen3Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
)

# NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pipeline_interface: Optional[PipelineInterface] = None,
return_context_logits: bool = False,
mrope_config: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:

if self._supports_pp and self.pp_size > 1:
output = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
pipeline_interface=pipeline_interface,
mrope_config=mrope_config,
)

# No need to compute logits for non-last PP ranks
if self.pp_rank < self.pp_size - 1:
return output
else:
output = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
mrope_config=mrope_config,
)

return self.logits_processor.forward(
output,
self.lm_head,
attn_metadata,
return_context_logits,
)
Loading