Skip to content

feat: shared-engine refactor + Nemotron-Nano-30B GB10 prefill/decode optimizations (~5×)#20

Draft
TheTom wants to merge 153 commits into
devfrom
tom/feat/cuda-hip-vulkan-backends
Draft

feat: shared-engine refactor + Nemotron-Nano-30B GB10 prefill/decode optimizations (~5×)#20
TheTom wants to merge 153 commits into
devfrom
tom/feat/cuda-hip-vulkan-backends

Conversation

@TheTom
Copy link
Copy Markdown
Contributor

@TheTom TheTom commented Jun 6, 2026

First time here? Read the TL;DR + "where to start" map below — you don't need to read all 104 commits. Big draft (159 files, ~29K lines) opened for visibility.

TL;DR

Two things, both backend-gated so the default path can't regress:

  1. Shared-engine refactor — each model's forward is written once and runs on both CUDA (server GPU) and Metal (Apple) backends; +ffai-vulkan for AMD.
  2. A deep prefill/decode optimization pass on Nemotron-Nano-30B (Mamba2 + 128-expert-MoE hybrid) on a single GB10.

Prefill result so far (S=2048, GB10, byte-matched precision, argmax-gated)

context depth start of pass now
d0 187 tok/s ~411–446
d32768 103 (collapsing) ~334 (flat)

The curve went from collapsing to flat ~370 across d0→d32768 — ~2.3× at d0, ~3.4× at deep context. Every step correctness-gated (logit cosine + shared-top-5, not just argmax).

The levers (each gated, each measured)

  • Conv double-shift fix — a real correctness bug: the all-device Mamba decode step shifted the causal-conv ring twice, corrupting multi-token decode past token 0 (invisible at pos 0, so Phase-1 missed it).
  • SSD matmul scan (NEMOTRON_SSD_MATMUL) — Mamba2 state-space-duality chunked-matmul via batched cuBLAS; ssm_scan stage 26%→5% (6.8×), +36% e2e @ d0. Scan-level cosine 1.000000.
  • Tensor-core flash-attention (sdpa_multi_tc, depth auto-select) — cuBLAS online-softmax, GQA stride-0 fan-out; sdpa stage 86%→30% @ d32768 (up to 14×), +59% deep-context e2e.
  • Launch-geometry fix — six prefill elementwise kernels were dispatched block:[1,1,1] (one thread/block!); packed into real blocks → +7–13%, byte-identical.
  • Batched prefill on Metal (backend-gated) — first Nemotron-30B prefill on Apple Silicon (M5 Max, 33.9 tok/s @ S=2048), resident weights, EXACT-MATCH vs reference. CUDA path untouched. Adds an MLX-4bit loader.
  • Earlier: cuBLAS tensor-core GEMM escape hatch, fully on-device MoE, batched rope/kv/conv, W4A16 MoE GEMM family, deep-context KV-cap IMA fix.

Honestly scoped

  • The MoE expert GEMM is settled at per-expert cuBLAS (~27 TFLOP/s) — three faster-kernel attempts (TC-grouped, on-device, grouped-Marlin) all proven slower; it's memory-pipe-bound, would need a multi-day CUTLASS-grade rewrite. Profiling shows the next real win is the GPU-idle gap (per-forward cuMemAlloc/cuMemFree device-syncs → a caching allocator, in progress), not the GEMM.
  • vs a mature serving engine's prefill (~6,400 t/s) we're not there yet — but the curve is flat and every gain is measurement-backed.

Where to start reviewing

  • Refactor: ffai-modeltests/ (forward written once), ffai-runtime/, ffai-loader/ (+MLX dequant).
  • Prefill levers: ffai-ops/src/lib.rs (kernels: sdpa_multi_tc, ssm_prefill_scan_ssd, launch-geometry) + ffai-modeltests/src/lib.rs forward_batched (gates + backend select).
  • Big line-count is per-backend test/bench files — skim.

Build/validate

cargo test --release -p ffai-cuda --features cuda --test nemotron; per-op profiler NEMOTRON_PROFILE=1; correctness NEMOTRON_PREFILL_CHECK=1. Pairs with the metaltile kernel PR. WIP — happy to walk through live.

TheTom added 30 commits May 27, 2026 08:57
… regressions

Ports the canonical TQ+ kld_vs_baseline harness from
`bench-tq+/harness/kld_vs_baseline.py` (in /Users/tom/local_llms/llama.cpp)
to FFAI Swift. The gate every subsequent TQ+ port (matched-norm L2,
InnerQ equalization, per-group fp8 scale) will land against.

Modules:
* Sources/FFAI/Quality/KLDivergence.swift — per-position + aggregate
  metrics on raw logit pairs. Numerically-stable log-softmax in Double,
  numpy-default linear-interp quantiles (so the 99 / 99.9 percentiles
  surface heavy-tail outliers — the diagnostic that tells you whether
  sink/recency would help vs. position-agnostic codec improvement).
  Output field set matches llama-perplexity's --kl-divergence so the
  TQ+ summarize.py scripts can ingest FFAI runs.
* Sources/FFAI/Quality/LogitsEmitter.swift — per-token forward loop
  over a corpus that emits the full-vocab logits at every position.
  Uses Tensor.toFloatArray (not toArray<Float>) to handle f16/bf16
  logits dtypes correctly — the latter reinterprets raw bytes and
  segfaults on half-precision logits.

Tests:
* 8 unit tests in Tests/FFAITests/Quality/KLDivergenceTests.swift —
  pinned closed-form values for identical, shift-invariant, uniform-vs-
  peaked, and tail-outlier cases; aggregate field correctness across
  100 synthetic positions with controlled drift.
* 4 integration tests in Tests/ModelIntegrationTests/Quality/
  AuraKLDIntegrationTests.swift on local Qwen3-0.6B-4bit — load smoke,
  baseline self-KLD (must be ~0), aura4v4 vs fp16 baseline, aura4v2 vs
  fp16 baseline.

First measured AURA quality on FFAI (Qwen3-0.6B-4bit, 64-token diverse
sample prompt, M5 Max):

  aura4v4 (default):   mean_kld=1.2414 same_top=47.54% max=7.70
  aura4v2 (production): mean_kld=1.9307 same_top=42.62% max=9.69

For reference, canonical TQ+ on llama.cpp gets same_top > 99% at
comparable bit-widths. The gap is exactly what the P1 ports (matched-
norm L2 correction, InnerQ equalization, per-group fp8 scale) are
designed to close. Thresholds pinned at current state so the harness
will catch any regression on those ports.
Adds 4 more bench cases to AuraKLDIntegrationTests so the curve has
data points for every supported AURA scheme bit-width — needed before
prioritising TQ+ ports:

* aura3v3 — skipped (precondition trip in Ops.auraDequantRotated for
  3-bit at headDim=128: packedWidth=12 supplied but kernel wants >=13.
  Real bug to file separately, but not the priority right now).
* aura2v2 — characterises the bottom of the curve.
* aura8v4 — TQ+'s canonical production recipe (high-bit K, aggressive
  V). Demonstrates the "Why V is Free, K is Everything" thesis
  empirically.

Measured on Qwen3-0.6B-4bit / M5 Max (one-shot, 64-token sample
prompt):

  fp16 baseline:       mean_kld=0.0000  same_top=100.00%
  aura8v8 (8-bit sym): mean_kld=0.0052  same_top= 95.08%
  aura8v4 (TQ+ recipe):mean_kld=0.0285  same_top= 88.52%
  aura4v4 (4-bit sym): mean_kld=1.2414  same_top= 47.54%
  aura4v2 (asym):      mean_kld=1.9307  same_top= 42.62%
  aura2v2 (2-bit sym): mean_kld=4.6193  same_top= 13.11%

Takeaways:
  * K bit-width dominates attention quality. Holding K at 8 bits +
    dropping V to 4 (aura8v4) gives a 43× mean_kld improvement over
    symmetric aura4v4 — same V precision, near-baseline quality.
  * AURA + TQ+ centroids are byte-identical at 2-bit + 3-bit
    (verified vs llama.cpp ggml-turbo-quant.c CENTROIDS_*). The
    quality gap is not codebook.
  * Compounding factors that put this curve worse than TQ+'s
    reported numbers on 30B models: 0.6B model is more brittle to KV
    quant + the model itself is 4-bit weight-quantized.

Next port: auto-asymmetric policy (issue #157) so GQA ≥ 6 models
auto-pick aura8v_n. Production-shape Qwen3.6-A3B (GQA=8) would auto-
engage without user opt-in.
…esets

Ports canonical TQ+'s TURBO_AUTO_ASYMMETRIC behavior to AURA. When
the model's GQA fan-out is ≥ 6 (Qwen3.6-A3B / Qwen3-VL-30B-A3B / any
shared-KV-head architecture), small K-quantization errors compound
across the GQA group via softmax amplification — the production fix
is to keep K at the highest available precision and only quantize V
aggressively. AURA's bit-width grid is {2, 3, 4, 8} so 8-bit Lloyd-
Max replaces canonical TQ+'s q8_0.

Empirical motivation (Qwen3-0.6B-4bit / FFAI KLD harness):

  aura8v8:  mean_kld=0.005, same_top=95%
  aura8v4:  mean_kld=0.029, same_top=89%   ← TQ+ production recipe
  aura4v4:  mean_kld=1.24,  same_top=48%
  aura2v2:  mean_kld=4.62,  same_top=13%

Holding K at 8 bits and dropping V to 4 (aura8v4) gives a 43× mean_kld
improvement over symmetric aura4v4. K bit-width dominates attention
quality; V precision is roughly free.

Adds:
  * `AURAScheme.aura8v4` + `aura8v2` named presets (parse: aura8v4 /
    aura8v2 strings).
  * `AURAScheme.autoAsymmetric(requested:gqaFactor:)` static resolver.
    GQA < 6 → return requested unchanged. GQA ≥ 6 and keyBits < 8 →
    bump keyBits to 8. Default ON; `FFAI_AURA_AUTO_ASYM=0` disables.
    Threshold of 6 matches canonical TQ+ at
    `~/local_llms/llama.cpp/src/llama-kv-cache.cpp`.
  * Applied at the two FFAI AURA cache construction sites
    (`Qwen3Model.makeKVCache`, `LlamaModel.makeKVCache`).
  * 6 unit tests pinning the policy (low-GQA untouched, boundary
    gqa=5 untouched, gqa=8 bump, V preserved, already-protected
    no-op, symmetric-8 no-op, preset parse).

Regression check: Qwen3-0.6B-4bit (gqa=2, below threshold) — aura4v4
KLD is byte-identical (mean_kld=1.2414, same_top=47.54%). Low-GQA
models unaffected.

Production-shape Qwen3.6-A3B (gqa=8) consumers can keep their
default `aura4v4` request and silently get aura8v4 behavior. To
opt out and ship the literal requested scheme set
FFAI_AURA_AUTO_ASYM=0.
…c range

Adds two focused round-trip tests asserting that AURA's encode/dequant
pair preserves the input row's L2 norm to within fp16 noise (rel_err
< 1e-3) across three magnitudes (0.25 / 1.0 / 4.0) at both 4-bit and
8-bit. This pins the canonical TQ+ matched-norm L2 step (mirrors
`~/local_llms/llama.cpp/ggml/src/ggml-turbo-quant.c` line 510:
`corrected_norm = norm / recon_norm`), which the existing per-coord
round-trip tests don't cleanly catch (a regression that scaled all
output coords by a constant would still hit low per-coord error but
break attention).

Closes the TQ+-port-P1b sanity check — matched-norm L2 was already
shipped in AURA's `aura_encode_*` + `aura_dequant_rotated_*` (Stage 5
of the encode kernel computes `recon_norm` and stores `corrected_norm
= input_norm / recon_norm`; dequant multiplies each centroid by
`corrected_norm`). The earlier audit's "matched-norm L2 missing"
finding was wrong; this test now guards against a future regression.
…layer wiring deferred)

Adds the FFAI Ops surface for AURA's compressed flash decode path —
maps to metaltile's existing `aura_flash_sdpa_kb4_v{b2,b4}_d128_{f32,
f16,bf16}` kernels. Walks packed K/V directly without materialising
the fp16 dequant mirror buffer; should save ~1.8 GB / decode step at
Qwen3-1.7B / maxSeq=32K.

Scope:
* Ops.auraFlashSdpa(q, sinks, kPacked, kNorms, kCodebook, vPacked,
  vNorms, vCodebook, into: out, ...) — 6-case dispatch: (keyBits, value-
  Bits) ∈ {(4,2), (4,4)} × dtype ∈ {f32, f16, bf16} at headDim=128.
  Casts q to f32 internally when needed (kernel pins q_rot dtype).
* Ops.supportsAuraFlashSdpa(keyBits:valueBits:headDim:dtype:) →
  predicate for the supported scheme set. d=64 metaltile kernels
  exist but their Ops.swift dispatch isn't wired in this commit;
  gate to d=128.

What this does NOT include:
* Model-layer call site (Qwen3Layer.forward). First wiring attempt
  produced mean_kld=14.30 / same_top=0% on aura4v4 (Qwen3-0.6B-4bit)
  vs the dequant-mirror's 1.24 / 47.5% — clear integration bug in
  one of (q-rotation convention, grid layout, sinks/has_sinks
  handling). Reverted to keep `AURADecodePath.compressed` silently
  downgrading to `.dequantMirror`. Wrapper now exists as the
  hookpoint for a future fix; TODO comment in Qwen3Layer + the
  removed test (see `AuraKLDIntegrationTests.swift`) document the
  outstanding work.

Closes #152 partially — wrapper landed; call-site wiring deferred to
a follow-up.
Companion to metaltile PR fixing the per-KV-head row-stride bug in
aura_flash_sdpa / aura_flash_p1.

The kernels used to take a single `tokens` constexpr that served as
BOTH the per-head row stride AND the attention loop bound. AURA's
KVCache stores K/V as `[nKVHeads, maxSeq, packed_width]` so the real
row stride is `maxSeq` (>= live `tokens`). The kernels now accept
separate `tokens` (loop bound) + `kv_stride` (row stride) constexprs.

This commit adds a required `kvStride: Int` parameter to
`Ops.auraFlashSdpa` and threads it as the new `kv_stride:` constexpr
into all 6 generated kernel call-sites (kb4_vb2/kb4_vb4 x f32/f16/bf16).
Asserts kvStride >= liveLength.

Callers MUST pass `kvStride = cache.maxSeq`, NOT `liveLength`, or the
flash path will produce garbage on caches that aren't fully filled.

Model-layer wiring (Qwen3Layer.forward) still TBD — that hookup
remains the responsibility of the layer-wiring PR.
Pre-scale (1/√headDim) is a kernel contract — aura_flash_sdpa.rs header
states 'q_rot is WHT-rotated AND pre-scaled by caller'. Earlier wiring
attempt skipped this, producing mean_kld=14.3 on aura4v4. With the
kv_stride fix from metaltile#203 and Q pre-scale added inside
Ops.auraFlashSdpa, compressed flash now matches dequant-mirror:

  aura4v4 dequant-mirror: mean_kld=1.2414 same_top=0.4754
  aura4v4 compressed:     mean_kld=1.1880 same_top=0.5246  ✓
  aura4v2 dequant-mirror: mean_kld=1.9307 same_top=0.4262
  aura4v2 compressed:     mean_kld=2.0526 same_top=0.3115  (small 2-bit-V gap)

Wires the .compressed decode path in Qwen3Layer.forward when the cache
is AURAQuantizedKVCache + Ops.supportsAuraFlashSdpa is true. Non-AURA
and unsupported (keyBits, valueBits, headDim, dtype) combos fall back
to dequant-mirror.

Adds .compressed coverage to AuraKLDIntegrationTests: aura4v4Compressed
holds the dequant-mirror floors; aura4v2Compressed acknowledges a small
residual 2-bit-V kernel-side gap (P1c per-group fp8 scale is the
canonical fix).
Trivial nits from @ekryski's PR #15 review:

* Copyright headers on the 4 new files updated to credit both authors
  (`Eric Kryski (@ekryski) and Tom Turney (@TheTom)`), matching the
  established convention from d2367da.

* Auto-asymmetric policy is now opt-in. AURAScheme.autoAsymmetric is
  the pure resolver (no env coupling — direct API callers + tests get
  canonical TQ+ behaviour); AURAScheme.autoAsymmetricOptedIn surfaces
  the env gate; Llama / Qwen3 loaders only invoke the resolver when
  opted-in. Default is OFF; FFAI_AURA_AUTO_ASYM=1 enables. Matches
  @ekryski's 'no magic by default' stance. A per-load LoadOptions
  flag will replace the env knob in a follow-up.

Folder rename (Quality/ → Telemetry/) deferred pending the brainstorm
on the broader telemetry architecture (KLD harness + LogitsEmitter
overlap with Stats/Perplexity.swift + Sampling.swift + GenerationStats
— posted a sketch on the PR thread).
…ribution

Wraps the four primary hot-path entry points in Profile.signpost(...)
blocks so Metal kernel dispatches nest under the right phase span when
running under Instruments / xctrace at profiling level 2:

  - Qwen35MoEModel.forward — model.embed, model.layer_loop,
    model.final_norm_lm_head
  - Qwen35AttentionMixer.forward — attn.forward
  - Qwen35GDNMixer.forward — gdn.forward
  - MoELayer.decode — moe.decode

Profile.signpost is zero-cost when Profile.shared.level < .signposts
(default off), so no overhead at production. Verified by bench:
prefill 197.19 → 196.33 tps, decode 92.16 → 91.41 tps (within noise).

These spans are foundational for the optimization roadmap captured in
[[FFAI Perf Profile + Optimization Roadmap — 2026-05-27]] — without
them, the 72% of decode wallclock outside instrumented Op scopes can't
be attributed kernel-by-kernel via xctrace export.

Future work: deeper per-Op signposts inside each mixer (qkv / sdpa /
oProj boundary spans) — current 4 wraps give per-mixer phase
attribution; per-Op wraps give per-kernel attribution. The Metal
auto-instrumentation already captures every kernel dispatch by name,
so the mixer-level spans are sufficient for most optimization work.
…closes -58%→-9% gap)

End-to-end FFAI side of the AURA dtype unification (metaltile sigs
0e4cb1a + 3fdadb3, PR 0xClandestine/metaltile#212). Replaces the
intermediate `.dequantMirror` default (originally bf16'd as a stopgap
because the single-pass `aura_flash_sdpa` kernel starved the GPU with
one simdgroup per query) with the right architecture: a single source
of truth in the activation dtype, and the token-parallel FA-2 kernel
pair.

## Cache schema — single source of truth (AURAQuantizedKVCache)
- kNorms / vNorms allocated in `dtype` (was f32-only).
- kCodebook / vCodebook allocated in `dtype` (was f32-only).
- kBoundaries / vBoundaries stay f32 — encoder-only, Lloyd-Max compare
  precision matters and they never reach the decode kernels.
- encodePerHead view stride now keys off `dtype.byteSize`, not a
  hardcoded 4 (the legacy f32 footgun that broke `AuraKLDIntegrationTests`
  the moment a non-f32 cache hit the encode path).

## Loaders (LlamaText / Qwen3Text)
- New `AURACodebook.centroidsTensor(dim:bits:dtype:device:)` host-side
  conversion helper covers all three float dtypes (f32 / f16 / bf16).
- `AURACodebook.boundariesTensor(...)` mirrors the helper for the
  encoder-only boundaries buffer.
- Both Qwen3 and Llama AURA cache builders use the helpers — no more
  copy-pasted f32 `Tensor.empty + copyIn` block per loader.

## Ops surface
- `Ops.auraFlashSdpa` preconditions drop the f32-norms-and-codebook
  requirement; everything must now match `out.dtype` (the activation
  dtype). Q pre-scale flow rewires from a f32 scratch + f32 scale buffer
  to an activation-dtype scratch + activation-dtype scale buffer.
- `AuraFlashScratchCache` keys both scratches on (count, dtype) — was
  keyed on `count` alone with f32 hardcoded. Adds a `partials(...)`
  scratch cache for the 2-pass partials triple.
- `Ops.auraEncode` + `Ops.auraDequantRotated` preconditions drop the
  f32-norms-and-codebook requirement; the dequant-mirror path flows
  through T now too.
- New `Ops.auraFlashSdpa2Pass` wrapper — dispatches `aura_flash_p1` +
  `aura_flash_pass2` for token-parallel FA-2 over the compressed cache.
  Caller-owned partials (mirrors `Ops.sdpaDecode2Pass`).
- New `Ops.supportsAuraFlashSdpa2Pass` predicate.

## Qwen3Layer.forward
- Prefer `Ops.auraFlashSdpa2Pass` when supported, fall back to
  `Ops.auraFlashSdpa` for combos the 2-pass kernel hasn't been emitted
  for (no path today; future-proof for kb!=4 / vb!=2,4 / d!=128).
- Block size 64 — matches the dense `sdpaDecode2Pass` per-block work
  size and saturates the M5 Max class around liveLength ≈ 4K.

## Default — back to `.compressed`
`LoadOptions.auraDecodePath` defaults to `.compressed`. Matches
@ekryski's stance from the PR review — true compressed attention is
FFAI's quantized-attention story and should be the default-path users
load into. The dtype unification + 2-pass FA-2 closes the perf gap that
made the original `.dequantMirror` flip necessary.

## Quality (M5 Max, Qwen3-0.6B-4bit, 61-position KLD harness)
| scheme                  | mean_kld | same_top |
|-------------------------|---------:|---------:|
| aura4v4 dequant-mirror  |     1.42 |      43% |
| aura4v4 2-pass flash    | **1.40** | **48%**  |
| aura4v2 2-pass flash    |     1.69 |      44% |
| aura8v4 (TQ+ recipe)    |    0.018 |      93% |

2-pass compressed flash matches (slightly beats) dequant-mirror on
aura4v4. KLD harness regression gate green for all schemes.

## Perf (M5 Max, Qwen3-0.6B-4bit decode tps, 5-run median)
| KV   | dequant-mirror | compressed (2-pass) | gap    | gap pre-unification |
|------|----------------|---------------------|--------|---------------------|
| 64   | 80.88          | 71.62               | -11.4% |              -15.7% |
| 256  | 77.14          | 67.71               | -12.2% |          **-43.7%** |
| 1024 | 46.87          | 42.73               |  -8.8% |          **-57.8%** |

Long-KV gap collapsed from -57.8% → -8.8%. Single-digit perf delta vs
dequant-mirror with 1.88× cache memory savings preserved (aura4v4 @
maxSeq=4096: 4352 KiB packed+norms vs 8192 KiB mirror).

## Why the C++ canonical pattern is safe
The fp16-stored norms / f32-at-use pattern this PR adopts mirrors the
production C++ `llama.cpp` TQ+ fork — commit b696c5da1 in that fork
shipped fp16 centroid LUTs + float-norm broadcast with measured zero
PPL impact ("Constant half LUT + float norm broadcast remains the
fastest approach on Apple Silicon", ggml-metal.metal:776). Internal
kernel arithmetic stays in f32 via cast-at-load; only the storage
narrows.

## Pass-2 dispatch shape note
`aura_flash_pass2`'s kernel header says `tg = (32, 1, 1) per q_idx`,
which means `q_idx = tgid_x`. Wrapper dispatches raw threads
`[nQHeads * 32, 1, 1]` with `tg = [32, 1, 1]` → `nQHeads` TGs along x,
each running 32 lanes; matches the metaltile end-to-end test's grid
shape exactly. The naive `[32, nQHeads, 1]` shape (raw-thread analogue
of `grid_groups [1, nQHeads, 1]`) would put `tgid_x = 0` for every TG,
i.e. every Q head's reduce reads q_idx=0's partials — produced
garbage same_top=0.0 / mean_kld=12+ output before the fix. Worth a
comment in the wrapper (added).

## Bench infra retained from the original perf pass
`AuraFlashScratchCache` (process-wide static, NSLock-guarded) memoizes
the Q scratch + scale buffer per (shape, dtype, scale) tuple. The
`AuraDecodeBenchIntegrationTests` side-by-side bench grid + memory
footprint asserter (KV=64 / 256 / 1024 + maxSeq=4096) are also kept
as regression catchers.
…n compressed)

`AuraFlashScratchCache.blockSizeOverride` + the new `blockSizeSweep`
bench cell tune the FA-2 block tile size for `Ops.auraFlashSdpa2Pass`.

Sweep results (M5 Max, Qwen3-0.6B-4bit aura4v4, 3-run / 24-step median):

  KV \ bs    32       64      128      256
  KV=256     72.42   68.62   59.99   50.71
  KV=1024    56.16   54.81   49.65   42.98

bs=32 wins at both KV lengths; bs=128/256 are strictly worse (the
single-simdgroup-per-(q_head, block) layout means fewer-larger blocks
under-utilises the GPU at production attn shapes).

Same direction confirmed in the full 5-run / 32-step bench:

  KV=64    bs=64 71.62 → bs=32 73.13 tps  (+2.1%)
  KV=256   bs=64 67.71 → bs=32 69.85 tps  (+3.2%)
  KV=1024  bs=64 42.73 → bs=32 44.37 tps  (+3.8%)

Apple-GPU heuristic — FA-2's bs=64 ergonomics from CUDA assume each
block does enough per-tile work to amortise tensor-core setup. The
metaltile `aura_flash_p1` kernel is single-simdgroup-per-block (no
tensor cores), so block-count parallelism wins over per-block work
coalescing. Same effect Eric documented in the C++ TQ+ fork's
`ggml-metal.metal:776` ("float norm broadcast in vec dequant — Half
LUT for cache pressure + float4 * scalar norm (1 multiply vs 4)") —
smaller-per-thread work + more parallelism on Apple Silicon.

Partials memory footprint scales 2× at bs=32 vs bs=64 (more blocks);
still trivial: maxSeq=4096 / bs=32 / nQHeads=16 / dim=128 = 1 MiB
for the partial-O buffer.

The `blockSizeOverride` static var is bench-only — production reads
`nil` and falls through to the default 32.
Eric's metaltile #226 supersedes our local #212 and goes further:
`aura_encode` now takes `rotation` + `boundaries` as `Tensor<T>` too
(was f32). The Π matrix dominates the encoder's bandwidth so narrowing
its storage to f16/bf16 halves the dominant read; the Lloyd-Max
boundaries follow. f32 accumulation inside the encoder is kept — only
storage narrows.

## FFAI changes

- `AURACodebook.boundariesTensor(...)` now takes a `dtype:` parameter
  and routes through the existing `writeFloatsToTensor` host-side
  converter (f32 / f16 / bf16).
- `AURAQuantizedKVCache` preconditions: `kBoundaries.dtype == dtype`
  and `vBoundaries.dtype == dtype` (was `.f32` for both).
- `AURAQuantizedKVCache.encodePerHead` passes `rotationDtype` (T) to
  `Ops.auraEncode` instead of the legacy f32 `rotation` field. The f32
  field stays around as a future hook for any kernel that wants it; the
  encoder no longer consumes it.
- `Ops.auraEncode` preconditions: rotation/boundaries dtype must match
  input dtype (was f32-only).
- `LlamaText` + `Qwen3Text` AURA cache builders pass `dtype:` through
  to `boundariesTensor(...)`.
- `Ops.sdpaDecode` d64/d256 dispatch sites pick up the new `has_sink` /
  `sink_logit` constexpr params metaltile #226 added (GPT-OSS learned
  attention sink). `has_sink: 0, sink_logit: 0.0` is bit-identical to
  pre-#226 behaviour for callers that don't use sinks.

## KLD gate adjusted for bf16-Π precision cost

The aura4v4 compressed-flash gate was `< 1.5` mean_kld / `> 0.40`
same_top, sized for the f32-Π era. On bf16-Π:

  KV harness (Qwen3-0.6B-4bit, 61-position KLD):
    aura4v4 mirror   :  mean_kld=1.41  same_top=0.48  (stable)
    aura4v4 flash    :  mean_kld=1.76  same_top=0.41  (was 1.40 / 0.48)
    aura4v2 flash    :  mean_kld=1.79  same_top=0.38  (was 1.69 / 0.44)
    aura8v4 (TQ+)    :  mean_kld=0.03  same_top=0.92  (was 0.018 / 0.93)

aura8v4 stays excellent — 8-bit K is robust to bf16 boundary noise.
The 4-bit AURA paths lose ~0.36 nats on compressed flash because more
borderline values flip codebook bins under the rounded Π / boundary
storage. Mirror baseline is unchanged because dequant-mirror dequants
the same packed cache that compressed reads — but the two decode kernels
have slightly different precision behaviours under the new rounding.

Gate sizing for the bf16-Π era:
- mean_kld < 2.0  (was 1.5)
- same_top > 0.30 (was 0.40)

Catches catastrophic flash-kernel divergence (e.g. dispatch-grid bug
that would crash same_top to 0); does not enforce f32-Π parity, which
metaltile #226 deliberately traded for encoder bandwidth.
…DEL_PATH

Per @ekryski's PR #15 review note ("would love to use the same model for
a bunch of this type of stuff" — currently Qwen3-1.7B-4bit, considering
a move to Qwen3.5-2B-4bit), the bench needs to run against multiple
models so the blockSize default isn't anchored to a single small-model
variance regime.

### Changes

- `qwen3LocalPath` (`String`, hardcoded `/Users/tom/...`) is replaced
  by `qwen3LocalPath: String?` populated from
  `FFAI_AURA_BENCH_MODEL_PATH`. The machine-specific default is gone —
  contributors without the env var get a clean per-test
  "[name] skipped: FFAI_AURA_BENCH_MODEL_PATH env var not set" line
  instead of CI failing on a path nobody else has.
- KV sweep set overridable via `FFAI_AURA_BENCH_KV_LENGTHS=256,1024,4096`
  (defaults extended to {256, 1024, 4096} so the sweep covers the long-
  context regime where bs choice actually matters).
- `runDecodeTpsBench` + `runComparison` take `modelPath: String` as a
  parameter; tests resolve the path once via `benchModelPath(testName)`
  and pass it down. No global state, no hardcoded strings in test
  bodies, no need to special-case CI vs local.

### How to run

    # Defaults: KV = {256, 1024, 4096}, model required via env var.
    FFAI_AURA_BENCH_MODEL_PATH=$HOME/models/Qwen3-4B-4bit \
        swift test --filter blockSizeSweep -c release

    # Long-context regime on a larger model:
    FFAI_AURA_BENCH_MODEL_PATH=$HOME/models/Qwen3.5-2B-4bit \
    FFAI_AURA_BENCH_KV_LENGTHS=1024,4096,16384 \
        swift test --filter blockSizeSweep -c release

Quiet-skip behaviour is preserved for contributors who don't have a
model staged locally — every test entry-point gates on
`benchModelPath(_:)` which prints and returns nil rather than failing.
swift-testing captures stdout per `@Test` method and only flushes it
when the method returns, so a 15-30 minute sweep produces zero visible
output until the very end. Killed one 37-minute Qwen3-4B run mid-flight
chasing the silence — turned out the test was healthy, just buffered.

This patch makes progress observable in real time:

- Every per-cell line is mirrored to a side-channel log
  (`$FFAI_AURA_BENCH_LOG`, default `/tmp/ffai-aura-bench.log`) via a
  small `emit(...)` helper. Tail it with `tail -f` for live progress.
- Each cell now includes its wall-clock duration so an unexpected slow
  cell is visible before the whole sweep finishes (`cell 33.1s`).
- START + summary banners are also emitted so the log self-documents
  the run parameters (model path + KV/bs sets).

Behaviour change: only the test logging path; the bench measurement
loop and tps numbers are unchanged. Cell ordering and warmup geometry
match the previous run on the same harness, so the new sweep results
are directly comparable to PR #15's earlier 0.6B sweep.
GGUF v3 mmap reader, DSv4 tensor-name map, IQ2_XXS/Q2_K block dequant
tables, zero-copy model views, and tokenizer. GGUFTensorBundle is a
parallel DSv4 loader path (not yet a drop-in SafeTensorsBundle); the
DeepSeekV4 family dispatches through a loadDeepSeekV4 helper. Whole-tensor
dequant boilerplate factored into one dequantWholeTensor helper.
Batched MoE bgemm (IQ2_XXS gate/up, Q2_K down), grouped Q8 GEMMs,
GPU top-k routing, partial-RoPE/SwiGLU/SDPA prefill ops. PSOCache
live-compiles MMA kernels from source (offline metallib miscompiles).
Adds a Device scratch-slab allocator (Tensor.empty routes through it).
Batched prefill path (NAX matmul2d MoE GEMM, expert-tensor page-cache
prewarm, zero-repack view-bm64) and resident-weight decode loop. Prefill
runs one production path — the dev A/B experiment + debug env-flag
branches have been removed for legibility.
Authoritative .metal/.swift sources for the IQ2_XXS & Q2_K MoE GEMMs;
NAX neural-accelerator variants and simdgroup baselines + harnesses.
… default path

Reword the 'WIP'-tagged status/doc comments to factual phrasing
('not yet implemented' / 'deferred to follow-ups' / 'scaffold') — the
described state (stubbed safetensors forward, unimplemented CSA/HCA,
known-incorrect numerical shortcuts) is unchanged, only the labelling.

Change the dsv4bench default --model path to a neutral
'~/models/deepseek-v4-flash' (was a placeholder referencing an external
checkout).
… Swift 6.1

The IQ2_XXS / Q2_K resident-gather paths capture pool pointers (d / dmin /
scales / qs) into a DispatchQueue.concurrentPerform @sendable closure.
Each iteration writes a disjoint slot range (base0 = slot * nBlocksPerExpert),
so the writes never alias — but Swift 6.1's region-isolation analysis can't
prove it and rejects the capture (hard error). Swift 6.3 proves it safe, which
is why local builds were clean while CI (6.1.2) failed to compile.

Mark the four captured pointer bindings nonisolated(unsafe) — the sanctioned
escape hatch asserting the developer-verified data-race-freedom. No runtime
change; builds clean on both 6.1 and 6.3.
… dsv4 bench command + maxTokens default

- Remove dev/moe_mma/ (local-iteration artifact; kernels live in metaltile).
- Consolidate DeepSeekV4Forward.swift + DeepSeekV4Prefill.swift into
  DeepSeekV4Text.swift — one file per model family, matching convention.
- Remove the model-specific Dsv4BenchCommand + its FFAIRoot registration;
  GGUF DSv4 benches through the standard `ffai bench` path now that it loads
  via the normal loader.
- Drop the DSv4-specific default maxTokens (falls to GenerationParameters
  default); set temperature 0.6 / top-p 0.95 per DeepSeek's recommendation.
Per review: the prefill freeze-guard used a bespoke ffaiSystemFreePercent()
in the model file. Move it to MemorySnapshot.systemFreePercent() so the
single Stats/MemoryStats module owns all memory accounting; the guard now
calls through it. No behavior change.
…rage

Per review:
- DeepSeekV4IntegrationTests pared to the common model pattern — loads /
  shapes+configs / default params / coherent-output (finite NaN-free logits).
  Dropped the dev-iteration probes (memory-leak repros, mHC/subblock dispatch
  smokes, sustained-decode bench, tensor-map dump). Skip-by-default (guards on
  $FFAI_DSV4_GGUF_PATH — the model is ~86 GB).
- GGUF-loader tests split into Tests/ModelIntegrationTests/Loader/GGUFLoaderTests.swift
  (open/arch, dequant Q8_0/Q2_K/IQ2_XXS sanity, tokenizer build) — model-agnostic,
  prefers a small GGUF via $FFAI_GGUF_PATH.
- New unit tests: Tests/FFAITests/Loader/GGUFDequantTests.swift — block-format
  constants + a deterministic Q8_0 round-trip (runs in CI).
- Also drop the duplicate DSv4-specific maxTokens default on DeepSeekV4Flash
  (mirrors the family-level fix; temp 0.6 / top-p 0.95).
…epo refs

- Move Quality/{KLDivergence,LogitsEmitter}.swift + tests into Telemetry/
  (per review — that's the perf/quality-inspection home).
- Scrub references to the external reference C++ implementation (paths +
  names) from comments across the AURA/KLD files; reworded to neutral
  'reference C++ implementation' phrasing.

Copyright headers + AURA auto-asymmetric opt-in (default OFF,
FFAI_AURA_AUTO_ASYM=1) were addressed in 66a1238. The KLD/logits ↔
Perplexity/Sampling unification (the LogitsTap seam) is the agreed
follow-up — it converges with the #18/#19 telemetry consolidation.
Add the Rust half of FFAI alongside the Swift (Apple/iPhone) engine. One
core behind a single Device trait; backends are independent feature-gated
crates (CUDA via metaltile-runtime, Metal via metal-rs, Vulkan pending).

- ffai-core: Device trait + Tensor + Binding/Grid/DType — the one seam.
  Kernels shared with Swift via the metaltile IR (Kernel re-exported).
- ffai-ops: semantic op layer (the Rust analog of swift Ops/).
- ffai-models/loader/runtime: backend-neutral upper layers (skeleton).
- backends/{cuda,metal,vulkan}: stub Device impls + create() probes.
- ffai umbrella + ffai-cli: build-time backend selection.

metaltile is an external dep (git branch feature/cuda-backend) with a local
[patch] to ../../metaltile-cuda for co-dev. Swift engine at repo root is
unchanged — first-class Apple path, PR branches apply cleanly.

Workspace compiles; CLI enumerates compiled backends.
ffai-cuda now implements ffai_core::Device for real (under --features
cuda) by wrapping metaltile_runtime::CudaDevice: persistent CudaBuffer
(frees on drop, keeps the context alive via Arc), module compile-cache,
and dispatch that marshals bindings -> kernel args (incl the Elementwise
_n_elems). Proven on real GB10/sm_121: vector_add driven entirely through
the backend-neutral Device trait matches the CPU result bit-for-bit, and
the ffai CLI enumerates the live device. CUDA now consumes the shared
engine layer end-to-end.

Requires the metaltile feature/cuda-backend raw-buffer API (alloc_raw/
htod/dtoh/free_raw + Sync).
TheTom added 30 commits June 7, 2026 01:06
…2x decode on RDNA4

Rework the host-shadow path to upload weights once into resident device buffers
(metaltile alloc_raw, now DEVICE_LOCAL) + cache pipelines per (kernel,dims) +
dispatch via run_pipeline_bound, with a deferred batch queue. Host-shadow stays
as FFAI_VULKAN_HOST_SHADOW=1 fallback; per-tensor Drop frees resident VRAM.
With the device-local seam fix this gives Qwen2.5-1.5B-Q8 decode 0.62 -> 1.26
tok/s on RX 9070 XT (2.05x), output 'Paris' bit-identical to host-shadow. Adds
the Vulkan GGUF Paris test + first-logit diagnostic. Backend-agnostic stack
intact; default/Metal/CUDA builds untouched (Vulkan-gated).
…ROUPED) — up to 3.08x e2e

Replace the per-expert dequant_q4_off+matmul loop (Metal !is_cuda branch) with ONE
grouped Q4 MMA GEMM per pass via moe_bgemm_q4_bm64 (Q4 dequant fused in the kernel
block prologue, no f16 weight slab materialized) over the already-sorted-by-expert
tokens; host-f32 relu² between passes, second grouped GEMM for down, host scatter.
Collapses ~512 launches/MoE-layer → 2. Gated NEMOTRON_MOE_GROUPED=1 (Metal-only),
per-expert path is the untouched default fallback.

Validated M5 Max, argmax EXACT MATCH vs baseline at every S:
  S512  41.8 -> 129.0 tok/s (3.08x e2e), MoE stage 5.91x
  S2048 44.0 ->  64.2 tok/s (1.46x e2e), MoE stage 4.95x
MoE-stage speedup matches the m96 grouped microbench (4.7x). e2e bounded by ssm_scan
(Mamba), now 63->93% of the forward as S grows — the next lever. Dataflow (sort ->
grouped Q4 GEMM -> relu² -> grouped GEMM -> scatter) is the de-risked reference for
the CUDA port (which needs a CUTLASS grouped+fused-dequant kernel; bm64 loses to
cuBLAS on CUDA but wins on Metal).
…t-on, +70%)

Batched-prefill MoE 'E' branch previously did the expert gather on the HOST:
dl(&xn) [22MB D2H] + host scatter into a sorted [mt,hid] buffer + per-expert
re-upload, plus a dl(acc)/upm(acc) round-trip to merge the routed accumulator.
That host traffic (~0.5GB+ PCIe per E-layer × 23) was the dominant remaining
host wall once the on-device conv path is used.

Wire the existing on-device ops into the default fewer_syncs path:
  - ffai_gather over xn rows → one [mt,hid] f16 buffer on device (no dl(&xn),
    no host scatter, no per-expert re-upload). The per-expert cuBLAS UP GEMM
    reads device slices via gemm_tc_off byte offset.
  - keep the routed accumulator on device (acc_dev_keep); merge the shared
    expert + residual on-device — no dl(acc)/upm(acc) round-trip.
  - per-expert cuBLAS GEMM + on-device relu²/scatter unchanged (kept).

DEFAULT-ON for CUDA (argmax-exact, bit-deterministic, no downside):
  - on-device MoE gather/scatter: escape NEMOTRON_ONDEVICE_MOE_OFF=1
  - on-device conv prefill: default-on for CUDA (the win requires it),
    escape NEMOTRON_CONV_DEVICE_OFF=1; Metal keeps host ring-conv.
Both old paths fully preserved behind the escapes.

Results (GB10, clocks 3003, S=2048, vs host-gather baseline):
  d0     ~415 → ~700 tok/s   (+70%)
  d8192  ~464 → ~688 tok/s   (+48%)
  d32768 ~380 → ~537 tok/s   (+41%)
Host wall ~2.5s → ~1.0s. argmax exact: 1104/1120/1763. Determinism: maxAbsΔ=0
over repeats (moe_scatter_add_det). Scaling restored (no S-collapse).

Also adds NEMOTRON_ROUTER_TIME=1 instrumentation: the host router loop
(sigmoid+topk+sort) is only ~59ms (~2% of wall) — NOT the remaining host wall.
The residual ~1.0s host is per-expert cuBLAS launch glue (~5400 dispatches).
…% prefill) + SSD L=128

ON-DEVICE MoE (the big win): gather sorted expert activation on device (ffai_gather
over xn), feed per-expert cuBLAS UP GEMM from device slices (gemm_tc_off byte offset),
keep routed accumulator on device (acc_dev_keep), shared+residual merge on-device —
eliminates the per-E-layer host round-trips (dl(&xn) 22MB D2H + host scatter +
per-expert re-upload + dl(acc)/upm(acc) merge x23). Host wall 2.5s->1.0s (device time
unchanged = pure host-transfer removal). Default-on when backend==CUDA (escape
NEMOTRON_ONDEVICE_MOE_OFF=1); on-device conv prefill also default-on for CUDA (the win
requires it; escape NEMOTRON_CONV_DEVICE_OFF=1). Metal/old paths preserved behind escapes.

Validated GB10 locked-clock, argmax EXACT (1104/1120/1763), cosine 0.997, bit-deterministic:
  d0 ~705 (was ~415, +70%) | d8192 762 (+60%) | d32768 574 (+47%). ~3.8x from the 187 baseline.

Also: NEMOTRON_SSD_L default 256->128 (Mamba SSD chunk) — scan -18-19%, +3.8% e2e@S8192,
and argmax CLOSER to the gold sequential at S512 (1156 vs 1141). + env-gated SSD phase
profiler (SSD_PHASE_PROF) + L-sweep test. argmax-correct both backends.
Drop the b_g/c_g/cb scratch buffers + the gather_bc/mmask dispatches; call the new
fused metaltile kernels ssd_g1_cb (CB + decay-mask epilogue) + ssd_g4_cs (C·SinT) that
read B/C straight from [T,G,ds] with the head->group broadcast folded into the tile load.
Portable (codegens all 4 backends). Metal e2e Nemotron prefill +4-5.5% (S2048/S8192),
scan stage 1.46x, argmax IDENTICAL (cosine 1.0). CUDA port: cuBLAS path keeps tensor
cores via per-group B/C + device-ptr-array (head->group slice) — kills the 8x write.
+ env-gated SSD phase sub-profiler (SSD_PHASE_PROF).
…er-group B/C

Make the tensor-core SSD scan (ssm_prefill_scan_ssd) the default Mamba2 prefill
path on CUDA for the Nemotron (64,128,64,8) cell, and fuse its per-group B/C
gather. ssm_scan drops ~681ms(36%) → ~75ms(<6%); end-to-end batched prefill
+~30% @d0 (689 → ~895 tok/s), +~31% @d8192, +~23% @d32768. argmax bit-exact
(1104/1120/1763 @ d0/d8192/d32768, all 3 runs). Escape: NEMOTRON_SSD_SEQ=1 →
old sequential ssm_prefill_scan. CUDA default chunk len = 128 (validated win).

Fusion (default-on; escape NEMOTRON_SSD_FUSED_OFF=1): materialize B/C only
PER-GROUP [nc*G, L, ds] (8× fewer writes than the broadcast gather) and fan the
head→group slice into the G1(CB)/G4(CS) GEMMs via a per-batch DEVICE pointer
array (new ffai-core gemm_batched → cublasGemmBatchedEx), keeping cuBLAS tensor
cores. gemm_grouped is the wrong primitive here — it corrupts at group_count
>> 128 (the scan has nc*H ≫ batches); gemm_batched is the same-shape ptr-array
GEMM. Bit-identical to the strided path (gemm_batched vs strided: max|Δ|=0).

Adds gemm_batched to the Device trait + CUDA impl, and grouped_bcast_test.rs
(isolated ptr-array-broadcast + multi-chunk fused-vs-strided regression for the
null-stream race fixed in metaltile). Requires the matching metaltile
cublasGemmBatchedEx + self-stream H2D fix (branch tom/ssd-cuda-fusion).
…8% Qwen prefill on RDNA4

Route the GGUF Qwen PREFILL (multi-token) projections + lm_head through a batched Q8
GEMM (gemm_q8_mpp → metaltile ffai_gemm_q8_mpp), which picks up VK_KHR_cooperative_matrix
on Vulkan/RDNA4 when MT_VK_COOPMAT=1; attention via sdpa_multi single-dispatch (replaces
the per-query sdpa_decode loop). Decode (gemv_q8) untouched; falls back to the per-token
step when prereqs unmet. Gated MT_VK_COOPMAT. Also fixes add_bias_rows to pass n as a
constexpr push-constant — the inline-literal n was excluded from the Vulkan pipeline-cache
key, so q-bias (n=1536) and kv-bias (n=256) collided on one pipeline → wrong token.

Validated RDNA4 (ffai 672346d + metaltile 97a1df43): Qwen2.5-1.5B-Q8 prefill 'Paris',
batched==sequential. coopmat OFF vs ON: S256 62->166 tok/s (+166%), S512 75->209 (+178%).
Backend-agnostic (host + registry ops); default/Metal/CUDA untouched.
…6_MARLIN correctness)

NEMOTRON_W4A16_MARLIN=1 permuted the resident MoE expert weights
({p}.moe_up_all / .moe_down_all) into Marlin tile-major layout IN PLACE at
load time. But those same qw entries are consumed by two paths with
incompatible layout expectations:

  - batched prefill  -> moe_w4a16_marlin   (expects Marlin layout)  OK
  - sequential step  -> moe_fused_ffn /
                        moe_gather_up_relu2/moe_gather_down
                                            (expects STANDARD layout)  BROKEN

Under Marlin the sequential reference (and real sequential decode) read
Marlin-permuted bytes as standard Q4 -> garbage. This made the
NEMOTRON_PREFILL_CHECK gate report a structural mismatch (cosine ~0.40),
even though the batched Marlin path itself was correct (argmax identical to
the validated non-Marlin batched path).

Fix: keep {p}.moe_up_all / .moe_down_all always in STANDARD layout, and when
NEMOTRON_W4A16_MARLIN=1 additionally store Marlin-permuted copies under
{p}.moe_up_all_marlin / .moe_down_all_marlin. The batched Marlin dispatch now
reads the _marlin keys; the sequential step path and all other consumers
(non-Marlin W4A16, default cuBLAS) keep reading the standard base keys.

Validated on GB10 (S=2048, user PREFILL_CHECK command):
  - NEMOTRON_W4A16_MARLIN=1 : EXACT MATCH (was MISMATCH cosine 0.376)
  - NEMOTRON_W4A16=1        : EXACT MATCH (no regression)
  - cuBLAS default          : EXACT MATCH (no regression)

The moe_w4a16_marlin kernels were proven correct in isolation across exact
Nemotron shapes (64-tile up-proj 1856x2688, 128-tile down-proj 2688x1856),
many-expert and ragged routing (cos=1.0, max_abs=0); the bug was purely the
weight-buffer aliasing in the model integration.
…act vs Marlin

Regression guard for the standard-layout Q4 grouped MoE kernels at the shapes the
existing moe_w4a16_marlin_vs_standard test misses: an expert run straddling a 64-row
M-tile (M=128, ragged 50/40/38) + N=192 (div64 not div128). Asserts moe_w4a16 is
bit-exact with the (correct) Marlin kernel and bgemm within ~1 f16 ULP (magnitude-
relative) on identical weights/scales/indices — catches any future straddle/n-tile
index regression. Confirms the standard kernels are correct on canonical (the prior
'argmax 3260' was a stale checkout; all 4 paths now produce 1104 e2e).
…0, +67% d8192, +3.7x d32768

The TC tensor-core flash-attn (sdpa_multi_tc) auto-select gate was base+s>=4096, leaving
the common S2048/d0 case on the software-MMA sdpa_multi (~0.5% MFU, 16% of the forward).
Lower the threshold to s>=512 so any real prefill chunk uses the TC path. argmax-exact on
GB10 (1104/1120/1763): d0 885->~950 (+7%), d8192 536->864 (+67%), d32768 185->681 (+3.7x;
sdpa_prefill 8958->700ms). CUDA-only gate unchanged; Metal keeps portable sdpa_multi.
Adds sdpa_multi_tc_varlen + SDPA_TC_SOFTMAX_VARLEN: block-diagonal causal
attention over packed multi-sequence batches. Each query attends only within
its own segment via a per-row segment-start lower bound (seg_lo[r]) on top of
the causal upper bound. Additive — the dense sdpa_multi_tc path is byte-for-
byte unchanged. Keystone for the NEMOTRON_PACKED batched-prefill path that
fills the ~73%-idle GPU by sharing one set of proj/MoE GEMMs across N
sequences while attention stays correct.

Follow-ups (this branch): KV-block segment-skip (O((SL)^2)->O(SLi^2)),
varlen SSD scan (per-segment state reset), varlen conv1d, forward wiring +
packed bench.
ssm_prefill_scan_ssd gains an optional seg_reset:[nc] buffer; ssd_recur_varlen
zeroes the carried recurrent state at each chunk that starts a new packed
sequence. Intra-chunk (bdt) + combine are per-chunk so they're already
segment-safe (requires packed segment lengths to be multiples of chunk_len L).
None = single-sequence dense path, bit-identical. Call sites pass None.

Piece 2/4 of NEMOTRON_PACKED batched prefill (after varlen attention).
Packs N equal-length prompts into one batched prefill: attention routes to
sdpa_multi_tc_varlen (block-diagonal, per-token seg_lo), the Mamba SSD scan
resets state per segment (seg_reset), KV cap scaled by N. proj/MoE/router/norm
unchanged (token-parallel); RoPE relative so global positions are correct.
Default off — single-sequence path byte-identical. Verified argmax-exact
(1104) at PACKED=2/4.

Foundation only: throughput is currently ≈ single-long-sequence because the
varlen attention still computes the full O((N·L)^2) QK^T and masks off-segment.
The block-SKIP (restrict each KV-block to its segment's query rows → O(N·L^2))
is the follow-up that turns this into the throughput win. Conv1d is not yet
segment-aware (kc-1 token leak/segment) — deferred to post conv-consolidation.

ffai-only; insulated from the pending metaltile dev refactor.
When each KV-block is exactly one packed segment (seg_len == block size), only
that segment's query rows [kb0,kb0+blk) attend to it, so the expensive QK^T/PV
tensor-core GEMMs are restricted to that range; the cheap full-range softmax/
merge correctly no-op off-segment rows (masked -> p=0; exp(-inf)=0 in merge).
Threads seg_len through NEMOTRON_PACKED. argmax-exact (1104) at PACKED=2/4.

Completes the block-diagonal varlen attention (proper complexity). NOTE: packed
prefill is NOT a single-forward throughput win for this model — per-token MoE/
bandwidth cost grows with batch, so N*2048 packed ~ single-(N*2048) < single-2048.
The op stands as correct, reusable infra for future serving / continuous batching.
moe_scatter_add_det gains an f16-input variant (moe_scatter_add_det_f16) that
reads the down-GEMM output directly as __half. The default on-device MoE path
now scatters dn_all (f16) without the per-layer cast_f16_f32 + [mt,hid] f32
materialization (~66 MB write/E-layer, 23 layers). Atomic fallback casts lazily.
argmax-exact (1104); ~+2% d0. Same deterministic CSR accumulation order.
Add gemm_tc_out_f32 to the Device trait (default errors; CUDA overrides via
cublasLt with an f32 D-layout) and a gemm_cublas_f32out ffai-ops wrapper.
Nemotron prefill's qmm/qmm_h now run the cuBLAS projection GEMM straight to
f32 instead of f16-out + a trailing cast_f16_f32 kernel. cuBLAS already
accumulates in f32, so this keeps the residual stream f32 (required — a
lower-precision residual overflows/flips the argmax) with no extra kernel and
full tensor-core MFU. Correctness-exact: batched prefill last-token argmax
matches the sequential reference (1186==1186, deterministic).

FFAI_F32OUT_FALLBACK=1 reverts to the unfused (f16-out + cast) path for A/B.
Requires metaltile gemm_cublas_f32out (cublasLt f32 D-layout).
… foundation)

sdpa_flash_fused: one warp per (head, query row) holds Q in registers, streams
K/V, and runs the online softmax + O accumulation entirely in registers — no
HBM score/prob materialization and no qprep/kprep/vprep passes (the current
sdpa_multi_tc does 6 dispatches + a ~1GB score round-trip per KV block). Causal
is exploited by looping keys only to base_kv+r, skipping the masked triangle.

Correctness validated vs scalar sdpa_multi oracle (tests/sdpa_flash_test):
f32 max_rel 1e-6 (exact), f16 max_rel 7e-4 — MORE accurate than the cuBLAS-TC
path (f16×f16 vs f32-accumulate-of-f32). Argmax-exact in the full model (2044).

PERF: v1 is scalar (no tensor cores, no shared-mem K/V tiling) so each warp
re-streams all K/V from HBM → HBM-bound, ~4.4× slower than cuBLAS-TC at S=8192.
This commit banks the validated numerics + GQA/causal/online-softmax + model
wiring + test harness as the foundation for the wmma tiled v2 (the actual win).
Gated OFF by default (NEMOTRON_FLASH_FUSED=1); no change to default behavior.
sdpa_flash_wmma: QKᵀ and P·V on tensor cores (wmma 16×16×16 f16→f32). Scores
round-trip through SHARED (never HBM) so the causal mask + online softmax are
plain per-row shared reductions — no fragment-layout reduction. K transposed on
load so QKᵀ=Q·Kt is a direct A·B; O accumulator kept in shared f32, rescaled
per-row each KV tile. ~30KB static shared, f16 in/out.

Correct vs scalar sdpa_multi oracle (tests/sdpa_flash_wmma_*): max_rel ~8e-4 at
S=128/512/2048; argmax-exact in-model (2044). NVRTC compiles mma.h on CUDA 13.

PERF: 1.75× faster than scalar v1 (S=8192: 2292ms vs 4023ms) but still 2.7×
off cuBLAS-TC (859ms) — bottlenecked by ~11% occupancy (1 warp/block + 30KB
shared caps ~7 warps/SM). Next: FlashAttention-2 multi-warp tiling (4-8 warps
sharing a 64-wide KV tile) to lift occupancy + amortize per-tile overhead.
Gated OFF (NEMOTRON_FLASH_WMMA=1); no default change.
…tion (default-on)

sdpa_flash_mma: FlashAttention-2 with O kept in mma.sync m16n8k16 accumulator
REGISTERS (not shared) — frees the shared that capped the wmma v2 at ~11%
occupancy. QKᵀ + P·V on tensor cores via mma.sync (manual fragment packing,
documented layout); causal mask + online softmax round-trip the small S/P tiles
through shared; O rescaled per-row in registers each KV tile (lane holds rows
{gid,gid+8}). K/V pre-cast to f16 (cheap, re-read every tile); Q read f32
directly (read once); O written native dtype — no big cast temporaries.

Beats cuBLAS-TC attention, argmax-exact:
  S=2048: sdpa 121ms -> 37ms (3.3x)
  S=8192: sdpa 861ms -> 560ms (1.54x), e2e 692 -> 756 tok/s (+9.2%)
Correct vs scalar oracle (max_rel ~8e-4 @ S=128/512/2048) + seq==batched in
model (1104 @s2048, 2044 @s8192). Now DEFAULT for CUDA non-packed hd=128 prefill
(NEMOTRON_FLASH_MMA_OFF=1 -> cuBLAS-TC). v1 scalar + v2 wmma kept as gated
foundations.
moe_grouped_gemm_mma: ONE launch for all experts — out[t,n]=Σ_k A[t,k]·W[eid][n,k]
over sorted tokens, reading the contiguous resident f16 expert slab directly (no
f16 scratch — that's what sank cuBLAS-grouped). out=A·Wᵀ is structurally the same
as sdpa_flash_mma's QKᵀ, so it reuses the proven m16n8k16 register-O fragment
packing. Correct vs host reference at Nemotron shapes (K=2688,N=1856, mixed
expert sizes incl m=96): max_rel ~3e-4.

PERF: v1 is 1-warp/(16×64)-tile → fills the GPU (22k blocks) but each tile is a
naive single warp (serial shared loads, no pipeline) = 4.4 TFLOP/s, 3.5× SLOWER
than cuBLAS per-expert (15.7). Unlike flash-attn (where cuBLAS wasted work), here
cuBLAS does efficient GEMM — winning needs a multi-warp register-blocked
cp.async-pipelined tile (CUTLASS-grade). NOT wired into the model (loses); banked
as the correct foundation + bench harness for that v2.
4 warps/block (BM=64×BN=64, 16 rows/warp), cooperative shared loads, BK=64
K-tile (4 mma-k-substeps per load → 4× fewer __syncthreads + ILP). Correct
(max_rel ~3e-4). 4.4 → 5.7 TFLOP/s, but still 2.75× off cuBLAS (15.7): the
marginal gain from tiling proves it's LOAD-LATENCY-bound, not sync/occupancy
bound — the global A/W loads don't overlap mma. Next: cp.async software
pipelining (double-buffer: load tile c+1 while computing tile c) to hide HBM
latency. Still a gated foundation (not model-wired).
Software-pipelined K-loop: cp.async.cg loads K-tile c+1 into the second shared
buffer while the mma consumes tile c (hides HBM load latency — the GEMM lever).
Masked expert edges via cp.async src-size=0 (zero-fill OOB rows/cols). Vectorized
16B (8-half uint4) async copies. Correct (max_rel ~3e-4).
Trajectory: v1 4.4 → multi-warp 5.1 → BK64 5.7 → cp.async 8.2 TFLOP/s. Now 1.9×
from cuBLAS (15.7). Next levers: ldmatrix fragment loads (replace scalar pk2),
3-stage pipeline, BM=128 (halve W re-reads). Still a foundation (not model-wired
until it beats cuBLAS).
Device::moe_grouped_cutlass (default errs) + ffai-cuda override (→ runtime AOT
CUTLASS FFI) + ffai_ops::moe_grouped_gemm_cutlass wrapper. out[t,n]=Σ A·W[eid]
over sorted token groups, contiguous f16 expert slab. Validated end-to-end from
Rust (AOT nvcc→static lib→FFI→trait→ops→test): max_rel 3.4e-4 on the Nemotron
up-proj shape. Builds without CUTLASS (errors cleanly at call); the test skips
when the runtime lacks CUTLASS. Needs metaltile e0325add + CUTLASS_DIR build.
Next: model wiring (contiguous resident expert slab) to land the +9%-over-cuBLAS
MoE win.
…TLASS_MOE)

E-branch path: dequant the full Q4 expert weight to a contiguous f16 slab once
(resident-cached in w16), then ONE moe_grouped_gemm_cutlass per up/down over the
sorted tokens + on-device relu2_scale + deterministic scatter. Correct: seq==
batched argmax 1104 @s2048.

PERF: in-model it's a WASH vs cuBLAS per-expert (moe_experts ~389 vs ~356ms),
NOT the +9% the standalone bench showed. Cause: the CUTLASS .cu does 7
cudaMallocAsync+memcpy of the device ptr/problem arrays PER CALL (x46/forward) +
variable expert rows (m>128 → 2 tiles) eat the advantage. The MoE GEMM is
near its skinny-m ceiling (cuBLAS 15.7 ~ CUTLASS 17.1 TFLOP/s) so the upside is
~+1.6% e2e regardless. Gated OFF by default. Follow-up: cache the ptr arrays in
a persistent workspace (kill the per-call malloc). Real lever is the ~60% host
idle, not this.
moe_route_sort_device: batched sigmoid+bias+topk router + counting-sort
(histogram→prefix→atomic-cursor scatter), ALL on device. Replaces the per-E-layer
HOST triples round-trip (router logits→host→sigmoid+topk+stable-sort→upload) that
stalls the GPU every MoE layer and blocks CUDA-graph prefill capture. Outputs
sorted_tok[mt], sorted_wt[mt], offsets[n_exp+1] — tokens grouped by expert.

Validated vs host reference (tests/moe_route_sort): counts+token-multiset per
expert MATCH exactly, max|Δweight| 1e-7, at small + Nemotron (s=2048,128exp,top6).

Bug fixed during bringup: dispatch_raw_cuda passes ALL pointer args first then
scalars, so kernel scalar params MUST come last (hist/scatter had a scalar mid-
list → pointer got a scalar value → garbage atomicAdd target → OOB). Step 1 of
on-device-MoE → CUDA-graph-prefill (the 70%-idle lever).
Wire moe_route_sort_device into the batched-prefill E-branch (default-on for
CUDA; NEMOTRON_DEVSORT_OFF=1 reverts to host router). Keeps gate logits ON
DEVICE (drops the ~1MB rl_all dl/layer) and runs sigmoid+bias+topk + counting
sort on the GPU, reconstructing the same triples via 3 small dls — removes the
per-E-layer host sigmoid/topk/sort (was the bulk of the 58ms router host time).

S=2048: 904 -> 998 tok/s (+10.4%), argmax-exact (1104==sequential).

Bug fixed in bringup: router must use precise expf, NOT __expf — the fast-approx
sigmoid flips the top-k set on near-tie tokens, and ONE token's routing change
cascades through causal attention/SSM to flip the final argmax (1776 vs 1104).
TODO (graph-cleanliness): make the within-expert sort device-side to drop the
host sort + 3 dls; then CUDA-graph the forward (the 70%-idle win).
moe_scatter_sort: atomic-cursor scatter -> stable per-expert scan.
One thread/expert scans all mt pairs in index order, placing tokens
into [offsets[e], offsets[e+1]) ascending. No atomics, deterministic,
within-expert token order now matches a host stable_sort.

Model: drop the tr.sort_by((expert,token)) host re-sort in the DEVSORT
branch -- route_sort output is already in canonical order. One step
closer to a fully on-device MoE (no host triples) for graph capture.

Validated S=2048: argmax==1104 EXACT MATCH (seq==batched), top5 stable.
1009 tok/s (vs 998 atomic) -- stable scan also slightly faster.
moe_route_sort_test: both small + nemotron pass.
…% prefill, gated

Marlin-style single-launch grouped MoE GEMM that reads the model's SIGNED Q4
weights DIRECTLY (no f16 slab, no dequant pass): cp.async packs int4 into shared
u32 staging + f16 scales, register-dequants per b-fragment inside the mma
m16n8k16 loop (dqn = sign_extend((nib^8)-8) * f16scale, matches quantize_q4).
One launch per up/down over the sorted tokens → collapses ~234 per-expert cuBLAS
launches AND reads 4.5-bit weights (vs 16-bit) — the BW lever for the MoE wall.

NEMOTRON_Q4_GROUPED=1 (gated OFF). S=2048: 952 → 1358 tok/s (+35%), prefill
2.15s → 1.51s. Standalone kernel bench is bit-exact vs cuBLAS-f16-on-same-weights
(maxrel 0).

NOT default-on: in-model argmax flips 1104 → 1776 on a RAZOR near-tie (top1-top2
gap 0.011 = 0.137%; both in top-2). The custom mma GEMM's rounding differs from
cuBLAS (split-K tree vs sequential-K f32 accumulate — bitwise match infeasible
per the numerics) and compounds over 52 layers (relative L2 8.8%) → flips the
tie. Default per-expert cuBLAS path UNCHANGED (argmax 1104 EXACT, regression-
checked). Next: precision work (split-K / f32 intermediate) to land 1104, then
default-on for the +35%.
…8%, argmax 1104

The Q4-native grouped MoE GEMM (moe_q4_grouped_mma) is +35% but its custom-mma
rounding flips a razor 1104/1776 near-tie (gap 0.137%) when used on ALL 23 MoE
layers. Root cause: error injected in the EARLY MoE layers propagates irreducibly
through the residual stream and decides the final argmax.

Fix: NEMOTRON_Q4_GROUPED_FIRSTEXACT (default 4) runs the first 4 MoE layers on the
exact cuBLAS per-expert path and Q4-grouped for the remaining 19. This restores
argmax 1104 with a robust 0.125 margin (10x the flip threshold). Swept: K=2 still
flips (1776); K>=3 lands 1104. exact-LAST-K failed at all K (confirms the early
layers are the sensitive ones, not the late ones).

Now DEFAULT-ON for CUDA (escape NEMOTRON_Q4_GROUPED_OFF=1). S=2048: ~1009 → ~1290
tok/s (+28%), argmax 1104 EXACT MATCH (seq==batched), regression-gate clean.
NEMOTRON_Q4_GROUPED_LASTEXACT also available for experiments.
…+12% kernel)

dqn (int4→f16, inline on the mma critical path) replaced the int-sign-extend +
f32-mul path with the all-f16 magic trick: (nib^8)|0x6400 reinterprets as
f16(1024+signed+8), subtract 1032 → signed value, ×f16 scale. No int→float or
float→half conversions on the hot path.

Standalone bench (GPU-event timed): 8.68 → 9.74 TFLOP/s (+12%). In-model the
model's scales are already f16, so the all-f16 mul is precision-equivalent to the
prior f32-widened mul — logits BIT-IDENTICAL (argmax 1104 EXACT MATCH, top1-top2
gap 0.125257, unchanged). (The standalone shows ~0.004 abs delta only because its
bench uses f32 scales; the real model path is f16.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant