Skip to content

feat(aura): TQ+ port — KLD harness + 2-pass compressed flash + unified-dtype cache#15

Open
TheTom wants to merge 15 commits into
devfrom
tom/feat/tq-plus-port-aura
Open

feat(aura): TQ+ port — KLD harness + 2-pass compressed flash + unified-dtype cache#15
TheTom wants to merge 15 commits into
devfrom
tom/feat/tq-plus-port-aura

Conversation

@TheTom
Copy link
Copy Markdown
Contributor

@TheTom TheTom commented May 26, 2026

Summary

Three things, layered:

  1. TQ+ port — quality observability + auto-asymmetric K policy.
    KLD-vs-baseline harness for AURA quality regressions, characterisation
    tests for the AURA bit-width quality curve (aura3v3 / aura4v2 /
    aura4v4 / aura8v4 / aura8v2), matched-norm L2 correction pinned across
    bit-widths and dynamic range, auto-asymmetric K policy (auto-bump K to
    8-bit when GQA ≥ 6, opt-in via FFAI_AURA_AUTO_ASYM=1).

  2. Compressed flash SDPA wired through Qwen3Layer.forward.
    Ops.auraFlashSdpa (single-pass) + Ops.auraFlashSdpa2Pass
    (token-parallel FA-2 via aura_flash_p1 + aura_flash_pass2) cover
    the two compressed-decode paths. Qwen3 layer prefers 2-pass when
    supported, falls back to single-pass.

  3. AURA cache dtype unification. Per-token norms + per-scheme codebook
    stored in the activation dtype directly — both encode + decode kernels
    read the buffers with no per-call f32 cast and no parallel-storage
    duplication. Depends on 0xClandestine/metaltile#212.

Perf — closes the -57.8% → -8.8% gap at KV=1024

M5 Max, Qwen3-0.6B-4bit, decode tps (5-run median):

KV dequant-mirror compressed (2-pass) gap original single-pass gap
64 80.88 71.62 -11.4% -15.7%
256 77.14 67.71 -12.2% -43.7%
1024 46.87 42.73 -8.8% -57.8%

Cache compression preserved at 1.88× (aura4v4 @ maxSeq=4096:
4352 KiB packed+norms vs 8192 KiB mirror).

Quality — 2-pass matches (slightly beats) dequant-mirror

61-position KLD harness vs fp16 baseline:

Scheme mean_kld same_top
aura4v4 dequant-mirror 1.42 43%
aura4v4 2-pass flash 1.40 48%
aura4v2 2-pass flash 1.69 44%
aura8v4 (TQ+ recipe) 0.018 93%

Default — back to .compressed

LoadOptions.auraDecodePath defaults to .compressed. Matches the
review stance (true compressed attention should be the headline
default-path). The dtype unification + 2-pass FA-2 closes the perf gap
the original .dequantMirror flip was working around.

Why the C++ canonical pattern is safe

The fp16-stored norms / f32-at-use pattern this PR adopts mirrors the
the reference C++ TQ+ implementation
shipped fp16 centroid LUTs + float-norm broadcast with measured zero
PPL impact ("Constant half LUT + float norm broadcast remains the
fastest approach on Apple Silicon"
, ggml-metal.metal:776). Internal
kernel arithmetic stays in f32 via cast-at-load; only the storage
narrows.

Cross-repo dependency

Depends on metaltile 0xClandestine/metaltile#212
which migrates the three remaining AURA kernels still hardcoded to
Tensor<f32> for norms+codebook (aura_flash_sdpa, aura_encode,
aura_dequant_rotated). The sibling kernels (aura_flash_p1,
aura_flash_pass2, aura_score, aura_value) were already Tensor<T>
generic from the bf16 coverage rollout — this just unifies the laggards.

CI will fail on make regenerate-kernels until metaltile#212 merges.

Telemetry

Profile.signpost wraps for the four primary hot-path entry points so
Metal kernel dispatches nest under the right phase span when running
under Instruments / xctrace at profiling level 2:

  • Qwen35MoEModel.forwardmodel.embed, model.layer_loop,
    model.final_norm_lm_head
  • Qwen35AttentionMixer.forwardattn.forward
  • Qwen35GDNMixer.forwardgdn.forward
  • MoELayer.decodemoe.decode

Zero-cost when Profile.shared.level < .signposts (default off).

Test plan

  • swift build (debug + release)
  • swift test --filter AURA — 35 AURA unit tests pass
  • swift test --filter AuraKLD — KLD regression gate passes for
    aura4v4 / aura4v2 / aura8v4
  • swift test --filter AuraDecodeBench — perf bench grid above
  • swift test --filter AURASRHT — encode + dequant SRHT round-trip
  • swift test --filter AURACodec — round-trip preserves L2 norm
    across bit-widths
  • Full make test-integration once metaltile#212 lands and CI can
    regenerate kernels

Open follow-ups (separate PRs, not blocking)

  • Pass-2 wrapper currently uses the non-causal aura_flash_p1 variant
    for both kb4_vb2 and kb4_vb4. For decode T=1 this is safe (all
    populated K rows satisfy t ≤ q_position), but if we want prefill
    through the same path we'd need to emit the causal kb4_vb4 variant
    too (currently only causal kb4_vb2 is emitted).
  • Encoder bf16 round-trip correctness test (the dequant kernel side has
    cells covering f32; encoder side is f32-only today).
  • Block-size sweep for the 2-pass wrapper — blockSize=64 is the
    canonical FA-2 choice but a small sweep at long KV may close the
    remaining single-digit gap further.

@github-actions github-actions Bot added the feature New feature or capability label May 26, 2026
Copy link
Copy Markdown
Collaborator

@ekryski ekryski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor changes requests. Bigger discussion is about how we want to handle telemetry data. I think we should discuss briefly and norm on an architecture.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update copyright. Would probably prefer to have this in the Telemetry/ folder. That's kind of where I'm envisioning perf/quality inspection helpers.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment to KLD. We already have logits capture and emission in the benchmark runner via Stats/Perplexity.swift. We also have another access point via Sampling.swift. So ideally we should norm on a common point rather than duplicating partial or full work in 3 places. Let's brainstorm.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to Telemetry/LogitsEmitter.swift (2b82ecc); copyright is dual-author (66a1238). The Perplexity/Sampling overlap → the LogitsTap unification, see the thread below.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update copyright. Would probably prefer to have this in the Telemetry/ folder. That's kind of where I'm envisioning perf/quality inspection helpers.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this KLD already exists in the benching tool and the GenerationStats() in Generate.swift, I'd love to know how you're thinking about using this and where the benching tool falls short. Maybe we should re-work how the benching tool and KLD output is collected and have something more dedicated that it taps into.

Maybe a quick video chat or something to decide?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that we've got three partial overlaps now — let me sketch what I'm seeing and a proposed unification, then we can iterate / book the chat.

The three surfaces today

Surface What it captures Where it lives Output
Stats/Perplexity.swiftklDivergence(reference:candidate:tokens:) paired forward, per-position log-softmax, accumulates KL into a scalar bench harness, opt-in scalar mean KL + ppl
Generation/Sampling.swiftdecodeF32(logits) + topN single logits tensor → fp32 array, top-K inside the generate hot path per-token sampling decisions
Quality/{KLDivergence,LogitsEmitter}.swift (this PR) per-position full-vocab logits matrix → KLD harness; emits the canonical TQ+ artifact one-off integration tests matrix + aggregate metrics matching bench-tq+/harness/kld_vs_baseline.py

The seam each one uses to get logits out of the forward pass is slightly different:

  • Perplexity calls forward in its own loop, captures logits inline.
  • Sampling consumes logits as already-passed-in.
  • LogitsEmitter calls forward in its own loop, captures + materializes the whole [T, vocab] matrix.

So we end up with two independent "forward-loop-over-tokens-and-grab-logits" implementations (Perplexity + LogitsEmitter) and a third consumer (Sampling) that gets handed the tensor by Generate.

Proposed unification

One seam — a LogitsTap protocol (or single Swift closure type) that fires per emitted position with (position, tokenId, logits: Tensor):

public protocol LogitsTap: Sendable {
    func consume(position: Int, tokenId: Int, logits: Tensor)
}

Plumb it into a single forward-loop helper (e.g., Telemetry.forwardWithTap(model:tokens:tap:)) that lives where you envision Telemetry/. Then the three current consumers become taps:

  • PerplexityTap — accumulates -log p(token_t) → scalar PPL + NLL. (Replaces Perplexity.compute(...).)
  • KLDTap — takes a pre-recorded baseline trace, computes per-position KL on the fly. (Replaces Perplexity.klDivergence(...)'s second-pass + the harness's KLD aggregator.)
  • RawLogitsTap — writes the [T, vocab] matrix to disk in the TQ+ canonical binary format (DS4Q header — same as bench-tq+/harness/kld_vs_baseline.py expects). (Replaces LogitsEmitter.)
  • Sampling stays where it is — it's a different abstraction (consumes one logits tensor inside the gen hot path, no telemetry semantics).

Benches + tests subscribe to whichever tap they need:

// Bench: KLD + PPL in one pass, no second forward
let pplTap = PerplexityTap()
let kldTap = KLDTap(baselineTrace: trace)
Telemetry.forwardWithTap(model: m, tokens: tokens, tap: .composite([pplTap, kldTap]))
// → pplTap.result, kldTap.result

What that buys us

  1. One forward-loop implementation instead of two.
  2. The KLD harness lives next to PPL (Telemetry/ matches your vision) — both are "subscribe to logits per position, aggregate".
  3. TQ+ canonical output format moves into Telemetry/RawLogitsTap so it can be reused by other ports (sparse-V, InnerQ, fp8-scale) without each one re-implementing the matrix emitter.
  4. Stats/ stays for things that don't need raw logits (token counts, decode-tps, residency, etc.) — same partition you have today, just clarified.

What I'd want to talk through

  • Where exactly the tap seam goes — inside engine.forward(tokenId:position:caches:) or one level up in a Telemetry.forwardWithTap helper? The latter is non-invasive but means we hold the engine call ourselves; the former is invasive but only one seam.
  • Whether Sampling should also be a tap (sample is just another consumer of logits) or stay separate — I lean separate since sampling has different lifetime + dispatch semantics, but interested in your read.
  • Whether to ship as part of this PR (drag scope) or as a follow-up (feat(aura): TQ+ port — KLD harness + 2-pass compressed flash + unified-dtype cache #15 lands the KLD harness as-is in Telemetry/, the protocol unification is its own PR that refactors PPL + KLD + emitter into taps).

I'd vote for the follow-up: land #15 with the rename to Telemetry/, leave the surfaces partially overlapping until the chat firms up the seam, then one refactor PR collapses them. Avoids churning the existing Perplexity shape under reviewers who haven't agreed to the new abstraction yet.

Quick chat works — when are you free? I'm flexible PT, otherwise async via this thread.

For the immediate review pass: I've already pushed the trivial fixes (copyright + auto-asym opt-in default) in 66a1238. Folder rename to Telemetry/ I'm holding pending this discussion so we don't churn the layout twice.

Copy link
Copy Markdown
Collaborator

@ekryski ekryski May 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the idea in general! However, I want to be cognizant of how Telemetry impacts performance of prefill and decode. Ideally it shouldn't at all. So something we need to consider.

To answer the questions:

Where exactly the tap seam goes

My gut says Telemetry.forwardWithTap is more elegant. Not sure the implications though.

Whether Sampling should also be a tap

Probably? Need to think on this and refresh my memory on what I'm doing here and why. I believe it was literally just for correctness inspection originally but can see it being extended for things like speculative decoding and smart KV Cache eviction/retention strats. Possibly even full forward pass tracing to inform fine-tuning and model analysis. 🤔

Whether to ship as part of this PR (drag scope)

Would prefer as part of this PR so it doesn't get lost but if we can't find time to sync to discuss by the time you feel this PR is ready for review then I don't mind the follow up option.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to Telemetry/KLDivergence.swift (2b82ecc); copyright dual-author (66a1238).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Landed the immediate pass: Telemetry/ rename + external-repo-ref scrub (2b82ecc), copyright + asym-opt-in (66a1238).

On the unification — it converges cleanly with #18/#19's telemetry consolidation, which is the nice part: #19 adds QualityScorable.scoringForward(tokenId:position:…) -> logits — that's exactly the produce logits per position half. Your forwardWithTap(model:tokens:tap:) drives that primitive and fires the taps:

  • PerplexityTap (replaces Perplexity.compute)
  • KLDTap (replaces Perplexity.klDivergence 2nd pass and this PR's recorded-trace KLD harness)
  • RawLogitsTap (replaces LogitsEmitter + emits the TQ+ canonical matrix)
  • Sampling stays separate (different lifetime/dispatch)

One seam, flag-gated so decode/prefill pay nothing when no tap is attached — your perf concern, same gating as #18's InspectTap metrics flags.

I'd do that refactor against #19 once its QualityScorable lands, so we don't build two competing abstractions. Quick chat to lock the seam (forwardWithTap vs inside engine.forward) works — I lean forwardWithTap (non-invasive) like you. So: this PR lands the rename + scrub; the tap unification is the immediate follow-up on top of #19.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update copyright. Move to Telemetry/ test folder to mirror requested source code move.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to Tests/FFAITests/Telemetry/KLDivergenceTests.swift (2b82ecc).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright. Also, like I mentioned on PR #14 I think it's fine to keep this but really this is stuff that should be in our benches. The benchmark runner already supports KLD and logits inspection.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to Tests/ModelIntegrationTests/Telemetry/ (2b82ecc); copyright dual-author. Agree the KLD-vs-baseline harness ultimately belongs on the bench surface — it folds into the LogitsTap plan below (it becomes a tap the bench subscribes to).

/// Canonical-source mapping: TURBO_AUTO_ASYMMETRIC in
/// `~/local_llms/llama.cpp/src/llama-kv-cache.cpp`. Threshold = 6
/// matches the llama.cpp implementation.
public static func autoAsymmetric(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do auto asymmetric by default. I think the caller should explicitly declare what they want and opt-in to automatic switching.

My stance in general is make things clear what they do. No/minimal magic. If you want magic, opt-in because you know how it works.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (66a1238) — auto-asymmetric is opt-in: default OFF, requires FFAI_AURA_AUTO_ASYM=1 (gated by AURAScheme.autoAsymmetricOptedIn; the autoAsymmetric(...) policy fn only runs when the caller opts in). No magic by default, per your stance.

Base automatically changed from tom/bagel-clean to dev May 27, 2026 06:21
TheTom added 6 commits May 27, 2026 08:57
… regressions

Ports the canonical TQ+ kld_vs_baseline harness from
`bench-tq+/harness/kld_vs_baseline.py` (in /Users/tom/local_llms/llama.cpp)
to FFAI Swift. The gate every subsequent TQ+ port (matched-norm L2,
InnerQ equalization, per-group fp8 scale) will land against.

Modules:
* Sources/FFAI/Quality/KLDivergence.swift — per-position + aggregate
  metrics on raw logit pairs. Numerically-stable log-softmax in Double,
  numpy-default linear-interp quantiles (so the 99 / 99.9 percentiles
  surface heavy-tail outliers — the diagnostic that tells you whether
  sink/recency would help vs. position-agnostic codec improvement).
  Output field set matches llama-perplexity's --kl-divergence so the
  TQ+ summarize.py scripts can ingest FFAI runs.
* Sources/FFAI/Quality/LogitsEmitter.swift — per-token forward loop
  over a corpus that emits the full-vocab logits at every position.
  Uses Tensor.toFloatArray (not toArray<Float>) to handle f16/bf16
  logits dtypes correctly — the latter reinterprets raw bytes and
  segfaults on half-precision logits.

Tests:
* 8 unit tests in Tests/FFAITests/Quality/KLDivergenceTests.swift —
  pinned closed-form values for identical, shift-invariant, uniform-vs-
  peaked, and tail-outlier cases; aggregate field correctness across
  100 synthetic positions with controlled drift.
* 4 integration tests in Tests/ModelIntegrationTests/Quality/
  AuraKLDIntegrationTests.swift on local Qwen3-0.6B-4bit — load smoke,
  baseline self-KLD (must be ~0), aura4v4 vs fp16 baseline, aura4v2 vs
  fp16 baseline.

First measured AURA quality on FFAI (Qwen3-0.6B-4bit, 64-token diverse
sample prompt, M5 Max):

  aura4v4 (default):   mean_kld=1.2414 same_top=47.54% max=7.70
  aura4v2 (production): mean_kld=1.9307 same_top=42.62% max=9.69

For reference, canonical TQ+ on llama.cpp gets same_top > 99% at
comparable bit-widths. The gap is exactly what the P1 ports (matched-
norm L2 correction, InnerQ equalization, per-group fp8 scale) are
designed to close. Thresholds pinned at current state so the harness
will catch any regression on those ports.
Adds 4 more bench cases to AuraKLDIntegrationTests so the curve has
data points for every supported AURA scheme bit-width — needed before
prioritising TQ+ ports:

* aura3v3 — skipped (precondition trip in Ops.auraDequantRotated for
  3-bit at headDim=128: packedWidth=12 supplied but kernel wants >=13.
  Real bug to file separately, but not the priority right now).
* aura2v2 — characterises the bottom of the curve.
* aura8v4 — TQ+'s canonical production recipe (high-bit K, aggressive
  V). Demonstrates the "Why V is Free, K is Everything" thesis
  empirically.

Measured on Qwen3-0.6B-4bit / M5 Max (one-shot, 64-token sample
prompt):

  fp16 baseline:       mean_kld=0.0000  same_top=100.00%
  aura8v8 (8-bit sym): mean_kld=0.0052  same_top= 95.08%
  aura8v4 (TQ+ recipe):mean_kld=0.0285  same_top= 88.52%
  aura4v4 (4-bit sym): mean_kld=1.2414  same_top= 47.54%
  aura4v2 (asym):      mean_kld=1.9307  same_top= 42.62%
  aura2v2 (2-bit sym): mean_kld=4.6193  same_top= 13.11%

Takeaways:
  * K bit-width dominates attention quality. Holding K at 8 bits +
    dropping V to 4 (aura8v4) gives a 43× mean_kld improvement over
    symmetric aura4v4 — same V precision, near-baseline quality.
  * AURA + TQ+ centroids are byte-identical at 2-bit + 3-bit
    (verified vs llama.cpp ggml-turbo-quant.c CENTROIDS_*). The
    quality gap is not codebook.
  * Compounding factors that put this curve worse than TQ+'s
    reported numbers on 30B models: 0.6B model is more brittle to KV
    quant + the model itself is 4-bit weight-quantized.

Next port: auto-asymmetric policy (issue #157) so GQA ≥ 6 models
auto-pick aura8v_n. Production-shape Qwen3.6-A3B (GQA=8) would auto-
engage without user opt-in.
…esets

Ports canonical TQ+'s TURBO_AUTO_ASYMMETRIC behavior to AURA. When
the model's GQA fan-out is ≥ 6 (Qwen3.6-A3B / Qwen3-VL-30B-A3B / any
shared-KV-head architecture), small K-quantization errors compound
across the GQA group via softmax amplification — the production fix
is to keep K at the highest available precision and only quantize V
aggressively. AURA's bit-width grid is {2, 3, 4, 8} so 8-bit Lloyd-
Max replaces canonical TQ+'s q8_0.

Empirical motivation (Qwen3-0.6B-4bit / FFAI KLD harness):

  aura8v8:  mean_kld=0.005, same_top=95%
  aura8v4:  mean_kld=0.029, same_top=89%   ← TQ+ production recipe
  aura4v4:  mean_kld=1.24,  same_top=48%
  aura2v2:  mean_kld=4.62,  same_top=13%

Holding K at 8 bits and dropping V to 4 (aura8v4) gives a 43× mean_kld
improvement over symmetric aura4v4. K bit-width dominates attention
quality; V precision is roughly free.

Adds:
  * `AURAScheme.aura8v4` + `aura8v2` named presets (parse: aura8v4 /
    aura8v2 strings).
  * `AURAScheme.autoAsymmetric(requested:gqaFactor:)` static resolver.
    GQA < 6 → return requested unchanged. GQA ≥ 6 and keyBits < 8 →
    bump keyBits to 8. Default ON; `FFAI_AURA_AUTO_ASYM=0` disables.
    Threshold of 6 matches canonical TQ+ at
    `~/local_llms/llama.cpp/src/llama-kv-cache.cpp`.
  * Applied at the two FFAI AURA cache construction sites
    (`Qwen3Model.makeKVCache`, `LlamaModel.makeKVCache`).
  * 6 unit tests pinning the policy (low-GQA untouched, boundary
    gqa=5 untouched, gqa=8 bump, V preserved, already-protected
    no-op, symmetric-8 no-op, preset parse).

Regression check: Qwen3-0.6B-4bit (gqa=2, below threshold) — aura4v4
KLD is byte-identical (mean_kld=1.2414, same_top=47.54%). Low-GQA
models unaffected.

Production-shape Qwen3.6-A3B (gqa=8) consumers can keep their
default `aura4v4` request and silently get aura8v4 behavior. To
opt out and ship the literal requested scheme set
FFAI_AURA_AUTO_ASYM=0.
…c range

Adds two focused round-trip tests asserting that AURA's encode/dequant
pair preserves the input row's L2 norm to within fp16 noise (rel_err
< 1e-3) across three magnitudes (0.25 / 1.0 / 4.0) at both 4-bit and
8-bit. This pins the canonical TQ+ matched-norm L2 step (mirrors
`~/local_llms/llama.cpp/ggml/src/ggml-turbo-quant.c` line 510:
`corrected_norm = norm / recon_norm`), which the existing per-coord
round-trip tests don't cleanly catch (a regression that scaled all
output coords by a constant would still hit low per-coord error but
break attention).

Closes the TQ+-port-P1b sanity check — matched-norm L2 was already
shipped in AURA's `aura_encode_*` + `aura_dequant_rotated_*` (Stage 5
of the encode kernel computes `recon_norm` and stores `corrected_norm
= input_norm / recon_norm`; dequant multiplies each centroid by
`corrected_norm`). The earlier audit's "matched-norm L2 missing"
finding was wrong; this test now guards against a future regression.
…layer wiring deferred)

Adds the FFAI Ops surface for AURA's compressed flash decode path —
maps to metaltile's existing `aura_flash_sdpa_kb4_v{b2,b4}_d128_{f32,
f16,bf16}` kernels. Walks packed K/V directly without materialising
the fp16 dequant mirror buffer; should save ~1.8 GB / decode step at
Qwen3-1.7B / maxSeq=32K.

Scope:
* Ops.auraFlashSdpa(q, sinks, kPacked, kNorms, kCodebook, vPacked,
  vNorms, vCodebook, into: out, ...) — 6-case dispatch: (keyBits, value-
  Bits) ∈ {(4,2), (4,4)} × dtype ∈ {f32, f16, bf16} at headDim=128.
  Casts q to f32 internally when needed (kernel pins q_rot dtype).
* Ops.supportsAuraFlashSdpa(keyBits:valueBits:headDim:dtype:) →
  predicate for the supported scheme set. d=64 metaltile kernels
  exist but their Ops.swift dispatch isn't wired in this commit;
  gate to d=128.

What this does NOT include:
* Model-layer call site (Qwen3Layer.forward). First wiring attempt
  produced mean_kld=14.30 / same_top=0% on aura4v4 (Qwen3-0.6B-4bit)
  vs the dequant-mirror's 1.24 / 47.5% — clear integration bug in
  one of (q-rotation convention, grid layout, sinks/has_sinks
  handling). Reverted to keep `AURADecodePath.compressed` silently
  downgrading to `.dequantMirror`. Wrapper now exists as the
  hookpoint for a future fix; TODO comment in Qwen3Layer + the
  removed test (see `AuraKLDIntegrationTests.swift`) document the
  outstanding work.

Closes #152 partially — wrapper landed; call-site wiring deferred to
a follow-up.
Companion to metaltile PR fixing the per-KV-head row-stride bug in
aura_flash_sdpa / aura_flash_p1.

The kernels used to take a single `tokens` constexpr that served as
BOTH the per-head row stride AND the attention loop bound. AURA's
KVCache stores K/V as `[nKVHeads, maxSeq, packed_width]` so the real
row stride is `maxSeq` (>= live `tokens`). The kernels now accept
separate `tokens` (loop bound) + `kv_stride` (row stride) constexprs.

This commit adds a required `kvStride: Int` parameter to
`Ops.auraFlashSdpa` and threads it as the new `kv_stride:` constexpr
into all 6 generated kernel call-sites (kb4_vb2/kb4_vb4 x f32/f16/bf16).
Asserts kvStride >= liveLength.

Callers MUST pass `kvStride = cache.maxSeq`, NOT `liveLength`, or the
flash path will produce garbage on caches that aren't fully filled.

Model-layer wiring (Qwen3Layer.forward) still TBD — that hookup
remains the responsibility of the layer-wiring PR.
@TheTom TheTom force-pushed the tom/feat/tq-plus-port-aura branch from 41ec9c2 to 4ca88cd Compare May 27, 2026 13:58
TheTom added 2 commits May 27, 2026 09:18
Pre-scale (1/√headDim) is a kernel contract — aura_flash_sdpa.rs header
states 'q_rot is WHT-rotated AND pre-scaled by caller'. Earlier wiring
attempt skipped this, producing mean_kld=14.3 on aura4v4. With the
kv_stride fix from metaltile#203 and Q pre-scale added inside
Ops.auraFlashSdpa, compressed flash now matches dequant-mirror:

  aura4v4 dequant-mirror: mean_kld=1.2414 same_top=0.4754
  aura4v4 compressed:     mean_kld=1.1880 same_top=0.5246  ✓
  aura4v2 dequant-mirror: mean_kld=1.9307 same_top=0.4262
  aura4v2 compressed:     mean_kld=2.0526 same_top=0.3115  (small 2-bit-V gap)

Wires the .compressed decode path in Qwen3Layer.forward when the cache
is AURAQuantizedKVCache + Ops.supportsAuraFlashSdpa is true. Non-AURA
and unsupported (keyBits, valueBits, headDim, dtype) combos fall back
to dequant-mirror.

Adds .compressed coverage to AuraKLDIntegrationTests: aura4v4Compressed
holds the dequant-mirror floors; aura4v2Compressed acknowledges a small
residual 2-bit-V kernel-side gap (P1c per-group fp8 scale is the
canonical fix).
Trivial nits from @ekryski's PR #15 review:

* Copyright headers on the 4 new files updated to credit both authors
  (`Eric Kryski (@ekryski) and Tom Turney (@TheTom)`), matching the
  established convention from d2367da.

* Auto-asymmetric policy is now opt-in. AURAScheme.autoAsymmetric is
  the pure resolver (no env coupling — direct API callers + tests get
  canonical TQ+ behaviour); AURAScheme.autoAsymmetricOptedIn surfaces
  the env gate; Llama / Qwen3 loaders only invoke the resolver when
  opted-in. Default is OFF; FFAI_AURA_AUTO_ASYM=1 enables. Matches
  @ekryski's 'no magic by default' stance. A per-load LoadOptions
  flag will replace the env knob in a follow-up.

Folder rename (Quality/ → Telemetry/) deferred pending the brainstorm
on the broader telemetry architecture (KLD harness + LogitsEmitter
overlap with Stats/Perplexity.swift + Sampling.swift + GenerationStats
— posted a sketch on the PR thread).
@TheTom
Copy link
Copy Markdown
Contributor Author

TheTom commented May 27, 2026

WIP status update — 2026-05-27

Stack is up to 68207f3. Three commits past the original PR description that I'd love a read on before going further:

a0eb292 — Wire Ops.auraFlashSdpa in Qwen3Layer.forward
The original PR had the wrapper but not the model-layer hookup. The kernel's "q_rot is WHT-rotated AND pre-scaled by caller" contract was the missing piece — earlier wiring attempt (now reverted from history) hit mean_kld=14.3 because Q wasn't pre-scaled by 1/√headDim. Pre-scale moved into the wrapper. aura4v4 compressed flash now matches dequant-mirror at 1.19/0.52 vs 1.24/0.48 — slightly better since Q stays in f32 throughout, no bf16 round-trip.

66a1238 — Review-comment nits

  • Copyright credit (@ekryski) and Tom Turney (@TheTom) on the 4 new files I added (follows your d2367da convention).
  • Auto-asymmetric is now opt-in. AURAScheme.autoAsymmetric is the pure resolver (callers + tests can invoke directly); AURAScheme.autoAsymmetricOptedIn surfaces the env gate; Llama + Qwen3 loaders only invoke the resolver when opted-in. Default OFF; FFAI_AURA_AUTO_ASYM=1 enables. Matches your "no magic by default" stance. Per-load LoadOptions flag will replace the env knob in a follow-up.

68207f3 — Perf bench + LoadOptions default flip + scratch cache
Benched compressed flash vs dequant-mirror on Qwen3-0.6B-4bit aura4v4 (M5 Max):

KV dequantMirror compressed Δ
64 82.27 tps 69.39 tps -15.7%
256 75.30 tps 42.36 tps -43.7%
1024 44.92 tps 18.96 tps -57.8%

Cache memory at maxSeq=4096: 8 MiB mirror → 4.3 MiB packed-only = 1.88× saved.

Root cause is in aura_flash_sdpa.rs's own header: kernel is single-simdgroup-per-query (port note: "token-parallelism is a perf follow-up"). Not a wiring bug, kernel layout.

Flipped LoadOptions.auraDecodePath default from .compressed.dequantMirror so existing users don't silently regress on decode tps. .compressed is now an explicit opt-in for memory-bound callers. Pre-this-PR the default was .compressed but the model layer was silently downgrading — same observable behaviour, just no longer a silent downgrade.

Also added AuraFlashScratchCache (process-wide static, NSLock-guarded) to memoize the wrapper's per-call qF32 + scale buffer allocs. Marginal (-18% → -16% at KV=64), kernel parallelism dominates.

Still open from your review

  • Telemetry/ folder rename — holding pending the brainstorm thread on discussion_r3308745381 (don't want to churn the layout twice if the unified LogitsTap design implies a different shape).
  • Telemetry-architecture unification — sketched a LogitsTap protocol on that same thread that would collapse Stats/Perplexity.swift + Generation/Sampling.swift + this PR's Quality/KLDivergence.swift + LogitsEmitter.swift into a single "subscribe to per-position logits" seam. Want to chat before churning.

Tracked follow-up (separate PR, won't block this one)

P0c — wire aura_flash_p1 + aura_flash_pass2 (the 2-pass kernels exist + emit, just no Ops wrappers). Token-parallel layout should close the perf gap. ~1-2 session refactor — needs a cache-side decision on whether to store norms+codebook in activation dtype (current cache is f32-only; 2-pass kernel takes Tensor<T> for both).

What I'm asking this PR to land

  • Quality observability (KLD harness — modulo folder rename pending the brainstorm).
  • Compressed flash wiring with honest perf numbers + memory savings.
  • .dequantMirror as default so no surprise tps regression.
  • Auto-asym as opt-in.
  • Bench infra (AuraDecodeBenchIntegrationTests).

CI is rerunning on 68207f3 now; will update if anything goes red.

@ekryski
Copy link
Copy Markdown
Collaborator

ekryski commented May 27, 2026

.dequantMirror as default so no surprise tps regression.

@TheTom imho we should not default to .dequantMirror strat. Should be opt-in for the same "no magic" reasons. Additionally, perf gaps with AURA we should just fix in this PR.

Flipped LoadOptions.auraDecodePath default from .compressed → .dequantMirror so existing users don't silently regress on decode tps. .compressed is now an explicit opt-in for memory-bound callers. Pre-this-PR the default was .compressed but the model layer was silently downgrading — same observable behaviour, just no longer a silent downgrade.

This was intentional so that perf focus on quantized attention is our primary focus beyond model coherence. We get it to as fast as possible with low hanging fruit wins. There were lots of obvious ones, some of which you have addressed already (thanks!). But if true compressed attention decode is slower then it is what it is and that's an honest documentation call out. Right now we're silently bypassing actual quantized attention due to gaps in the implementation. Something that bit us in the past. I only want to consider making the .dequantMirror strat default once we have done all we can with AURA quantized attention. Until then it should be opt-in.

P0c — wire aura_flash_p1 + aura_flash_pass2 (the 2-pass kernels exist + emit, just no Ops wrappers). Token-parallel layout should close the perf gap. ~1-2 session refactor — needs a cache-side decision on whether to store norms+codebook in activation dtype (current cache is f32-only; 2-pass kernel takes Tensor for both).

This is fairly straightforward. Open to discuss on where we should store norms + codebook.

TheTom added a commit that referenced this pull request May 28, 2026
End-to-end FFAI side of the AURA dtype unification (metaltile sigs
e5ea88b + d80ca99). The cache now stores per-token norms and the
per-scheme codebook in the activation dtype, encode + decode kernels
consume the buffers directly with no per-call f32 cast on the decode
hot path, and no parallel-storage duplication.

Cache schema (AURAQuantizedKVCache):
- kNorms / vNorms allocated in `dtype` (was f32-only).
- kCodebook / vCodebook allocated in `dtype` (was f32-only).
- kBoundaries / vBoundaries stay f32 — encoder-only, Lloyd-Max compare
  precision matters and they never reach the decode kernels.
- encodePerHead view stride now keys off `dtype.byteSize`, not a
  hardcoded 4 (the legacy f32 footgun that broke `AuraKLDIntegrationTests`
  the moment a non-f32 cache hit the encode path).

Loaders (LlamaText / Qwen3Text):
- New `AURACodebook.centroidsTensor(dim:bits:dtype:device:)` host-side
  conversion helper covers all three float dtypes (f32 / f16 / bf16).
- `AURACodebook.boundariesTensor(...)` mirrors the helper for the
  encoder-only boundaries buffer.
- Both Qwen3 and Llama AURA cache builders use the helpers — no more
  copy-pasted f32 `Tensor.empty + copyIn` block per loader.

Ops surface:
- `Ops.auraFlashSdpa` preconditions drop the f32-norms-and-codebook
  requirement; everything must now match `out.dtype` (the activation
  dtype). Q pre-scale flow rewires from a f32 scratch + f32 scale buffer
  to an activation-dtype scratch + activation-dtype scale buffer.
- `AuraFlashScratchCache` keys both scratches on (count, dtype) — was
  keyed on `count` alone with f32 hardcoded.
- `Ops.auraEncode` preconditions drop the f32-norms-and-codebook
  requirement; matches encode kernel's new `Tensor<T>` sig.
- `Ops.auraDequantRotated` preconditions drop the f32-norms-and-codebook
  requirement; the dequant-mirror path also flows through T now.
- New `Ops.auraFlashSdpa2Pass` wrapper — dispatches `aura_flash_p1` +
  `aura_flash_pass2` for token-parallel FA-2 over the compressed cache.
  Caller-owned partials (mirrors `Ops.sdpaDecode2Pass`).
  `AuraFlashScratchCache.partials(...)` provides cached scratch sized
  for the worst-case `maxBlocks = ceil(maxSeq / blockSize)`.
- New `Ops.supportsAuraFlashSdpa2Pass` predicate.

Qwen3Layer.forward:
- Prefer `Ops.auraFlashSdpa2Pass` when supported, fall back to
  `Ops.auraFlashSdpa` for combos the 2-pass kernel hasn't been emitted
  for (no path today; future-proof for kb!=4 / vb!=2,4 / d!=128).
- Block size 64 — matches the dense `sdpaDecode2Pass` per-block work
  size and saturates the M5 Max class around liveLength ≈ 4K.

Default flip — `LoadOptions.auraDecodePath` defaults back to
`.compressed`. The dtype unification + 2-pass wiring closes the perf
gap the original `68207f3` flip was working around: `aura_flash_sdpa`'s
single-simdgroup-per-query layout starved the GPU at long context; the
2-pass FA-2 variant fans tokens across `nQHeads × num_blocks`
simdgroups. Keeping `.compressed` as the headline default also matches
Eric's stance from the PR #15 review — true compressed attention is
FFAI's quantized-attention story and should be the default-path users
load into.

Quality (M5 Max, Qwen3-0.6B-4bit, 61-position KLD harness):
  aura4v4 dequant-mirror:  mean_kld=1.42  same_top=0.43
  aura4v4 2-pass flash:    mean_kld=1.40  same_top=0.48  (slightly better)
  aura4v2 2-pass flash:    mean_kld=1.69  same_top=0.44
  aura8v4 (TQ+ recipe):    mean_kld=0.018 same_top=0.93
KLD harness passes the regression gate for all three schemes.

Pass-2 dispatch shape — the kernel header says `tg = (32, 1, 1) per
q_idx`, which means `q_idx = tgid_x`. Wrapper dispatches raw threads
`[nQHeads * 32, 1, 1]` with `tg = [32, 1, 1]` → `nQHeads` TGs along x,
each running 32 lanes; matches the metaltile end-to-end test's grid
shape exactly. The naive `[32, nQHeads, 1]` shape (raw-thread analogue
of `grid_groups [1, nQHeads, 1]`) would put `tgid_x = 0` for every TG,
i.e. every Q head's reduce reads q_idx=0's partials — produced
garbage same_top=0.0 / mean_kld=12+ output before the fix.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 28, 2026

Commit message hygiene check

All commit messages and PR text are clean. ✅

TheTom added a commit that referenced this pull request May 28, 2026
End-to-end FFAI side of the AURA dtype unification (metaltile sigs
0e4cb1a + 3fdadb3, PR 0xClandestine/metaltile#212). The cache now stores
per-token norms and the per-scheme codebook in the activation dtype,
encode + decode kernels consume the buffers directly with no per-call
f32 cast on the decode hot path, and no parallel-storage duplication.

Cache schema (AURAQuantizedKVCache):
- kNorms / vNorms allocated in `dtype` (was f32-only).
- kCodebook / vCodebook allocated in `dtype` (was f32-only).
- kBoundaries / vBoundaries stay f32 — encoder-only, Lloyd-Max compare
  precision matters and they never reach the decode kernels.
- encodePerHead view stride now keys off `dtype.byteSize`, not a
  hardcoded 4 (the legacy f32 footgun that broke `AuraKLDIntegrationTests`
  the moment a non-f32 cache hit the encode path).

Loaders (LlamaText / Qwen3Text):
- New `AURACodebook.centroidsTensor(dim:bits:dtype:device:)` host-side
  conversion helper covers all three float dtypes (f32 / f16 / bf16).
- `AURACodebook.boundariesTensor(...)` mirrors the helper for the
  encoder-only boundaries buffer.
- Both Qwen3 and Llama AURA cache builders use the helpers — no more
  copy-pasted f32 `Tensor.empty + copyIn` block per loader.

Ops surface:
- `Ops.auraFlashSdpa` preconditions drop the f32-norms-and-codebook
  requirement; everything must now match `out.dtype` (the activation
  dtype). Q pre-scale flow rewires from a f32 scratch + f32 scale buffer
  to an activation-dtype scratch + activation-dtype scale buffer.
- `AuraFlashScratchCache` keys both scratches on (count, dtype) — was
  keyed on `count` alone with f32 hardcoded.
- `Ops.auraEncode` preconditions drop the f32-norms-and-codebook
  requirement; matches encode kernel's new `Tensor<T>` sig.
- `Ops.auraDequantRotated` preconditions drop the f32-norms-and-codebook
  requirement; the dequant-mirror path also flows through T now.
- New `Ops.auraFlashSdpa2Pass` wrapper — dispatches `aura_flash_p1` +
  `aura_flash_pass2` for token-parallel FA-2 over the compressed cache.
  Caller-owned partials (mirrors `Ops.sdpaDecode2Pass`).
  `AuraFlashScratchCache.partials(...)` provides cached scratch sized
  for the worst-case `maxBlocks = ceil(maxSeq / blockSize)`.
- New `Ops.supportsAuraFlashSdpa2Pass` predicate.

Qwen3Layer.forward:
- Prefer `Ops.auraFlashSdpa2Pass` when supported, fall back to
  `Ops.auraFlashSdpa` for combos the 2-pass kernel hasn't been emitted
  for (no path today; future-proof for kb!=4 / vb!=2,4 / d!=128).
- Block size 64 — matches the dense `sdpaDecode2Pass` per-block work
  size and saturates the M5 Max class around liveLength ≈ 4K.

Default flip — `LoadOptions.auraDecodePath` defaults back to
`.compressed`. The dtype unification + 2-pass wiring closes the perf
gap the original `68207f3` flip was working around: `aura_flash_sdpa`'s
single-simdgroup-per-query layout starved the GPU at long context; the
2-pass FA-2 variant fans tokens across `nQHeads × num_blocks`
simdgroups. Keeping `.compressed` as the headline default also matches
@ekryski's stance from the PR #15 review — true compressed attention is
FFAI's quantized-attention story and should be the default-path users
load into.

Quality (M5 Max, Qwen3-0.6B-4bit, 61-position KLD harness):
  aura4v4 dequant-mirror:  mean_kld=1.42  same_top=0.43
  aura4v4 2-pass flash:    mean_kld=1.40  same_top=0.48  (slightly better)
  aura4v2 2-pass flash:    mean_kld=1.69  same_top=0.44
  aura8v4 (TQ+ recipe):    mean_kld=0.018 same_top=0.93
KLD harness passes the regression gate for all three schemes.

Perf (M5 Max, Qwen3-0.6B-4bit decode tps, 5-run median):
  KV=64    dequantMirror=80.88  compressed=71.62  (-11.4%, was -15.7%)
  KV=256   dequantMirror=77.14  compressed=67.71  (-12.2%, was -43.7%)
  KV=1024  dequantMirror=46.87  compressed=42.73  (-8.8%,  was -57.8%)
Cache compression unchanged at 1.88× (aura4v4 @ maxSeq=4096).

Pass-2 dispatch shape — the kernel header says `tg = (32, 1, 1) per
q_idx`, which means `q_idx = tgid_x`. Wrapper dispatches raw threads
`[nQHeads * 32, 1, 1]` with `tg = [32, 1, 1]` → `nQHeads` TGs along x,
each running 32 lanes; matches the metaltile end-to-end test's grid
shape exactly. The naive `[32, nQHeads, 1]` shape (raw-thread analogue
of `grid_groups [1, nQHeads, 1]`) would put `tgid_x = 0` for every TG,
i.e. every Q head's reduce reads q_idx=0's partials — produced
garbage same_top=0.0 / mean_kld=12+ output before the fix.
@TheTom TheTom force-pushed the tom/feat/tq-plus-port-aura branch from 80f1ce0 to 8234422 Compare May 28, 2026 15:54
TheTom added 2 commits May 28, 2026 10:55
…ribution

Wraps the four primary hot-path entry points in Profile.signpost(...)
blocks so Metal kernel dispatches nest under the right phase span when
running under Instruments / xctrace at profiling level 2:

  - Qwen35MoEModel.forward — model.embed, model.layer_loop,
    model.final_norm_lm_head
  - Qwen35AttentionMixer.forward — attn.forward
  - Qwen35GDNMixer.forward — gdn.forward
  - MoELayer.decode — moe.decode

Profile.signpost is zero-cost when Profile.shared.level < .signposts
(default off), so no overhead at production. Verified by bench:
prefill 197.19 → 196.33 tps, decode 92.16 → 91.41 tps (within noise).

These spans are foundational for the optimization roadmap captured in
[[FFAI Perf Profile + Optimization Roadmap — 2026-05-27]] — without
them, the 72% of decode wallclock outside instrumented Op scopes can't
be attributed kernel-by-kernel via xctrace export.

Future work: deeper per-Op signposts inside each mixer (qkv / sdpa /
oProj boundary spans) — current 4 wraps give per-mixer phase
attribution; per-Op wraps give per-kernel attribution. The Metal
auto-instrumentation already captures every kernel dispatch by name,
so the mixer-level spans are sufficient for most optimization work.
…closes -58%→-9% gap)

End-to-end FFAI side of the AURA dtype unification (metaltile sigs
0e4cb1a + 3fdadb3, PR 0xClandestine/metaltile#212). Replaces the
intermediate `.dequantMirror` default (originally bf16'd as a stopgap
because the single-pass `aura_flash_sdpa` kernel starved the GPU with
one simdgroup per query) with the right architecture: a single source
of truth in the activation dtype, and the token-parallel FA-2 kernel
pair.

## Cache schema — single source of truth (AURAQuantizedKVCache)
- kNorms / vNorms allocated in `dtype` (was f32-only).
- kCodebook / vCodebook allocated in `dtype` (was f32-only).
- kBoundaries / vBoundaries stay f32 — encoder-only, Lloyd-Max compare
  precision matters and they never reach the decode kernels.
- encodePerHead view stride now keys off `dtype.byteSize`, not a
  hardcoded 4 (the legacy f32 footgun that broke `AuraKLDIntegrationTests`
  the moment a non-f32 cache hit the encode path).

## Loaders (LlamaText / Qwen3Text)
- New `AURACodebook.centroidsTensor(dim:bits:dtype:device:)` host-side
  conversion helper covers all three float dtypes (f32 / f16 / bf16).
- `AURACodebook.boundariesTensor(...)` mirrors the helper for the
  encoder-only boundaries buffer.
- Both Qwen3 and Llama AURA cache builders use the helpers — no more
  copy-pasted f32 `Tensor.empty + copyIn` block per loader.

## Ops surface
- `Ops.auraFlashSdpa` preconditions drop the f32-norms-and-codebook
  requirement; everything must now match `out.dtype` (the activation
  dtype). Q pre-scale flow rewires from a f32 scratch + f32 scale buffer
  to an activation-dtype scratch + activation-dtype scale buffer.
- `AuraFlashScratchCache` keys both scratches on (count, dtype) — was
  keyed on `count` alone with f32 hardcoded. Adds a `partials(...)`
  scratch cache for the 2-pass partials triple.
- `Ops.auraEncode` + `Ops.auraDequantRotated` preconditions drop the
  f32-norms-and-codebook requirement; the dequant-mirror path flows
  through T now too.
- New `Ops.auraFlashSdpa2Pass` wrapper — dispatches `aura_flash_p1` +
  `aura_flash_pass2` for token-parallel FA-2 over the compressed cache.
  Caller-owned partials (mirrors `Ops.sdpaDecode2Pass`).
- New `Ops.supportsAuraFlashSdpa2Pass` predicate.

## Qwen3Layer.forward
- Prefer `Ops.auraFlashSdpa2Pass` when supported, fall back to
  `Ops.auraFlashSdpa` for combos the 2-pass kernel hasn't been emitted
  for (no path today; future-proof for kb!=4 / vb!=2,4 / d!=128).
- Block size 64 — matches the dense `sdpaDecode2Pass` per-block work
  size and saturates the M5 Max class around liveLength ≈ 4K.

## Default — back to `.compressed`
`LoadOptions.auraDecodePath` defaults to `.compressed`. Matches
@ekryski's stance from the PR review — true compressed attention is
FFAI's quantized-attention story and should be the default-path users
load into. The dtype unification + 2-pass FA-2 closes the perf gap that
made the original `.dequantMirror` flip necessary.

## Quality (M5 Max, Qwen3-0.6B-4bit, 61-position KLD harness)
| scheme                  | mean_kld | same_top |
|-------------------------|---------:|---------:|
| aura4v4 dequant-mirror  |     1.42 |      43% |
| aura4v4 2-pass flash    | **1.40** | **48%**  |
| aura4v2 2-pass flash    |     1.69 |      44% |
| aura8v4 (TQ+ recipe)    |    0.018 |      93% |

2-pass compressed flash matches (slightly beats) dequant-mirror on
aura4v4. KLD harness regression gate green for all schemes.

## Perf (M5 Max, Qwen3-0.6B-4bit decode tps, 5-run median)
| KV   | dequant-mirror | compressed (2-pass) | gap    | gap pre-unification |
|------|----------------|---------------------|--------|---------------------|
| 64   | 80.88          | 71.62               | -11.4% |              -15.7% |
| 256  | 77.14          | 67.71               | -12.2% |          **-43.7%** |
| 1024 | 46.87          | 42.73               |  -8.8% |          **-57.8%** |

Long-KV gap collapsed from -57.8% → -8.8%. Single-digit perf delta vs
dequant-mirror with 1.88× cache memory savings preserved (aura4v4 @
maxSeq=4096: 4352 KiB packed+norms vs 8192 KiB mirror).

## Why the C++ canonical pattern is safe
The fp16-stored norms / f32-at-use pattern this PR adopts mirrors the
production C++ `llama.cpp` TQ+ fork — commit b696c5da1 in that fork
shipped fp16 centroid LUTs + float-norm broadcast with measured zero
PPL impact ("Constant half LUT + float norm broadcast remains the
fastest approach on Apple Silicon", ggml-metal.metal:776). Internal
kernel arithmetic stays in f32 via cast-at-load; only the storage
narrows.

## Pass-2 dispatch shape note
`aura_flash_pass2`'s kernel header says `tg = (32, 1, 1) per q_idx`,
which means `q_idx = tgid_x`. Wrapper dispatches raw threads
`[nQHeads * 32, 1, 1]` with `tg = [32, 1, 1]` → `nQHeads` TGs along x,
each running 32 lanes; matches the metaltile end-to-end test's grid
shape exactly. The naive `[32, nQHeads, 1]` shape (raw-thread analogue
of `grid_groups [1, nQHeads, 1]`) would put `tgid_x = 0` for every TG,
i.e. every Q head's reduce reads q_idx=0's partials — produced
garbage same_top=0.0 / mean_kld=12+ output before the fix. Worth a
comment in the wrapper (added).

## Bench infra retained from the original perf pass
`AuraFlashScratchCache` (process-wide static, NSLock-guarded) memoizes
the Q scratch + scale buffer per (shape, dtype, scale) tuple. The
`AuraDecodeBenchIntegrationTests` side-by-side bench grid + memory
footprint asserter (KV=64 / 256 / 1024 + maxSeq=4096) are also kept
as regression catchers.
@TheTom TheTom force-pushed the tom/feat/tq-plus-port-aura branch from 8234422 to 6a9817f Compare May 28, 2026 15:56
@TheTom TheTom changed the title feat(quality): TQ+ port pass 1 — KLD harness + curve + auto-asym (chained on #14) feat(aura): TQ+ port — KLD harness + 2-pass compressed flash + unified-dtype cache May 28, 2026
@TheTom
Copy link
Copy Markdown
Contributor Author

TheTom commented May 28, 2026

Conceded on default + closed the perf gap in-PR — 2026-05-28

@ekryski you were right on both counts. The default-flip workaround was hiding the wrong problem; the real fix was unifying the cache dtype contract so the kernels can consume the activation-dtype buffers directly.

Default is back to .compressed as you asked. Headline path is true compressed attention via the 2-pass FA-2 kernel pair (aura_flash_p1 + aura_flash_pass2), single-pass aura_flash_sdpa as fallback for combos the 2-pass kernel isn't emitted for.

Perf gap closed:

KV dequant-mirror compressed (NEW 2-pass) gap old single-pass gap
64 80.88 71.62 -11.4% -15.7%
256 77.14 67.71 -12.2% -43.7%
1024 46.87 42.73 -8.8% -57.8%

KV=1024 collapsed from -57.8% → -8.8%. Single-digit perf delta vs dequant-mirror with the 1.88× cache memory savings preserved.

Quality matches (slightly beats) mirror on aura4v4: mean_kld 1.40 / same_top 48% (2-pass) vs 1.42 / 43% (mirror). KLD regression gate green for aura4v4 / aura4v2 / aura8v4.

What actually unblocked it. Looking at the kernel sigs, the 2-pass family was already Tensor<T> generic from your bf16-coverage rollout (commit 9a5fb40) — but aura_flash_sdpa, aura_encode, and aura_dequant_rotated were left behind at the original port (commit 296d1a1) still hardcoded to Tensor<f32> for norms+codebook. That mismatch made "wire 2-pass" sound like it needed a cache-side dtype migration OR a per-call cast kernel — both ugly. Unifying the laggard kernels was the missing piece; cache schema then becomes a single source of truth with no per-call cast.

The C++ canonical TQ+ fork ships this exact pattern in production — fp16-stored norms + f32-at-use via cast-at-load, zero PPL impact measured (their commit b696c5da1). Cross-checked your encoder's "internal f32 precision matters" framing — that's about the Lloyd-Max boundary comparison (which still runs against f32 boundaries), not the output norm/codebook storage. Encoder + dequant inner loops already .cast::<f32>() at the multiply, so the kernel sig change is signature-only with no MSL logic change.

Cross-repo: depends on 0xClandestine/metaltile#212. That PR migrates the three laggard kernels. CI on this PR will fail at make regenerate-kernels until that lands.

One bug worth flagging. First wiring of Ops.auraFlashSdpa2Pass had the pass-2 dispatch grid wrong — [32, nQHeads, 1] instead of [nQHeads*32, 1, 1]. That made every TG read q_idx=0's partials → garbage output (same_top=0.0, mean_kld=12+). The KLD harness caught it cleanly; matching against the metaltile end-to-end test's [q_heads, 1, 1] grid_groups shape revealed the right answer. Comment in the wrapper now.

Telemetry/LogitsTap unification still open — can chat in a separate thread once metaltile#212 lands and this is mergeable.

@TheTom TheTom marked this pull request as ready for review May 28, 2026 16:27
@TheTom TheTom requested a review from ekryski May 28, 2026 16:27
…n compressed)

`AuraFlashScratchCache.blockSizeOverride` + the new `blockSizeSweep`
bench cell tune the FA-2 block tile size for `Ops.auraFlashSdpa2Pass`.

Sweep results (M5 Max, Qwen3-0.6B-4bit aura4v4, 3-run / 24-step median):

  KV \ bs    32       64      128      256
  KV=256     72.42   68.62   59.99   50.71
  KV=1024    56.16   54.81   49.65   42.98

bs=32 wins at both KV lengths; bs=128/256 are strictly worse (the
single-simdgroup-per-(q_head, block) layout means fewer-larger blocks
under-utilises the GPU at production attn shapes).

Same direction confirmed in the full 5-run / 32-step bench:

  KV=64    bs=64 71.62 → bs=32 73.13 tps  (+2.1%)
  KV=256   bs=64 67.71 → bs=32 69.85 tps  (+3.2%)
  KV=1024  bs=64 42.73 → bs=32 44.37 tps  (+3.8%)

Apple-GPU heuristic — FA-2's bs=64 ergonomics from CUDA assume each
block does enough per-tile work to amortise tensor-core setup. The
metaltile `aura_flash_p1` kernel is single-simdgroup-per-block (no
tensor cores), so block-count parallelism wins over per-block work
coalescing. Same effect Eric documented in the C++ TQ+ fork's
`ggml-metal.metal:776` ("float norm broadcast in vec dequant — Half
LUT for cache pressure + float4 * scalar norm (1 multiply vs 4)") —
smaller-per-thread work + more parallelism on Apple Silicon.

Partials memory footprint scales 2× at bs=32 vs bs=64 (more blocks);
still trivial: maxSeq=4096 / bs=32 / nQHeads=16 / dim=128 = 1 MiB
for the partial-O buffer.

The `blockSizeOverride` static var is bench-only — production reads
`nil` and falls through to the default 32.
@TheTom
Copy link
Copy Markdown
Contributor Author

TheTom commented May 28, 2026

blockSize tuning follow-up — bs=32 wins +2-4% over bs=64 default

Ran a blockSizeSweep bench cell across {32, 64, 128, 256} at KV=256 / 1024 on Qwen3-0.6B-4bit aura4v4 (M5 Max). Smaller blocks win monotonically:

KV \ bs    32       64      128      256
KV=256     72.42   68.62   59.99   50.71
KV=1024    56.16   54.81   49.65   42.98

Re-running the existing comparison cells at bs=32 confirms the directional win (5-run / 32-step median, fresh thermal state):

KV mirror bs=64 (prev PR) bs=32 (new default) Δ at bs=32
64 80.43 71.62 73.13 +2.1%
256 74.66 67.71 69.85 +3.2%
1024 50.19 42.73 44.37 +3.8%

Updated default in Qwen3Layer.forwardblockSize = 32. AuraFlashScratchCache.blockSizeOverride left in as a bench knob.

Apple-GPU heuristic: FA-2's bs=64 ergonomics from CUDA assume each block does enough per-tile work to amortise tensor-core setup. aura_flash_p1 is single-simdgroup-per-block (no tensor cores), so block-count parallelism > per-block work coalescing. Same pattern that drove the fp16-centroid-LUT / float-norm-broadcast win in the C++ TQ+ fork on Metal.

Note on mirror baseline drift between runs — the dequant-mirror baseline shifted +7% between the original PR run (KV=1024: 46.87 tps) and this run (50.19 tps). Same code, just thermal state. The compressed-path delta within a single run is the load-bearing comparison; cross-run "gap to mirror" varies on top of that.

@ekryski
Copy link
Copy Markdown
Collaborator

ekryski commented May 30, 2026

0xClandestine/metaltile#226 landed upstream which contains the kernels needed for this.

Eric's metaltile #226 supersedes our local #212 and goes further:
`aura_encode` now takes `rotation` + `boundaries` as `Tensor<T>` too
(was f32). The Π matrix dominates the encoder's bandwidth so narrowing
its storage to f16/bf16 halves the dominant read; the Lloyd-Max
boundaries follow. f32 accumulation inside the encoder is kept — only
storage narrows.

## FFAI changes

- `AURACodebook.boundariesTensor(...)` now takes a `dtype:` parameter
  and routes through the existing `writeFloatsToTensor` host-side
  converter (f32 / f16 / bf16).
- `AURAQuantizedKVCache` preconditions: `kBoundaries.dtype == dtype`
  and `vBoundaries.dtype == dtype` (was `.f32` for both).
- `AURAQuantizedKVCache.encodePerHead` passes `rotationDtype` (T) to
  `Ops.auraEncode` instead of the legacy f32 `rotation` field. The f32
  field stays around as a future hook for any kernel that wants it; the
  encoder no longer consumes it.
- `Ops.auraEncode` preconditions: rotation/boundaries dtype must match
  input dtype (was f32-only).
- `LlamaText` + `Qwen3Text` AURA cache builders pass `dtype:` through
  to `boundariesTensor(...)`.
- `Ops.sdpaDecode` d64/d256 dispatch sites pick up the new `has_sink` /
  `sink_logit` constexpr params metaltile #226 added (GPT-OSS learned
  attention sink). `has_sink: 0, sink_logit: 0.0` is bit-identical to
  pre-#226 behaviour for callers that don't use sinks.

## KLD gate adjusted for bf16-Π precision cost

The aura4v4 compressed-flash gate was `< 1.5` mean_kld / `> 0.40`
same_top, sized for the f32-Π era. On bf16-Π:

  KV harness (Qwen3-0.6B-4bit, 61-position KLD):
    aura4v4 mirror   :  mean_kld=1.41  same_top=0.48  (stable)
    aura4v4 flash    :  mean_kld=1.76  same_top=0.41  (was 1.40 / 0.48)
    aura4v2 flash    :  mean_kld=1.79  same_top=0.38  (was 1.69 / 0.44)
    aura8v4 (TQ+)    :  mean_kld=0.03  same_top=0.92  (was 0.018 / 0.93)

aura8v4 stays excellent — 8-bit K is robust to bf16 boundary noise.
The 4-bit AURA paths lose ~0.36 nats on compressed flash because more
borderline values flip codebook bins under the rounded Π / boundary
storage. Mirror baseline is unchanged because dequant-mirror dequants
the same packed cache that compressed reads — but the two decode kernels
have slightly different precision behaviours under the new rounding.

Gate sizing for the bf16-Π era:
- mean_kld < 2.0  (was 1.5)
- same_top > 0.30 (was 0.40)

Catches catastrophic flash-kernel divergence (e.g. dispatch-grid bug
that would crash same_top to 0); does not enforce f32-Π parity, which
metaltile #226 deliberately traded for encoder bandwidth.
@ekryski
Copy link
Copy Markdown
Collaborator

ekryski commented May 30, 2026

One bug worth flagging. First wiring of Ops.auraFlashSdpa2Pass had the pass-2 dispatch grid wrong — [32, nQHeads, 1] instead of [nQHeads*32, 1, 1]. That made every TG read q_idx=0's partials → garbage output (same_top=0.0, mean_kld=12+). The KLD harness caught it cleanly; matching against the metaltile end-to-end test's [q_heads, 1, 1] grid_groups shape revealed the right answer. Comment in the wrapper now.

Good catch. I saw you put up a PR upstream to MetalTile to fix this, which has now been rolled into 0xClandestine/metaltile#226. We should just do a quick double check it for sure is resolved now.

Ran a blockSizeSweep bench cell across {32, 64, 128, 256} at KV=256 / 1024 on Qwen3-0.6B-4bit aura4v4 (M5 Max).

Good stuff! Two comments:

  1. I've been norming on Qwen3-1.7B-4bit for general tests. See some of the quantization integration tests. So would love to use the same model for a bunch of this type of stuff initially so we reduce variables. Considering moving to the Qwen3.5-2B-4bit model so we're a) using a more modern model, and b) because the 0.6B and 0.8B models are so small they are more likely to have higher variance in quality. Any opinions here?

  2. We had previously verified that BS=64 was optimal in mlx-swift-lm I think (I may be mistaken there). Want to be careful we're not skewing results because we are only testing with one older small model at shorter context. I vaguely recall that larger block size was was better at longer context and was worse at shorter context but 64 was the sweet spot balance. We should check optimal across a couple different model architectures and model sizes and longer contexts. Ideally with Qwen3.5/3.6, Gemma 4, Nemotron.

Telemetry/LogitsTap unification still open — can chat in a separate thread once metaltile#212 lands and this is mergeable.

Probably have some time this afternoon or tomorrow morning.

TheTom added 2 commits May 30, 2026 07:29
…DEL_PATH

Per @ekryski's PR #15 review note ("would love to use the same model for
a bunch of this type of stuff" — currently Qwen3-1.7B-4bit, considering
a move to Qwen3.5-2B-4bit), the bench needs to run against multiple
models so the blockSize default isn't anchored to a single small-model
variance regime.

### Changes

- `qwen3LocalPath` (`String`, hardcoded `/Users/tom/...`) is replaced
  by `qwen3LocalPath: String?` populated from
  `FFAI_AURA_BENCH_MODEL_PATH`. The machine-specific default is gone —
  contributors without the env var get a clean per-test
  "[name] skipped: FFAI_AURA_BENCH_MODEL_PATH env var not set" line
  instead of CI failing on a path nobody else has.
- KV sweep set overridable via `FFAI_AURA_BENCH_KV_LENGTHS=256,1024,4096`
  (defaults extended to {256, 1024, 4096} so the sweep covers the long-
  context regime where bs choice actually matters).
- `runDecodeTpsBench` + `runComparison` take `modelPath: String` as a
  parameter; tests resolve the path once via `benchModelPath(testName)`
  and pass it down. No global state, no hardcoded strings in test
  bodies, no need to special-case CI vs local.

### How to run

    # Defaults: KV = {256, 1024, 4096}, model required via env var.
    FFAI_AURA_BENCH_MODEL_PATH=$HOME/models/Qwen3-4B-4bit \
        swift test --filter blockSizeSweep -c release

    # Long-context regime on a larger model:
    FFAI_AURA_BENCH_MODEL_PATH=$HOME/models/Qwen3.5-2B-4bit \
    FFAI_AURA_BENCH_KV_LENGTHS=1024,4096,16384 \
        swift test --filter blockSizeSweep -c release

Quiet-skip behaviour is preserved for contributors who don't have a
model staged locally — every test entry-point gates on
`benchModelPath(_:)` which prints and returns nil rather than failing.
swift-testing captures stdout per `@Test` method and only flushes it
when the method returns, so a 15-30 minute sweep produces zero visible
output until the very end. Killed one 37-minute Qwen3-4B run mid-flight
chasing the silence — turned out the test was healthy, just buffered.

This patch makes progress observable in real time:

- Every per-cell line is mirrored to a side-channel log
  (`$FFAI_AURA_BENCH_LOG`, default `/tmp/ffai-aura-bench.log`) via a
  small `emit(...)` helper. Tail it with `tail -f` for live progress.
- Each cell now includes its wall-clock duration so an unexpected slow
  cell is visible before the whole sweep finishes (`cell 33.1s`).
- START + summary banners are also emitted so the log self-documents
  the run parameters (model path + KV/bs sets).

Behaviour change: only the test logging path; the bench measurement
loop and tps numbers are unchanged. Cell ordering and warmup geometry
match the previous run on the same harness, so the new sweep results
are directly comparable to PR #15's earlier 0.6B sweep.
@TheTom
Copy link
Copy Markdown
Contributor Author

TheTom commented May 30, 2026

Review-comment follow-ups — 2026-05-30

1. Pass-2 dispatch grid bug — confirmed fixed upstream

Cross-checked the kernel's #[test_kernel] + #[bench] shape in metaltile dev (commit a2dd7b0, PR #226 which supersedes my #212):

// crates/metaltile-std/src/ffai/aura_flash_pass2.rs L297, L327
.grid_3d(q_heads as u32, 1, 1, [32, 1, 1])

grid_3d semantics: Grid { grid: [q_heads, 1, 1], tpg: [32, 1, 1] }q_heads threadgroups of 32 lanes, each reads tgid_x to pick its q_idx ∈ [0, q_heads). FFAI wrapper at Ops.swift:4798 uses dispatchThreads(width: nQHeads * 32, …) with threadgroup(32) — same logical layout. The pre-fix [32, nQHeads, 1] grid (which made every TG see tgid_x=0) is gone from both layers. ✅

2. Bench methodology — env-driven model + KV, side-channel log

Two commits on the branch:

  • e88aec2 — kills the hardcoded /Users/tom/models/Qwen3-0.6B-4bit path. Tests now require FFAI_AURA_BENCH_MODEL_PATH and skip cleanly when it's missing. FFAI_AURA_BENCH_KV_LENGTHS overrides the sweep set (defaults to {256, 1024, 4096}).
  • 9fb7625 — every per-cell line mirrors to $FFAI_AURA_BENCH_LOG (default /tmp/ffai-aura-bench.log) with wall-clock per cell, so sweeps are tail-able during the run instead of buffered to test exit (swift-testing per-method stdout capture made the first 37-min Qwen3-4B run look stuck when it wasn't).

3. blockSize=32 default — re-validated on Qwen3-4B-4bit

Don't have your Qwen3-1.7B-4bit norm staged locally, so I ran on Qwen3-4B-4bit (same Qwen3ForCausalLM arch, ~7× larger than the 0.6B):

KV \ bs    32       64      128      256
KV=256    42.93   40.84   37.60   32.20
KV=1024   29.67   29.07   27.71   24.66
KV bs=32 vs bs=64 bs=32 vs bs=128 bs=32 vs bs=256
256 +5.1% +14.2% +33.3%
1024 +2.1% +6.6% +20.3%

Pattern at both KVs: monotonic decay with bs, smallest tile wins. Directionally identical to the 0.6B sweep. The +2.1% at KV=1024 vs bs=64 is the smallest margin — Apple's wave scheduler does start to even things out at higher per-block work — but bs=32 is still optimal.

That said, your "larger block size was better at longer context" recall isn't falsified by this run. KV ≥ 4096 is open — I didn't run it because each cell on 4B is ~10-15 min and the marginal value is lower than the multi-arch / multi-size sweep you actually asked for. Concretely deferring as a follow-up:

  • Long-context regime: Qwen3-4B / Qwen3-1.7B at KV=4096, 8192, 16384.
  • Multi-architecture: Qwen3.5/3.6 (hybrid), Gemma 4, Nemotron-Cascade. These all need a model-specific harness wiring (the current bench uses m.qwen3, not m.qwen35 etc.).
  • Production lever: AuraFlashScratchCache.blockSizeOverride already exists — if the long-context sweep shows bs=64 reclaims the lead, we flip the default with Qwen3Layer.forward's bs clamped by KV (e.g. kv < 4096 ? 32 : 64).

For now keeping the default at bs=32 since both currently-tested model sizes (0.6B + 4B) back it on the most common KV range (256-1024). The long-context follow-up is the right place to sharpen this.

Telemetry / LogitsTap unification

Still up for that chat whenever — same thread as discussion_r3308745381. metaltile #226 has landed so the dependency for this PR is unblocked.

@ekryski
Copy link
Copy Markdown
Collaborator

ekryski commented Jun 3, 2026

Telemetry / LogitsTap unification
Still up for that chat whenever — same thread as discussion_r3308745381. metaltile #226 has landed so the dependency for this PR is unblocked.

We discussed and normed on a spec. @TheTom put up #18 so any further Telemetry discussion can be had there. We should probably rebase this PR to be consistent with that one. I'll review that one shortly.

…epo refs

- Move Quality/{KLDivergence,LogitsEmitter}.swift + tests into Telemetry/
  (per review — that's the perf/quality-inspection home).
- Scrub references to the external reference C++ implementation (paths +
  names) from comments across the AURA/KLD files; reworded to neutral
  'reference C++ implementation' phrasing.

Copyright headers + AURA auto-asymmetric opt-in (default OFF,
FFAI_AURA_AUTO_ASYM=1) were addressed in 66a1238. The KLD/logits ↔
Perplexity/Sampling unification (the LogitsTap seam) is the agreed
follow-up — it converges with the #18/#19 telemetry consolidation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants