Skip to content

Lift max_batch_size=1 under EP=2 via head↔worker slot multiplexing (#99)#101

Open
camerono wants to merge 10 commits into
Avarok-Cybersecurity:mainfrom
camerono:pr/ep-batched-decode
Open

Lift max_batch_size=1 under EP=2 via head↔worker slot multiplexing (#99)#101
camerono wants to merge 10 commits into
Avarok-Cybersecurity:mainfrom
camerono:pr/ep-batched-decode

Conversation

@camerono
Copy link
Copy Markdown

Implements #99 — lifts the max_batch_size = 1 clamp under --ep-size 2 by multiplexing the head↔worker protocol, and makes the batched multi-sequence decode path it unlocks both correct and fast.

Behind ATLAS_EP_PROTOCOL=v2 so the wire change is opt-in; v1 behaviour is unchanged when the flag is unset.

Protocol (the design from #99, landed)

SequenceState.slot_idx already identifies a sequence's SSM-pool slot and the whole forward pass indexes by it — the protocol just couldn't express which slot the head was operating on, so the worker only ever used slot 0 and serve.rs was forced to clamp max_batch_size = 1. The change plumbs the existing slot_idx through as the seq_id:

  1. ep_broadcast_seq_and_cmd / ep_recv_seq_and_cmd helpers (preamble is skipped when v2 is off).
  2. Head emits the seq_id preamble at every EP broadcast site.
  3. Worker dispatches multi-slot (Vec<Option<SequenceState>> instead of a singleton).
  4. Explicit alloc-slot / free-slot command codes.
  5. Drop the world_size > 1 clamp under v2; pre-allocate all worker slots and skip retire-compaction (the worker keeps slots in place, keyed by slot_idx).

Two head-side fixes the protocol exposed

These were in the model layer, not the wire protocol:

  • decode_batch_dispatch's EP branch wrote single-row logits to row 0 per sequence, so every sequence sampled the last one's logits. Fixed by staging each sequence's logits row (later, a d2d copy; finally subsumed by the batched path).
  • Broadcasting all seq-id preambles up front put the head's comm-stream op order out of step with the worker's per-layer all-reduces, deadlocking NCCL. Fixed by interleaving the broadcasts inside the dispatch.

Batched-decode MoE: use the fused batch kernels, not the grouped-GEMM

The SSM layers' batched MoE was dead code behind an early return. Reviving it with forward_prefill (the prefill grouped-GEMM) was a ~60% regression — at decode batch sizes the per-expert M is ≈1, and the 256-expert sort/permute/ptr-table overhead dominates, ×36 SSM layers. The dispatch that works:

  • N=2/3 → fused forward_k2 / forward_k3 (one batched all-reduce, no per-token launch overhead)
  • N≥4 → per-token MoE loop

The SSM mixer stays per-sequence (it carries recurrent state); only the stateless MoE is hoisted out and batched. The attention layers had the same forward_prefill-at-N≥4 trap, fixed the same way.

Measured (Qwen3.5-122B-A10B-NVFP4, EP=2 on 2× GB10, MTP on)

Change Effect
Gate lift 4-concurrent burst now batches instead of serializing — removes the batch=1 tail-spike from #99
SSM MoE dispatch (forward_k2) SSM decode step 44 ms → 35 ms at N=2 (~15–20%)
Attention MoE dispatch attention block ~40 ms → ~24 ms at N=4

One expectation to set: lifting the gate buys tail latency and admission, not aggregate throughput at low concurrency. Batched vs serialized at N=2 is ~equal (~32 tok/s either way) — the decode step is memory-bandwidth + inter-node-NCCL bound, and two concurrent tokens share almost no weight loads. Output is coherent and cross-sequence-isolated at N=2 and N=4. Full reasoning, including why CUDA graphs and single-host don't help this model, is in docs/adr/0011-ep-batched-decode-optimization.md.

camerono added 10 commits May 27, 2026 19:11
…ocol

Foundation for atlas#99 (lift max_batch_size=1 under --ep-size N). Adds two
helpers that wrap an optional seq_id preamble around the existing cmd
broadcast:

  ep_broadcast_seq_and_cmd(seq_id, cmd, v2)
  ep_recv_seq_and_cmd(v2) -> (seq_id, cmd)

When v2 is false, the preamble is skipped and the wire shape is byte-
identical to today's protocol — head and worker built before this change
continue to interoperate. When v2 is true, a slot identifier (which will be
SequenceState.slot_idx in subsequent commits) precedes each command so the
worker can route slot-bound dispatch into the right SsmStatePool slot.

No callers yet — this commit is purely additive. Commits 2-5 will:
  - emit the preamble at head-side broadcast sites (scheduler, verify_k*)
  - dispatch multi-slot on the worker (loop in build.rs, slots vec on
    ep_worker_step_impl)
  - split the legacy "free+realloc" 0xFFFFFFF1 into explicit alloc-slot
    and free-slot commands
  - flip the world_size > 1 clamp in serve.rs and switch the scheduler to
    round-robin dispatch over active sequences

v2 wiring lives at the call site for now (caller passes the flag
explicitly, no implicit env-var reads inside the helper) to keep this
commit testable as pure logic and to align with AGENTS.md's PCND
invariant.
Wires up the ep_broadcast_seq_and_cmd helper added in 21e2130 by switching
every command-kickoff site in the scheduler from ep_broadcast_cmd(cmd) to
ep_broadcast_cmd_for_seq(slot_idx, cmd). Follow-on broadcasts within the
same command (chunk metadata, additional tokens, accept/reject result)
keep using ep_broadcast_cmd unchanged — they ride the slot context the
worker picked up from the preamble.

Plumbing:

  * `TransformerModel.ep_protocol_v2: bool` (types.rs) reads
    ATLAS_EP_PROTOCOL env at construction (impl_a1.rs).
  * `Model::ep_broadcast_cmd_for_seq(seq_id, cmd)` trait method added in
    traits/model.rs (default no-op); TransformerModel override in
    trait_impl/mod.rs routes through the helper using its v2 flag.
  * `Model::ep_protocol_v2() -> bool` trait method added (default false).

Call sites migrated (kickoff cmd codes 0xFFFFFFF0..F4 + token-level
decode kickoffs that broadcast a.last_token):

  scheduler/verify_k2_step.rs    : K=2 marker
  scheduler/verify_k3_step.rs    : K=3 marker
  scheduler/verify_k4_step.rs    : K=4 marker
  scheduler/prefill_a_step.rs    : chunk-0 prefill + 2 error-recovery
  scheduler/prefill_b_step.rs    : single-shot prefill + 1 error-recovery
  scheduler/phase_continue_prefills/run_standard.rs : chunked prefill
  scheduler/phase_promote_prefills.rs : error-path free
  scheduler/lifecycle.rs         : finish_sequence + send_error + swap
  scheduler/mod.rs               : scheduler shutdown drain + shutdown cmd
  scheduler/spec_step.rs         : self-spec + ngram bootstrap + ngram K=2
  scheduler/mtp_step.rs          : MTP bootstrap decode

For sites with `&mut ActiveSeq` in scope we use `a.seq.slot_idx as u32`;
for sites with a bare `&mut SequenceState` (prefill paths,
phase_promote_prefills error path) we use `seq.slot_idx as u32`. The
scheduler shutdown command (0xFFFFFFFF) isn't slot-bound; seq_id is
broadcast as 0 by convention and the worker ignores it.

Wire effect:

  * ATLAS_EP_PROTOCOL unset / != "v2": ep_protocol_v2 is false, the
    helper skips the preamble, broadcast shape is byte-identical to today.
  * ATLAS_EP_PROTOCOL=v2: head broadcasts (slot_idx, cmd) as two
    sequential u32s before any follow-on data. Worker doesn't yet
    consume the preamble — that's commit 3.

Single-rank (no comm) and v1 multi-rank both keep working through this
commit. v2 isn't end-to-end until commit 3 lands the worker dispatch.
Closes out the v2 protocol end-to-end by teaching the worker to maintain
N parallel `SequenceState` slots and route incoming commands by the
seq_id preamble emitted in commit 31e44c6.

Refactor in `model/impl_a2.rs`:

  * `ep_worker_step_impl` now takes `&mut [Option<SequenceState>]`. It
    reads `(seq_id, cmd)` via `ep_recv_seq_and_cmd` and handles the two
    slot-independent codes inline: shutdown (`0xFFFFFFFF`, applies to
    the whole worker, seq_id ignored) and alloc-slot (`0xFFFFFFF1`,
    frees any prior occupant of `slots[seq_id]` then allocates a fresh
    `SequenceState`).
  * Per-command dispatch (prefill/decode/verify K=2/3/4) moves into a
    new `ep_worker_dispatch_cmd(cmd, seq)` helper. The body is verbatim
    from before — only the prelude (slot lookup + alloc handling) moved.
  * Under v2 we defensively `bail!` if the worker's freshly-claimed
    SSM-pool slot doesn't equal the seq_id the head broadcast. Head and
    worker both pop from a `free_slots: Mutex<Vec<usize>>` in matched
    order so this should always hold — the check is here so we catch
    the invariant breaking loudly rather than corrupting KV.

Trait + dispatch chain:

  * `Model::ep_worker_step` signature is now `(&mut
    [Option<SequenceState>])`. Default impl returns `Ok(true)` as before.
  * `ep_worker_step_dispatch` (ep_misc.rs) and the TransformerModel
    trait impl (mod.rs) forward the slots slice through.

Worker entry in `spark-server/src/main_modules/serve_phases/build.rs`:

  * Allocates a `Vec<Option<SequenceState>>` of `args.max_batch_size`
    `None`s. Pre-allocates slot 0 with `alloc_sequence()` so v1 (which
    never issues an explicit alloc command before its first decode)
    keeps working without head-side changes.
  * On shutdown, walks every slot and frees occupants.

Backward compatibility:

  * With `ATLAS_EP_PROTOCOL` unset/!=`v2`: `ep_recv_seq_and_cmd` returns
    `seq_id=0` regardless of what the head sends (helper skips the
    preamble read entirely). `slots[0]` is pre-allocated. Every command
    routes to slot 0. Wire shape + worker semantics are byte-identical
    to the pre-PR singleton path.
  * With `ATLAS_EP_PROTOCOL=v2`: worker reads the preamble, routes
    correctly, alloc commands fill in additional slots as the head
    requests them.

The gate clamp in `serve.rs:307-314` still forces `max_batch_size=1`
under EP. Lifting that is commit 5 — at which point the scheduler can
actually drive multiple active sequences. Without commit 5, v2 is
inert because the head never has more than one active sequence to
preface with a nonzero seq_id.

File size: `impl_a2.rs` grew from 359 to 417 LoC. Under the 500-line CI
guideline; the existing match arms account for most of the bulk and
weren't worth splitting into a separate module for this PR.
Two changes that pair conceptually: the worker pre-allocates every slot
in its slots Vec at startup (not just slot 0), and head-side compaction
in retire_finished_sequences is skipped when the model reports
ep_protocol_v2().

Pre-allocation rationale. Under v1 the worker only ever saw slot 0, so
init-time pre-alloc of slot 0 sufficed. With the v2 protocol layer in
place the head broadcasts a per-cmd seq_id preamble — and a future
`max_batch_size > 1` lift will route prefill commands to slots > 0
without a preceding alloc broadcast (head's prefill_step claims a fresh
slot via alloc_sequence + broadcasts 0xFFFFFFF0; the 0xFFFFFFF1 alloc
command only fires on lifecycle events). Without all slots pre-claimed
on the worker, that future routing would bail with "cmd 0xfffffff0 for
unallocated slot N". Pre-allocating every slot up-front matches what
the head's own SSM pool does (`(0..max_slots).rev().collect()` + `pop`)
so sequential alloc_sequence() calls on both ranks return the same
slot_idx order, keeping the new_seq.slot_idx == slot_idx defensive
check in ep_worker_step satisfied.

Skip-compaction rationale. retire_finished_sequences compacts the
active vec so position == slot_idx, then tags the retired entry with
usize::MAX as a do-not-double-free sentinel. Under v2, two things
break: (1) moving SSM states on the head only would leave the worker's
mirror at the original slot since the worker is keyed on slot_idx not
active-set position, and (2) usize::MAX cast to u32 is 0xFFFFFFFF —
the v1 shutdown command — which the worker would read as a real seq_id
and trip its bounds check on the next preamble broadcast. Pre-allocated
slots stay valid in place across the swap_remove, and the per-slot
CUDA graph cache stays warm because the seq never moved.

No behavior change under v1. With max_batch_size=1 (the standing EP
clamp at serve.rs:307-314), the slots Vec has length 1, pre-allocation
claims exactly slot 0, and active.len() never exceeds 1 so the
compaction branch is unreachable. ep_protocol_v2() defaults false, so
skip_compaction is false on v1, identical legacy behavior.
Bench-validated end-to-end on 2× GB10 (GB10 × 2, EP=2,
qwen3.5-122b-a10b NVFP4, MTP nvfp4 speculative): N=4 and N=8 concurrent
decode now correct and coherent. Single-seq baseline unchanged.

Three coordinated changes:

1. decode_a2.rs — decode_batch_dispatch's EP branch was a known dead
   path under v1 (max_batch_size=1 clamp). Three things needed to be
   true at once to make it correct for N>1 EP:

   a) Per-layer NCCL allreduces must align in size and order with the
      worker's matching allreduces. The worker runs decode() per slot
      in ep_worker_step, so the head must also run decode() per seq —
      no batched decode_multi_seq under EP.

   b) The order of ops submitted to the comm matters. Worker submits
      per seq: broadcast(preamble), then N_layer all_reduces. Head
      previously batched all N preambles up-front from the scheduler,
      which made head submit [B,B,B,B,AR,AR,...] while worker
      submitted [B,AR,...,AR,B,AR,...,AR,...]. NCCL collectives match
      by submission position; mismatched positions deadlocked the
      comm. Observed empirically as "NCCL broadcast took 51.1s" on the
      worker followed by stale comm reads. Now the broadcasts live
      inside decode_batch_dispatch's EP branch, interleaved with each
      self.decode() — both ranks submit [B,AR,...,AR,B,AR,...] in
      matching order.

   c) self.decode() writes single-row logits to row 0 of the logits
      buffer on every call. Looping N decodes overwrites the buffer
      so process_decode_logits ends up sampling N rows of the last
      seq's logits. Stage each seq's row to host immediately after
      its decode() (the buffer is fresh within the scope of one
      decode()) then upload the assembled [n, vocab] back to the
      logits buffer before returning. Same pattern as the existing
      MLA per-seq fallback below.

   And one stream subtlety: decode_dispatch overrides the caller's
   `stream` parameter and uses `self.gpu.default_stream()` internally
   for its forward-pass kernels. Issuing the per-seq D2H copy on the
   scheduler's stream=0 (legacy NULL stream) landed it on a different
   CUDA stream than the GEMV that wrote the logits. The copy could
   read stale logits even though both streams "should" have synced.
   Use self.gpu.default_stream() throughout the EP path so the copy
   queues onto the same stream as the GEMV writes.

   Also broadcast in the n=1 short-circuit so the scheduler is fully
   relieved of decode-broadcast responsibility (the EP n>1 branch
   couldn't move broadcasts inline without the n=1 path also handling
   its own).

2. decode_step.rs — remove the per-token broadcast loop. The
   responsibility moved into decode_batch_dispatch, so step_decode_only
   no longer needs to know about EP at the cmd-broadcast layer.

3. serve.rs — honor --max-batch-size under EP when ep_protocol_v2()
   returns true. v1 still clamps to 1 so existing deployments are
   byte-identical without ATLAS_EP_PROTOCOL=v2.

Operational result on the motivating workload (nemoclaw 122B EP=2,
qwen3.5-122b-a10b NVFP4 + MTP nvfp4 speculative, --max-batch-size 4):

  Wall-clock (concurrent users, max_tokens=80 each):
    v1 max_batch=1 k=4: 2 of 4 tail-spiked to 605s (head-of-line)
    v2 max_batch=4 N=4: 7.01s  (all 4 coherent)
    v2 max_batch=4 N=8: 13.12s (all 8 coherent, 4 in-flight + 4 queued)

  Per-seq throughput is lower than the single-seq baseline (5-10 tok/s
  vs ~38 tok/s) because the head still runs N sequential forward passes
  and the per-step host-staging adds overhead. The user-visible win is
  tail-latency elimination, not aggregate throughput. A true batched-EP
  forward pass — Option A in PR_NOTES — remains the follow-up for the
  throughput multiplier; this PR is the structural prerequisite that
  also delivers the head-of-line fix today.
Eliminates the host round-trip in decode_batch_dispatch's EP branch.
Previously: each per-seq decode() wrote row 0 of the logits buffer, the
host code copied row 0 to a staging Vec, then uploaded the assembled
[n, vocab] back to logits. Two PCIe transfers per seq plus one final
upload.

Now: iterate in reverse, decode each seq (still writing to row 0), then
issue a device-to-device copy from row 0 to the target row i. For i=0
(processed last in the reverse iteration), no copy needed — row 0
already holds seq 0's logits. Stays on GPU memory throughout.

Bench on 2× GB10 (qwen3.5-122b-a10b NVFP4, EP=2,
--max-batch-size 4, MTP nvfp4 speculative):

  N=4:  7.01s -> 5.89s (-16%)
  N=8: 13.12s -> 11.43s (-13%)

The eventual true batched-EP forward pass (one decode_multi_seq call
per step instead of N sequential decode()s) subsumes this — N rows
get written directly by the lm_head GEMV loop and no staging is
needed at all. Until that lands, the d2d cuts the visible bench wall
time without touching kernel correctness.
…i-seq decode

Two coordinated changes that together make the multi-seq decode path a
first-class entry point under EP, using the same kernel set that prefill
already uses.

1. Batched-EP protocol (decode_a2.rs + impl_a2.rs)

   Reintroduces `0xFFFFFFE0` — the batched-decode command code originally
   tried (and reverted) in the foundation cycle. Head broadcasts
   `(seq_id=0, 0xFFFFFFE0)` preamble + N + seq_ids[N] + tokens[N] via the
   new `ep_broadcast_decode_batch_dispatch` helper. Worker matches with
   `ep_worker_decode_batch` (handled BEFORE the slot_idx lookup in
   `ep_worker_step_impl`, like shutdown), reads the payload, builds an
   in-order `Vec<&mut SequenceState>` from the addressed slots, and
   dispatches into the shared compute path.

   The shared compute path is `decode_batch_compute_main` — extracted
   from `decode_batch_dispatch`'s former non-EP main branch so both
   ranks now reach it. The head's EP branch broadcasts the protocol
   primitive, then calls it. The worker's batched handler also calls
   it. No host-staging, no per-seq broadcast loop — one batched forward
   pass per step per rank.

   Comm-stream submission order per decode step is identical on both
   ranks: `B(0) B(0xFFFFFFE0) B(N) B*N(seq_ids) B*N(tokens)` then per
   MoE layer N × `comm.all_reduce(h*elem)` from the per-token loop
   inside the batched MoE path (Avarok-Cybersecurity#2 below).

2. Grouped-MoE in multi-seq decode (trait_decode_multi_seq.rs +
   qwen3_attention/trait_impl/multi_seq/ffn.rs)

   The multi-seq decode path on both qwen3_ssm and qwen3_attention
   layers previously called `self.ffn.forward(normed2_i, ctx, stream)`
   inside a per-token loop — N × (gate GEMV + top_k expert GEMVs +
   weighted sum) per MoE sublayer.

   Refactor the loop into three phases:
     A: per-token residual_add_rms_norm, laying out `norm_output[0..n]`
        as a contiguous [N, h] MoE input
     B: ONE call to `self.ffn.forward_prefill(norm_base, n, ctx, stream)`
        — the grouped-GEMM path that the prefill scheduler already uses.
        Sort tokens by expert, one grouped gate+up GEMM, SiLU, one
        grouped down GEMM, unpermute.
     C: per-token residual_add reading `moe_output[i]`.

   Bug-Avarok-Cybersecurity#6 invariant preserved: SSM outputs are still copied to
   `ssm_out_safe` before Phase A so the batched MoE's writes to
   `moe_output[0..n]` don't clobber the SSM outputs the rms_norm reads.

Perf on 2× GB10 (qwen3.5-122b-a10b NVFP4, EP=2,
--max-batch-size 4, MTP nvfp4 speculative, greedy bench warm):

  N=1: ~38 tok/s aggregate
  N=2: ~29 tok/s aggregate
  N=4: ~38 tok/s aggregate (per-seq ~10 tok/s)
  N=8: ~26 tok/s aggregate (4 in-flight + 4 queued)

Same throughput as the d2d-only path the previous commit shipped —
the architectural ceiling on this hardware at low N is roughly the
single-seq decode rate. The per-seq drop is the expected cost of
sharing compute across N tokens; aggregate stays at the GPU's
weight-load-bandwidth ceiling.

What this PR's structural changes enable that d2d-only couldn't:
- A single-call multi-seq decode entry point that uses the same MoE
  kernels (`moe_w4a16_grouped_gemm_ptrtable`, `moe_topk_softmax_batched`,
  `moe_sort_by_expert`, `moe_unpermute_reduce_indexed`) as prefill —
  one less code path to keep in sync, one less code path to optimize.
- A clear hook for future kernel-level wins: a true batched
  `comm.all_reduce(N*h*elem)` instead of N per-token `comm.all_reduce(h)`
  would land cleanly here without changing the dispatch.
- EP=4 / higher-N regimes where expert reuse becomes meaningful
  (N >> 256/top_k) will exercise the same path; today's grouped-GEMM
  kernel is already in production for prefill and known-correct.

No behavior change under v1 (ATLAS_EP_PROTOCOL unset or != "v2"):
the EP n>1 path is still unreachable because `serve.rs` clamps
max_batch_size=1. Under v2 with N=1, the n=1 short-circuit fires;
the batched code path only sees N>1 when the gate is flipped.
The SSM multi-seq decode path delegated every sequence to the full
single-token decode(), running N independent single-token MoE forwards
(N x top_k expert GEMVs + N per-token all_reduces under EP). The batched
grouped-MoE code meant to replace it sat behind an early return,
unreachable since the bug-Avarok-Cybersecurity#6 buffer-aliasing debugging.

Replace the delegation with a per-seq SSM mixer loop (conv1d/GDN
recurrent state is inherently per-sequence, so the proven single-token
kernels stay) feeding a batch-dispatched MoE:
  - N=2/3: fused forward_k2/k3 -- one batched all_reduce, no per-token
    launch overhead. SSM decode-step 44->35ms at N=2 on GB10
    (qwen3.5-122b-a10b NVFP4, EP=2).
  - N>=4: per-token MoE loop. The generic grouped-GEMM (forward_prefill)
    is a net loss for this 256-expert MoE at small batch -- per-expert M
    is ~1 and the sort/permute/ptr-table overhead, paid once per layer
    across 36 SSM layers, pushed the SSM step to ~140ms vs ~88ms
    per-token. forward_prefill is declined here until a true batched-EP
    MoE kernel exists.

Buffer safety: each per-seq mixer writes its MoE input to norm_output[i]
(distinct per-seq offset); ssm_forward never touches norm_output and its
ssm_out is consumed within the same iteration, so nothing survives across
sequences and the old aliasing cannot recur.

Validated coherent + cross-seq isolated at N=2 and N=4. Removes the
unreachable batched-decode block.
Mirror of the SSM-layer fix (parent commit): the attention layers'
multi-seq FFN used the generic grouped-GEMM (forward_prefill) for the
N>=4 branch, which is a net loss for this 256-expert MoE at small batch
-- per-expert M ~1 and the sort/permute/ptr-table overhead dominates.

Replace it with the per-token MoE loop (identical to decode()'s MoE).
N=2/3 keep the fused forward_k2/k3 branches. Measured on GB10
(qwen3.5-122b-a10b NVFP4, EP=2, N=4): attention decode-block ~40->~24ms,
full step ~132->~122ms, no regression. This also makes the N=2/3 MLA
fallback path (force_seq_ffn) avoid the batched-MoE kernels it was never
safe with.

Validated coherent + cross-seq isolated at N=4.
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 29, 2026

All contributors have signed the CLA. Thank you!
Posted by the CLA Assistant Lite bot.

@camerono
Copy link
Copy Markdown
Author

I have read the CLA Document and I hereby sign the CLA

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant