Lift max_batch_size=1 under EP=2 via head↔worker slot multiplexing (#99)#101
Open
camerono wants to merge 10 commits into
Open
Lift max_batch_size=1 under EP=2 via head↔worker slot multiplexing (#99)#101camerono wants to merge 10 commits into
camerono wants to merge 10 commits into
Conversation
…ocol
Foundation for atlas#99 (lift max_batch_size=1 under --ep-size N). Adds two
helpers that wrap an optional seq_id preamble around the existing cmd
broadcast:
ep_broadcast_seq_and_cmd(seq_id, cmd, v2)
ep_recv_seq_and_cmd(v2) -> (seq_id, cmd)
When v2 is false, the preamble is skipped and the wire shape is byte-
identical to today's protocol — head and worker built before this change
continue to interoperate. When v2 is true, a slot identifier (which will be
SequenceState.slot_idx in subsequent commits) precedes each command so the
worker can route slot-bound dispatch into the right SsmStatePool slot.
No callers yet — this commit is purely additive. Commits 2-5 will:
- emit the preamble at head-side broadcast sites (scheduler, verify_k*)
- dispatch multi-slot on the worker (loop in build.rs, slots vec on
ep_worker_step_impl)
- split the legacy "free+realloc" 0xFFFFFFF1 into explicit alloc-slot
and free-slot commands
- flip the world_size > 1 clamp in serve.rs and switch the scheduler to
round-robin dispatch over active sequences
v2 wiring lives at the call site for now (caller passes the flag
explicitly, no implicit env-var reads inside the helper) to keep this
commit testable as pure logic and to align with AGENTS.md's PCND
invariant.
Wires up the ep_broadcast_seq_and_cmd helper added in 21e2130 by switching every command-kickoff site in the scheduler from ep_broadcast_cmd(cmd) to ep_broadcast_cmd_for_seq(slot_idx, cmd). Follow-on broadcasts within the same command (chunk metadata, additional tokens, accept/reject result) keep using ep_broadcast_cmd unchanged — they ride the slot context the worker picked up from the preamble. Plumbing: * `TransformerModel.ep_protocol_v2: bool` (types.rs) reads ATLAS_EP_PROTOCOL env at construction (impl_a1.rs). * `Model::ep_broadcast_cmd_for_seq(seq_id, cmd)` trait method added in traits/model.rs (default no-op); TransformerModel override in trait_impl/mod.rs routes through the helper using its v2 flag. * `Model::ep_protocol_v2() -> bool` trait method added (default false). Call sites migrated (kickoff cmd codes 0xFFFFFFF0..F4 + token-level decode kickoffs that broadcast a.last_token): scheduler/verify_k2_step.rs : K=2 marker scheduler/verify_k3_step.rs : K=3 marker scheduler/verify_k4_step.rs : K=4 marker scheduler/prefill_a_step.rs : chunk-0 prefill + 2 error-recovery scheduler/prefill_b_step.rs : single-shot prefill + 1 error-recovery scheduler/phase_continue_prefills/run_standard.rs : chunked prefill scheduler/phase_promote_prefills.rs : error-path free scheduler/lifecycle.rs : finish_sequence + send_error + swap scheduler/mod.rs : scheduler shutdown drain + shutdown cmd scheduler/spec_step.rs : self-spec + ngram bootstrap + ngram K=2 scheduler/mtp_step.rs : MTP bootstrap decode For sites with `&mut ActiveSeq` in scope we use `a.seq.slot_idx as u32`; for sites with a bare `&mut SequenceState` (prefill paths, phase_promote_prefills error path) we use `seq.slot_idx as u32`. The scheduler shutdown command (0xFFFFFFFF) isn't slot-bound; seq_id is broadcast as 0 by convention and the worker ignores it. Wire effect: * ATLAS_EP_PROTOCOL unset / != "v2": ep_protocol_v2 is false, the helper skips the preamble, broadcast shape is byte-identical to today. * ATLAS_EP_PROTOCOL=v2: head broadcasts (slot_idx, cmd) as two sequential u32s before any follow-on data. Worker doesn't yet consume the preamble — that's commit 3. Single-rank (no comm) and v1 multi-rank both keep working through this commit. v2 isn't end-to-end until commit 3 lands the worker dispatch.
Closes out the v2 protocol end-to-end by teaching the worker to maintain N parallel `SequenceState` slots and route incoming commands by the seq_id preamble emitted in commit 31e44c6. Refactor in `model/impl_a2.rs`: * `ep_worker_step_impl` now takes `&mut [Option<SequenceState>]`. It reads `(seq_id, cmd)` via `ep_recv_seq_and_cmd` and handles the two slot-independent codes inline: shutdown (`0xFFFFFFFF`, applies to the whole worker, seq_id ignored) and alloc-slot (`0xFFFFFFF1`, frees any prior occupant of `slots[seq_id]` then allocates a fresh `SequenceState`). * Per-command dispatch (prefill/decode/verify K=2/3/4) moves into a new `ep_worker_dispatch_cmd(cmd, seq)` helper. The body is verbatim from before — only the prelude (slot lookup + alloc handling) moved. * Under v2 we defensively `bail!` if the worker's freshly-claimed SSM-pool slot doesn't equal the seq_id the head broadcast. Head and worker both pop from a `free_slots: Mutex<Vec<usize>>` in matched order so this should always hold — the check is here so we catch the invariant breaking loudly rather than corrupting KV. Trait + dispatch chain: * `Model::ep_worker_step` signature is now `(&mut [Option<SequenceState>])`. Default impl returns `Ok(true)` as before. * `ep_worker_step_dispatch` (ep_misc.rs) and the TransformerModel trait impl (mod.rs) forward the slots slice through. Worker entry in `spark-server/src/main_modules/serve_phases/build.rs`: * Allocates a `Vec<Option<SequenceState>>` of `args.max_batch_size` `None`s. Pre-allocates slot 0 with `alloc_sequence()` so v1 (which never issues an explicit alloc command before its first decode) keeps working without head-side changes. * On shutdown, walks every slot and frees occupants. Backward compatibility: * With `ATLAS_EP_PROTOCOL` unset/!=`v2`: `ep_recv_seq_and_cmd` returns `seq_id=0` regardless of what the head sends (helper skips the preamble read entirely). `slots[0]` is pre-allocated. Every command routes to slot 0. Wire shape + worker semantics are byte-identical to the pre-PR singleton path. * With `ATLAS_EP_PROTOCOL=v2`: worker reads the preamble, routes correctly, alloc commands fill in additional slots as the head requests them. The gate clamp in `serve.rs:307-314` still forces `max_batch_size=1` under EP. Lifting that is commit 5 — at which point the scheduler can actually drive multiple active sequences. Without commit 5, v2 is inert because the head never has more than one active sequence to preface with a nonzero seq_id. File size: `impl_a2.rs` grew from 359 to 417 LoC. Under the 500-line CI guideline; the existing match arms account for most of the bulk and weren't worth splitting into a separate module for this PR.
Two changes that pair conceptually: the worker pre-allocates every slot in its slots Vec at startup (not just slot 0), and head-side compaction in retire_finished_sequences is skipped when the model reports ep_protocol_v2(). Pre-allocation rationale. Under v1 the worker only ever saw slot 0, so init-time pre-alloc of slot 0 sufficed. With the v2 protocol layer in place the head broadcasts a per-cmd seq_id preamble — and a future `max_batch_size > 1` lift will route prefill commands to slots > 0 without a preceding alloc broadcast (head's prefill_step claims a fresh slot via alloc_sequence + broadcasts 0xFFFFFFF0; the 0xFFFFFFF1 alloc command only fires on lifecycle events). Without all slots pre-claimed on the worker, that future routing would bail with "cmd 0xfffffff0 for unallocated slot N". Pre-allocating every slot up-front matches what the head's own SSM pool does (`(0..max_slots).rev().collect()` + `pop`) so sequential alloc_sequence() calls on both ranks return the same slot_idx order, keeping the new_seq.slot_idx == slot_idx defensive check in ep_worker_step satisfied. Skip-compaction rationale. retire_finished_sequences compacts the active vec so position == slot_idx, then tags the retired entry with usize::MAX as a do-not-double-free sentinel. Under v2, two things break: (1) moving SSM states on the head only would leave the worker's mirror at the original slot since the worker is keyed on slot_idx not active-set position, and (2) usize::MAX cast to u32 is 0xFFFFFFFF — the v1 shutdown command — which the worker would read as a real seq_id and trip its bounds check on the next preamble broadcast. Pre-allocated slots stay valid in place across the swap_remove, and the per-slot CUDA graph cache stays warm because the seq never moved. No behavior change under v1. With max_batch_size=1 (the standing EP clamp at serve.rs:307-314), the slots Vec has length 1, pre-allocation claims exactly slot 0, and active.len() never exceeds 1 so the compaction branch is unreachable. ep_protocol_v2() defaults false, so skip_compaction is false on v1, identical legacy behavior.
Bench-validated end-to-end on 2× GB10 (GB10 × 2, EP=2,
qwen3.5-122b-a10b NVFP4, MTP nvfp4 speculative): N=4 and N=8 concurrent
decode now correct and coherent. Single-seq baseline unchanged.
Three coordinated changes:
1. decode_a2.rs — decode_batch_dispatch's EP branch was a known dead
path under v1 (max_batch_size=1 clamp). Three things needed to be
true at once to make it correct for N>1 EP:
a) Per-layer NCCL allreduces must align in size and order with the
worker's matching allreduces. The worker runs decode() per slot
in ep_worker_step, so the head must also run decode() per seq —
no batched decode_multi_seq under EP.
b) The order of ops submitted to the comm matters. Worker submits
per seq: broadcast(preamble), then N_layer all_reduces. Head
previously batched all N preambles up-front from the scheduler,
which made head submit [B,B,B,B,AR,AR,...] while worker
submitted [B,AR,...,AR,B,AR,...,AR,...]. NCCL collectives match
by submission position; mismatched positions deadlocked the
comm. Observed empirically as "NCCL broadcast took 51.1s" on the
worker followed by stale comm reads. Now the broadcasts live
inside decode_batch_dispatch's EP branch, interleaved with each
self.decode() — both ranks submit [B,AR,...,AR,B,AR,...] in
matching order.
c) self.decode() writes single-row logits to row 0 of the logits
buffer on every call. Looping N decodes overwrites the buffer
so process_decode_logits ends up sampling N rows of the last
seq's logits. Stage each seq's row to host immediately after
its decode() (the buffer is fresh within the scope of one
decode()) then upload the assembled [n, vocab] back to the
logits buffer before returning. Same pattern as the existing
MLA per-seq fallback below.
And one stream subtlety: decode_dispatch overrides the caller's
`stream` parameter and uses `self.gpu.default_stream()` internally
for its forward-pass kernels. Issuing the per-seq D2H copy on the
scheduler's stream=0 (legacy NULL stream) landed it on a different
CUDA stream than the GEMV that wrote the logits. The copy could
read stale logits even though both streams "should" have synced.
Use self.gpu.default_stream() throughout the EP path so the copy
queues onto the same stream as the GEMV writes.
Also broadcast in the n=1 short-circuit so the scheduler is fully
relieved of decode-broadcast responsibility (the EP n>1 branch
couldn't move broadcasts inline without the n=1 path also handling
its own).
2. decode_step.rs — remove the per-token broadcast loop. The
responsibility moved into decode_batch_dispatch, so step_decode_only
no longer needs to know about EP at the cmd-broadcast layer.
3. serve.rs — honor --max-batch-size under EP when ep_protocol_v2()
returns true. v1 still clamps to 1 so existing deployments are
byte-identical without ATLAS_EP_PROTOCOL=v2.
Operational result on the motivating workload (nemoclaw 122B EP=2,
qwen3.5-122b-a10b NVFP4 + MTP nvfp4 speculative, --max-batch-size 4):
Wall-clock (concurrent users, max_tokens=80 each):
v1 max_batch=1 k=4: 2 of 4 tail-spiked to 605s (head-of-line)
v2 max_batch=4 N=4: 7.01s (all 4 coherent)
v2 max_batch=4 N=8: 13.12s (all 8 coherent, 4 in-flight + 4 queued)
Per-seq throughput is lower than the single-seq baseline (5-10 tok/s
vs ~38 tok/s) because the head still runs N sequential forward passes
and the per-step host-staging adds overhead. The user-visible win is
tail-latency elimination, not aggregate throughput. A true batched-EP
forward pass — Option A in PR_NOTES — remains the follow-up for the
throughput multiplier; this PR is the structural prerequisite that
also delivers the head-of-line fix today.
Eliminates the host round-trip in decode_batch_dispatch's EP branch. Previously: each per-seq decode() wrote row 0 of the logits buffer, the host code copied row 0 to a staging Vec, then uploaded the assembled [n, vocab] back to logits. Two PCIe transfers per seq plus one final upload. Now: iterate in reverse, decode each seq (still writing to row 0), then issue a device-to-device copy from row 0 to the target row i. For i=0 (processed last in the reverse iteration), no copy needed — row 0 already holds seq 0's logits. Stays on GPU memory throughout. Bench on 2× GB10 (qwen3.5-122b-a10b NVFP4, EP=2, --max-batch-size 4, MTP nvfp4 speculative): N=4: 7.01s -> 5.89s (-16%) N=8: 13.12s -> 11.43s (-13%) The eventual true batched-EP forward pass (one decode_multi_seq call per step instead of N sequential decode()s) subsumes this — N rows get written directly by the lm_head GEMV loop and no staging is needed at all. Until that lands, the d2d cuts the visible bench wall time without touching kernel correctness.
…i-seq decode Two coordinated changes that together make the multi-seq decode path a first-class entry point under EP, using the same kernel set that prefill already uses. 1. Batched-EP protocol (decode_a2.rs + impl_a2.rs) Reintroduces `0xFFFFFFE0` — the batched-decode command code originally tried (and reverted) in the foundation cycle. Head broadcasts `(seq_id=0, 0xFFFFFFE0)` preamble + N + seq_ids[N] + tokens[N] via the new `ep_broadcast_decode_batch_dispatch` helper. Worker matches with `ep_worker_decode_batch` (handled BEFORE the slot_idx lookup in `ep_worker_step_impl`, like shutdown), reads the payload, builds an in-order `Vec<&mut SequenceState>` from the addressed slots, and dispatches into the shared compute path. The shared compute path is `decode_batch_compute_main` — extracted from `decode_batch_dispatch`'s former non-EP main branch so both ranks now reach it. The head's EP branch broadcasts the protocol primitive, then calls it. The worker's batched handler also calls it. No host-staging, no per-seq broadcast loop — one batched forward pass per step per rank. Comm-stream submission order per decode step is identical on both ranks: `B(0) B(0xFFFFFFE0) B(N) B*N(seq_ids) B*N(tokens)` then per MoE layer N × `comm.all_reduce(h*elem)` from the per-token loop inside the batched MoE path (Avarok-Cybersecurity#2 below). 2. Grouped-MoE in multi-seq decode (trait_decode_multi_seq.rs + qwen3_attention/trait_impl/multi_seq/ffn.rs) The multi-seq decode path on both qwen3_ssm and qwen3_attention layers previously called `self.ffn.forward(normed2_i, ctx, stream)` inside a per-token loop — N × (gate GEMV + top_k expert GEMVs + weighted sum) per MoE sublayer. Refactor the loop into three phases: A: per-token residual_add_rms_norm, laying out `norm_output[0..n]` as a contiguous [N, h] MoE input B: ONE call to `self.ffn.forward_prefill(norm_base, n, ctx, stream)` — the grouped-GEMM path that the prefill scheduler already uses. Sort tokens by expert, one grouped gate+up GEMM, SiLU, one grouped down GEMM, unpermute. C: per-token residual_add reading `moe_output[i]`. Bug-Avarok-Cybersecurity#6 invariant preserved: SSM outputs are still copied to `ssm_out_safe` before Phase A so the batched MoE's writes to `moe_output[0..n]` don't clobber the SSM outputs the rms_norm reads. Perf on 2× GB10 (qwen3.5-122b-a10b NVFP4, EP=2, --max-batch-size 4, MTP nvfp4 speculative, greedy bench warm): N=1: ~38 tok/s aggregate N=2: ~29 tok/s aggregate N=4: ~38 tok/s aggregate (per-seq ~10 tok/s) N=8: ~26 tok/s aggregate (4 in-flight + 4 queued) Same throughput as the d2d-only path the previous commit shipped — the architectural ceiling on this hardware at low N is roughly the single-seq decode rate. The per-seq drop is the expected cost of sharing compute across N tokens; aggregate stays at the GPU's weight-load-bandwidth ceiling. What this PR's structural changes enable that d2d-only couldn't: - A single-call multi-seq decode entry point that uses the same MoE kernels (`moe_w4a16_grouped_gemm_ptrtable`, `moe_topk_softmax_batched`, `moe_sort_by_expert`, `moe_unpermute_reduce_indexed`) as prefill — one less code path to keep in sync, one less code path to optimize. - A clear hook for future kernel-level wins: a true batched `comm.all_reduce(N*h*elem)` instead of N per-token `comm.all_reduce(h)` would land cleanly here without changing the dispatch. - EP=4 / higher-N regimes where expert reuse becomes meaningful (N >> 256/top_k) will exercise the same path; today's grouped-GEMM kernel is already in production for prefill and known-correct. No behavior change under v1 (ATLAS_EP_PROTOCOL unset or != "v2"): the EP n>1 path is still unreachable because `serve.rs` clamps max_batch_size=1. Under v2 with N=1, the n=1 short-circuit fires; the batched code path only sees N>1 when the gate is flipped.
The SSM multi-seq decode path delegated every sequence to the full single-token decode(), running N independent single-token MoE forwards (N x top_k expert GEMVs + N per-token all_reduces under EP). The batched grouped-MoE code meant to replace it sat behind an early return, unreachable since the bug-Avarok-Cybersecurity#6 buffer-aliasing debugging. Replace the delegation with a per-seq SSM mixer loop (conv1d/GDN recurrent state is inherently per-sequence, so the proven single-token kernels stay) feeding a batch-dispatched MoE: - N=2/3: fused forward_k2/k3 -- one batched all_reduce, no per-token launch overhead. SSM decode-step 44->35ms at N=2 on GB10 (qwen3.5-122b-a10b NVFP4, EP=2). - N>=4: per-token MoE loop. The generic grouped-GEMM (forward_prefill) is a net loss for this 256-expert MoE at small batch -- per-expert M is ~1 and the sort/permute/ptr-table overhead, paid once per layer across 36 SSM layers, pushed the SSM step to ~140ms vs ~88ms per-token. forward_prefill is declined here until a true batched-EP MoE kernel exists. Buffer safety: each per-seq mixer writes its MoE input to norm_output[i] (distinct per-seq offset); ssm_forward never touches norm_output and its ssm_out is consumed within the same iteration, so nothing survives across sequences and the old aliasing cannot recur. Validated coherent + cross-seq isolated at N=2 and N=4. Removes the unreachable batched-decode block.
Mirror of the SSM-layer fix (parent commit): the attention layers' multi-seq FFN used the generic grouped-GEMM (forward_prefill) for the N>=4 branch, which is a net loss for this 256-expert MoE at small batch -- per-expert M ~1 and the sort/permute/ptr-table overhead dominates. Replace it with the per-token MoE loop (identical to decode()'s MoE). N=2/3 keep the fused forward_k2/k3 branches. Measured on GB10 (qwen3.5-122b-a10b NVFP4, EP=2, N=4): attention decode-block ~40->~24ms, full step ~132->~122ms, no regression. This also makes the N=2/3 MLA fallback path (force_seq_ffn) avoid the batched-MoE kernels it was never safe with. Validated coherent + cross-seq isolated at N=4.
|
All contributors have signed the CLA. Thank you! |
Author
|
I have read the CLA Document and I hereby sign the CLA |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Implements #99 — lifts the
max_batch_size = 1clamp under--ep-size 2by multiplexing the head↔worker protocol, and makes the batched multi-sequence decode path it unlocks both correct and fast.Behind
ATLAS_EP_PROTOCOL=v2so the wire change is opt-in; v1 behaviour is unchanged when the flag is unset.Protocol (the design from #99, landed)
SequenceState.slot_idxalready identifies a sequence's SSM-pool slot and the whole forward pass indexes by it — the protocol just couldn't express which slot the head was operating on, so the worker only ever used slot 0 andserve.rswas forced to clampmax_batch_size = 1. The change plumbs the existingslot_idxthrough as the seq_id:ep_broadcast_seq_and_cmd/ep_recv_seq_and_cmdhelpers (preamble is skipped when v2 is off).seq_idpreamble at every EP broadcast site.Vec<Option<SequenceState>>instead of a singleton).world_size > 1clamp under v2; pre-allocate all worker slots and skip retire-compaction (the worker keeps slots in place, keyed byslot_idx).Two head-side fixes the protocol exposed
These were in the model layer, not the wire protocol:
decode_batch_dispatch's EP branch wrote single-row logits to row 0 per sequence, so every sequence sampled the last one's logits. Fixed by staging each sequence's logits row (later, a d2d copy; finally subsumed by the batched path).Batched-decode MoE: use the fused batch kernels, not the grouped-GEMM
The SSM layers' batched MoE was dead code behind an early
return. Reviving it withforward_prefill(the prefill grouped-GEMM) was a ~60% regression — at decode batch sizes the per-expert M is ≈1, and the 256-expert sort/permute/ptr-table overhead dominates, ×36 SSM layers. The dispatch that works:forward_k2/forward_k3(one batched all-reduce, no per-token launch overhead)The SSM mixer stays per-sequence (it carries recurrent state); only the stateless MoE is hoisted out and batched. The attention layers had the same
forward_prefill-at-N≥4 trap, fixed the same way.Measured (Qwen3.5-122B-A10B-NVFP4, EP=2 on 2× GB10, MTP on)
forward_k2)One expectation to set: lifting the gate buys tail latency and admission, not aggregate throughput at low concurrency. Batched vs serialized at N=2 is ~equal (~32 tok/s either way) — the decode step is memory-bandwidth + inter-node-NCCL bound, and two concurrent tokens share almost no weight loads. Output is coherent and cross-sequence-isolated at N=2 and N=4. Full reasoning, including why CUDA graphs and single-host don't help this model, is in
docs/adr/0011-ep-batched-decode-optimization.md.