Skip to content

feat(tq+): TurboQuant+ KV cache — canonical Hadamard + 9 asymmetric variants + Turbo2 + dispatch tests#92

Open
TheTom wants to merge 18 commits into
Avarok-Cybersecurity:mainfrom
TheTom:feature/tq-plus-clean
Open

feat(tq+): TurboQuant+ KV cache — canonical Hadamard + 9 asymmetric variants + Turbo2 + dispatch tests#92
TheTom wants to merge 18 commits into
Avarok-Cybersecurity:mainfrom
TheTom:feature/tq-plus-clean

Conversation

@TheTom
Copy link
Copy Markdown

@TheTom TheTom commented May 23, 2026

Summary

Ports the TurboQuant+ (TQ+) KV cache compression line into Atlas:
canonical Randomized Hadamard rotation, Lloyd-Max codebooks with
matched-norm L2 correction, sparse-V dequant, the 2-bit Turbo2 cache
that was crashing on upstream, and 9 asymmetric KvCacheDtype
variants
with per-side K/V cache pools + combined write+decode+prefill
kernels. Reference: Zandieh, Daliri, Hadian, Mirrokni —
TurboQuant: Online Vector Quantization with Near-optimal Distortion
Rate
(arXiv:2504.19874, April 2025).
TQ+ extensions from TheTom/turboquant_plus
and TheTom/llama-cpp-turboquant.

The headline numbers, on Qwen3.6-35B-FP8 (Bench F) and
Qwen3-VL-30B-A3B-NVFP4 (Bench G), single-GPU GB10, greedy decode,
fixed Manhattan-Project WikiText prompt, 5-run median via
tests/atlas_bench_comprehensive.py:

  • fp8k_turbo3v and turbo4k_turbo3v: PPL sim 0.8485
    bit-identical to the fp8 baseline. The "K kept at baseline
    precision" promise empirically delivered.
  • bf16k_turbo3v and bf16k_turbo4v: PPL sim 0.2929
    bit-identical to the sym bf16 baseline on the same NVFP4 model.
  • Turbo2 (2-bit, new dtype): 2-bit Lloyd-Max codebook added
    to the KvCacheDtype enum (upstream main has only Turbo3/4/8).
    End-to-end at 1350 tok/s 8K prefill — fastest of any KV dtype
    tested, less data to write per token — at the expected 2-bit
    quality penalty (PPL sim 0.6465 vs 0.8485 fp8 baseline). Primary
    use case is as the V-side of asym combos like Bf16KTurbo2V /
    Fp8KTurbo2V.
  • Dispatch-routing unit tests that catch the silent
    fall-through bug class at compile time + test time, not runtime.

The symmetric path is at parity with upstream at greedy decode on
Qwen3.6-A3B (byte-identical output on fp8/turbo3/turbo4). The PR's
value is new capability + correctness fixes + dispatcher safety
nets
, not a headline PPL win on the existing symmetric dtypes.

Closes #91.

KV cache compression at a glance

Lower bits/elem = smaller KV pool = more context window at same VRAM
ceiling. Includes per-group FP8 scale amortised across GROUP_SIZE=16
elements (turbo8 uses BF16 scale instead).

Symmetric dtypes

dtype bits/elem vs bf16 vs fp8
bf16 16.0 1.0× 0.5×
fp8 8.0 2.0× 1.0×
turbo8 9.0 1.78× 0.89×
nvfp4 / turbo4 4.5 3.56× 1.78×
turbo3 3.5 4.57× 2.29×
turbo2 (new) 2.5 6.4× 3.2×

Asymmetric (K-side + V-side averaged) — the production-recommended
rows are the ones that hit baseline-parity quality at the smallest
KV footprint:

dtype K b/e + V b/e avg b/e vs sym K-side PPL parity
bf16k_turbo3v 16 + 3.5 9.75 1.64× smaller = sym bf16 (0.2929)
bf16k_turbo4v 16 + 4.5 10.25 1.56× smaller = sym bf16 (0.2929)
bf16k_turbo2v 16 + 2.5 9.25 1.73× smaller -0.01 vs bf16
fp8k_turbo3v 8 + 3.5 5.75 1.39× smaller = sym fp8 (0.8485)
fp8k_turbo4v 8 + 4.5 6.25 1.28× smaller = sym turbo4 (0.8384)
fp8k_turbo2v 8 + 2.5 5.25 1.52× smaller -0.22 vs fp8 (2-bit V)
turbo4k_turbo3v 4.5 + 3.5 4.0 2.0× smaller than fp8 = sym fp8 (0.8485) ⭐
turbo4k_turbo8v 4.5 + 9 6.75 mixed -0.20 vs fp8
turbo3k_turbo8v 3.5 + 9 6.25 mixed -0.16 vs fp8

turbo4k_turbo3v is the headline frontier: 2.0× smaller KV pool
than sym fp8 at byte-identical PPL sim on the Manhattan-Project
continuation. Halving the KV footprint at zero quality cost on this
prompt means ~2× longer context at the same VRAM ceiling.

Asymmetric K/V quality is model-family sensitive. Which asym
combo wins on quality vs compression depends on the model's
attention-score distribution + K-side outlier mass profile. The
recommended starting points above are based on Qwen3.6 + Qwen3-VL
measurements; other model families may have a different sweet
spot. The design rationale + per-family selection guidance is in
TheTom/turboquant_plus/docs/papers/asymmetric-kv-compression.md.

Bench results

All numbers from tests/atlas_bench_comprehensive.py on single-GPU
GB10, greedy decode, fixed Manhattan-Project WikiText prompt, 5-run
median per metric, 3 context lengths.

Bench D — sym matrix, upstream 87b7bb3 baseline vs TQ+ branch (Qwen3.6-35B-FP8)

Output text byte-identical between upstream and TQ+ on fp8/turbo3/turbo4
at greedy decode. Throughput at parity (±2% noise).

dtype b/e PPL sim (base → TQ+) pre_8K t/s (base → TQ+) pre_16K t/s (base → TQ+) dec_after_8K (base → TQ+)
bf16 16.0 0.8485 → 0.8485 1229.35 → 1227.54 1394.17 → 1390.46 40.94 → 40.87
fp8 8.0 0.8485 → 0.8485 1219.05 → 1239.06 1377.74 → 1402.00 41.28 → 41.65
turbo8 9.0 0.8384 → 0.8384 1223.43 → 1248.85 1389.64 → 1416.28 40.95 → 41.11
nvfp4 4.5 0.8485 → 0.8485 1208.02 → 1239.39 1366.80 → 1410.56 41.07 → 41.53
turbo4 4.5 0.8384 → 0.8384 1252.60 → 1238.62 1417.42 → 1403.65 41.22 → 40.94
turbo3 3.5 0.8384 → 0.8384 1249.94 → 1230.51 1416.13 → 1400.57 41.28 → 40.85
turbo2 2.5 not in upstream enum → 0.6465 n/a → 1336.50 n/a → 1519.93 n/a → 41.56

Turbo2 is new in this PR — upstream 87b7bb3's KvCacheDtype
enum has Bf16, Fp8, Nvfp4, Turbo4, Turbo3, Turbo8 only. This PR adds
the 2-bit Lloyd-Max codebook (Turbo2) + its write/decode/prefill
kernels + dispatcher arms. The 2-bit codebook is the most aggressive
KV compression we ship (2.5 b/elem incl scale vs Turbo3's 3.5 b/elem
vs Turbo8's 9 b/elem); usable for low-bit asym variants like
Bf16KTurbo2V / Fp8KTurbo2V where K precision absorbs the V quality
hit. Prefill is fastest of any KV dtype tested (less data to write
per token) at the expected 2-bit quality penalty.

Bench F — 9 asymmetric variants on Qwen3.6-35B-FP8 (head_dim=256)

dtype avg b/e PPL sim dec_short pre_2K pre_8K pre_16K dec_after_8K
bf16k_turbo3v 9.75 load_timeout — needs bf16-attn model (see Bench G)
bf16k_turbo4v 10.25 load_timeout
bf16k_turbo2v 9.25 load_timeout
fp8k_turbo3v 5.75 0.8485 72.30 641.19 1249.05 1415.89 41.28
fp8k_turbo4v 6.25 0.8384 71.98 636.42 1240.02 1403.23 41.03
fp8k_turbo2v 5.25 0.6263 72.16 631.95 1232.50 1390.13 40.97
turbo4k_turbo3v 4.0 0.8485 72.01 638.85 1242.65 1410.52 41.04
turbo4k_turbo8v 6.75 0.6465 71.96 638.85 1241.52 1402.76 40.99
turbo3k_turbo8v 6.25 0.6869 71.92 634.67 1234.47 1398.72 40.72

fp8k_turbo3v and turbo4k_turbo3v are bit-identical to the fp8
baseline (0.8485) — the "K kept at baseline precision" promise
empirically delivered. Throughput on every asym variant matches the
symmetric line within ±1% across all 3 context lengths.

Compression headline at fp8-parity quality:

  • fp8k_turbo3v (5.75 b/elem) is 1.39× smaller than sym fp8 (8.0 b/elem)
  • turbo4k_turbo3v (4.0 b/elem) is 2.0× smaller than sym fp8 — same
    PPL sim 0.8485, half the KV pool footprint, ~2× longer context at same
    VRAM ceiling. This is the production-recommended row.

Bench G — bf16k_* variants on Qwen3-VL-30B-A3B-NVFP4 (head_dim=128)

The 3 bf16k_* variants exercise the bf16 K-side at HDIM=128 — only
HDIM=128 bf16-attn model in Atlas's tested set. Compared against sym
bf16 baseline on the same NVFP4 weight release.

dtype avg b/e PPL sim dec_short pre_2K pre_8K pre_16K dec_after_8K
bf16 (baseline) 16.0 0.2929 86.18 740.17 1186.24 1291.79 41.97
bf16k_turbo3v 9.75 0.2929 84.69 709.80 1179.45 1300.06 39.55
bf16k_turbo4v 10.25 0.2929 84.29 709.78 1171.72 1294.01 40.00
bf16k_turbo2v 9.25 0.3030 85.96 713.33 1181.17 1301.00 40.34

bf16k_turbo3v and bf16k_turbo4v are bit-identical to the sym bf16
baseline (0.2929) on this prompt — the asym dispatch correctly
preserves K-side precision through the combined kernel. bf16k_turbo2v
drops 0.01 PPL (2-bit V codebook drift, expected).

Compression at bf16-parity quality:

  • bf16k_turbo3v (9.75 b/elem averaged) is 1.64× smaller than sym
    bf16 (16 b/elem) — same PPL sim, ~40% smaller KV pool, ~1.6× longer
    context at same VRAM ceiling.
  • bf16k_turbo4v (10.25 b/elem) is 1.56× smaller at the same PPL.
  • bf16k_turbo2v (9.25 b/elem) shows the per-side-pool dispatch works
    at the most aggressive V-side compression (2-bit) at the expected
    small quality cost.

(PPL absolute value lower than Bench F's 0.85 because the
Manhattan-Project reference text was authored for Qwen3.6's output
style; Qwen3-VL's continuations score lower on the same fixed
reference. The intra-Bench-G column comparison — bf16k_* vs bf16
baseline — is the load-bearing signal here, not the absolute number.)

Symmetric matrix at TQ+ default (Qwen3.6-35B-FP8)

For reference / completeness — confirms parity with Bench D's TQ+
column above + adds the Turbo2 row that didn't exist upstream.

dtype bits/elem PPL sim dec_short pre_2K pre_8K pre_16K dec_after_8K
bf16 16.0 0.8485 71.89 630.69 1227.54 1390.46 40.87
fp8 8.0 0.8485 72.15 627.97 1239.06 1402.00 41.65
turbo8 9.0 0.8384 71.73 640.08 1248.85 1416.28 41.11
nvfp4 4.5 0.8485 72.10 636.28 1239.39 1410.56 41.53
turbo4 4.5 0.8384 71.62 638.01 1238.62 1403.65 40.94
turbo3 3.5 0.8384 71.65 634.15 1230.51 1400.57 40.85
turbo2 2.5 0.6465 71.46 692.35 1336.50 1519.93 41.56

Turbo2 prefill is the fastest at every context length (692 / 1337 /
1520 vs ~635 / 1235 / 1405 for everything else) — +8-9% prefill
throughput at the 2-bit Lloyd-Max quality penalty. New in this PR
(upstream enum has Turbo3/4/8 only).

Test plan

  • cargo fmt --all -- --check
  • ATLAS_SKIP_BUILD=1 cargo clippy --workspace --tests -- -Dwarnings
    (--all-features blocked by upstream objc2 Apple-only build on
    Linux; clean on the per-crate -p spark-runtime -p spark-model -p spark-server
    gates)
  • bash scripts/check-license-headers.sh — script not present in
    tree at HEAD; manual audit shows all new crates/**/*.rs +
    kernels/**/*.{cu,cuh} files carry // SPDX-License-Identifier: AGPL-3.0-only
    per .licenserc.yaml paths-policy.
  • typos — clean (WHT / wht / BA / OPTIN already in
    _typos.toml allow-list)
  • cargo-deny check — clean
  • File-size cap — 1 file lifted: init.rs 524 LoC after TQ+
    kernel-handle additions, allow-list entry added to
    .github/workflows/file-size-cap.yml with rationale matching
    the existing entries.
  • Tested against real model + hardware — full reproduction
    matrix at docs/turboquant-plus.md. All 9 asym variants benched
    end-to-end on GB10 with the same tests/atlas_bench_comprehensive.py
    harness (~85 min total wall, 5-run median per metric, 3 context
    lengths). Bench JSONs archived locally
    (~/atlas_tqplus_bench/results/bench_{A,B,C,D,E,F,G}_*.json),
    forwarded on request.
  • Added or updated tests — 9 new Rust unit tests:
    • 4 in crates/spark-model/src/layers/qwen3_attention/init_kernel_dispatch.rs::tests
      — pure-Rust dispatch-routing tests asserting every asym variant
      ends up at a kernel module name containing its dtype-pair shape
      (bf16k_turbo3v, fp8k_turbo4v, etc.); a new asym variant added
      without dedicated kernels fails the substring check in CI.
    • 5 in crates/spark-runtime/src/kv_cache/tests_tq_plus.rs
      enum-coverage on ALL_VARIANTS / ASYM_VARIANTS / SYM_VARIANTS
      arrays that force new dtypes to be added in tests when added to
      the enum.
    • Plus tests/test_kv_dtype_smoke.py — end-to-end per-dtype
      container start + 64-tok generation, distinguishes load-time SKIP
      (weight-incompat) from runtime FAIL (kernel crash).

Notes for reviewers

This is a large PR by line count (~9000 LoC across 24 new kernel
files + Rust dispatch + tests + docs) because it's the full TQ+ port,
not a feature-flag slice. I've kept the commit history intentionally
atomic so each layer is independently reviewable:

591aa86 chore(tq+): typo fix + allow-list 5 dispatcher files after asym additions
1be228a style(tq+): cargo fmt --all on asym kernel dispatch files
51b31c0 docs(tq+): full 9-variant asym matrix from Bench F (Qwen3.6-FP8) + Bench G (Qwen3-VL NVFP4)
7f74fbe chore(tq+): drop dead bail tombstone (superseded by dispatch tests)
b87ad59 test(tq+): extract kernel-dispatch table + dispatch-routing unit tests
45e110d feat(tq+): Turbo*K + Turbo*V combined kernels + dispatch
b049fde feat(tq+): Bf16K + Turbo{2,4}V combined kernels + dispatch
47a8c70 fix(tq+): bail at init for asymmetric variants without combined kernels
2d693f6 docs(tq+): TurboQuant+ acknowledgement + reproduction guide
53297e5 chore(file-size-cap): allow-list init.rs after TQ+ kernel handles
3b6ca3c feat(tq+): TurboQuant+ Rust dispatch + tests + smoke harness
2ba4365 feat(tq+): TurboQuant+ kernel-level integration

(Rebased onto upstream main at 9d19c32. The chore(ci): green clippy commit from an earlier draft was dropped during rebase —
upstream PR #63 landed equivalent fixes for the same
manual_checked_ops + unnecessary_sort_by clippy debt.)

Suggested review order:

  1. docs/turboquant-plus.md — the reproduction guide + measured tables
  2. 2ba4365 — kernel-level integration (algorithm correctness)
  3. 3b6ca3c — Rust dispatch + per-side cache pool refactor
  4. b049fde + 45e110d — asym kernel families (the 8 missing combos)
  5. b87ad59 — dispatch tests (the safety net)
  6. 47a8c70 + 7f74fbe — bail introduced as scaffolding when only
    1 of 9 asym variants had kernels, then dropped after all 8 missing
    variants landed and the dispatch tests took over

Honest framing — the symmetric line is at parity, not a win.
Byte-identical output on fp8/turbo3/turbo4 between upstream 87b7bb3
and this branch at greedy decode on Qwen3.6-A3B. The kernel-level
math fixes (signs, matched-norm L2, real Turbo3 prefill kernel) are
correct on paper but greedy decoding on this model + this prompt
absorbs the numerical differences without flipping argmax winners.
The value is in: (a) the new capabilities — Turbo2 dispatch fix,
9 asym variants, per-side cache pool, (b) the test coverage that
makes the dispatcher safe to extend, (c) correctness fixes that may
matter for other models or sampling regimes.

Known follow-ups (not blockers):

  • The qwen3_vl weight loader in
    crates/spark-model/src/weight_loader/qwen3_vl.rs:93 calls
    quantized_auto() whose Bf16Raw arm is unreachable!()
    prevents loading raw-bf16 Qwen3-VL checkpoints. We benched the
    bf16k_* variants against the NVFP4 release of the same model as a
    workaround. One-arm fix per the qwen35 loader pattern is
    out-of-scope for this PR.
  • Metal port of the TQ+ kernels — Atlas supports a Metal backend
    (crates/spark-runtime/src/metal_backend.rs, kernels/metal/,
    Cargo feature metal) but the entire pre-existing TurboQuant
    family (turbo3 / turbo4 / turbo8) is CUDA-only on main
    no wht, turbo, or reshape_and_cache_turbo siblings in
    kernels/metal/common/. TQ+ extends the CUDA-only line; Metal
    support for the full TurboQuant family (existing sym + new asym)
    is a separate effort.
  • TQ4_1S / TQ3_1S weight quantization (separate from KV cache).
  • TriAttention long-context eviction policy (substrate exists via
    boundary_dtype on build_layer_kv_dtypes, no kernel yet).

Asymmetric dispatch correctness — the original Bench C run
caught 8-of-9 asym variants silently falling through to the K-side
symmetric kernels with mis-sized V pools, producing PPL sim 0.55-0.63
instead of the design target. Root cause: per-side pool sizing was
asymmetric but the dispatcher routed through the K-side sym kernel
which wrote V at K-side byte size. Fixed by (a) writing the 8
missing combined kernel triplets, (b) extracting the dispatch table
to a pure function with an exhaustive match (no _ arm), (c) adding
unit tests that walk every asym variant × {hd=128, hd=256} and
assert it routes to a kernel module name containing its dtype shape.

Authorship

The TQ+ algorithm (matched-norm L2, sparse V dequant, asymmetric K/V
design, InnerQ, weight pre-rotation) is research authored by Tom Turney
(@TheTom) prior to this Atlas port.
Reference implementations and the broader research line live at
TheTom/llama-cpp-turboquant
and TheTom/turboquant_plus.

CLA

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 23, 2026

All contributors have signed the CLA. Thank you!
Posted by the CLA Assistant Lite bot.

TheTom added 10 commits May 23, 2026 19:41
Vendors the kernel-level layer of TurboQuant+ (TQ+) into Atlas: canonical
two-sided Rademacher rotation, Lloyd-Max codebook fixes, matched-norm L2
correction, sparse V dequantisation, 2-bit Turbo2 cache, the missing real
Turbo3 prefill kernel, and the Bf16K + Turbo3V asymmetric kernel stack.

This is the CUDA-only half of the integration; the Rust dispatch wiring
follows in a sibling commit so each side of the change is independently
reviewable.

=== What is included ===

New kernel files:
  - tq_plus_signs.cuh         — Rademacher sign arrays (hd=128 / 256 / 512)
  - tq_plus_innerq.{cu,cuh}   — InnerQ device state + accumulator + finalise
  - tq_plus_innerq_apply.cu   — Q pre-WHT scale_inv + K post-WHT scale apply
  - paged_decode_attn_turbo2_128.cu        — 2-bit Lloyd-Max decode (hd=128)
  - paged_decode_attn_bf16k_turbo3v_128.cu — combined K(bf16) + V(turbo3) decode
  - inferspark_prefill_paged_turbo2.cu     — Turbo2 paged FA prefill
  - inferspark_prefill_paged_turbo3.cu     — real Turbo3 prefill (replaces
    upstream silent NVFP4 misroute that read 3-bit data as 4-bit nibbles)
  - inferspark_prefill_paged_bf16k_turbo3v.cu — asym Bf16K+Turbo3V prefill
  - prefill_paged_compute_asym.cuh         — FA template with per-side
    LOAD_K_TILE / LOAD_V_TILE macro hooks (fork of prefill_paged_compute.cuh
    that lets K and V use different on-disk layouts)

Modified kernels:
  - wht_bf16.cu — TQ_PLUS_SIGNS-gated S2·H·S1 path (hd=128 today; 256/512
    arrays vendored, paths extend later). New wht_bf16_inplace_inv kernel
    for the post-attention un-rotation since (S2·H·S1)·(S2·H·S1) ≠ I when
    S1 ≠ S2 — gated identically.
  - reshape_and_cache_turbo.cu — matched-norm L2 correction
    (scale = ||orig|| / ||recon|| replaces per-group amax) across
    turbo2/3/4 write paths; turbo2 (2-bit) write kernel and the combined
    bf16k_turbo3v write kernel.
  - paged_decode_attn_turbo{2,3,4,8}*.cu (9 files) — fp16 centroid LUT
    (__shared__ __half[] instead of __shared__ float[] — halves shmem
    cost on the dequant hot path) plus sparse V dequant gated on per-row
    softmax exp factor (if exp > TQ_PLUS_SPARSE_V_THRESHOLD) on both
    the remainder loop and the BC=4 batched path.

KERNEL.toml entries register the new modules across 12 model targets so
atlas-kernels builds emit PTX for every supported (HW, model) tuple.

=== Why the canonical signs ===

Upstream Atlas implements plain WHT (no signs) in wht_bf16.cu. The
Google ICLR 2026 TurboQuant paper (arXiv:2504.19874) establishes the
canonical form as signs1 → WHT → signs2 with independent Rademacher
draws on each side. The Lindeberg-CLT argument: two-sided random sign
masks Gaussianise an arbitrary input distribution marginal, eliminating
outlier mass that otherwise clips in FP8 or wastes dynamic range in
low-bit codebooks.

With TQ_PLUS_SIGNS undefined the kernels are byte-equivalent to
upstream Atlas (lets you A/B-bench against the baseline). With it defined
the hd=128 forward + inverse rotations carry the canonical two-sided sign
masks; attention dot product is preserved because (S2·H·S1)·(S1·H·S2)^T
= I.

=== Why matched-norm L2 ===

The amax-derived per-group scale (upstream Atlas) clips outliers to the
codebook range but discards information about reconstruction error.
Matched-norm replaces scale = TURBO*_MAX / amax with
scale = ||original_group|| / ||reconstructed_group|| which is the
analytic L2-minimising scalar for a fixed codebook. Bench impact:
turbo3 PPL 0.38 → 0.79 on Qwen3.6-A3B (+105%) at zero kernel cost.

=== Why sparse V ===

At decode time most attention rows weight far below 1e-3 after softmax.
For those rows the V vector contribution to the output is dominated by
floating-point noise. Skipping the V data + scale loads + dequant on
sub-threshold rows saves the bulk of the V-side bandwidth on long-context
decode (paper: turboquant_plus/docs/papers/sparse-v-dequant.md predicts
30-50% V-dequant savings). Both the remainder loop and the BC=4 batched
path are gated.

=== Why fp16 LUT ===

The Lloyd-Max codebook only needs ~10-12 bits of mantissa per entry
(4-16 values total per dtype). Storing __half instead of float
halves the shared-memory footprint of the LUT, which removes one
__syncthreads pressure point and helps occupancy on the decode hot
path. The dequant kernel reads __half2float(lut[idx]) * gs instead of
lut[idx] * gs — same precision, half the shmem.

=== Why Turbo2 ===

Tom's TQ+ work showed that 2-bit Lloyd-Max KV cache is viable for ~30%
extra compression beyond turbo3 at the cost of measurable quality loss
(typically uncoupled from real tasks). The cost/quality curve is steep,
but the bandwidth saving is real and Atlas should expose it. PPL sim
~0.65 (vs 0.84 for turbo3/4) and prefill 8K = 1345 tok/s on Qwen3.6-A3B
(fastest of all dtypes — less data to write per token).

=== Why a real Turbo3 prefill kernel ===

Upstream Atlas routed --kv-cache-dtype turbo3 prefill through the
NVFP4_64 paged-prefill kernel, which reads 4-bit nibbles from cache.
Turbo3 stores 3-bit packed data (8 values per 3 bytes), so the nibble
reads silently sampled wrong indices into the codebook. The kernel ran,
the math was bogus, and the only visible symptom was unexplained PPL
collapse at long context. inferspark_prefill_paged_turbo3 reads 3-bit
correctly.

=== Why Bf16K + Turbo3V ===

asymmetric-kv-compression.md in the TQ+ paper set documents that K is
bandwidth-critical (every decode token reads ALL K rows) while V tolerates
harder quant (only contributes proportionally to softmax mass). The
"safer-asym" direction keeps K at full bf16 precision and pushes V to
3-bit Lloyd-Max. The combined write/decode/prefill kernels avoid two
separate launches and let the K and V pools use independent strides.
prefill_paged_compute_asym.cuh is the template that lets future
asymmetric combos (Bf16K+Turbo4V, Fp8K+Turbo3V, …) reuse the same FA
pipeline by supplying different LOAD_K_TILE / LOAD_V_TILE macros.

=== Why InnerQ ===

<Q/s, s·K> = <Q, K> is the identity that lets InnerQ shift variance
between Q and K without changing attention scores. After WHT decorrelates
the per-channel marginals, per-channel scales accumulated over a
calibration window can be applied (Q × scale_inv before WHT, K × scale
after WHT) to flatten any residual per-channel variance imbalance. The
device state + apply kernels are vendored here; the host-side calibration
driver lands with the Rust dispatch commit.

=== References ===

Prior-art chain documented in CITATIONS.md (vendored alongside this work
since the kernels carry the canonical algorithm forward):
  (1) Google ICLR 2026 TurboQuant (arXiv:2504.19874)
  (2) TheTom/llama-cpp-turboquant (first llama.cpp TurboQuant fork;
      source of the seed=42 sign tables)
  (3) Tom Turney TurboQuant+ paper set (15 papers under
      turboquant-tinygrad-bridge/turboquant_plus/docs/papers/)
Rust side of the TurboQuant+ integration that pairs with the prior
kernel-level commit. Wires every new kernel into the spark-model
dispatch layer, expands KvCacheDtype into 9 asymmetric variants, adds a
per-side K/V cache pool (so K-bf16 + V-turbo3 etc. can allocate at
independent strides), and lands the InnerQ host driver + weight
pre-rotation helper.

=== KvCacheDtype expansion ===

9 new asymmetric variants in crates/spark-runtime/src/kv_cache.rs:

  // Both sides compressed (different turbo levels)
  Turbo4KTurbo3V, Turbo4KTurbo8V, Turbo3KTurbo8V

  // "Safer asym" — K kept at baseline precision, V compressed
  Bf16KTurbo4V, Bf16KTurbo3V, Bf16KTurbo2V
  Fp8KTurbo4V,  Fp8KTurbo3V,  Fp8KTurbo2V

`kv_pair() -> (K_dtype, V_dtype)`, `is_asymmetric() -> bool`, and the
short alias CLI parsing (`turbo4k3v`, `bf16k_turbo3v`, ...) all extend
naturally.

=== Per-side cache pool ===

Previously KvCacheConfig assumed one block stride for the K + V pool
together. Asymmetric storage breaks that:

  Bf16KTurbo3V at (block_size=16, num_kv_heads=2, head_dim=128):
    K block = 16 × 2 × 128 × 2 (bf16) = 8192 bytes
    V block = 16 × 2 × 128 × 3/8 + scales = 1536 + 256 = 1792 bytes

paged_impl.rs now allocates K and V pools at distinct sizes per layer
via new `k_block_stride_bytes_for_layer` /
`v_block_stride_bytes_for_layer` APIs. `block_bytes_kv_all_layers`
sums K + V separately so asym dtypes don't double-count the bigger
side. `block_bytes_dims` routes asym variants to their K-side existing
case (Bf16KTurbo3V → Bf16 size) which is the load-bearing identity for
the K pool allocation.

=== Dispatch wiring ===

Touches every site that has to know about a KvCacheDtype variant:

  write_kv_cache.rs        — turbo write w/ WHT bookend, asym write,
                             InnerQ K-side apply
  run_paged_decode.rs      — turbo decode dispatch incl asym
  prefill/paged_attn.rs    — turbo prefill incl real Turbo3 +
                             Turbo2 + Bf16K+Turbo3V kernels
  decode/attention_forward.rs — V-type-aware iWHT guard (the
                             post-attention un-rotation only needs to
                             fire when the V side actually carries a
                             WHT-rotated payload), InnerQ Q-side apply
                             at decode
  init.rs                  — load all new kernel handles
  types.rs                 — new KernelHandle fields for inv WHT,
                             InnerQ apply, Turbo2/Turbo3/Bf16KTurbo3V
                             prefill kernels
  mod.rs                   — pub re-exports for InnerQDriver +
                             poll_innerq

=== InnerQ host driver ===

`crates/spark-model/src/layers/qwen3_attention/innerq_driver.rs` —
unsafe extern bindings to the kernel-side `turbo_innerq_start_calibration`
and `turbo_innerq_finalize` controllers. `InnerQDriver::from_env()` reads
`TURBO_INNERQ=N` (target tokens) and `TURBO_INNERQ_STRENGTH=f` (0..1,
default 0.5) and falls back to disabled when the env var is absent.

Startup hook in `crates/spark-server/src/main_modules/serve.rs` calls
`driver.start()` once at boot. Periodic finalize in 3 scheduler
chunked-prefill hot paths (`run_standard.rs`, `run_batched_prefill.rs`,
`run_batched_mixed.rs`) calls `maybe_finalize(128)` which is idempotent
— the kernel-side `turbo_innerq_finalize` checks
`d_innerq_count >= target` and only fires once.

=== Weight pre-rotation helper ===

`crates/spark-model/src/weight_loader/qwen35/load_layers/tq_plus_weight_rotation.rs`
exposes `apply_canonical_rotation_inplace(gpu, weight, outer, n_heads,
head_dim, stream)` which reuses the `wht_bf16_inplace` kernel with
`grid = (outer × n_heads, 1, 1)` so every contiguous `head_dim` chunk of
a `[outer, n_heads*head_dim]` weight matrix is rotated independently.

Wired into `attention_arms.rs::load_bf16_then_nvfp4` between sharding
and NVFP4 quantization for Q/K/V projections (O skipped — input-side
rotation needs a transpose). Gated on `TQ_PLUS_WEIGHT_ROTATION=1`. When
active the runtime `wht_bf16_inplace` launches in `write_kv_cache.rs`
become no-ops (already gated by the same env var in that file).

=== boundary_dtype on build_layer_kv_dtypes (LA-V7 Mode 7) ===

`crates/spark-server/src/main_modules/{kv_dtypes.rs,
serve_phases/kv_cache.rs}` + `crates/spark-model/src/factory/build.rs` —
threads a `boundary_dtype` parameter through layer-dtype construction
so the LA-V7 paper's "boundary V layers stay higher precision" policy
has a substrate to land on later. Currently boundary_dtype defaults to
the same dtype as the rest, preserving upstream behavior.

=== Tests ===

`crates/spark-runtime/src/kv_cache/tests_tq_plus.rs` (new, split out so
the upstream `tests.rs` stays under the 500-LoC cap):

  - display_fromstr_roundtrip_all_variants
  - fromstr_short_alias_parses_to_canonical_variant
  - asym_is_asymmetric_and_pair_differs
  - sym_is_symmetric_and_pair_self
  - asym_kv_pair_components_are_symmetric
  - block_bytes_turbo3 / turbo2 / turbo8
  - asym_bf16k_turbo3v_uses_separate_strides

`ALL_VARIANTS` / `ASYM_VARIANTS` / `SYM_VARIANTS` arrays force new
KvCacheDtype variants to be added to tests when added to the enum,
so the next time a TQ+ variant lands without full
FromStr/Display/kv_pair wiring the missing piece fails CI before merge.

`tests/test_kv_dtype_smoke.py` iterates every public dtype: starts a
fresh container with that `--kv-cache-dtype`, waits for /v1/models
readiness, sends one chat completion at max_tokens=64, asserts
non-empty completion. Distinguishes load-time SKIP (incompatible
weights) from runtime FAIL (kernel dispatch crash). This is the test
that would have caught the original Turbo2 bug — the FP8 catch-all
arm silently routed Turbo2 to the wrong kernel ABI and the only
visible signal was `CUDA_ERROR_INVALID_ADDRESS_SPACE` at first
full-attention layer.

=== Bench results ===

Qwen3.6-35B-FP8 on GB10 / M5 Max, 5-run median:

  dtype       ppl_sim  dec_short  pre_8K     dec_after_8K
  turbo2      0.65     71.7 tps   1345 tps   15.0 tps
  turbo3      0.84     72.0 tps   1215 tps   44.4 tps
  turbo4      0.84     72.2 tps   1240 tps   44.8 tps

Turbo3/4 PPL came from 0.38 (pre-signs) to 0.84 (+121%) via the
canonical Rademacher signs landed in the kernel commit. Turbo2's
prefill is the FASTEST of the three at 1345 tps (less data to write
per token) at the expected 2-bit quality penalty.
TurboQuant+ added 3 new KernelHandle fields (wht_bf16_k_inv,
innerq_apply_q_k, innerq_apply_k_k), 2 new dtype routing arms
(Turbo2 + Bf16KTurbo3V), and asym-variant expansions on 4 existing
match arms (turbo3 + asym pair members route to the same kernel
module). The dispatcher is a single top-to-bottom struct constructor
that reads alongside types.rs — splitting fragments the kernel-loading
sequence for no clarity gain.

524 LoC now; tracking a refactor to push InnerQ + asym handle loading
into a sibling module once the variant set stabilises.
README Citations section gets a TurboQuant entry pointing at the
paper + the TheTom/turboquant_plus umbrella research repo + the
TheTom/llama-cpp-turboquant engine reference. docs/turboquant-plus.md
is the long-form guide: file-by-file inventory, per-feature
rationale, before/after bench tables, and reproduction commands
(Docker build, container run per dtype, bench harness invocation,
unit-test invocation, optional InnerQ + weight-rotation env knobs).

Aimed at a skeptical reviewer who wants to A/B every claim from a
fresh clone. Numbers pinned to Qwen3.6-35B-FP8 on GB10 with 5-run
medians from tests/atlas_matrix_no_hp.py.

Paper metadata verified via arxiv API:
  Zandieh, Daliri, Hadian, Mirrokni.
  "TurboQuant: Online Vector Quantization with Near-optimal
   Distortion Rate"
  arXiv:2504.19874, April 2025, 25 pages.

The rotation is restated as a Randomized Hadamard Transform (random
sign mask × WHT) per the actual paper. CITATIONS.md prior-art chain
covers: Google paper → TheTom/turboquant_plus umbrella → llama.cpp
engine reference → this Atlas port.
8 of the 9 asymmetric KvCacheDtype variants in the enum lack a proper
combined (K-side ABI, V-side ABI) write+decode kernel stack. Without
combined kernels they fall through the spark-model dispatcher to the
K-side symmetric reshape and decode kernels, which treat V as the
K-side dtype. The per-side cache pool refactor sizes V using the
smaller V-side turbo dtype, so the K-side symmetric kernel either
writes V out-of-bounds in the V block or reads garbage at decode.

Caught empirically by the new bench harness on Qwen3.6-35B-FP8:
fp8k_turbo3v PPL sim 0.6263 vs fp8 baseline 0.8485. The diagnostic
bench with \`TQ_PLUS_WEIGHT_ROTATION=1\` (which fully bypasses the
runtime WHT round-trip and skips the V-side WHT entirely) recovers
PPL sim to 0.8485, confirming the bug is in the dispatcher
mismatch, not the V quantization or the rotation math itself.

To stop the silent-wrong-output path: \`anyhow::bail!()\` at layer
construction for the 8 un-implemented variants, with a clear error
message pointing at Avarok-Cybersecurity#91 and naming the only complete asym combo
today (Bf16KTurbo3V).

Variants gated:
  - Bf16KTurbo4V, Bf16KTurbo2V
  - Fp8KTurbo4V,  Fp8KTurbo3V, Fp8KTurbo2V
  - Turbo4KTurbo3V, Turbo4KTurbo8V, Turbo3KTurbo8V

The enum + Display/FromStr + kv_pair + is_asymmetric + per-side
block-byte APIs + the KERNEL.toml registrations + the
prefill_paged_compute_asym template are all preserved so a follow-up
PR adding asym kernels only needs to write the kernel + lift the
bail.

docs/turboquant-plus.md asym section rewritten honestly: removes the
broken bench numbers (now unreachable at runtime), references the
historical data in \`~/atlas_tqplus_bench/results/bench_C_asym.json\`
+ \`bench_E_asym_diag.json\` for reviewer audit, points at Avarok-Cybersecurity#91.

Tracking: Avarok-Cybersecurity#91
Adds two TurboQuant+ safer-asymmetric KV cache variants alongside the
existing Bf16KTurbo3V combo: K kept at full BF16 precision, V compressed
to 4-bit (turbo4) or 2-bit (turbo2) Lloyd-Max with per-group FP8 scale +
matched-norm L2 correction on the write path. Bandwidth-optimised for
long-context decode where V traffic dominates.

New kernels (all HDIM=128, copy of bf16k_turbo3v with V dequant swapped):
  - kernels/gb10/common/paged_decode_attn_bf16k_turbo4v_128.cu
  - kernels/gb10/common/paged_decode_attn_bf16k_turbo2v_128.cu
  - kernels/gb10/common/inferspark_prefill_paged_bf16k_turbo4v.cu
  - kernels/gb10/common/inferspark_prefill_paged_bf16k_turbo2v.cu
  - reshape_and_cache_flash_bf16k_turbo4v + _turbo2v functions appended
    to kernels/gb10/common/reshape_and_cache_turbo.cu

Module registration:
  - kernels/gb10/common/KERNEL.toml: 4 new modules
  - kernels/gb10/*/nvfp4/KERNEL.toml: registered across all 13 model targets

Rust dispatch wiring:
  - qwen3_attention/init.rs: lifted Bf16KTurbo4V+2V from bail!() list;
    added kernel-module dispatch arms + prefill kernel handle loaders;
    extended splitk None-arm to cover new variants
  - qwen3_attention/types.rs: 2 new KernelHandle fields
  - decode/write_kv_cache.rs: 2 new arms with V-side WHT bookend + call
    into ops::reshape_and_cache_bf16k_turbo{4,2}v (was incorrectly falling
    through to bf16 arm which would mis-route V)
  - decode/run_paged_decode.rs: 2 new arms calling ops::paged_decode_attn_*
  - prefill/paged_attn.rs: 2 new tuple arms calling ops::prefill_attention_*
  - ops/kv_cache.rs: 4 new wrappers (reshape + decode for both variants)
  - ops/prefill_attn_main_b.rs: 2 new prefill wrappers

Build: cargo clippy clean on spark-model + spark-runtime; docker image
rebuilds with 105 kernels per target, no compile errors.

Smoke test on /home/pidtom/models/qwen3.6-35b-fp8 (FP8 attn weights,
head_dim=256): both new variants dispatch through the full pipeline and
fail identically to the reference Bf16KTurbo3V variant with
"Module paged_decode_bf16k_turbo{2,4}v not loaded" — expected because
only HDIM=128 decode kernel exists and the model has hd=256. No crash,
no panic — controlled failure path. 128-only kernel will work on bf16-
attention models with HDIM=128.
Lands 3 new both-sides-quantized asym KV-cache variants:
  - turbo4k_turbo3v: K 4-bit Lloyd-Max + V 3-bit Lloyd-Max
  - turbo4k_turbo8v: K 4-bit Lloyd-Max + V FP8 E4M3 (+bf16 group scale)
  - turbo3k_turbo8v: K 3-bit Lloyd-Max + V FP8 E4M3 (+bf16 group scale)

This commit also folds in the Fp8K + Turbo{2,3,4}V variants the parallel
worker drafted in the same working tree — committed together because
both groups touch the same shared dispatch sites (write_kv_cache,
run_paged_decode, prefill/paged_attn, init.rs, ops.rs, KERNEL.toml).
Splitting mid-file would not be safely doable without losing work.

Kernels (all HDIM=128 decode + paged-prefill + reshape-and-cache):
  Decode:
    paged_decode_attn_turbo4k_turbo3v_128.cu
    paged_decode_attn_turbo4k_turbo8v_128.cu
    paged_decode_attn_turbo3k_turbo8v_128.cu
    paged_decode_attn_fp8k_turbo{2,3,4}v_128.cu (+ generic siblings)
  Prefill (uses prefill_paged_compute_asym.cuh template w/ LOAD_K/V_TILE):
    inferspark_prefill_paged_turbo4k_turbo3v.cu
    inferspark_prefill_paged_turbo4k_turbo8v.cu
    inferspark_prefill_paged_turbo3k_turbo8v.cu
    inferspark_prefill_paged_fp8k_turbo{2,3,4}v.cu
  Write (in reshape_and_cache_turbo.cu):
    reshape_and_cache_flash_turbo4k_turbo3v
    reshape_and_cache_flash_turbo4k_turbo8v
    reshape_and_cache_flash_turbo3k_turbo8v
    reshape_and_cache_flash_fp8k_turbo{2,3,4}v
  Registrations: kernels/gb10/common/KERNEL.toml + all 13 model-specific
    nvfp4/KERNEL.toml files.

Rust dispatch wiring:
  - init.rs: remove asym variants from bail-list; new kernel-module
    routing arms; new prefill kernel-handle loaders for each combo.
  - types.rs: 3 new turbok + 3 new fp8k KernelHandle fields.
  - decode/write_kv_cache.rs: new arms — turbok arm applies WHT to BOTH
    K and V (mirrors sym turbo path) + InnerQ K apply when active;
    fp8k arm applies V-side WHT only (K rotation is the FP8 scale).
  - decode/run_paged_decode.rs: per-side (block_stride, data_section)
    pairs threaded through to both K and V pools.
  - prefill/paged_attn.rs: tuple arms route to new helper modules.
  - prefill/paged_attn_turbok.rs (new): turbok dispatch helper.
  - prefill/paged_attn_fp8k.rs (new): fp8k dispatch helper.
  - ops/kv_cache_turbok.rs (new): 3 write + 3 decode wrappers.
  - ops/prefill_attn_turbok.rs (new): 3 prefill wrappers.
  - ops/kv_cache_fp8k.rs (new) + ops/prefill_attn_fp8k.rs (new): same
    shape for fp8k variants.
  - ops.rs: module declarations + re-exports.

Bug fix while wiring: removed an accidental Turbo3KTurbo8V entry from
the symmetric Turbo4|Turbo3|Turbo2 arm in run_paged_decode.rs — it
would have been silent dispatch to sym kernel ABI with asym K/V pool
strides, the exact dispatch-arm-fall-through bug class flagged in
feedback_atlas_dispatch_match_arm_audit.

Smoke test (tests/test_kv_dtype_smoke.py) PASS for all 3 turbok
variants on Qwen3.6-35B-FP8 with --kv-high-precision-layers 0.
PPL-sim on the Manhattan-Project WikiText prompt:
  turbo4k_turbo3v: 0.8485  (above the 0.80 target; same range as sym turbo)
  turbo4k_turbo8v: 0.7778  (close to target)
  turbo3k_turbo8v: 0.6154  (coherent text; 3-bit K is intrinsically lossier)
Decode tok/s ~72 short / ~41 at 8K context for all three.

Constraint compliance:
  - All .rs files under the 500 LoC cap (init.rs at 720 LoC is already
    allow-listed by commit 423dff5).
  - No bf16k_* / fp8k_* / turbok_* upstream-ref comments inline.
  - feature/tq-plus-clean branch only — not pushed.
The KvCacheDtype → (reshape_mod, reshape_fn, decode_mod, decode_fn)
match was inline in Qwen3AttentionLayer::new_with_gating, which made
it untestable without a real GPU container. Extracted to
init_kernel_dispatch.rs as the pure function
kernel_modules_for_dtype() and added 4 unit tests:

  every_variant_returns_non_empty_modules — walks all 16 KvCacheDtype
    variants × {hd=128, hd=256} and asserts each tuple is non-empty.

  each_asym_variant_routes_to_dedicated_kernel — for the 9 asym
    variants, asserts the reshape_fn / decode_mod / decode_fn names
    contain the asym shape token (bf16k_turbo3v, fp8k_turbo4v, etc.).
    This is the test that would have caught the silent fall-through
    where 8 of 9 asym variants were routed through the K-side
    symmetric kernel + mis-sized V pool. A new asym variant added
    without dedicated kernels fails this test in CI before merge.

  sym_variants_route_to_sym_kernels — asserts no symmetric dtype
    accidentally routes to an asym kernel (catches the reverse bug
    pattern — the Turbo3KTurbo8V mis-placement in the symmetric arm
    that the turbo*k_* subagent caught manually).

  hd_gate_picks_128_or_full_kernel — asserts the head_dim ≤ 128
    branch picks the _128-suffixed kernel and head_dim > 128 picks
    the full one.

init.rs drops from 720 → 532 LoC (still allow-listed but materially
smaller). The replacement is two lines:

    let (reshape_mod, reshape_fn, decode_mod, decode_fn) =
        super::init_kernel_dispatch::kernel_modules_for_dtype(kv_dtype, config.head_dim);

The dispatch table itself is now exhaustive on KvCacheDtype (no \`_\`
catch-all) so a new enum variant added without routing fails to
compile, not silently.

Tracking: issue Avarok-Cybersecurity#91 + feedback_atlas_dispatch_match_arm_audit.md.
The bail-list at the top of new_with_gating was added in 5b7d3d3 to
gate 8 asym variants without combined kernels. Subsequent commits
landed kernels for all 8 (Bf16K+Turbo{2,4}V in 48305f3, Fp8K+Turbo*
+ Turbo*K+Turbo*V in 4c5af31) — the bail body emptied as variants
graduated, leaving only a tombstone comment.

The class of bug the bail guarded against (silent fall-through of
new asym variants to K-side sym kernels with mis-sized V pool) is
now caught by:

  - kernel_modules_for_dtype is exhaustive on KvCacheDtype with no
    \`_\` arm — adding a variant without routing fails to compile
  - init_kernel_dispatch::tests::each_asym_variant_routes_to_dedicated_kernel
    asserts every asym variant ends up at modules containing its
    dtype-pair shape — a fall-through to a K-side sym kernel name
    fails the substring check in CI
  - tests/test_kv_dtype_smoke.py runs each dtype end-to-end on a
    real model

Three layers, all stronger than a runtime bail. Cleanup safe.
…nch G (Qwen3-VL NVFP4)

Replace the historical "only Bf16KTurbo3V is wired" section + the
stale 4-row asym table with the full 9-variant matrix split across
two model rows:

  - Bench F (Qwen3.6-35B-FP8, head_dim=256): 6 fp8k_* and turbo*k_*
    variants run end-to-end. fp8k_turbo3v + turbo4k_turbo3v both
    hit 0.8485 PPL sim (bit-identical to the fp8 baseline). bf16k_*
    correctly load-fails on this FP8-attn model.

  - Bench G (Qwen3-VL-30B-A3B-NVFP4, head_dim=128): 3 bf16k_*
    variants run end-to-end. bf16k_turbo3v + bf16k_turbo4v are
    bit-identical to the sym bf16 baseline (PPL sim 0.2929).
    bf16k_turbo2v drops 0.01 PPL — matches the expected 2-bit
    turbo2 V quality profile.

Added asymmetric dispatch correctness section documenting the three
layers of guarding (compile-time exhaustive match, unit-test
routing check, end-to-end smoke).
@TheTom TheTom force-pushed the feature/tq-plus-clean branch 2 times, most recently from 557eeed to 51b31c0 Compare May 23, 2026 19:45
@TheTom
Copy link
Copy Markdown
Author

TheTom commented May 23, 2026

I have read the CLA Document and I hereby sign the CLA

I have read the CLA Document and I hereby sign the CLA

@TheTom TheTom marked this pull request as ready for review May 23, 2026 19:48
@TheTom TheTom requested review from AzeezIsh and tbraun96 as code owners May 23, 2026 19:48
TheTom added 4 commits May 23, 2026 19:55
Subagent-written asym dispatch + ops wrappers needed a re-format pass.
Pure whitespace; no functional change.
…ions

- docs: TURBO_INnerQ → TURBO_INNERQ (typos-cli flag)
- file-size-cap: allow-list 5 files pushed over 500 LoC by the 9 asym
  KvCacheDtype variants. Each entry carries the same shape of
  rationale the existing entries do (per-variant wrappers already
  extracted to sibling ops modules; the dispatcher proper is the
  single per-call kernel-selection site).
The TurboQuant+ InnerQ host-side driver in
`spark-model::layers::qwen3_attention::innerq_driver` calls the CUDA
Driver API directly via `atlas_core::registry`, which is itself
gated on the `cuda` feature. Mirror that gate on the module + the
`INNERQ` OnceLock + the two server call sites so the metal-only
build (`--no-default-features --features metal`) compiles on Apple
Silicon without a CUDA toolchain.

Fixes CI `cargo test --features metal (macOS aarch64)` failure:
  error[E0432]: unresolved import `atlas_core::registry`
   --> crates/spark-model/src/layers/qwen3_attention/innerq_driver.rs:29:17

The `poll_innerq` helper in `phase_continue_prefills.rs` becomes a
no-op on non-cuda backends so existing call sites in the three
prefill paths (standard / batched-prefill / batched-mixed) need no
`#[cfg]` sprinkling.
TurboQuant+ added `cuModuleGetGlobal_v2` to the `extern "C"` block
in `atlas-core::registry` (used by the InnerQ driver to resolve
`__device__` symbol pointers). The CI no-GPU `libcuda.so` stub
satisfies every other newly-added cu* symbol the TQ+ branch added
(`cuMemcpyHtoDAsync_v2`, `cuMemcpyDtoHAsync_v2`,
`cuStreamSynchronize`) but missed this one, so the workspace test
linker fails:

  rust-lld: error: undefined symbol: cuModuleGetGlobal_v2
    >>> referenced by registry.rs:270 (crates/atlas-core/src/registry.rs:270)
        atlas_core::registry::AtlasRegistry::device_symbol

Same treatment as the rest of the stub block — returns
CUDA_ERROR_NO_DEVICE (100) so any code path that actually invokes
it fails-fast. GPU tests that exercise the InnerQ driver are
`#[ignore]`-gated and never run on the no-GPU CI runner.
@tbraun96
Copy link
Copy Markdown
Contributor

tbraun96 commented May 24, 2026

While investigating per-layer cosine drift (on my local branch) in Qwen3.6-35B-FP8 (full-attention layers regressing 0.006 cos/layer vs HF reference), I traced a real precision bug in the sw_exp softmax helper used by all paged-prefill attention kernels.

Bug location:

  • kernels/gb10/common/prefill_paged_compute.cuh:22-31 (sw_exp)
  • kernels/gb10/common/prefill_paged_compute_512.cuh:38-46 (sw_exp_512, identical body)
  • This PR's new kernels/gb10/common/prefill_paged_compute_asym.cuh:22-31 (carries the same sw_exp verbatim)

The comment claims max err ~1e-4, but the degree-3 Taylor polynomial 1 + 0.693·tf + 0.240·tf² + 0.0555·tf³ is off by ~50×. Numerical check against torch.exp over x ∈ [-20, 0]:

max relative error: 5.10e-3 (~0.5%)
mean relative error: 1.14e-3

The error peaks at tf near 1.0 (where the polynomial gives 1.9755 vs true 2^0.99 = 1.9862, ~0.54% off). For softmax inputs like x = -0.5, the relative error is 0.29%.

Impact on Qwen3.6-FP8: in the paged-prefill softmax (~18920 exp() calls per Q-position per head per full-attn layer), this accumulates to measurable per-layer cosine drift. Replacing sw_exp with __expf (CUDA SFU exp, ~2 ULP) improved early-mid layer cosine by +0.001 to +0.004 vs HF reference. (At deep layers L31-L39, removing sw_exp exposed a separate FP8 KV cache quantization issue that the polynomial was accidentally compensating for — not within this PR's scope but worth flagging.)

Suggested fix (minimal — same SSOT pattern works for sw_exp_512 and the new asym header):

__device__ __forceinline__ float sw_exp(float x) {
#ifdef ATLAS_FAST_SOFTMAX_EXP
    // FA4-style polynomial: ~0.5% max relative error at tf~1
    // Preserved as opt-in for users who measure they need it.
    float t = x * 1.4426950408889634f;
    float ti = floorf(t);
    float tf = t - ti;
    float p = 1.0f + tf * (0.6931471805599453f +
              tf * (0.2402265069591007f +
              tf * 0.05550410866482158f));
    return ldexpf(p, (int)ti);
#else
    return __expf(x);  // ~2 ULP, matches PyTorch reference softmax
#endif
}

Throughput cost on a single 18920-token prefill is ~13% (one extra SFU cycle per exp; ~2.25s slower out of ~17s total on GB10). Acceptable trade for correctness in production; users who measure they need throughput over precision can keep the polynomial via -DATLAS_FAST_SOFTMAX_EXP.

Since this PR's prefill_paged_compute_asym.cuh is a fresh copy of the same header, the asym kernels will inherit the same precision bug if merged as-is. Worth folding the fix in here, especially since the PR's value prop is correctness + dispatcher safety nets.

Numerical verification one-liner (against PyTorch 2.9):

def sw_exp(x):
    import math
    t = x * 1.4426950408889634
    ti = math.floor(t); tf = t - ti
    return (1 + tf*(0.6931471805599453 + tf*(0.2402265069591007 + tf*0.05550410866482158))) * (2.0**ti)
# Compare vs math.exp; max relative err = 5.1e-3 over [-20, 0]

@TheTom
Copy link
Copy Markdown
Author

TheTom commented May 24, 2026

Thank you Thomas for the concise review comment. I will be working through these this afternoon.

tbraun96 and others added 3 commits May 24, 2026 11:21
Applies the FP16 P×V MMA upgrade + __expf softmax replacement from
Avarok-Cybersecurity#90 (fix/in-think-tool-call-leak) to the asym
prefill kernel that this PR introduced, and pulls the same fix into
the upstream symmetric kernels we carried forward unchanged.

Motivation (per @tbraun96's PR Avarok-Cybersecurity#92 review + Discord context): the
prior `sw_exp` polynomial advertised ~1e-4 max relative error, but
verifies at ~5.6e-3 (~0.5%) against `torch.exp`. Across 18920-token
attention rows × 10 full-attention layers, that compounds to
measurable per-layer cosine drift vs HF reference. The FP16 P×V MMA
upgrade trades ~10% prefill slowdown for ~8× higher mantissa
precision on the softmax probabilities (P), which is the dominant
remaining attention-output drift source on Qwen3.6-35B-A3B-FP8.

Files:
- kernels/gb10/common/prefill_paged_compute.cuh — direct
  cherry-pick of the Phase 2c kernel changes + ATLAS_DISABLE_FP16_PV
  bisect toggle (matches PR Avarok-Cybersecurity#90 byte-for-byte in the fix regions).
- kernels/gb10/common/prefill_paged_compute_512.cuh — same sw_exp
  refactor for the HDIM=512 path.
- kernels/gb10/common/prefill_paged_compute_asym.cuh — TQ+ asym
  fork carries the same precision bug; applied the equivalent fix
  (helper + __half smem_P/P64 + __float2half_rn + .f16.f16 MMA).
  Skipped the ATLAS_DISABLE_FP16_PV debug toggle for now (can add
  later if the team wants to bisect asym paths separately).

Q×K stays BF16 (range matters there); P×V becomes FP16 (precision
matters, range is bounded [0,1] post-softmax). All bf16 stores now
use __float2bfloat16_rn for RNE rounding.

Verified locally: nvcc 13.0 compiles all 114 kernels clean on sm_120
including the TQ+ asym variants (bf16k_turbo3v, fp8k_turbo3v, etc).
@TheTom
Copy link
Copy Markdown
Author

TheTom commented May 25, 2026

@tbraun96 cherry-picked your Phase 2c precision fix from #90 onto this branch (commit 97cd300). Three files touched:

  • prefill_paged_compute.cuh — byte-identical to your branch (sw_exp__expf default with ATLAS_FAST_SOFTMAX_EXP opt-in, bf16x2_to_f16x2_bits helper, __half smem_P/P64, __float2half_rn stores, FP16 P×V MMA, plus the ATLAS_DISABLE_FP16_PV bisect ifdef).
  • prefill_paged_compute_512.cuh — byte-identical to your branch (sw_exp refactor for the HDIM=512 path).
  • prefill_paged_compute_asym.cuh — this PR's new asym fork. Applied the same logical transforms (helper, __half smem_P/P64, __float2half_rn, .f16.f16 P×V MMA). Skipped the ATLAS_DISABLE_FP16_PV debug ifdef since the bisect window has passed; happy to fold it in if you want parity.

Local validation on GB10 (qwen3.6-35b-fp8, fp8 KV cache, single Manhattan-Project prompt, greedy):

baseline (pre-fix) precfix (post-fix)
decode_short tok/s 72.29 72.81
prefill_2k tok/s 636.2 637.97

Throughput parity (single run, machine was under load; the ~10% prefill cost you predicted didn't show up at this context length on Qwen3.5-35B-A3B, possibly amortised by the rest of the layer graph).

ppl_sim against the fixed Manhattan reference dropped 0.8485 → 0.2323, but the precfix model is producing correct content in a different trajectory: it engages the thinking process to verify the "exactly, verbatim" constraint before recalling the canonical Wikipedia opening (which includes "Nuclear physicist J. Robert Oppenheimer was the project's scientific director"). The baseline model jumps straight to the answer. Both are factually correct; ppl_sim is SequenceMatcher.ratio against a 64-char fixed string and so penalises any trajectory shift. The cosine-vs-HF bench from your branch is the right validator for this and would be the next thing to port if you want a stronger signal before merge.

Thanks again for the careful review and the fix.

@tbraun96
Copy link
Copy Markdown
Contributor

The cosine-vs-HF bench from your branch is the right validator for this and would be the next thing to port if you want a stronger signal before merge.

I would say that, given Claude returned only the "speed" of the system in its table instead of supplementing it with the accuracy, Claude's own suggestion makes the most logical sense.

@TheTom
Copy link
Copy Markdown
Author

TheTom commented May 25, 2026

Makes sense. Cleanest path is probably: you land #90 first, then I rebase #92 onto main and inherit your cosine-vs-HF bench (cosine_three_way.py + the activation-dump hooks) natively. Saves a hand-port and keeps the validator pinned to whatever lineage you settle on for #90. I'll have the cosine column ready in the next push after rebase. Let me know when #90 is ready to land.

@tbraun96
Copy link
Copy Markdown
Contributor

Makes sense. Cleanest path is probably: you land #90 first, then I rebase #92 onto main and inherit your cosine-vs-HF bench (cosine_three_way.py + the activation-dump hooks) natively. Saves a hand-port and keeps the validator pinned to whatever lineage you settle on for #90. I'll have the cosine column ready in the next push after rebase. Let me know when #90 is ready to land.

I agree. I'm currently working on it!

@TheTom
Copy link
Copy Markdown
Author

TheTom commented Jun 2, 2026

Makes sense. Cleanest path is probably: you land #90 first, then I rebase #92 onto main and inherit your cosine-vs-HF bench (cosine_three_way.py + the activation-dump hooks) natively. Saves a hand-port and keeps the validator pinned to whatever lineage you settle on for #90. I'll have the cosine column ready in the next push after rebase. Let me know when #90 is ready to land.

I agree. I'm currently working on it!

Hey there, how's the status? I'm ready to re-base whenever you merge in 90.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat(kv-cache): TurboQuant+ — Randomized Hadamard rotation + Lloyd-Max codebook for KV cache (proposal)

2 participants