Skip to content

fix(ssm): batched multi-seq decode with corrected buffer layout#74

Open
pragmaxim wants to merge 4 commits into
Avarok-Cybersecurity:mainfrom
pragmaxim-com:fix/ssm-decode-multi-seq-cleanup
Open

fix(ssm): batched multi-seq decode with corrected buffer layout#74
pragmaxim wants to merge 4 commits into
Avarok-Cybersecurity:mainfrom
pragmaxim-com:fix/ssm-decode-multi-seq-cleanup

Conversation

@pragmaxim
Copy link
Copy Markdown
Contributor

@pragmaxim pragmaxim commented May 20, 2026

Summary

Revives the disabled batched SSM multi-seq decode path with a corrected scratch-buffer layout that addresses the original aliasing bug (#6).

Original bug

The earlier batched implementation wrote conv1d output into ctx.buffers.attn_output(), which is sized for m * mamba2_d_inner * bf16 — half of what n * conv_dim * bf16 needs on Qwen3.5-A3B (conv_dim = 2*key_dim + value_dim ≈ 8K, mamba2_d_inner ≈ value_dim). For n ≥ 2 the conv1d write for seq 1+ scribbled past attn_output's end and corrupted whatever buffer followed in the arena — producing the multilingual gibberish that disabled the path.

Fix

  • Drop attn_output as the conv/GDN target.
  • Write both conv1d output and GDN output into per-seq slices of ssm_conv_out_f32 (sized m * ssm_qkvz_size * 4 = FP32 worst case). Per-seq slice is exactly ssm_qkvz_size * 4, with conv at the base and GDN appended at offset (2*key_dim + value_dim) * 4 — mirroring the single-seq ssm_forward layout.
  • All other scratch buffers (ssm_deinterleaved, ssm_gates, ssm_qkvz, moe_output) use their natural per-token strides, all sized for max_batch_tokens by BufferSizes::from_config.

Batching wins

  • Batched outer ops: rms_norm_residual and residual_add_rms_norm fire once per layer for all N sequences (saves 2*(N-1) launches/layer).
  • Per-seq inner ops: conv1d_update, gdn_decode, etc. stay per-seq — each kernel takes a single state pointer, so multi-state fan-out would need new kernels.
  • Interleaved MoE: MoE.forward writes to moe_output[0..h] regardless of which seq it processes; per-seq residual_add happens before the next MoE call (same pattern decode_batched_inner uses for non-fused K).

Safety fallback

If FP32 conv1d / GDN / gated-RMS kernels aren't loaded (e.g. Metal backend), the path falls back to per-sequence single-decode. The BF16 GDN output lives in attn_output, which can't safely host N>1 concurrent slices.

Commits

  1. refactor(ssm): drop dead batched decode path — removes the 250 lines of unreachable code.
  2. fix(ssm): batched multi-seq decode with corrected buffer layout — adds the corrected batched implementation alongside the per-seq fallback.

Test plan

  • cargo fmt --all -- --check
  • ATLAS_SKIP_BUILD=1 CUDARC_CUDA_VERSION=13000 cargo clippy --workspace --tests (passes locally)
  • Run Qwen3.5-35B-A3B with --max-batch-size > 1; confirm no multilingual gibberish on creative prompts.
  • Run Nemotron-H with --max-batch-size > 1; confirm long-context (≤8K) responses remain coherent.
  • Decode-throughput delta vs the per-seq fallback at N=4 (set ATLAS_GEMMA4_* style override to force fallback if a runtime knob is needed).
  • Verify the BF16 fallback path on a backend without the FP32 conv1d kernel (Metal smoke).

🤖 Generated with Claude Code

The unreachable branch after `return Ok(())` carried 250+ lines of the
prior batched SSM decode implementation gated behind
`#[allow(unreachable_code, unused_variables)]`. The buffer-aliasing bug
that disabled it (Avarok-Cybersecurity#6) is preserved in git history and summarized in the
doc comment, so the dead block adds noise without protecting any
behavior.

Single-sequence delegation and BF16/FP32 residual stride handling are
unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 20, 2026

All contributors have signed the CLA. Thank you!
Posted by the CLA Assistant Lite bot.

@pragmaxim
Copy link
Copy Markdown
Contributor Author

I have read the CLA Document and I hereby sign the CLA

The previous batched implementation (Avarok-Cybersecurity#6) wrote conv1d output into
`ctx.buffers.attn_output()`, which is sized for `m * mamba2_d_inner *
bf16` — half of what `n * conv_dim * bf16` needs on Qwen3.5-A3B
(conv_dim = 2*key_dim + value_dim ≈ 8K, mamba2_d_inner ≈ value_dim).
The out-of-bounds writes corrupted GDN reads on subsequent sequences
and produced the multilingual gibberish that disabled the path.

The corrected layout drops `attn_output` and instead writes both conv1d
output and GDN output into per-seq slices of `ssm_conv_out_f32`
(sized for `m * ssm_qkvz_size * 4` = the FP32 worst case). The
per-seq slice is exactly `ssm_qkvz_size * fp32`, with conv1d at the
slice base and GDN appended at offset `(2*key_dim + value_dim) * fp32`
— mirroring the single-seq `ssm_forward` layout.

Outer ops (`rms_norm_residual`, `residual_add_rms_norm`) are batched
across all N sequences in one launch. State-bearing inner ops
(`conv1d_update`, `gdn_decode`) run per-seq because each carries its
own recurrent state. The final MoE+residual is interleaved per-seq —
same pattern decode_batched uses for non-fused K — so the moe_output
slot can be safely reused between sequences.

If FP32 conv1d / GDN / gated-RMS kernels aren't loaded (Metal backend,
or any future target without the FP32 variants), fall back to the
proven per-sequence single-decode path: the BF16 GDN output goes to
`attn_output`, whose per-token sizing can't accommodate N>1 without
the same overflow.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@pragmaxim pragmaxim changed the title refactor(ssm): drop dead batched decode path fix(ssm): batched multi-seq decode with corrected buffer layout May 21, 2026
@pragmaxim
Copy link
Copy Markdown
Contributor Author

pragmaxim commented May 21, 2026

Pivoted from the dead-code removal to actually reviving the batched path with a corrected buffer layout — see the updated PR description.

The root cause of #6 was that the disabled implementation used ctx.buffers.attn_output() as conv1d output, which is sized for m * mamba2_d_inner * bf16 ≈ half of what n * conv_dim * bf16 needs on the GDN models. The new layout writes conv + GDN output into per-seq slices of ssm_conv_out_f32 (sized m * ssm_qkvz_size * 4), mirroring the single-seq ssm_forward layout.

Two commits on the branch — the first still removes the unreachable dead code as before, the second adds the corrected batched implementation. Happy to squash if you'd prefer a single landing commit.

Note: I cannot currently test it on dgx to validate the fix end-to-end. The CI clippy/fmt jobs pass. Multi-seq SSM gibberish smoke-tests on Qwen3.5-A3B and Nemotron-H are still required before merge.

@tbraun96
Copy link
Copy Markdown
Contributor

Note: I cannot currently test it on dgx to validate the fix end-to-end. The CI clippy/fmt jobs pass. Multi-seq SSM gibberish smoke-tests on Qwen3.5-A3B and Nemotron-H are still required before merge.

I'll give this a test once #63 is merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants