fix(ssm): batched multi-seq decode with corrected buffer layout#74
fix(ssm): batched multi-seq decode with corrected buffer layout#74pragmaxim wants to merge 4 commits into
Conversation
The unreachable branch after `return Ok(())` carried 250+ lines of the prior batched SSM decode implementation gated behind `#[allow(unreachable_code, unused_variables)]`. The buffer-aliasing bug that disabled it (Avarok-Cybersecurity#6) is preserved in git history and summarized in the doc comment, so the dead block adds noise without protecting any behavior. Single-sequence delegation and BF16/FP32 residual stride handling are unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
All contributors have signed the CLA. Thank you! |
|
I have read the CLA Document and I hereby sign the CLA |
The previous batched implementation (Avarok-Cybersecurity#6) wrote conv1d output into `ctx.buffers.attn_output()`, which is sized for `m * mamba2_d_inner * bf16` — half of what `n * conv_dim * bf16` needs on Qwen3.5-A3B (conv_dim = 2*key_dim + value_dim ≈ 8K, mamba2_d_inner ≈ value_dim). The out-of-bounds writes corrupted GDN reads on subsequent sequences and produced the multilingual gibberish that disabled the path. The corrected layout drops `attn_output` and instead writes both conv1d output and GDN output into per-seq slices of `ssm_conv_out_f32` (sized for `m * ssm_qkvz_size * 4` = the FP32 worst case). The per-seq slice is exactly `ssm_qkvz_size * fp32`, with conv1d at the slice base and GDN appended at offset `(2*key_dim + value_dim) * fp32` — mirroring the single-seq `ssm_forward` layout. Outer ops (`rms_norm_residual`, `residual_add_rms_norm`) are batched across all N sequences in one launch. State-bearing inner ops (`conv1d_update`, `gdn_decode`) run per-seq because each carries its own recurrent state. The final MoE+residual is interleaved per-seq — same pattern decode_batched uses for non-fused K — so the moe_output slot can be safely reused between sequences. If FP32 conv1d / GDN / gated-RMS kernels aren't loaded (Metal backend, or any future target without the FP32 variants), fall back to the proven per-sequence single-decode path: the BF16 GDN output goes to `attn_output`, whose per-token sizing can't accommodate N>1 without the same overflow. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Pivoted from the dead-code removal to actually reviving the batched path with a corrected buffer layout — see the updated PR description. The root cause of #6 was that the disabled implementation used Two commits on the branch — the first still removes the unreachable dead code as before, the second adds the corrected batched implementation. Happy to squash if you'd prefer a single landing commit. Note: I cannot currently test it on dgx to validate the fix end-to-end. The CI clippy/fmt jobs pass. Multi-seq SSM gibberish smoke-tests on Qwen3.5-A3B and Nemotron-H are still required before merge. |
I'll give this a test once #63 is merged |
Summary
Revives the disabled batched SSM multi-seq decode path with a corrected scratch-buffer layout that addresses the original aliasing bug (#6).
Original bug
The earlier batched implementation wrote conv1d output into
ctx.buffers.attn_output(), which is sized form * mamba2_d_inner * bf16— half of whatn * conv_dim * bf16needs on Qwen3.5-A3B (conv_dim = 2*key_dim + value_dim ≈ 8K,mamba2_d_inner ≈ value_dim). Forn ≥ 2the conv1d write for seq 1+ scribbled pastattn_output's end and corrupted whatever buffer followed in the arena — producing the multilingual gibberish that disabled the path.Fix
attn_outputas the conv/GDN target.ssm_conv_out_f32(sizedm * ssm_qkvz_size * 4= FP32 worst case). Per-seq slice is exactlyssm_qkvz_size * 4, with conv at the base and GDN appended at offset(2*key_dim + value_dim) * 4— mirroring the single-seqssm_forwardlayout.ssm_deinterleaved,ssm_gates,ssm_qkvz,moe_output) use their natural per-token strides, all sized formax_batch_tokensbyBufferSizes::from_config.Batching wins
rms_norm_residualandresidual_add_rms_normfire once per layer for all N sequences (saves 2*(N-1) launches/layer).conv1d_update,gdn_decode, etc. stay per-seq — each kernel takes a single state pointer, so multi-state fan-out would need new kernels.MoE.forwardwrites tomoe_output[0..h]regardless of which seq it processes; per-seq residual_add happens before the next MoE call (same patterndecode_batched_inneruses for non-fused K).Safety fallback
If FP32 conv1d / GDN / gated-RMS kernels aren't loaded (e.g. Metal backend), the path falls back to per-sequence single-decode. The BF16 GDN output lives in
attn_output, which can't safely host N>1 concurrent slices.Commits
refactor(ssm): drop dead batched decode path— removes the 250 lines of unreachable code.fix(ssm): batched multi-seq decode with corrected buffer layout— adds the corrected batched implementation alongside the per-seq fallback.Test plan
cargo fmt --all -- --checkATLAS_SKIP_BUILD=1 CUDARC_CUDA_VERSION=13000 cargo clippy --workspace --tests(passes locally)--max-batch-size > 1; confirm no multilingual gibberish on creative prompts.--max-batch-size > 1; confirm long-context (≤8K) responses remain coherent.ATLAS_GEMMA4_*style override to force fallback if a runtime knob is needed).🤖 Generated with Claude Code