Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,44 @@ impl Qwen3AttentionLayer {
} else {
2usize
};
// Phase A: per-token residual_add + post-attention RMS norm.
// Lays out `norm_output[0..n]` as a contiguous [N, h] MoE input.
for i in 0..n {
let hidden_i = hidden.offset(i * h * residual_elem);
let o_out_i = o_out.offset(i * h * bf16); // BF16 attn output
let residual_i = residual.offset(i * h * residual_elem);
let normed2 = fwd.buffers.norm_output().offset(i * h * bf16);
let normed2_i = fwd.buffers.norm_output().offset(i * h * bf16);
ops::residual_add_rms_norm(
fwd.gpu,
self.residual_add_rms_norm_k,
hidden_i,
o_out_i,
&self.post_attn_norm,
normed2,
normed2_i,
residual_i,
1,
h as u32,
eps,
stream,
)?;
let moe_out = self.ffn.forward(normed2, fwd, stream)?;
}
// Phase B+C: per-token MoE + residual. The generic grouped-GEMM
// (forward_prefill) is a NET LOSS for this 256-expert MoE at
// small batch — per-expert M ~1 and the sort/permute/ptr-table
// overhead dominates (measured: attention block ~40ms vs ~20ms
// per-token at N=4 on GB10). N=2/3 already take the fused
// forward_k2/k3 branches above; this `else` only sees N>=4 (or
// MLA, which must avoid the batched-MoE kernels anyway), so the
// per-token path — identical to decode()'s MoE — is fastest here
// until a true batched-EP MoE kernel exists. Mirrors the SSM
// dispatch in qwen3_ssm/trait_decode_multi_seq.rs. Each forward()
// writes moe_output[0]; consume it immediately before the next
// iteration overwrites it.
let normed_base = fwd.buffers.norm_output();
for i in 0..n {
let hidden_i = hidden.offset(i * h * residual_elem);
let normed2_i = normed_base.offset(i * h * bf16);
let moe_out = self.ffn.forward(normed2_i, fwd, stream)?;
ops::residual_add(
fwd.gpu,
self.residual_add_k,
Expand Down
Loading
Loading