diff --git a/crates/spark-model/src/layers/qwen3_attention/trait_impl/multi_seq/ffn.rs b/crates/spark-model/src/layers/qwen3_attention/trait_impl/multi_seq/ffn.rs index f4674aba..672614ea 100644 --- a/crates/spark-model/src/layers/qwen3_attention/trait_impl/multi_seq/ffn.rs +++ b/crates/spark-model/src/layers/qwen3_attention/trait_impl/multi_seq/ffn.rs @@ -109,25 +109,44 @@ impl Qwen3AttentionLayer { } else { 2usize }; + // Phase A: per-token residual_add + post-attention RMS norm. + // Lays out `norm_output[0..n]` as a contiguous [N, h] MoE input. for i in 0..n { let hidden_i = hidden.offset(i * h * residual_elem); let o_out_i = o_out.offset(i * h * bf16); // BF16 attn output let residual_i = residual.offset(i * h * residual_elem); - let normed2 = fwd.buffers.norm_output().offset(i * h * bf16); + let normed2_i = fwd.buffers.norm_output().offset(i * h * bf16); ops::residual_add_rms_norm( fwd.gpu, self.residual_add_rms_norm_k, hidden_i, o_out_i, &self.post_attn_norm, - normed2, + normed2_i, residual_i, 1, h as u32, eps, stream, )?; - let moe_out = self.ffn.forward(normed2, fwd, stream)?; + } + // Phase B+C: per-token MoE + residual. The generic grouped-GEMM + // (forward_prefill) is a NET LOSS for this 256-expert MoE at + // small batch — per-expert M ~1 and the sort/permute/ptr-table + // overhead dominates (measured: attention block ~40ms vs ~20ms + // per-token at N=4 on GB10). N=2/3 already take the fused + // forward_k2/k3 branches above; this `else` only sees N>=4 (or + // MLA, which must avoid the batched-MoE kernels anyway), so the + // per-token path — identical to decode()'s MoE — is fastest here + // until a true batched-EP MoE kernel exists. Mirrors the SSM + // dispatch in qwen3_ssm/trait_decode_multi_seq.rs. Each forward() + // writes moe_output[0]; consume it immediately before the next + // iteration overwrites it. + let normed_base = fwd.buffers.norm_output(); + for i in 0..n { + let hidden_i = hidden.offset(i * h * residual_elem); + let normed2_i = normed_base.offset(i * h * bf16); + let moe_out = self.ffn.forward(normed2_i, fwd, stream)?; ops::residual_add( fwd.gpu, self.residual_add_k, diff --git a/crates/spark-model/src/layers/qwen3_ssm/trait_decode_multi_seq.rs b/crates/spark-model/src/layers/qwen3_ssm/trait_decode_multi_seq.rs index e4e3e5c2..d1ddfbeb 100644 --- a/crates/spark-model/src/layers/qwen3_ssm/trait_decode_multi_seq.rs +++ b/crates/spark-model/src/layers/qwen3_ssm/trait_decode_multi_seq.rs @@ -6,17 +6,30 @@ use super::*; impl Qwen3SsmLayer { #[allow(clippy::too_many_arguments)] - /// Multi-sequence decode: falls back to per-sequence single decode. + /// Multi-sequence decode for SSM (gated-delta-net) layers. /// - /// The batched SSM path had buffer aliasing bugs (#6) where shared scratch - /// buffers (conv_out, gdn_out, moe_output) corrupted across sequences, - /// producing gibberish (Chinese/multilingual tokens). Instead of debugging - /// every buffer interaction, we delegate to the proven single-sequence - /// decode path which has no aliasing issues. + /// The SSM mixer (conv1d + GDN recurrence + in/out projections) carries + /// independent per-sequence recurrent state, so it runs in a per-seq loop + /// using the SAME single-token kernels as `decode()` (proven correct). The + /// MoE sublayer is stateless and shared across sequences, so it is hoisted + /// OUT of the loop and run ONCE as a batched grouped-GEMM over all N + /// tokens — the same `forward_prefill` path the prefill scheduler and the + /// attention layers' multi-seq path already use. /// - /// Performance impact: negligible — SSM decode is memory-bandwidth-bound - /// and per-sequence GEMV weights stay in L2 cache across iterations. - #[allow(unreachable_code, unused_variables)] + /// This supersedes the earlier "delegate every sequence to the full + /// single-token `decode()`" fallback, which ran N separate single-token + /// MoE forwards (N × top_k expert GEMVs + N per-token all_reduces under + /// EP). Phase B collapses those to one grouped gate+up+down GEMM and one + /// batched all_reduce. + /// + /// Buffer safety (the old bug #6): each per-seq mixer writes its MoE input + /// to `norm_output[i]` — a distinct per-seq offset. `ssm_forward` never + /// touches `norm_output` (verified: 0 references) and its returned + /// `ssm_out` (in `moe_output[0]`) is consumed by the same iteration's + /// `residual_add_rms_norm` before the next iteration runs, so nothing + /// needs to survive across sequences and no aliasing is possible. + /// `forward_prefill` then reads the assembled `norm_output[0..n]` and + /// writes `moe_output[0..n]`. pub(super) fn decode_multi_seq_inner<'a, 'b: 'a>( &self, hidden: DevicePtr, @@ -31,338 +44,123 @@ impl Qwen3SsmLayer { ) -> Result<()> { let h = ctx.config.hidden_size; let bf16 = 2usize; + let eps = ctx.config.rms_norm_eps as f32; + let n = num_seqs; - // CONCURRENT-DECODE BUG FIX: per-seq stride must match the ACTUAL - // hidden/residual element size, not always FP32 (4 bytes). When - // `use_fp32_residual()` is false (BF16 hidden — the default for - // GB10 LPDDR5X bandwidth-limited systems via HARDWARE.toml), the - // hardcoded `i * h * 4` skipped to position 2 of the BF16 buffer - // for `i=1`, leaving seq-1's actual position-1 slice UNTOUCHED by - // every SSM layer. Result: seq 1 only got modifications from the - // attention layers (which use the correct n>=2 batched indexing - // internally), producing the position-specific gibberish that - // reproduced even with identical prompts. Use the same `fp32` - // bytes-per-element computation the dispatcher uses at - // model.rs:4250. + // Per-seq hidden/residual stride must match the ACTUAL residual + // element size: BF16 on GB10 (HARDWARE.toml ATLAS_HW_FP32_RESIDUAL= + // false), FP32 otherwise. A hardcoded `* 4` would over-stride into + // the wrong batch slot for i>=1 on BF16 hidden. let residual_elem = if ctx.config.use_fp32_residual() { 4usize } else { 2usize }; - // Delegate to per-sequence single decode (proven correct, no buffer aliasing). - let mut _stub_disk = Vec::::new(); - let mut _stub_last_offloaded = Vec::::new(); - for i in 0..num_seqs { + // ── Phase A: per-sequence SSM mixer ── + // Pre-norm, SSM mixer (recurrent, per-seq state), post-attn-norm. + // Lays out `norm_output[0..n]` as the contiguous [N, h] BF16 MoE + // input. Identical kernel sequence to `decode()`'s mixer; only the + // MoE is deferred to Phase B. + for i in 0..n { let hidden_i = hidden.offset(i * h * residual_elem); let residual_i = residual.offset(i * h * residual_elem); - self.decode( - hidden_i, - residual_i, - states[i], - _kv_cache, - _seq_lens[i], - &mut _block_tables[i].clone(), - &mut _stub_disk, - &mut _stub_last_offloaded, - ctx, - stream, - )?; - } - return Ok(()); - - // ── Original batched path (disabled — buffer aliasing bug #6) ── - let eps = ctx.config.rms_norm_eps as f32; - let fp32 = 4usize; - let n = num_seqs; - - let nk = ctx.config.linear_num_key_heads; - let kd = ctx.config.linear_key_head_dim; - let nv = ctx.config.linear_num_value_heads; - let vd = ctx.config.linear_value_head_dim; - let vpg = nv / nk; - let key_dim = nk * kd; - let value_dim = nv * vd; - let conv_dim = key_dim * 2 + value_dim; - let d_conv = ctx.config.linear_conv_kernel_dim; - let qkvz_size = ctx.config.ssm_qkvz_size(); - let ba_size = ctx.config.ssm_ba_size(); + let normed_i = ctx.buffers.norm_output().offset(i * h * bf16); - // ── 1. RMS norm + residual for N tokens ── - let normed = ctx.buffers.norm_output(); - ops::rms_norm_residual( - ctx.gpu, - self.rms_norm_residual_k, - hidden, - &self.input_norm, - normed, - residual, - n as u32, - h as u32, - eps, - stream, - )?; - - // ── 2-9. Per-sequence SSM forward + projections ── - // GEMV projections are sequential (weights cached in L2 after first call). - // Conv1d, GDN are per-sequence (independent recurrent state). - let qkvz_out = ctx.buffers.ssm_qkvz(); - let deinterleaved = ctx.buffers.ssm_deinterleaved(); - let gates_buf = ctx.buffers.ssm_gates(); - let conv_out_buf = ctx.buffers.attn_output(); // reuse - let gdn_out_buf = ctx.buffers.qkv_output(); // reuse for GDN output - - for i in 0..n { - let normed_i = normed.offset(i * h * bf16); let ssm_state = states[i] .as_any_mut() .downcast_mut::() .ok_or_else(|| anyhow::anyhow!("Expected SsmLayerState for seq {i}"))?; - // QKVZ projection: GEMV (sequential writes directly to deinterleaved) - let deint_i = deinterleaved.offset(i * qkvz_size * bf16); - if self.sequential_qkvz { - if let Some(ref nvfp4) = self.qkvz_nvfp4 { - ops::w4a16_gemv( - ctx.gpu, - self.w4a16_gemv_k, - normed_i, - nvfp4, - deint_i, - qkvz_size as u32, - h as u32, - stream, - )?; - } else { - ops::dense_gemv( - ctx.gpu, - self.dense_gemv_k, - normed_i, - &self.ssm.in_proj_qkvz, - deint_i, - qkvz_size as u32, - h as u32, - stream, - )?; - } - } else { - let qkvz_i = qkvz_out.offset(i * qkvz_size * bf16); - if let Some(ref nvfp4) = self.qkvz_nvfp4 { - ops::w4a16_gemv( - ctx.gpu, - self.w4a16_gemv_k, - normed_i, - nvfp4, - qkvz_i, - qkvz_size as u32, - h as u32, - stream, - )?; - } else { - ops::dense_gemv( - ctx.gpu, - self.dense_gemv_k, - normed_i, - &self.ssm.in_proj_qkvz, - qkvz_i, - qkvz_size as u32, - h as u32, - stream, - )?; - } - ops::deinterleave_qkvz( - ctx.gpu, - self.deinterleave_k, - qkvz_i, - deint_i, - 1, - nk as u32, - kd as u32, - vpg as u32, - vd as u32, - stream, - )?; - } - - // BA projection + GDN gates - let ba_out = ctx.buffers.ssm_ba().offset(i * ba_size * bf16); - ops::dense_gemv( + // normed_i = rms_norm(hidden_i); residual_i = hidden_i + ops::rms_norm_residual( ctx.gpu, - self.dense_gemv_k, + self.rms_norm_residual_k, + hidden_i, + &self.input_norm, normed_i, - &self.ssm.in_proj_ba, - ba_out, - ba_size as u32, - h as u32, - stream, - )?; - let gate_beta_stride = nv * 2 * fp32; - let gate_i = gates_buf.offset(i * gate_beta_stride); - let beta_i = gates_buf.offset(i * gate_beta_stride + nv * fp32); - ops::compute_gdn_gates( - ctx.gpu, - self.compute_gdn_gates_k, - ba_out, - self.ssm.a_log.weight, - self.ssm.dt_bias.weight, - gate_i, - beta_i, - 1, - nv as u32, - nk as u32, - vpg as u32, - ba_size as u32, - stream, - )?; - - // Conv1d update - let qkv_i = deint_i; - let conv_out_i = conv_out_buf.offset(i * conv_dim * bf16); - ops::conv1d_update( - ctx.gpu, - self.conv1d_k, - ssm_state.conv_state, - qkv_i, - &self.ssm.conv1d, - conv_out_i, - conv_dim as u32, - d_conv as u32, - 1, - stream, - )?; - - // L2 norm on Q,K - ops::l2_norm( - ctx.gpu, - self.l2_norm_k, - conv_out_i, - (nk * 2) as u32, - kd as u32, - 1e-6, - 1, - (nk * 2 * kd) as u32, - stream, - )?; - - // GDN decode - let q_i = conv_out_i; - let k_i = conv_out_i.offset(key_dim * bf16); - let v_i = conv_out_i.offset(key_dim * 2 * bf16); - let gdn_out_i = gdn_out_buf.offset(i * value_dim * bf16); - ops::gdn_decode( - ctx.gpu, - self.gdn_k, - ssm_state.h_state, - q_i, - k_i, - v_i, - gate_i, - beta_i, - gdn_out_i, + residual_i, 1, - nk as u32, - nv as u32, - kd as u32, - vd as u32, - stream, - )?; - - // Gated RMS norm - let z_i = deint_i.offset((key_dim * 2 + value_dim) * bf16); - let normed_ssm_i = conv_out_i; // reuse - ops::gated_rms_norm( - ctx.gpu, - self.gated_rms_norm_k, - gdn_out_i, - z_i, - &self.ssm.norm, - normed_ssm_i, - nv as u32, - vd as u32, - vd as u32, + h as u32, eps, - vd as u32, stream, )?; - // Output projection: GEMV - let ssm_out_i = ctx.buffers.moe_output().offset(i * h * bf16); - if let Some(ref dense_out) = self.out_proj_dense { - ops::dense_gemv( - ctx.gpu, - self.dense_gemv_k, - normed_ssm_i, - dense_out, - ssm_out_i, - h as u32, - value_dim as u32, - stream, - )?; - } else { - ops::w4a16_gemv( - ctx.gpu, - self.w4a16_gemv_k, - normed_ssm_i, - &self.ssm.out_proj, - ssm_out_i, - h as u32, - value_dim as u32, - stream, - )?; - } - } + // SSM mixer: consumes normed_i, returns ssm_out (in moe_output[0]). + let ssm_out = self.ssm_forward(normed_i, ssm_state, ctx, stream, false)?; - // ── 10. Residual + post-norm + MoE per-sequence ── - // Bug #6 fix: copy SSM outputs to a safe buffer before running MoE. - // `self.ffn.forward()` writes its result to `moe_output[0]`, which would - // overwrite seq 1's SSM output at `moe_output[h]` if the MoE internally - // uses the full moe_output region as scratch. By copying SSM outputs to - // `ssm_deinterleaved` (no longer needed after step 9), we decouple them. - let ssm_out_safe = ctx.buffers.ssm_deinterleaved(); // reuse, large enough for n*h - for i in 0..n { - let src = ctx.buffers.moe_output().offset(i * h * bf16); - let dst = ssm_out_safe.offset(i * h * bf16); - ctx.gpu.copy_d2d_async(src, dst, h * bf16, stream)?; - } - // STRIDE FIX (mirrors 2026-04-22 fix at lines 47/57): use dynamic - // residual_elem instead of hardcoded `* 4`. On GB10 hidden states - // are BF16 (2 bytes), not FP32. Hardcoded `i * h * 4` causes - // position 1+ in concurrent batched decode to read/write at WRONG - // offsets, producing either silent gibberish (small N) or CUDA-700 - // illegal memory access (large per-seq offsets exceeding allocated - // buffer region). See project_batch_decode_corruption.md memory. - let residual_elem = if ctx.config.use_fp32_residual() { - 4usize - } else { - 2usize - }; - for i in 0..n { - let hidden_i = hidden.offset(i * h * residual_elem); - let ssm_out_i = ssm_out_safe.offset(i * h * bf16); - let residual_i = residual.offset(i * h * residual_elem); - let normed2 = ctx.buffers.norm_output().offset(i * h * bf16); + // hidden_i += ssm_out; normed_i = rms_norm(hidden_i); residual_i = hidden_i ops::residual_add_rms_norm( ctx.gpu, self.residual_add_rms_norm_k, hidden_i, - ssm_out_i, + ssm_out, &self.post_attn_norm, - normed2, + normed_i, residual_i, 1, h as u32, eps, stream, )?; - let moe_out = self.ffn.forward(normed2, ctx, stream)?; - ops::residual_add( - ctx.gpu, - self.residual_add_k, - hidden_i, - moe_out, - h as u32, - stream, - )?; + } + + // ── Phase B+C: MoE + residual, dispatched by batch size ── + // Measured on GB10 (qwen3.5-122b, 256-expert MoE, EP=2): + // N=2/3: the FUSED batch-2/3 expert kernels (forward_k2/k3) win — + // SSM step 44->36.5ms at N=2 (one batched all_reduce, no + // per-token launch overhead). + // N>=4: the generic grouped-GEMM (forward_prefill) is a NET LOSS + // here — per-expert M ~1, and the expert sort/permute/ptr- + // table overhead (paid once per layer, x36 SSM layers) + // dominates (SSM step ~88ms per-token vs ~140ms grouped). + // So fall back to the per-token MoE loop, identical to + // decode()'s MoE — the fastest option at these sizes until + // a true batched-EP MoE kernel exists. + // Mirrors the attention layers' forward_k2/k3 dispatch + // (qwen3_attention/.../multi_seq/ffn.rs); diverges only in declining + // forward_prefill at N>=4, which that path uses but which loses for + // the 36-layer SSM stack. + let normed_base = ctx.buffers.norm_output(); + match n { + 2 | 3 => { + if n == 2 { + self.ffn.forward_k2(normed_base, ctx, stream)?; + } else { + self.ffn.forward_k3(normed_base, ctx, stream)?; + } + // Batched output lives in moe_output[0..n]. + for i in 0..n { + let hidden_i = hidden.offset(i * h * residual_elem); + let moe_out_i = ctx.buffers.moe_output().offset(i * h * bf16); + ops::residual_add( + ctx.gpu, + self.residual_add_k, + hidden_i, + moe_out_i, + h as u32, + stream, + )?; + } + } + _ => { + // Per-token MoE: each seq's forward() writes moe_output[0]; + // consume it immediately with a per-seq residual add before + // the next iteration overwrites it. + for i in 0..n { + let hidden_i = hidden.offset(i * h * residual_elem); + let normed_i = normed_base.offset(i * h * bf16); + let moe_out = self.ffn.forward(normed_i, ctx, stream)?; + ops::residual_add( + ctx.gpu, + self.residual_add_k, + hidden_i, + moe_out, + h as u32, + stream, + )?; + } + } } Ok(()) diff --git a/crates/spark-model/src/model/impl_a1.rs b/crates/spark-model/src/model/impl_a1.rs index 8d4fc30a..4bb39a0e 100644 --- a/crates/spark-model/src/model/impl_a1.rs +++ b/crates/spark-model/src/model/impl_a1.rs @@ -459,6 +459,7 @@ impl TransformerModel { secondary_event, comm, ep_cmd_buf, + ep_protocol_v2: matches!(std::env::var("ATLAS_EP_PROTOCOL").as_deref(), Ok("v2")), self_speculative, last_mtp_hidden_idx: std::sync::atomic::AtomicUsize::new(0), vision_encoder, diff --git a/crates/spark-model/src/model/impl_a2.rs b/crates/spark-model/src/model/impl_a2.rs index 623b21dc..d2703fb8 100644 --- a/crates/spark-model/src/model/impl_a2.rs +++ b/crates/spark-model/src/model/impl_a2.rs @@ -172,6 +172,86 @@ impl TransformerModel { Ok(min_val) } + /// Broadcast a `(seq_id, cmd)` pair from rank 0 to all ranks. + /// + /// When `v2` is true, this fires a `seq_id` broadcast immediately before + /// the existing `cmd` broadcast. Workers reading the stream pick up the + /// preamble via [`Self::ep_recv_seq_and_cmd`] and route the command to + /// the matching `SequenceState` slot. + /// + /// When `v2` is false, the preamble is skipped and the wire shape is + /// byte-identical to the legacy single-sequence protocol — head and + /// worker built before this change continue to interoperate. + /// + /// Both ranks must agree on `v2` at startup (e.g. via the same env + /// var). Disagreement causes the worker to misread the next u32 as a + /// command code and is the kind of misconfiguration we want to fail + /// loudly in development — there's no graceful fallback. + pub(super) fn ep_broadcast_seq_and_cmd(&self, seq_id: u32, cmd: u32, v2: bool) -> Result<()> { + if v2 { + self.ep_broadcast_u32(seq_id)?; + } + self.ep_broadcast_u32(cmd)?; + Ok(()) + } + + /// Wire-protocol shape for v2 batched decode (`0xFFFFFFE0`): + /// + /// ```text + /// preamble seq_id = 0 (ignored — cmd routes the whole batch) + /// cmd = 0xFFFFFFE0 + /// N (u32) + /// seq_ids[N] (one bulk broadcast) + /// tokens[N] (one bulk broadcast) + /// ``` + /// + /// The matched receive on the worker is `ep_worker_decode_batch` in + /// `ep_worker_step_impl`'s dispatch. Both ranks then call the + /// `decode_batch_compute_main` path which runs the existing batched + /// `decode_multi_seq` per-layer with N tokens — same per-layer NCCL + /// allreduce sequence on both ranks, comm-stream order matches. + /// + /// Caller must hold `self.comm.is_some()` (no-op on world_size=1) and + /// `self.ep_protocol_v2 == true` (without the preamble, the worker + /// would mis-parse the seq_id u32 as a cmd code). Both conditions are + /// guaranteed at the only caller — `decode_batch_dispatch`'s EP + /// branch — but asserted defensively here. + pub(super) fn ep_broadcast_decode_batch_dispatch( + &self, + seq_ids: &[u32], + tokens: &[u32], + ) -> Result<()> { + if !(self.comm.is_some() && self.config.ep_world_size > 1) { + return Ok(()); + } + debug_assert!( + self.ep_protocol_v2, + "ep_broadcast_decode_batch_dispatch called without ATLAS_EP_PROTOCOL=v2" + ); + debug_assert_eq!( + seq_ids.len(), + tokens.len(), + "seq_ids and tokens length mismatch" + ); + self.ep_broadcast_seq_and_cmd(0, 0xFFFFFFE0, true)?; + self.ep_broadcast_u32(seq_ids.len() as u32)?; + self.ep_broadcast_tokens(seq_ids)?; + self.ep_broadcast_tokens(tokens)?; + Ok(()) + } + + /// Receive a `(seq_id, cmd)` pair from rank 0. Worker-side counterpart + /// of [`Self::ep_broadcast_seq_and_cmd`]. + /// + /// With `v2` enabled the returned `seq_id` is the slot the head wants + /// the worker to dispatch the command into; with `v2` disabled the + /// returned `seq_id` is always 0 (the legacy singleton slot). + pub(super) fn ep_recv_seq_and_cmd(&self, v2: bool) -> Result<(u32, u32)> { + let seq_id = if v2 { self.ep_broadcast_u32(0)? } else { 0 }; + let cmd = self.ep_broadcast_u32(0)?; + Ok((seq_id, cmd)) + } + /// Broadcast a u32 command from rank 0 to all ranks. /// Rank 0 writes `val` to GPU buffer and broadcasts. /// Other ranks receive the value and return it. @@ -194,28 +274,94 @@ impl TransformerModel { } } - /// EP worker step: receive a command from rank 0 and execute it. + /// EP worker step: receive a (seq_id, cmd) preamble from rank 0 and + /// execute the command in the addressed slot. /// /// Returns false when the worker should shut down. - /// Protocol: rank 0 broadcasts u32 commands before each model operation: - /// - 0..0xFFFFFFF0: token ID → decode - /// - 0xFFFFFFF0: prefill start → next broadcast = length, then length tokens - /// - 0xFFFFFFF1: free+realloc sequence - /// - 0xFFFFFFF2: verify K=2 → next 2 broadcasts = tokens, then accept/reject - /// - 0xFFFFFFF3: verify K=3 → next 3 broadcasts = tokens, then num_accepted - /// - 0xFFFFFFF4: verify K=4 → next 4 broadcasts = tokens, then num_accepted - /// - 0xFFFFFFFF: shutdown - pub(super) fn ep_worker_step_impl(&self, seq: &mut SequenceState) -> Result { - let cmd = self.ep_broadcast_u32(0)?; + /// + /// Protocol (`ATLAS_EP_PROTOCOL=v2`): rank 0 broadcasts the slot + /// identifier first (worker uses it to pick the right `SequenceState` + /// from `slots`), then the command code, then any per-command follow-on + /// data. With v1 (the default) the preamble is skipped and every + /// command targets slot 0 — equivalent to the singleton path this + /// function originally implemented. + /// + /// Command codes: + /// - 0..0xFFFFFFEF: token ID → decode in the addressed slot + /// - 0xFFFFFFF0: prefill start → chunk_len, chunk_start, full_len, then full_len tokens + /// - 0xFFFFFFF1: alloc slot (frees any prior occupant first, then re-allocates) + /// - 0xFFFFFFF2/3/4: verify K=2/3/4 → K tokens, then accept/num_accepted + /// - 0xFFFFFFFF: shutdown (seq_id is ignored; applies to the whole worker) + pub(super) fn ep_worker_step_impl(&self, slots: &mut [Option]) -> Result { + let (seq_id, cmd) = self.ep_recv_seq_and_cmd(self.ep_protocol_v2)?; + + // Shutdown applies to the whole worker — seq_id is ignored. + if cmd == 0xFFFFFFFF { + return Ok(false); + } + + // Batched-decode (`0xFFFFFFE0`): the preamble seq_id is sentinel-0; + // the real per-token routing lives in the seq_ids[N] payload that + // follows. Hand off to the batched handler which reads N + seq_ids + // + tokens off the wire and dispatches the matched compute. + if cmd == 0xFFFFFFE0 { + return self.ep_worker_decode_batch(slots); + } + + let slot_idx = seq_id as usize; + if slot_idx >= slots.len() { + anyhow::bail!( + "ep_worker_step: seq_id {} exceeds slot capacity {} \ + (head and worker likely disagree on max_batch_size)", + seq_id, + slots.len(), + ); + } + + // `alloc-slot` (0xFFFFFFF1): replace the slot's sequence wholesale. + // Frees the prior occupant if any, then allocates a fresh one. The + // SSM-pool slot the new sequence claims may or may not equal + // slot_idx — head and worker stay aligned because both ranks call + // `claim_slot()` from a free-list pop in matched order. Defensive + // bail if they ever diverge so we fail fast rather than corrupt KV. + if cmd == 0xFFFFFFF1 { + if let Some(mut old) = slots[slot_idx].take() { + self.free_sequence(&mut old)?; + } + let new_seq = self.alloc_sequence()?; + if self.ep_protocol_v2 && new_seq.slot_idx != slot_idx { + anyhow::bail!( + "ep_worker_step: SSM-pool slot {} doesn't match head's seq_id {} \ + after alloc — claim_slot ordering invariant violated", + new_seq.slot_idx, + slot_idx, + ); + } + slots[slot_idx] = Some(new_seq); + return Ok(true); + } + + // All other commands operate on an already-allocated slot. + let seq = slots[slot_idx].as_mut().ok_or_else(|| { + anyhow::anyhow!( + "ep_worker_step: cmd {:#x} arrived for unallocated slot {} \ + — head dispatched without a prior alloc", + cmd, + slot_idx, + ) + })?; + + self.ep_worker_dispatch_cmd(cmd, seq) + } + + /// Per-command dispatch for [`Self::ep_worker_step_impl`]. The + /// (seq_id, cmd) preamble + slot lookup + shutdown + alloc are already + /// handled by the caller; this routine assumes `seq` is the right + /// slot's allocated `SequenceState`. + fn ep_worker_dispatch_cmd(&self, cmd: u32, seq: &mut SequenceState) -> Result { let stream = self.gpu.default_stream(); match cmd { - 0xFFFFFFFF => return Ok(false), // shutdown - 0xFFFFFFF1 => { - // Free and realloc sequence - self.free_sequence(seq)?; - *seq = self.alloc_sequence()?; - } 0xFFFFFFF0 => { // Prefill chunk: receive chunk_len, chunk_start, full prompt length, // then ALL prompt tokens via bulk broadcast (single NCCL op). @@ -321,4 +467,69 @@ impl TransformerModel { Ok(true) } + + /// Worker-side handler for the batched-decode protocol (`0xFFFFFFE0`). + /// + /// Reads `N` (u32), `seq_ids[N]` (bulk broadcast), and `tokens[N]` + /// (bulk broadcast) off the wire — matching what the head wrote in + /// `ep_broadcast_decode_batch_dispatch`. Then builds an in-order + /// `Vec<&mut SequenceState>` from the addressed slots and hands off + /// to the shared compute path. The compute does the same per-layer + /// `decode_multi_seq` the non-EP main batched path runs, with the + /// NCCL allreduces inside each layer matching the head's submission + /// order on the comm. + /// + /// Validates seq_ids up-front (bounds + duplicates) so a malformed + /// payload from a buggy head fails before touching slot state. + fn ep_worker_decode_batch(&self, slots: &mut [Option]) -> Result { + let n = self.ep_broadcast_u32(0)? as usize; + let seq_ids = self.ep_broadcast_tokens(&vec![0u32; n])?; + let tokens = self.ep_broadcast_tokens(&vec![0u32; n])?; + + // Validate up front so we fail before touching slot state. + let mut seen = std::collections::HashSet::new(); + for &id in &seq_ids { + let idx = id as usize; + if idx >= slots.len() { + anyhow::bail!( + "ep_worker_decode_batch: seq_id {} exceeds slot capacity {}", + id, + slots.len(), + ); + } + if !seen.insert(id) { + anyhow::bail!("ep_worker_decode_batch: duplicate seq_id {} in batch", id); + } + } + + // Drain populated slots into a (idx, ref) Vec we can index by + // position with `swap_remove`. The borrow checker won't let us + // index `slots[seq_ids[i]]` in a loop because each `&mut` is + // distinct but the indexer can't prove non-overlap. + let mut slot_refs: Vec<(usize, &mut SequenceState)> = slots + .iter_mut() + .enumerate() + .filter_map(|(i, opt)| opt.as_mut().map(|s| (i, s))) + .collect(); + + // Order the refs to match the head's seq_ids order so the + // compute path processes tokens in the same batch index as the + // head — critical for KV-cache row alignment per slot. + let mut refs: Vec<&mut SequenceState> = Vec::with_capacity(n); + for &id in &seq_ids { + let idx = id as usize; + let pos = slot_refs + .iter() + .position(|(i, _)| *i == idx) + .ok_or_else(|| { + anyhow::anyhow!("ep_worker_decode_batch: slot {} not allocated", idx) + })?; + let (_, seq) = slot_refs.swap_remove(pos); + refs.push(seq); + } + + let stream = self.gpu.default_stream(); + self.decode_batch_compute_main(&tokens, &mut refs, stream)?; + Ok(true) + } } diff --git a/crates/spark-model/src/model/trait_impl/decode_a2.rs b/crates/spark-model/src/model/trait_impl/decode_a2.rs index f39d62bf..cdf10c61 100644 --- a/crates/spark-model/src/model/trait_impl/decode_a2.rs +++ b/crates/spark-model/src/model/trait_impl/decode_a2.rs @@ -33,19 +33,39 @@ impl TransformerModel { // Single-sequence: delegate to decode() which uses CUDA graphs. // decode_batch disables graphs for n≥2 (SSM state pointer staleness), // but n=1 is safe and benefits from graph replay (2x throughput). + // + // Broadcast the seq_id preamble + cmd here (rather than in the + // scheduler) so the EP n>1 branch below can interleave broadcasts + // with decode() calls — see that branch for the rationale. if n == 1 { + self.ep_broadcast_cmd_for_seq(seqs[0].slot_idx as u32, tokens[0])?; self.decode(tokens[0], seqs[0], stream)?; return Ok(self.decode_logits_ptr()); } - // EP mode: use per-sequence decode() to match the worker's batch size. - // EP workers run one sequence at a time, so the single-row logits - // buffer is consumed before the next call — no row scatter needed. + // EP mode + n > 1: one batched forward pass per rank. + // + // Both ranks must call the same `decode_multi_seq` per-layer with + // the same N tokens so the per-token NCCL all_reduces inside the + // MoE forward match in shape and submission order across ranks. + // The head announces the batch up-front via the `0xFFFFFFE0` + // protocol primitive (seq_ids[N] + tokens[N] in one shot), then + // both ranks run `decode_batch_compute_main` — the worker reaches + // it via the matching handler in `ep_worker_step_impl`. + // + // Comm-stream op order on both ranks per step: + // B(0) B(0xFFFFFFE0) B(N) B*N(seq_ids) B*N(tokens) + // then per layer: per-token AR*N (forward_batched's inner loop) + // + // Single batched forward amortises weight loads + kernel launches + // across N tokens. Per-token all_reduces (forward.rs:445, + // forward_batched.rs:269) remain at shape `h * elem` per call — + // batching the comm shape would need new MoE kernel work and is + // deliberately out of scope here. if self.comm.is_some() { - for i in 0..n { - self.decode(tokens[i], seqs[i], stream)?; - } - return Ok(self.decode_logits_ptr()); + let seq_ids: Vec = seqs.iter().map(|s| s.slot_idx as u32).collect(); + self.ep_broadcast_decode_batch_dispatch(&seq_ids, tokens)?; + return self.decode_batch_compute_main(tokens, seqs, stream); } // MLA models: as of issue #84 the batched `decode_multi_seq` path @@ -102,6 +122,25 @@ impl TransformerModel { return Ok(logits); } + self.decode_batch_compute_main(tokens, seqs, stream) + } + + /// Shared batched-compute path used by both the head's EP branch and + /// the worker's `0xFFFFFFE0` handler. Contains the per-step embed + + /// KV-block alloc + metadata upload + per-layer `decode_multi_seq` + + /// final norm + per-row LM-head GEMV pipeline. No EP broadcasts here + /// — the head emits the protocol primitive before calling this; the + /// worker reads the matching payload and dispatches into this from + /// `ep_worker_decode_batch`. Both ranks then submit identical + /// per-token `comm.all_reduce(h * elem)` ops on every MoE layer in + /// the same order. + pub(crate) fn decode_batch_compute_main( + &self, + tokens: &[u32], + seqs: &mut [&mut SequenceState], + _stream: u64, + ) -> Result { + let n = tokens.len(); let stream = self.gpu.default_stream(); let h = self.config.hidden_size; let bf16 = 2usize; diff --git a/crates/spark-model/src/model/trait_impl/ep_misc.rs b/crates/spark-model/src/model/trait_impl/ep_misc.rs index c0159a18..a9b1bb8b 100644 --- a/crates/spark-model/src/model/trait_impl/ep_misc.rs +++ b/crates/spark-model/src/model/trait_impl/ep_misc.rs @@ -28,8 +28,11 @@ use crate::traits::{ChunkedPrefillPageMetadata, Model, SequenceState}; use crate::weight_map::{DenseWeight, MtpWeights, QuantizedWeight}; impl TransformerModel { - pub(super) fn ep_worker_step_dispatch(&self, seq: &mut SequenceState) -> Result { - self.ep_worker_step_impl(seq) + pub(super) fn ep_worker_step_dispatch( + &self, + slots: &mut [Option], + ) -> Result { + self.ep_worker_step_impl(slots) } pub(super) fn is_ep_dispatch(&self) -> bool { diff --git a/crates/spark-model/src/model/trait_impl/mod.rs b/crates/spark-model/src/model/trait_impl/mod.rs index d5fc4a0d..7fd6eb56 100644 --- a/crates/spark-model/src/model/trait_impl/mod.rs +++ b/crates/spark-model/src/model/trait_impl/mod.rs @@ -341,8 +341,8 @@ impl Model for TransformerModel { ) -> Result<()> { self.commit_verify_state_async_dispatch(seq, num_accepted, k) } - fn ep_worker_step(&self, seq: &mut SequenceState) -> Result { - self.ep_worker_step_dispatch(seq) + fn ep_worker_step(&self, slots: &mut [Option]) -> Result { + self.ep_worker_step_dispatch(slots) } fn is_ep(&self) -> bool { self.is_ep_dispatch() @@ -359,6 +359,14 @@ impl Model for TransformerModel { fn ep_broadcast_cmd(&self, cmd: u32) -> Result<()> { self.ep_broadcast_cmd_dispatch(cmd) } + fn ep_broadcast_cmd_for_seq(&self, seq_id: u32, cmd: u32) -> Result<()> { + // Routes to the helper added in 21e2130. Behaviour depends on the + // ep_protocol_v2 field set at construction from ATLAS_EP_PROTOCOL. + self.ep_broadcast_seq_and_cmd(seq_id, cmd, self.ep_protocol_v2) + } + fn ep_protocol_v2(&self) -> bool { + self.ep_protocol_v2 + } fn ep_broadcast_tokens(&self, tokens: &[u32]) -> Result> { self.ep_broadcast_tokens_dispatch(tokens) } diff --git a/crates/spark-model/src/model/types.rs b/crates/spark-model/src/model/types.rs index 996184a6..5d408ba9 100644 --- a/crates/spark-model/src/model/types.rs +++ b/crates/spark-model/src/model/types.rs @@ -129,6 +129,13 @@ pub struct TransformerModel { pub(super) comm: Option>, /// Small GPU buffer for EP token broadcast (4 bytes). pub(super) ep_cmd_buf: DevicePtr, + /// EP wire-protocol version. When true, the seq_id-preamble protocol + /// extension from atlas#99 is active — every command broadcast is + /// preceded by a `seq_id` broadcast so the worker can dispatch + /// slot-bound work into the right `SequenceState` slot. When false, + /// the legacy single-sequence protocol is used. Set at construction + /// from `ATLAS_EP_PROTOCOL` env var; both ranks must agree. + pub(super) ep_protocol_v2: bool, /// Self-speculative decoding mode: draft via layer-skipping (no MTP weights needed). pub(super) self_speculative: bool, /// Last token index passed to save_hidden_for_mtp (for EP broadcast to rank 1). diff --git a/crates/spark-model/src/traits/model.rs b/crates/spark-model/src/traits/model.rs index 12801eb1..39056361 100644 --- a/crates/spark-model/src/traits/model.rs +++ b/crates/spark-model/src/traits/model.rs @@ -444,11 +444,16 @@ pub trait Model: Send + Sync { Ok(()) } - /// EP worker step: receive a command from rank 0 and execute it. + /// EP worker step: receive a (seq_id, cmd) preamble from rank 0 and + /// execute the command in the addressed slot. /// /// Returns false when the worker should shut down. /// Only valid on rank > 0 with EP enabled. - fn ep_worker_step(&self, _seq: &mut SequenceState) -> Result { + /// + /// `slots` must be sized to `args.max_batch_size` (same as the head's + /// scheduler `active` capacity); commands with `seq_id >= slots.len()` + /// fail loudly rather than corrupt unrelated state. + fn ep_worker_step(&self, _slots: &mut [Option]) -> Result { Ok(true) // no-op for non-EP models } @@ -505,6 +510,28 @@ pub trait Model: Send + Sync { Ok(()) // no-op for non-EP models } + /// EP broadcast: send a `(seq_id, cmd)` pair to all worker ranks. + /// + /// Use this at the *first* broadcast of a logical command sequence + /// (e.g. the K=2 verify marker, prefill start, decode token, etc.). + /// Follow-up broadcasts within the same command (chunk metadata, more + /// tokens, accept/reject result) keep using [`Self::ep_broadcast_cmd`] + /// — the worker consumes the preamble once per command and routes + /// subsequent reads through the slot it identified. + /// + /// When [`Self::ep_protocol_v2`] returns false (the default), the + /// `seq_id` is ignored on the wire and behaviour matches the legacy + /// single-sequence broadcast. + fn ep_broadcast_cmd_for_seq(&self, _seq_id: u32, _cmd: u32) -> Result<()> { + Ok(()) // no-op for non-EP models + } + + /// Returns true if this model's EP comm path is using the v2 protocol + /// (slot-aware seq_id preamble). Default false — pre-PR behaviour. + fn ep_protocol_v2(&self) -> bool { + false + } + /// EP bulk broadcast: send an array of u32 tokens to all worker ranks. /// Uses a single NCCL broadcast instead of per-token broadcasts. fn ep_broadcast_tokens(&self, _tokens: &[u32]) -> Result> { diff --git a/crates/spark-server/src/main_modules/serve.rs b/crates/spark-server/src/main_modules/serve.rs index 14f71dee..bab2c77b 100644 --- a/crates/spark-server/src/main_modules/serve.rs +++ b/crates/spark-server/src/main_modules/serve.rs @@ -304,11 +304,25 @@ pub(crate) async fn serve(mut args: cli::ServeArgs) -> Result<()> { let scheduler_model = model; let scheduler_eos = eos_tokens; - // EP: force batch_size=1 (worker protocol is single-sequence). - // MTP speculative decoding IS supported with EP via verify broadcast protocol. + // EP gate. v1 single-sequence worker protocol required max_batch_size=1 + // because each cmd targeted one slot and the head's per-token broadcast + // loop had no way to address slot N. v2 adds a per-cmd seq_id preamble + // (set ATLAS_EP_PROTOCOL=v2) so the worker routes commands by slot_idx + // and runs decode() per-seq. The head's decode_batch_dispatch EP branch + // stages each seq's logits row to host between decode() calls so all N + // rows survive into process_decode_logits — without that, the single-row + // logits buffer overwrites and N>1 produces garbage. let max_batch_size = if world_size > 1 { - tracing::info!("EP active: forcing max_batch_size=1"); - 1 + if scheduler_model.ep_protocol_v2() { + tracing::info!( + "EP v2 active: honoring max_batch_size={}", + args.max_batch_size, + ); + args.max_batch_size + } else { + tracing::info!("EP v1 active: forcing max_batch_size=1"); + 1 + } } else { args.max_batch_size }; diff --git a/crates/spark-server/src/main_modules/serve_phases/build.rs b/crates/spark-server/src/main_modules/serve_phases/build.rs index 5304ce95..fe5ddce8 100644 --- a/crates/spark-server/src/main_modules/serve_phases/build.rs +++ b/crates/spark-server/src/main_modules/serve_phases/build.rs @@ -183,6 +183,9 @@ pub(crate) fn maybe_run_ep_worker( ); } let worker_hss_cfg = early_high_speed_swap_cfg.clone(); + // Copy primitives out of `args` so the worker thread (which is + // `'static`) doesn't capture the function-scoped `&ServeArgs` ref. + let max_batch_size = args.max_batch_size; let handle = std::thread::spawn(move || { model_owned .bind_gpu_to_thread() @@ -208,12 +211,35 @@ pub(crate) fn maybe_run_ep_worker( } } } - let mut seq = model_owned - .alloc_sequence() - .expect("Failed to allocate EP worker sequence"); - tracing::info!("EP worker ready (rank {rank}), waiting for commands"); + // Slots vec sized to match the head's scheduler `max_batch_size`. + // Pre-allocate every slot. The head only emits `0xFFFFFFF1` + // (free+realloc) on lifecycle events — sequence finish/error — + // not on first use, so a fresh `prefill_a_step` for slot N + // arrives as `0xFFFFFFF0` with no prior alloc broadcast. Under v1 + // (max_batch_size=1) this is just slot 0, matching the legacy + // behavior. Under v2 (max_batch_size>1) every slot must be + // populated up front for the same reason. + // + // Both ranks' SSM pools start with the same free-list ordering + // (see ssm_pool.rs: `(0..max_slots).rev().collect()` + `pop()`), + // so pre-allocating in `0..max_batch_size` order on the worker + // means `slots[i].slot_idx == i` — matching the slot ids the + // head's `alloc_sequence` returns for its Nth claim. + let mut slots: Vec> = + (0..max_batch_size).map(|_| None).collect(); + for slot in slots.iter_mut() { + *slot = Some( + model_owned + .alloc_sequence() + .expect("Failed to allocate EP worker sequence"), + ); + } + tracing::info!( + "EP worker ready (rank {rank}, {} slots), waiting for commands", + slots.len() + ); loop { - match model_owned.ep_worker_step(&mut seq) { + match model_owned.ep_worker_step(&mut slots) { Ok(true) => {} Ok(false) => break, Err(e) => { @@ -222,7 +248,11 @@ pub(crate) fn maybe_run_ep_worker( } } } - let _ = model_owned.free_sequence(&mut seq); + for slot in slots.iter_mut() { + if let Some(seq) = slot.as_mut() { + let _ = model_owned.free_sequence(seq); + } + } tracing::info!("EP worker stopped (rank {rank})"); }); handle.join().expect("EP worker thread panicked"); diff --git a/crates/spark-server/src/scheduler/decode_step.rs b/crates/spark-server/src/scheduler/decode_step.rs index f6fea847..8c2a5e72 100644 --- a/crates/spark-server/src/scheduler/decode_step.rs +++ b/crates/spark-server/src/scheduler/decode_step.rs @@ -45,16 +45,13 @@ pub fn step_decode_only( tracing::debug!("CONC_DIAG n={n}: {}", diag.join(" ")); } - // EP: broadcast token(s) to worker before decode. - for &t in &tokens { - if let Err(e) = model.ep_broadcast_cmd(t) { - tracing::error!("EP broadcast token: {e:#}"); - for mut a in active.drain(..) { - send_error(model, &mut a, &format!("EP broadcast: {e:#}")); - } - return; - } - } + // EP broadcasts (seq_id preamble + cmd per active seq) are emitted + // inside `decode_batch_dispatch` itself, interleaved with each per-seq + // `decode()` call. Batching them up-front here would diverge the head's + // comm-stream op order ([B,B,...,B,AR,AR,...]) from the worker's + // ([B,AR,...,AR,B,AR,...,AR,...]) and deadlock NCCL — observed + // empirically as a 51s broadcast timeout on the worker followed by + // stale comm data reads. See decode_a2.rs for the full rationale. let mut refs: Vec<&mut SequenceState> = active.iter_mut().map(|a| &mut a.seq).collect(); diff --git a/crates/spark-server/src/scheduler/lifecycle.rs b/crates/spark-server/src/scheduler/lifecycle.rs index be6c4ea5..cd21ba3c 100644 --- a/crates/spark-server/src/scheduler/lifecycle.rs +++ b/crates/spark-server/src/scheduler/lifecycle.rs @@ -74,7 +74,7 @@ pub fn finish_sequence(model: &dyn Model, a: &mut ActiveSeq) { tracing::error!("free_sequence: {e:#}"); } // EP: signal worker to free+realloc its mirrored sequence. - if let Err(e) = model.ep_broadcast_cmd(0xFFFFFFF1) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(a.seq.slot_idx as u32, 0xFFFFFFF1) { tracing::error!("EP broadcast free+realloc: {e:#}"); } } @@ -98,7 +98,7 @@ pub fn send_error(model: &dyn Model, a: &mut ActiveSeq, msg: &str) { if let Err(e) = model.free_sequence(&mut a.seq) { tracing::error!("send_error: free_sequence: {e:#}"); } - if let Err(e) = model.ep_broadcast_cmd(0xFFFFFFF1) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(a.seq.slot_idx as u32, 0xFFFFFFF1) { tracing::error!("send_error: ep_broadcast free+realloc: {e:#}"); } } @@ -155,8 +155,9 @@ pub fn swap_out_sequence( let tokens = a.seq.tokens.clone(); // Free GPU resources (KV blocks + SSM slot). + let slot_idx = a.seq.slot_idx as u32; model.free_sequence(&mut a.seq)?; - let _ = model.ep_broadcast_cmd(0xFFFFFFF1); + let _ = model.ep_broadcast_cmd_for_seq(slot_idx, 0xFFFFFFF1); Ok(SwappedSeq { tokens, diff --git a/crates/spark-server/src/scheduler/mod.rs b/crates/spark-server/src/scheduler/mod.rs index 5f9dfbc1..523f7de6 100644 --- a/crates/spark-server/src/scheduler/mod.rs +++ b/crates/spark-server/src/scheduler/mod.rs @@ -384,8 +384,9 @@ pub fn run( for p in prefilling { let mut seq = p.seq; let _ = model.free_sequence(&mut seq); - let _ = model.ep_broadcast_cmd(0xFFFFFFF1); + let _ = model.ep_broadcast_cmd_for_seq(seq.slot_idx as u32, 0xFFFFFFF1); } - let _ = model.ep_broadcast_cmd(0xFFFFFFFF); + // Shutdown applies to every slot the worker has; seq_id is ignored. + let _ = model.ep_broadcast_cmd_for_seq(0, 0xFFFFFFFF); tracing::info!("Scheduler stopped"); } diff --git a/crates/spark-server/src/scheduler/mod_helpers.rs b/crates/spark-server/src/scheduler/mod_helpers.rs index 633a86ec..5276b8c8 100644 --- a/crates/spark-server/src/scheduler/mod_helpers.rs +++ b/crates/spark-server/src/scheduler/mod_helpers.rs @@ -130,12 +130,26 @@ pub(super) fn drain_pending_requests( /// slots [0..N)). /// /// CRITICAL: compact_sequence MUST run BEFORE finish_sequence (BUG #35). +/// +/// Under v2 EP (`ep_protocol_v2`) the worker pre-allocates every slot at +/// startup and the head-worker mirror is keyed by `slot_idx`, not by the +/// active-set position. Moving SSM states on the head only would leave +/// the worker's mirror at the original slot — the next op against that +/// seq would address different physical memory on each rank. The retired +/// seq also can't be tagged with `usize::MAX` because that sentinel +/// becomes `0xFFFFFFFF` when cast to a u32 seq_id and trips the worker's +/// bounds check on the next `0xFFFFFFF1` broadcast. So v2 skips both +/// the compaction and the sentinel and lets the active vec be +/// non-contiguous w.r.t. `slot_idx` — pre-allocated slots stay valid +/// in place across the swap_remove, and the per-slot CUDA graph cache +/// stays warm because the seq never moved. pub(super) fn retire_finished_sequences(model: &dyn Model, active: &mut Vec) { + let skip_compaction = model.ep_protocol_v2(); let mut i = 0; while i < active.len() { if active[i].finished { let mut a = active.swap_remove(i); - if i < active.len() && active[i].seq.slot_idx != i { + if !skip_compaction && i < active.len() && active[i].seq.slot_idx != i { // Compact the swapped-in sequence to reuse the retired // seq's slot. Mark the retired seq's slot as reused so // free_sequence doesn't double-release it. diff --git a/crates/spark-server/src/scheduler/mtp_step.rs b/crates/spark-server/src/scheduler/mtp_step.rs index 6bb7d2c3..f87c504b 100644 --- a/crates/spark-server/src/scheduler/mtp_step.rs +++ b/crates/spark-server/src/scheduler/mtp_step.rs @@ -21,7 +21,7 @@ pub fn step_mtp(model: &dyn Model, active: &mut [ActiveSeq], num_drafts: usize) for &idx in &bootstrap_idxs { let a = &mut active[idx]; // EP: broadcast token to worker before decode (worker runs decode in lockstep). - if let Err(e) = model.ep_broadcast_cmd(a.last_token) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(a.seq.slot_idx as u32, a.last_token) { tracing::error!("EP broadcast bootstrap token: {e:#}"); a.finished = true; continue; diff --git a/crates/spark-server/src/scheduler/phase_continue_prefills/run_standard.rs b/crates/spark-server/src/scheduler/phase_continue_prefills/run_standard.rs index 3ecc920a..28b0c36b 100644 --- a/crates/spark-server/src/scheduler/phase_continue_prefills/run_standard.rs +++ b/crates/spark-server/src/scheduler/phase_continue_prefills/run_standard.rs @@ -155,7 +155,7 @@ pub(super) fn run_standard_chunk_loop( // ── Standard path: prefill chunk only, decode separately ── // EP: broadcast chunk tokens to worker (bulk, single NCCL op). let ep_ok = (|| -> Result<()> { - model.ep_broadcast_cmd(0xFFFFFFF0)?; + model.ep_broadcast_cmd_for_seq(p.seq.slot_idx as u32, 0xFFFFFFF0)?; model.ep_broadcast_cmd(chunk_len as u32)?; model.ep_broadcast_cmd(p.chunk_offset as u32)?; model.ep_broadcast_cmd(p.prompt_tokens.len() as u32)?; diff --git a/crates/spark-server/src/scheduler/phase_promote_prefills.rs b/crates/spark-server/src/scheduler/phase_promote_prefills.rs index 373668de..707e3115 100644 --- a/crates/spark-server/src/scheduler/phase_promote_prefills.rs +++ b/crates/spark-server/src/scheduler/phase_promote_prefills.rs @@ -30,7 +30,7 @@ pub(super) fn promote_completed_prefills( if let Err(e) = model.free_sequence(&mut seq) { tracing::error!("phase_promote_prefills: free_sequence (error path): {e:#}"); } - if let Err(e) = model.ep_broadcast_cmd(0xFFFFFFF1) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(seq.slot_idx as u32, 0xFFFFFFF1) { tracing::error!( "phase_promote_prefills: ep_broadcast free+realloc (error path): {e:#}" ); diff --git a/crates/spark-server/src/scheduler/prefill_a_step.rs b/crates/spark-server/src/scheduler/prefill_a_step.rs index fcb91290..5e910fa1 100644 --- a/crates/spark-server/src/scheduler/prefill_a_step.rs +++ b/crates/spark-server/src/scheduler/prefill_a_step.rs @@ -127,7 +127,7 @@ pub fn start_chunked_prefill( // identical Marconi prefix-cache lookups (bug #33 fix). // Uses bulk broadcast (single NCCL op) instead of per-token broadcast // which caused NCCL timeouts on long prompts (6K+ tokens = 6K+ broadcasts). - model.ep_broadcast_cmd(0xFFFFFFF0)?; + model.ep_broadcast_cmd_for_seq(seq.slot_idx as u32, 0xFFFFFFF0)?; model.ep_broadcast_cmd(chunk_len as u32)?; model.ep_broadcast_cmd(0)?; // chunk_start model.ep_broadcast_cmd(prompt_tokens.len() as u32)?; // full prompt length @@ -153,7 +153,8 @@ pub fn start_chunked_prefill( "prefill_a_step: free_sequence (after prefill error): {free_err:#}" ); } - if let Err(bcast_err) = model.ep_broadcast_cmd(0xFFFFFFF1) { + if let Err(bcast_err) = model.ep_broadcast_cmd_for_seq(seq.slot_idx as u32, 0xFFFFFFF1) + { tracing::error!( "prefill_a_step: ep_broadcast (after prefill error): {bcast_err:#}" ); @@ -186,7 +187,9 @@ pub fn start_chunked_prefill( "prefill_a_step: free_sequence (after sample error): {free_err:#}" ); } - if let Err(bcast_err) = model.ep_broadcast_cmd(0xFFFFFFF1) { + if let Err(bcast_err) = + model.ep_broadcast_cmd_for_seq(seq.slot_idx as u32, 0xFFFFFFF1) + { tracing::error!( "prefill_a_step: ep_broadcast (after sample error): {bcast_err:#}" ); diff --git a/crates/spark-server/src/scheduler/prefill_b_step.rs b/crates/spark-server/src/scheduler/prefill_b_step.rs index 4e1abc91..807256a2 100644 --- a/crates/spark-server/src/scheduler/prefill_b_step.rs +++ b/crates/spark-server/src/scheduler/prefill_b_step.rs @@ -115,7 +115,7 @@ pub fn prefill_request( } // EP: broadcast prefill command + tokens to worker (bulk, single NCCL op). - model.ep_broadcast_cmd(0xFFFFFFF0)?; + model.ep_broadcast_cmd_for_seq(seq.slot_idx as u32, 0xFFFFFFF0)?; model.ep_broadcast_cmd(prompt_tokens.len() as u32)?; model.ep_broadcast_cmd(0)?; // chunk_start = 0 (non-chunked) model.ep_broadcast_cmd(prompt_tokens.len() as u32)?; // full prompt length @@ -135,7 +135,8 @@ pub fn prefill_request( "prefill_b_step: free_sequence (after prefill error): {free_err:#}" ); } - if let Err(bcast_err) = model.ep_broadcast_cmd(0xFFFFFFF1) { + if let Err(bcast_err) = model.ep_broadcast_cmd_for_seq(seq.slot_idx as u32, 0xFFFFFFF1) + { tracing::error!( "prefill_b_step: ep_broadcast (after prefill error): {bcast_err:#}" ); diff --git a/crates/spark-server/src/scheduler/spec_step.rs b/crates/spark-server/src/scheduler/spec_step.rs index 4801c94e..64a5498e 100644 --- a/crates/spark-server/src/scheduler/spec_step.rs +++ b/crates/spark-server/src/scheduler/spec_step.rs @@ -10,7 +10,7 @@ pub fn step_self_spec(model: &dyn Model, active: &mut [ActiveSeq], num_drafts: u let a = &mut active[0]; // 1. Full-model decode to get token_0 - if let Err(e) = model.ep_broadcast_cmd(a.last_token) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(a.seq.slot_idx as u32, a.last_token) { tracing::error!("EP broadcast self-spec token: {e:#}"); a.finished = true; return; @@ -159,7 +159,7 @@ pub fn step_ngram(model: &dyn Model, active: &mut [ActiveSeq], proposer: &mut Ng step_ngram_verify(model, a, &drafts, proposer); } else { // ── Phase A: Bootstrap decode + N-gram propose ── - if let Err(e) = model.ep_broadcast_cmd(a.last_token) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(a.seq.slot_idx as u32, a.last_token) { tracing::error!("EP broadcast ngram bootstrap: {e:#}"); a.finished = true; return; @@ -220,7 +220,7 @@ pub fn step_ngram_verify( // EP: broadcast verify K=2 command + tokens let tokens_k2 = [a.last_token, drafts[0]]; - if let Err(e) = model.ep_broadcast_cmd(0xFFFFFFF2) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(a.seq.slot_idx as u32, 0xFFFFFFF2) { tracing::error!("EP broadcast ngram verify cmd: {e:#}"); a.finished = true; return; diff --git a/crates/spark-server/src/scheduler/verify_k2_step.rs b/crates/spark-server/src/scheduler/verify_k2_step.rs index 6307825b..4281d1f9 100644 --- a/crates/spark-server/src/scheduler/verify_k2_step.rs +++ b/crates/spark-server/src/scheduler/verify_k2_step.rs @@ -17,7 +17,7 @@ pub fn step_verify_k2(model: &dyn Model, a: &mut ActiveSeq, drafts: &[u32], num_ // EP: broadcast verify K=2 command + tokens so worker runs decode_verify_graphed in lockstep. let t_ep = Instant::now(); let tokens_k2 = [a.last_token, drafts[0]]; - if let Err(e) = model.ep_broadcast_cmd(0xFFFFFFF2) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(a.seq.slot_idx as u32, 0xFFFFFFF2) { tracing::error!("EP broadcast verify_k2 cmd: {e:#}"); a.finished = true; return; diff --git a/crates/spark-server/src/scheduler/verify_k3_step.rs b/crates/spark-server/src/scheduler/verify_k3_step.rs index 957039af..4e98bab0 100644 --- a/crates/spark-server/src/scheduler/verify_k3_step.rs +++ b/crates/spark-server/src/scheduler/verify_k3_step.rs @@ -14,7 +14,7 @@ pub fn step_verify_k3(model: &dyn Model, a: &mut ActiveSeq, drafts: &[u32], num_ // EP: broadcast verify K=3 command + 3 tokens so worker runs decode_verify_graphed_k3 in lockstep. let tokens_k3 = [a.last_token, drafts[0], drafts[1]]; - if let Err(e) = model.ep_broadcast_cmd(0xFFFFFFF3) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(a.seq.slot_idx as u32, 0xFFFFFFF3) { tracing::error!("EP broadcast verify_k3 cmd: {e:#}"); a.finished = true; return; diff --git a/crates/spark-server/src/scheduler/verify_k4_step.rs b/crates/spark-server/src/scheduler/verify_k4_step.rs index 39907152..eb6ce841 100644 --- a/crates/spark-server/src/scheduler/verify_k4_step.rs +++ b/crates/spark-server/src/scheduler/verify_k4_step.rs @@ -16,7 +16,7 @@ pub fn step_verify_k4(model: &dyn Model, a: &mut ActiveSeq, drafts: &[u32], num_ let tokens_k4 = [a.last_token, drafts[0], drafts[1], drafts[2]]; // EP: broadcast verify K=4 command + 4 tokens. - if let Err(e) = model.ep_broadcast_cmd(0xFFFFFFF4) { + if let Err(e) = model.ep_broadcast_cmd_for_seq(a.seq.slot_idx as u32, 0xFFFFFFF4) { tracing::error!("EP broadcast verify_k4 cmd: {e:#}"); a.finished = true; return; diff --git a/docs/adr/0011-ep-batched-decode-optimization.md b/docs/adr/0011-ep-batched-decode-optimization.md new file mode 100644 index 00000000..c3044744 --- /dev/null +++ b/docs/adr/0011-ep-batched-decode-optimization.md @@ -0,0 +1,163 @@ +# ADR-0011: Optimizing batched EP decode, and why it is bandwidth-bound + +**Status:** Accepted +**Date:** 2026-05-29 + +## Context + +Issue #99 lifts the `max_batch_size = 1` clamp under `--ep-size 2` by +multiplexing the head↔worker protocol (the `ATLAS_EP_PROTOCOL=v2` work). +Once the gate is lifted, concurrent requests actually batch instead of +serializing behind a one-slot queue. The motivating workload was a +4-concurrent agent burst whose tail latency spiked to ~605 s under the +batch=1 ceiling. + +Lifting the gate is necessary but not the whole story. The batched +multi-sequence decode path it unlocks had two problems worth recording: +a correctness gap (the SSM layers' batched MoE was dead code) and a +performance trap (the generic grouped-GEMM is a net loss at small batch). +This ADR records the decisions made while making batched decode both +correct and fast on 2× GB10, and the larger finding that reframes where +future decode wins can come from. + +The reference model throughout is Qwen3.5-122B-A10B-NVFP4 (48 layers: 36 +gated-delta-net / SSM, 12 full-attention; 256-expert MoE), EP=2. + +## Contents + +- [Decision 1 — SSM multi-seq decode: per-seq mixer + batch-dispatched MoE](#decision-1) +- [Decision 2 — Attention multi-seq MoE: per-token at N≥4](#decision-2) +- [Decision 3 — Do not pursue CUDA graphs for EP decode](#decision-3) +- [Decision 4 — Leave the SSM projections BF16 (quantization frontier)](#decision-4) +- [Measurements](#measurements) +- [The binding constraint](#the-binding-constraint) +- [Consequences](#consequences) + +## Decision 1 — SSM multi-seq decode: per-seq mixer + batch-dispatched MoE {#decision-1} + +`Qwen3SsmLayer::decode_multi_seq_inner` delegated every sequence to the +single-token `decode()` in a loop, with a full batched path sitting +behind an early `return Ok(())` and `#[allow(unreachable_code)]` — dead +since the bug-#6 buffer-aliasing debugging. Per-sequence decode runs N +independent single-token MoE forwards (N × top_k expert GEMVs + N +per-token all-reduces under EP). + +The decision: keep the per-sequence SSM **mixer** (conv1d + GDN +recurrence + projections — it carries independent recurrent state, so the +proven single-token kernels stay), but hoist the **MoE** out of the loop +and run it once, dispatched by batch size: + +- N=2/3 → the fused `forward_k2`/`forward_k3` expert kernels (one batched + all-reduce, no per-token launch overhead) +- N≥4 → the per-token MoE loop + +Buffer safety (the old bug #6) is structural in the new layout: each +per-seq mixer writes its MoE input to `norm_output[i]` (a distinct +per-seq offset), `ssm_forward` never touches `norm_output`, and the +`ssm_out` it returns is consumed within the same loop iteration — so +nothing survives across sequences and the aliasing cannot recur. + +The non-obvious part is the N≥4 fallback. The generic grouped-GEMM +(`forward_prefill`) is built for prefill, where M (tokens) ≫ the number +of active experts. At decode batch sizes the per-expert M is ~1, and the +expert sort / permute / pointer-table overhead — paid once per layer, +across 36 SSM layers — dominates. Measured: the grouped path pushed the +SSM decode step to ~140 ms at N=4 versus ~88 ms for the per-token loop. +So `forward_prefill` is declined for the SSM MoE until a true batched-EP +MoE kernel exists. + +## Decision 2 — Attention multi-seq MoE: per-token at N≥4 {#decision-2} + +The same grouped-GEMM trap applied to the attention layers' multi-seq +FFN, whose N≥4 branch used `forward_prefill`. Switched to the per-token +MoE loop for N≥4 (N=2/3 keep the fused `forward_k2`/`k3`). Measured at +N=4: the attention decode block dropped from ~40 ms to ~24 ms, ~8% off +the whole step, no regression. This mirrors Decision 1. + +## Decision 3 — Do not pursue CUDA graphs for EP decode {#decision-3} + +CUDA graphs are disabled under EP (`use_graphs = self.comm.is_none()`) +because the path assumed NCCL all-reduce was not capturable. That +assumption was tested and is false: the 2-rank all-reduce +(`ncclSend`/`ncclRecv` + a local add) runs entirely on one stream, and +the event-based async variant fork-joins a comm stream to the compute +stream — a multi-stream-capturable pattern. A prototype captured the +full n=1 decode step, including the inter-node RoCE collective, into a +graph: clean capture, coherent output, no deadlock. + +It still does not help. n=1 throughput was ~40 tok/s with graphs versus +~42 without — a slight regression. The decode step is memory-bandwidth +and inter-node-NCCL bound; CUDA graphs only remove CPU launch overhead, +which is already hidden behind the GPU and network work. The "graphs give +~2×" result holds for launch-bound regimes (smaller models, single host), +not this one. The capability is kept on a side branch for a future +launch-bound model; it is not enabled by default. + +Single-host (no EP) was also ruled out as a vehicle: the model is ~76 GB +of weights in memory, and on one GB10 (~109 GB usable) the KV cache gets +zero allocatable blocks at any usable batch/seq-len. EP=2 is the only +topology that leaves room for KV plus batch, which the issue already +observed. + +## Decision 4 — Leave the SSM projections BF16 (quantization frontier) {#decision-4} + +The largest BF16 weight still loaded every decode step is the SSM mixer +projections — `in_proj_qkv` + `in_proj_z` + `out_proj`, ~6.3 GB across +36 layers. Quantizing them to NVFP4 would cut the most per-step +bandwidth. They are kept BF16 deliberately: the loader already forces +`A_log`/`dt_bias` to FP32 to avoid "exponential error amplification in +the decay gate at 8k+ tokens." The recurrent path is precision-sensitive, +so 4-bit on the projections that feed it is a quality risk that would +pass a short coherence check and degrade at long context. Everything +safely quantizable is already NVFP4 at load (lm_head, the MTP head and +its experts, the routed MoE experts). The model sits at its safe +quantization frontier; this decision is to respect that boundary. + +## Measurements {#measurements} + +2× GB10, EP=2, Qwen3.5-122B-A10B-NVFP4, MTP speculative on (~77% accept). + +- Decode step composition at N=2 (host-sync bucketed): SSM layers ~79% + (within SSM: MoE ~57% / mixer ~43%), attention ~15%, lm_head ~6%. +- SSM decode step at N=2: 44 ms → 35 ms with the fused `forward_k2` + (Decision 1), ~15–20% aggregate. +- Attention block at N=4: ~40 ms → ~24 ms (Decision 2). +- Batched vs serialized aggregate throughput at N=2 is ~equal (~32 tok/s + either way); batching converts first-fast-second-waits into + both-finish-together. The win is tail latency and admission, not + aggregate tok/s. +- CUDA graphs under EP at n=1: ~40 tok/s on, ~42 off (Decision 3). + +## The binding constraint {#the-binding-constraint} + +At decode batch sizes that fit this hardware (N ≤ 8), the step is bound +by weight-load bandwidth and the inter-node all-reduce, not by kernel +launch overhead or arithmetic. Two concurrent tokens share almost no +weight loads — the SSM projections are sequential GEMVs and the MoE +experts at N=2 are mostly disjoint across the 256-expert pool — so +batching does not raise aggregate throughput at low N; it amortizes +admission and launch overhead and flattens the tail. + +This explains the pattern across the whole arc: batching gave no +aggregate win at N=2, the grouped-GEMM lost (more work, same bandwidth), +and graphs did nothing (launch is not the bottleneck). + +## Consequences {#consequences} + +**Better.** Lifting the EP gate removes the tail-latency cliff from the +issue (4-concurrent burst no longer serializes). The MoE-dispatch fixes +make the batched path faster than the grouped-GEMM it replaced (~15–20% +at N=2/3 on SSM, ~8% at N=4 on attention) and delete a large dead-code +block. Output is coherent and cross-sequence-isolated at N=2 and N=4. + +**Bounded.** Aggregate decode throughput at low concurrency is set by +memory bandwidth, not by these changes. Future decode wins on this model +must cut bytes-moved or collective latency — quantizing the recurrent +projections (a quality tradeoff, deferred), a true batched-EP MoE kernel +(one grouped expert GEMM tuned for small M plus a single batched +all-reduce), or a faster interconnect. Launch-overhead work (CUDA graphs, +kernel-launch fusion) does not move this workload. + +**Kept.** The EP CUDA-graph capability and the decode phase-timing +instrumentation live on side branches, off by default, for a future +launch-bound model and for re-measurement.