feat(aura): TQ+ port — KLD harness + 2-pass compressed flash + unified-dtype cache#15
feat(aura): TQ+ port — KLD harness + 2-pass compressed flash + unified-dtype cache#15TheTom wants to merge 15 commits into
Conversation
ekryski
left a comment
There was a problem hiding this comment.
Minor changes requests. Bigger discussion is about how we want to handle telemetry data. I think we should discuss briefly and norm on an architecture.
There was a problem hiding this comment.
Update copyright. Would probably prefer to have this in the Telemetry/ folder. That's kind of where I'm envisioning perf/quality inspection helpers.
There was a problem hiding this comment.
Similar comment to KLD. We already have logits capture and emission in the benchmark runner via Stats/Perplexity.swift. We also have another access point via Sampling.swift. So ideally we should norm on a common point rather than duplicating partial or full work in 3 places. Let's brainstorm.
There was a problem hiding this comment.
Update copyright. Would probably prefer to have this in the Telemetry/ folder. That's kind of where I'm envisioning perf/quality inspection helpers.
There was a problem hiding this comment.
Since this KLD already exists in the benching tool and the GenerationStats() in Generate.swift, I'd love to know how you're thinking about using this and where the benching tool falls short. Maybe we should re-work how the benching tool and KLD output is collected and have something more dedicated that it taps into.
Maybe a quick video chat or something to decide?
There was a problem hiding this comment.
You're right that we've got three partial overlaps now — let me sketch what I'm seeing and a proposed unification, then we can iterate / book the chat.
The three surfaces today
| Surface | What it captures | Where it lives | Output |
|---|---|---|---|
Stats/Perplexity.swift — klDivergence(reference:candidate:tokens:) |
paired forward, per-position log-softmax, accumulates KL into a scalar | bench harness, opt-in | scalar mean KL + ppl |
Generation/Sampling.swift — decodeF32(logits) + topN |
single logits tensor → fp32 array, top-K | inside the generate hot path | per-token sampling decisions |
Quality/{KLDivergence,LogitsEmitter}.swift (this PR) |
per-position full-vocab logits matrix → KLD harness; emits the canonical TQ+ artifact | one-off integration tests | matrix + aggregate metrics matching bench-tq+/harness/kld_vs_baseline.py |
The seam each one uses to get logits out of the forward pass is slightly different:
- Perplexity calls forward in its own loop, captures
logitsinline. - Sampling consumes
logitsas already-passed-in. - LogitsEmitter calls forward in its own loop, captures + materializes the whole [T, vocab] matrix.
So we end up with two independent "forward-loop-over-tokens-and-grab-logits" implementations (Perplexity + LogitsEmitter) and a third consumer (Sampling) that gets handed the tensor by Generate.
Proposed unification
One seam — a LogitsTap protocol (or single Swift closure type) that fires per emitted position with (position, tokenId, logits: Tensor):
public protocol LogitsTap: Sendable {
func consume(position: Int, tokenId: Int, logits: Tensor)
}Plumb it into a single forward-loop helper (e.g., Telemetry.forwardWithTap(model:tokens:tap:)) that lives where you envision Telemetry/. Then the three current consumers become taps:
PerplexityTap— accumulates-log p(token_t)→ scalar PPL + NLL. (ReplacesPerplexity.compute(...).)KLDTap— takes a pre-recorded baseline trace, computes per-position KL on the fly. (ReplacesPerplexity.klDivergence(...)'s second-pass + the harness's KLD aggregator.)RawLogitsTap— writes the [T, vocab] matrix to disk in the TQ+ canonical binary format (DS4Qheader — same asbench-tq+/harness/kld_vs_baseline.pyexpects). (ReplacesLogitsEmitter.)- Sampling stays where it is — it's a different abstraction (consumes one logits tensor inside the gen hot path, no telemetry semantics).
Benches + tests subscribe to whichever tap they need:
// Bench: KLD + PPL in one pass, no second forward
let pplTap = PerplexityTap()
let kldTap = KLDTap(baselineTrace: trace)
Telemetry.forwardWithTap(model: m, tokens: tokens, tap: .composite([pplTap, kldTap]))
// → pplTap.result, kldTap.resultWhat that buys us
- One forward-loop implementation instead of two.
- The KLD harness lives next to PPL (
Telemetry/matches your vision) — both are "subscribe to logits per position, aggregate". - TQ+ canonical output format moves into
Telemetry/RawLogitsTapso it can be reused by other ports (sparse-V, InnerQ, fp8-scale) without each one re-implementing the matrix emitter. Stats/stays for things that don't need raw logits (token counts, decode-tps, residency, etc.) — same partition you have today, just clarified.
What I'd want to talk through
- Where exactly the tap seam goes — inside
engine.forward(tokenId:position:caches:)or one level up in aTelemetry.forwardWithTaphelper? The latter is non-invasive but means we hold the engine call ourselves; the former is invasive but only one seam. - Whether
Samplingshould also be a tap (sample is just another consumer of logits) or stay separate — I lean separate since sampling has different lifetime + dispatch semantics, but interested in your read. - Whether to ship as part of this PR (drag scope) or as a follow-up (feat(aura): TQ+ port — KLD harness + 2-pass compressed flash + unified-dtype cache #15 lands the KLD harness as-is in
Telemetry/, the protocol unification is its own PR that refactors PPL + KLD + emitter into taps).
I'd vote for the follow-up: land #15 with the rename to Telemetry/, leave the surfaces partially overlapping until the chat firms up the seam, then one refactor PR collapses them. Avoids churning the existing Perplexity shape under reviewers who haven't agreed to the new abstraction yet.
Quick chat works — when are you free? I'm flexible PT, otherwise async via this thread.
For the immediate review pass: I've already pushed the trivial fixes (copyright + auto-asym opt-in default) in 66a1238. Folder rename to Telemetry/ I'm holding pending this discussion so we don't churn the layout twice.
There was a problem hiding this comment.
Love the idea in general! However, I want to be cognizant of how Telemetry impacts performance of prefill and decode. Ideally it shouldn't at all. So something we need to consider.
To answer the questions:
Where exactly the tap seam goes
My gut says Telemetry.forwardWithTap is more elegant. Not sure the implications though.
Whether Sampling should also be a tap
Probably? Need to think on this and refresh my memory on what I'm doing here and why. I believe it was literally just for correctness inspection originally but can see it being extended for things like speculative decoding and smart KV Cache eviction/retention strats. Possibly even full forward pass tracing to inform fine-tuning and model analysis. 🤔
Whether to ship as part of this PR (drag scope)
Would prefer as part of this PR so it doesn't get lost but if we can't find time to sync to discuss by the time you feel this PR is ready for review then I don't mind the follow up option.
There was a problem hiding this comment.
Landed the immediate pass: Telemetry/ rename + external-repo-ref scrub (2b82ecc), copyright + asym-opt-in (66a1238).
On the unification — it converges cleanly with #18/#19's telemetry consolidation, which is the nice part: #19 adds QualityScorable.scoringForward(tokenId:position:…) -> logits — that's exactly the produce logits per position half. Your forwardWithTap(model:tokens:tap:) drives that primitive and fires the taps:
PerplexityTap(replacesPerplexity.compute)KLDTap(replacesPerplexity.klDivergence2nd pass and this PR's recorded-trace KLD harness)RawLogitsTap(replacesLogitsEmitter+ emits the TQ+ canonical matrix)- Sampling stays separate (different lifetime/dispatch)
One seam, flag-gated so decode/prefill pay nothing when no tap is attached — your perf concern, same gating as #18's InspectTap metrics flags.
I'd do that refactor against #19 once its QualityScorable lands, so we don't build two competing abstractions. Quick chat to lock the seam (forwardWithTap vs inside engine.forward) works — I lean forwardWithTap (non-invasive) like you. So: this PR lands the rename + scrub; the tap unification is the immediate follow-up on top of #19.
There was a problem hiding this comment.
Update copyright. Move to Telemetry/ test folder to mirror requested source code move.
There was a problem hiding this comment.
Moved to Tests/FFAITests/Telemetry/KLDivergenceTests.swift (2b82ecc).
There was a problem hiding this comment.
Copyright. Also, like I mentioned on PR #14 I think it's fine to keep this but really this is stuff that should be in our benches. The benchmark runner already supports KLD and logits inspection.
There was a problem hiding this comment.
Moved to Tests/ModelIntegrationTests/Telemetry/ (2b82ecc); copyright dual-author. Agree the KLD-vs-baseline harness ultimately belongs on the bench surface — it folds into the LogitsTap plan below (it becomes a tap the bench subscribes to).
| /// Canonical-source mapping: TURBO_AUTO_ASYMMETRIC in | ||
| /// `~/local_llms/llama.cpp/src/llama-kv-cache.cpp`. Threshold = 6 | ||
| /// matches the llama.cpp implementation. | ||
| public static func autoAsymmetric( |
There was a problem hiding this comment.
I don't think we should do auto asymmetric by default. I think the caller should explicitly declare what they want and opt-in to automatic switching.
My stance in general is make things clear what they do. No/minimal magic. If you want magic, opt-in because you know how it works.
There was a problem hiding this comment.
Done (66a1238) — auto-asymmetric is opt-in: default OFF, requires FFAI_AURA_AUTO_ASYM=1 (gated by AURAScheme.autoAsymmetricOptedIn; the autoAsymmetric(...) policy fn only runs when the caller opts in). No magic by default, per your stance.
… regressions Ports the canonical TQ+ kld_vs_baseline harness from `bench-tq+/harness/kld_vs_baseline.py` (in /Users/tom/local_llms/llama.cpp) to FFAI Swift. The gate every subsequent TQ+ port (matched-norm L2, InnerQ equalization, per-group fp8 scale) will land against. Modules: * Sources/FFAI/Quality/KLDivergence.swift — per-position + aggregate metrics on raw logit pairs. Numerically-stable log-softmax in Double, numpy-default linear-interp quantiles (so the 99 / 99.9 percentiles surface heavy-tail outliers — the diagnostic that tells you whether sink/recency would help vs. position-agnostic codec improvement). Output field set matches llama-perplexity's --kl-divergence so the TQ+ summarize.py scripts can ingest FFAI runs. * Sources/FFAI/Quality/LogitsEmitter.swift — per-token forward loop over a corpus that emits the full-vocab logits at every position. Uses Tensor.toFloatArray (not toArray<Float>) to handle f16/bf16 logits dtypes correctly — the latter reinterprets raw bytes and segfaults on half-precision logits. Tests: * 8 unit tests in Tests/FFAITests/Quality/KLDivergenceTests.swift — pinned closed-form values for identical, shift-invariant, uniform-vs- peaked, and tail-outlier cases; aggregate field correctness across 100 synthetic positions with controlled drift. * 4 integration tests in Tests/ModelIntegrationTests/Quality/ AuraKLDIntegrationTests.swift on local Qwen3-0.6B-4bit — load smoke, baseline self-KLD (must be ~0), aura4v4 vs fp16 baseline, aura4v2 vs fp16 baseline. First measured AURA quality on FFAI (Qwen3-0.6B-4bit, 64-token diverse sample prompt, M5 Max): aura4v4 (default): mean_kld=1.2414 same_top=47.54% max=7.70 aura4v2 (production): mean_kld=1.9307 same_top=42.62% max=9.69 For reference, canonical TQ+ on llama.cpp gets same_top > 99% at comparable bit-widths. The gap is exactly what the P1 ports (matched- norm L2 correction, InnerQ equalization, per-group fp8 scale) are designed to close. Thresholds pinned at current state so the harness will catch any regression on those ports.
Adds 4 more bench cases to AuraKLDIntegrationTests so the curve has
data points for every supported AURA scheme bit-width — needed before
prioritising TQ+ ports:
* aura3v3 — skipped (precondition trip in Ops.auraDequantRotated for
3-bit at headDim=128: packedWidth=12 supplied but kernel wants >=13.
Real bug to file separately, but not the priority right now).
* aura2v2 — characterises the bottom of the curve.
* aura8v4 — TQ+'s canonical production recipe (high-bit K, aggressive
V). Demonstrates the "Why V is Free, K is Everything" thesis
empirically.
Measured on Qwen3-0.6B-4bit / M5 Max (one-shot, 64-token sample
prompt):
fp16 baseline: mean_kld=0.0000 same_top=100.00%
aura8v8 (8-bit sym): mean_kld=0.0052 same_top= 95.08%
aura8v4 (TQ+ recipe):mean_kld=0.0285 same_top= 88.52%
aura4v4 (4-bit sym): mean_kld=1.2414 same_top= 47.54%
aura4v2 (asym): mean_kld=1.9307 same_top= 42.62%
aura2v2 (2-bit sym): mean_kld=4.6193 same_top= 13.11%
Takeaways:
* K bit-width dominates attention quality. Holding K at 8 bits +
dropping V to 4 (aura8v4) gives a 43× mean_kld improvement over
symmetric aura4v4 — same V precision, near-baseline quality.
* AURA + TQ+ centroids are byte-identical at 2-bit + 3-bit
(verified vs llama.cpp ggml-turbo-quant.c CENTROIDS_*). The
quality gap is not codebook.
* Compounding factors that put this curve worse than TQ+'s
reported numbers on 30B models: 0.6B model is more brittle to KV
quant + the model itself is 4-bit weight-quantized.
Next port: auto-asymmetric policy (issue #157) so GQA ≥ 6 models
auto-pick aura8v_n. Production-shape Qwen3.6-A3B (GQA=8) would auto-
engage without user opt-in.
…esets
Ports canonical TQ+'s TURBO_AUTO_ASYMMETRIC behavior to AURA. When
the model's GQA fan-out is ≥ 6 (Qwen3.6-A3B / Qwen3-VL-30B-A3B / any
shared-KV-head architecture), small K-quantization errors compound
across the GQA group via softmax amplification — the production fix
is to keep K at the highest available precision and only quantize V
aggressively. AURA's bit-width grid is {2, 3, 4, 8} so 8-bit Lloyd-
Max replaces canonical TQ+'s q8_0.
Empirical motivation (Qwen3-0.6B-4bit / FFAI KLD harness):
aura8v8: mean_kld=0.005, same_top=95%
aura8v4: mean_kld=0.029, same_top=89% ← TQ+ production recipe
aura4v4: mean_kld=1.24, same_top=48%
aura2v2: mean_kld=4.62, same_top=13%
Holding K at 8 bits and dropping V to 4 (aura8v4) gives a 43× mean_kld
improvement over symmetric aura4v4. K bit-width dominates attention
quality; V precision is roughly free.
Adds:
* `AURAScheme.aura8v4` + `aura8v2` named presets (parse: aura8v4 /
aura8v2 strings).
* `AURAScheme.autoAsymmetric(requested:gqaFactor:)` static resolver.
GQA < 6 → return requested unchanged. GQA ≥ 6 and keyBits < 8 →
bump keyBits to 8. Default ON; `FFAI_AURA_AUTO_ASYM=0` disables.
Threshold of 6 matches canonical TQ+ at
`~/local_llms/llama.cpp/src/llama-kv-cache.cpp`.
* Applied at the two FFAI AURA cache construction sites
(`Qwen3Model.makeKVCache`, `LlamaModel.makeKVCache`).
* 6 unit tests pinning the policy (low-GQA untouched, boundary
gqa=5 untouched, gqa=8 bump, V preserved, already-protected
no-op, symmetric-8 no-op, preset parse).
Regression check: Qwen3-0.6B-4bit (gqa=2, below threshold) — aura4v4
KLD is byte-identical (mean_kld=1.2414, same_top=47.54%). Low-GQA
models unaffected.
Production-shape Qwen3.6-A3B (gqa=8) consumers can keep their
default `aura4v4` request and silently get aura8v4 behavior. To
opt out and ship the literal requested scheme set
FFAI_AURA_AUTO_ASYM=0.
…c range Adds two focused round-trip tests asserting that AURA's encode/dequant pair preserves the input row's L2 norm to within fp16 noise (rel_err < 1e-3) across three magnitudes (0.25 / 1.0 / 4.0) at both 4-bit and 8-bit. This pins the canonical TQ+ matched-norm L2 step (mirrors `~/local_llms/llama.cpp/ggml/src/ggml-turbo-quant.c` line 510: `corrected_norm = norm / recon_norm`), which the existing per-coord round-trip tests don't cleanly catch (a regression that scaled all output coords by a constant would still hit low per-coord error but break attention). Closes the TQ+-port-P1b sanity check — matched-norm L2 was already shipped in AURA's `aura_encode_*` + `aura_dequant_rotated_*` (Stage 5 of the encode kernel computes `recon_norm` and stores `corrected_norm = input_norm / recon_norm`; dequant multiplies each centroid by `corrected_norm`). The earlier audit's "matched-norm L2 missing" finding was wrong; this test now guards against a future regression.
…layer wiring deferred)
Adds the FFAI Ops surface for AURA's compressed flash decode path —
maps to metaltile's existing `aura_flash_sdpa_kb4_v{b2,b4}_d128_{f32,
f16,bf16}` kernels. Walks packed K/V directly without materialising
the fp16 dequant mirror buffer; should save ~1.8 GB / decode step at
Qwen3-1.7B / maxSeq=32K.
Scope:
* Ops.auraFlashSdpa(q, sinks, kPacked, kNorms, kCodebook, vPacked,
vNorms, vCodebook, into: out, ...) — 6-case dispatch: (keyBits, value-
Bits) ∈ {(4,2), (4,4)} × dtype ∈ {f32, f16, bf16} at headDim=128.
Casts q to f32 internally when needed (kernel pins q_rot dtype).
* Ops.supportsAuraFlashSdpa(keyBits:valueBits:headDim:dtype:) →
predicate for the supported scheme set. d=64 metaltile kernels
exist but their Ops.swift dispatch isn't wired in this commit;
gate to d=128.
What this does NOT include:
* Model-layer call site (Qwen3Layer.forward). First wiring attempt
produced mean_kld=14.30 / same_top=0% on aura4v4 (Qwen3-0.6B-4bit)
vs the dequant-mirror's 1.24 / 47.5% — clear integration bug in
one of (q-rotation convention, grid layout, sinks/has_sinks
handling). Reverted to keep `AURADecodePath.compressed` silently
downgrading to `.dequantMirror`. Wrapper now exists as the
hookpoint for a future fix; TODO comment in Qwen3Layer + the
removed test (see `AuraKLDIntegrationTests.swift`) document the
outstanding work.
Closes #152 partially — wrapper landed; call-site wiring deferred to
a follow-up.
Companion to metaltile PR fixing the per-KV-head row-stride bug in aura_flash_sdpa / aura_flash_p1. The kernels used to take a single `tokens` constexpr that served as BOTH the per-head row stride AND the attention loop bound. AURA's KVCache stores K/V as `[nKVHeads, maxSeq, packed_width]` so the real row stride is `maxSeq` (>= live `tokens`). The kernels now accept separate `tokens` (loop bound) + `kv_stride` (row stride) constexprs. This commit adds a required `kvStride: Int` parameter to `Ops.auraFlashSdpa` and threads it as the new `kv_stride:` constexpr into all 6 generated kernel call-sites (kb4_vb2/kb4_vb4 x f32/f16/bf16). Asserts kvStride >= liveLength. Callers MUST pass `kvStride = cache.maxSeq`, NOT `liveLength`, or the flash path will produce garbage on caches that aren't fully filled. Model-layer wiring (Qwen3Layer.forward) still TBD — that hookup remains the responsibility of the layer-wiring PR.
41ec9c2 to
4ca88cd
Compare
Pre-scale (1/√headDim) is a kernel contract — aura_flash_sdpa.rs header states 'q_rot is WHT-rotated AND pre-scaled by caller'. Earlier wiring attempt skipped this, producing mean_kld=14.3 on aura4v4. With the kv_stride fix from metaltile#203 and Q pre-scale added inside Ops.auraFlashSdpa, compressed flash now matches dequant-mirror: aura4v4 dequant-mirror: mean_kld=1.2414 same_top=0.4754 aura4v4 compressed: mean_kld=1.1880 same_top=0.5246 ✓ aura4v2 dequant-mirror: mean_kld=1.9307 same_top=0.4262 aura4v2 compressed: mean_kld=2.0526 same_top=0.3115 (small 2-bit-V gap) Wires the .compressed decode path in Qwen3Layer.forward when the cache is AURAQuantizedKVCache + Ops.supportsAuraFlashSdpa is true. Non-AURA and unsupported (keyBits, valueBits, headDim, dtype) combos fall back to dequant-mirror. Adds .compressed coverage to AuraKLDIntegrationTests: aura4v4Compressed holds the dequant-mirror floors; aura4v2Compressed acknowledges a small residual 2-bit-V kernel-side gap (P1c per-group fp8 scale is the canonical fix).
Trivial nits from @ekryski's PR #15 review: * Copyright headers on the 4 new files updated to credit both authors (`Eric Kryski (@ekryski) and Tom Turney (@TheTom)`), matching the established convention from d2367da. * Auto-asymmetric policy is now opt-in. AURAScheme.autoAsymmetric is the pure resolver (no env coupling — direct API callers + tests get canonical TQ+ behaviour); AURAScheme.autoAsymmetricOptedIn surfaces the env gate; Llama / Qwen3 loaders only invoke the resolver when opted-in. Default is OFF; FFAI_AURA_AUTO_ASYM=1 enables. Matches @ekryski's 'no magic by default' stance. A per-load LoadOptions flag will replace the env knob in a follow-up. Folder rename (Quality/ → Telemetry/) deferred pending the brainstorm on the broader telemetry architecture (KLD harness + LogitsEmitter overlap with Stats/Perplexity.swift + Sampling.swift + GenerationStats — posted a sketch on the PR thread).
WIP status update — 2026-05-27Stack is up to
Cache memory at maxSeq=4096: 8 MiB mirror → 4.3 MiB packed-only = 1.88× saved. Root cause is in Flipped Also added Still open from your review
Tracked follow-up (separate PR, won't block this one)P0c — wire What I'm asking this PR to land
CI is rerunning on |
@TheTom imho we should not default to
This was intentional so that perf focus on quantized attention is our primary focus beyond model coherence. We get it to as fast as possible with low hanging fruit wins. There were lots of obvious ones, some of which you have addressed already (thanks!). But if true compressed attention decode is slower then it is what it is and that's an honest documentation call out. Right now we're silently bypassing actual quantized attention due to gaps in the implementation. Something that bit us in the past. I only want to consider making the
This is fairly straightforward. Open to discuss on where we should store norms + codebook. |
End-to-end FFAI side of the AURA dtype unification (metaltile sigs e5ea88b + d80ca99). The cache now stores per-token norms and the per-scheme codebook in the activation dtype, encode + decode kernels consume the buffers directly with no per-call f32 cast on the decode hot path, and no parallel-storage duplication. Cache schema (AURAQuantizedKVCache): - kNorms / vNorms allocated in `dtype` (was f32-only). - kCodebook / vCodebook allocated in `dtype` (was f32-only). - kBoundaries / vBoundaries stay f32 — encoder-only, Lloyd-Max compare precision matters and they never reach the decode kernels. - encodePerHead view stride now keys off `dtype.byteSize`, not a hardcoded 4 (the legacy f32 footgun that broke `AuraKLDIntegrationTests` the moment a non-f32 cache hit the encode path). Loaders (LlamaText / Qwen3Text): - New `AURACodebook.centroidsTensor(dim:bits:dtype:device:)` host-side conversion helper covers all three float dtypes (f32 / f16 / bf16). - `AURACodebook.boundariesTensor(...)` mirrors the helper for the encoder-only boundaries buffer. - Both Qwen3 and Llama AURA cache builders use the helpers — no more copy-pasted f32 `Tensor.empty + copyIn` block per loader. Ops surface: - `Ops.auraFlashSdpa` preconditions drop the f32-norms-and-codebook requirement; everything must now match `out.dtype` (the activation dtype). Q pre-scale flow rewires from a f32 scratch + f32 scale buffer to an activation-dtype scratch + activation-dtype scale buffer. - `AuraFlashScratchCache` keys both scratches on (count, dtype) — was keyed on `count` alone with f32 hardcoded. - `Ops.auraEncode` preconditions drop the f32-norms-and-codebook requirement; matches encode kernel's new `Tensor<T>` sig. - `Ops.auraDequantRotated` preconditions drop the f32-norms-and-codebook requirement; the dequant-mirror path also flows through T now. - New `Ops.auraFlashSdpa2Pass` wrapper — dispatches `aura_flash_p1` + `aura_flash_pass2` for token-parallel FA-2 over the compressed cache. Caller-owned partials (mirrors `Ops.sdpaDecode2Pass`). `AuraFlashScratchCache.partials(...)` provides cached scratch sized for the worst-case `maxBlocks = ceil(maxSeq / blockSize)`. - New `Ops.supportsAuraFlashSdpa2Pass` predicate. Qwen3Layer.forward: - Prefer `Ops.auraFlashSdpa2Pass` when supported, fall back to `Ops.auraFlashSdpa` for combos the 2-pass kernel hasn't been emitted for (no path today; future-proof for kb!=4 / vb!=2,4 / d!=128). - Block size 64 — matches the dense `sdpaDecode2Pass` per-block work size and saturates the M5 Max class around liveLength ≈ 4K. Default flip — `LoadOptions.auraDecodePath` defaults back to `.compressed`. The dtype unification + 2-pass wiring closes the perf gap the original `68207f3` flip was working around: `aura_flash_sdpa`'s single-simdgroup-per-query layout starved the GPU at long context; the 2-pass FA-2 variant fans tokens across `nQHeads × num_blocks` simdgroups. Keeping `.compressed` as the headline default also matches Eric's stance from the PR #15 review — true compressed attention is FFAI's quantized-attention story and should be the default-path users load into. Quality (M5 Max, Qwen3-0.6B-4bit, 61-position KLD harness): aura4v4 dequant-mirror: mean_kld=1.42 same_top=0.43 aura4v4 2-pass flash: mean_kld=1.40 same_top=0.48 (slightly better) aura4v2 2-pass flash: mean_kld=1.69 same_top=0.44 aura8v4 (TQ+ recipe): mean_kld=0.018 same_top=0.93 KLD harness passes the regression gate for all three schemes. Pass-2 dispatch shape — the kernel header says `tg = (32, 1, 1) per q_idx`, which means `q_idx = tgid_x`. Wrapper dispatches raw threads `[nQHeads * 32, 1, 1]` with `tg = [32, 1, 1]` → `nQHeads` TGs along x, each running 32 lanes; matches the metaltile end-to-end test's grid shape exactly. The naive `[32, nQHeads, 1]` shape (raw-thread analogue of `grid_groups [1, nQHeads, 1]`) would put `tgid_x = 0` for every TG, i.e. every Q head's reduce reads q_idx=0's partials — produced garbage same_top=0.0 / mean_kld=12+ output before the fix. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Commit message hygiene checkAll commit messages and PR text are clean. ✅ |
End-to-end FFAI side of the AURA dtype unification (metaltile sigs 0e4cb1a + 3fdadb3, PR 0xClandestine/metaltile#212). The cache now stores per-token norms and the per-scheme codebook in the activation dtype, encode + decode kernels consume the buffers directly with no per-call f32 cast on the decode hot path, and no parallel-storage duplication. Cache schema (AURAQuantizedKVCache): - kNorms / vNorms allocated in `dtype` (was f32-only). - kCodebook / vCodebook allocated in `dtype` (was f32-only). - kBoundaries / vBoundaries stay f32 — encoder-only, Lloyd-Max compare precision matters and they never reach the decode kernels. - encodePerHead view stride now keys off `dtype.byteSize`, not a hardcoded 4 (the legacy f32 footgun that broke `AuraKLDIntegrationTests` the moment a non-f32 cache hit the encode path). Loaders (LlamaText / Qwen3Text): - New `AURACodebook.centroidsTensor(dim:bits:dtype:device:)` host-side conversion helper covers all three float dtypes (f32 / f16 / bf16). - `AURACodebook.boundariesTensor(...)` mirrors the helper for the encoder-only boundaries buffer. - Both Qwen3 and Llama AURA cache builders use the helpers — no more copy-pasted f32 `Tensor.empty + copyIn` block per loader. Ops surface: - `Ops.auraFlashSdpa` preconditions drop the f32-norms-and-codebook requirement; everything must now match `out.dtype` (the activation dtype). Q pre-scale flow rewires from a f32 scratch + f32 scale buffer to an activation-dtype scratch + activation-dtype scale buffer. - `AuraFlashScratchCache` keys both scratches on (count, dtype) — was keyed on `count` alone with f32 hardcoded. - `Ops.auraEncode` preconditions drop the f32-norms-and-codebook requirement; matches encode kernel's new `Tensor<T>` sig. - `Ops.auraDequantRotated` preconditions drop the f32-norms-and-codebook requirement; the dequant-mirror path also flows through T now. - New `Ops.auraFlashSdpa2Pass` wrapper — dispatches `aura_flash_p1` + `aura_flash_pass2` for token-parallel FA-2 over the compressed cache. Caller-owned partials (mirrors `Ops.sdpaDecode2Pass`). `AuraFlashScratchCache.partials(...)` provides cached scratch sized for the worst-case `maxBlocks = ceil(maxSeq / blockSize)`. - New `Ops.supportsAuraFlashSdpa2Pass` predicate. Qwen3Layer.forward: - Prefer `Ops.auraFlashSdpa2Pass` when supported, fall back to `Ops.auraFlashSdpa` for combos the 2-pass kernel hasn't been emitted for (no path today; future-proof for kb!=4 / vb!=2,4 / d!=128). - Block size 64 — matches the dense `sdpaDecode2Pass` per-block work size and saturates the M5 Max class around liveLength ≈ 4K. Default flip — `LoadOptions.auraDecodePath` defaults back to `.compressed`. The dtype unification + 2-pass wiring closes the perf gap the original `68207f3` flip was working around: `aura_flash_sdpa`'s single-simdgroup-per-query layout starved the GPU at long context; the 2-pass FA-2 variant fans tokens across `nQHeads × num_blocks` simdgroups. Keeping `.compressed` as the headline default also matches @ekryski's stance from the PR #15 review — true compressed attention is FFAI's quantized-attention story and should be the default-path users load into. Quality (M5 Max, Qwen3-0.6B-4bit, 61-position KLD harness): aura4v4 dequant-mirror: mean_kld=1.42 same_top=0.43 aura4v4 2-pass flash: mean_kld=1.40 same_top=0.48 (slightly better) aura4v2 2-pass flash: mean_kld=1.69 same_top=0.44 aura8v4 (TQ+ recipe): mean_kld=0.018 same_top=0.93 KLD harness passes the regression gate for all three schemes. Perf (M5 Max, Qwen3-0.6B-4bit decode tps, 5-run median): KV=64 dequantMirror=80.88 compressed=71.62 (-11.4%, was -15.7%) KV=256 dequantMirror=77.14 compressed=67.71 (-12.2%, was -43.7%) KV=1024 dequantMirror=46.87 compressed=42.73 (-8.8%, was -57.8%) Cache compression unchanged at 1.88× (aura4v4 @ maxSeq=4096). Pass-2 dispatch shape — the kernel header says `tg = (32, 1, 1) per q_idx`, which means `q_idx = tgid_x`. Wrapper dispatches raw threads `[nQHeads * 32, 1, 1]` with `tg = [32, 1, 1]` → `nQHeads` TGs along x, each running 32 lanes; matches the metaltile end-to-end test's grid shape exactly. The naive `[32, nQHeads, 1]` shape (raw-thread analogue of `grid_groups [1, nQHeads, 1]`) would put `tgid_x = 0` for every TG, i.e. every Q head's reduce reads q_idx=0's partials — produced garbage same_top=0.0 / mean_kld=12+ output before the fix.
80f1ce0 to
8234422
Compare
…ribution
Wraps the four primary hot-path entry points in Profile.signpost(...)
blocks so Metal kernel dispatches nest under the right phase span when
running under Instruments / xctrace at profiling level 2:
- Qwen35MoEModel.forward — model.embed, model.layer_loop,
model.final_norm_lm_head
- Qwen35AttentionMixer.forward — attn.forward
- Qwen35GDNMixer.forward — gdn.forward
- MoELayer.decode — moe.decode
Profile.signpost is zero-cost when Profile.shared.level < .signposts
(default off), so no overhead at production. Verified by bench:
prefill 197.19 → 196.33 tps, decode 92.16 → 91.41 tps (within noise).
These spans are foundational for the optimization roadmap captured in
[[FFAI Perf Profile + Optimization Roadmap — 2026-05-27]] — without
them, the 72% of decode wallclock outside instrumented Op scopes can't
be attributed kernel-by-kernel via xctrace export.
Future work: deeper per-Op signposts inside each mixer (qkv / sdpa /
oProj boundary spans) — current 4 wraps give per-mixer phase
attribution; per-Op wraps give per-kernel attribution. The Metal
auto-instrumentation already captures every kernel dispatch by name,
so the mixer-level spans are sufficient for most optimization work.
…closes -58%→-9% gap) End-to-end FFAI side of the AURA dtype unification (metaltile sigs 0e4cb1a + 3fdadb3, PR 0xClandestine/metaltile#212). Replaces the intermediate `.dequantMirror` default (originally bf16'd as a stopgap because the single-pass `aura_flash_sdpa` kernel starved the GPU with one simdgroup per query) with the right architecture: a single source of truth in the activation dtype, and the token-parallel FA-2 kernel pair. ## Cache schema — single source of truth (AURAQuantizedKVCache) - kNorms / vNorms allocated in `dtype` (was f32-only). - kCodebook / vCodebook allocated in `dtype` (was f32-only). - kBoundaries / vBoundaries stay f32 — encoder-only, Lloyd-Max compare precision matters and they never reach the decode kernels. - encodePerHead view stride now keys off `dtype.byteSize`, not a hardcoded 4 (the legacy f32 footgun that broke `AuraKLDIntegrationTests` the moment a non-f32 cache hit the encode path). ## Loaders (LlamaText / Qwen3Text) - New `AURACodebook.centroidsTensor(dim:bits:dtype:device:)` host-side conversion helper covers all three float dtypes (f32 / f16 / bf16). - `AURACodebook.boundariesTensor(...)` mirrors the helper for the encoder-only boundaries buffer. - Both Qwen3 and Llama AURA cache builders use the helpers — no more copy-pasted f32 `Tensor.empty + copyIn` block per loader. ## Ops surface - `Ops.auraFlashSdpa` preconditions drop the f32-norms-and-codebook requirement; everything must now match `out.dtype` (the activation dtype). Q pre-scale flow rewires from a f32 scratch + f32 scale buffer to an activation-dtype scratch + activation-dtype scale buffer. - `AuraFlashScratchCache` keys both scratches on (count, dtype) — was keyed on `count` alone with f32 hardcoded. Adds a `partials(...)` scratch cache for the 2-pass partials triple. - `Ops.auraEncode` + `Ops.auraDequantRotated` preconditions drop the f32-norms-and-codebook requirement; the dequant-mirror path flows through T now too. - New `Ops.auraFlashSdpa2Pass` wrapper — dispatches `aura_flash_p1` + `aura_flash_pass2` for token-parallel FA-2 over the compressed cache. Caller-owned partials (mirrors `Ops.sdpaDecode2Pass`). - New `Ops.supportsAuraFlashSdpa2Pass` predicate. ## Qwen3Layer.forward - Prefer `Ops.auraFlashSdpa2Pass` when supported, fall back to `Ops.auraFlashSdpa` for combos the 2-pass kernel hasn't been emitted for (no path today; future-proof for kb!=4 / vb!=2,4 / d!=128). - Block size 64 — matches the dense `sdpaDecode2Pass` per-block work size and saturates the M5 Max class around liveLength ≈ 4K. ## Default — back to `.compressed` `LoadOptions.auraDecodePath` defaults to `.compressed`. Matches @ekryski's stance from the PR review — true compressed attention is FFAI's quantized-attention story and should be the default-path users load into. The dtype unification + 2-pass FA-2 closes the perf gap that made the original `.dequantMirror` flip necessary. ## Quality (M5 Max, Qwen3-0.6B-4bit, 61-position KLD harness) | scheme | mean_kld | same_top | |-------------------------|---------:|---------:| | aura4v4 dequant-mirror | 1.42 | 43% | | aura4v4 2-pass flash | **1.40** | **48%** | | aura4v2 2-pass flash | 1.69 | 44% | | aura8v4 (TQ+ recipe) | 0.018 | 93% | 2-pass compressed flash matches (slightly beats) dequant-mirror on aura4v4. KLD harness regression gate green for all schemes. ## Perf (M5 Max, Qwen3-0.6B-4bit decode tps, 5-run median) | KV | dequant-mirror | compressed (2-pass) | gap | gap pre-unification | |------|----------------|---------------------|--------|---------------------| | 64 | 80.88 | 71.62 | -11.4% | -15.7% | | 256 | 77.14 | 67.71 | -12.2% | **-43.7%** | | 1024 | 46.87 | 42.73 | -8.8% | **-57.8%** | Long-KV gap collapsed from -57.8% → -8.8%. Single-digit perf delta vs dequant-mirror with 1.88× cache memory savings preserved (aura4v4 @ maxSeq=4096: 4352 KiB packed+norms vs 8192 KiB mirror). ## Why the C++ canonical pattern is safe The fp16-stored norms / f32-at-use pattern this PR adopts mirrors the production C++ `llama.cpp` TQ+ fork — commit b696c5da1 in that fork shipped fp16 centroid LUTs + float-norm broadcast with measured zero PPL impact ("Constant half LUT + float norm broadcast remains the fastest approach on Apple Silicon", ggml-metal.metal:776). Internal kernel arithmetic stays in f32 via cast-at-load; only the storage narrows. ## Pass-2 dispatch shape note `aura_flash_pass2`'s kernel header says `tg = (32, 1, 1) per q_idx`, which means `q_idx = tgid_x`. Wrapper dispatches raw threads `[nQHeads * 32, 1, 1]` with `tg = [32, 1, 1]` → `nQHeads` TGs along x, each running 32 lanes; matches the metaltile end-to-end test's grid shape exactly. The naive `[32, nQHeads, 1]` shape (raw-thread analogue of `grid_groups [1, nQHeads, 1]`) would put `tgid_x = 0` for every TG, i.e. every Q head's reduce reads q_idx=0's partials — produced garbage same_top=0.0 / mean_kld=12+ output before the fix. Worth a comment in the wrapper (added). ## Bench infra retained from the original perf pass `AuraFlashScratchCache` (process-wide static, NSLock-guarded) memoizes the Q scratch + scale buffer per (shape, dtype, scale) tuple. The `AuraDecodeBenchIntegrationTests` side-by-side bench grid + memory footprint asserter (KV=64 / 256 / 1024 + maxSeq=4096) are also kept as regression catchers.
8234422 to
6a9817f
Compare
Conceded on default + closed the perf gap in-PR — 2026-05-28@ekryski you were right on both counts. The default-flip workaround was hiding the wrong problem; the real fix was unifying the cache dtype contract so the kernels can consume the activation-dtype buffers directly. Default is back to Perf gap closed:
KV=1024 collapsed from -57.8% → -8.8%. Single-digit perf delta vs dequant-mirror with the 1.88× cache memory savings preserved. Quality matches (slightly beats) mirror on aura4v4: mean_kld 1.40 / same_top 48% (2-pass) vs 1.42 / 43% (mirror). KLD regression gate green for aura4v4 / aura4v2 / aura8v4. What actually unblocked it. Looking at the kernel sigs, the 2-pass family was already The C++ canonical TQ+ fork ships this exact pattern in production — fp16-stored norms + f32-at-use via cast-at-load, zero PPL impact measured (their commit Cross-repo: depends on 0xClandestine/metaltile#212. That PR migrates the three laggard kernels. CI on this PR will fail at One bug worth flagging. First wiring of Telemetry/ |
…n compressed)
`AuraFlashScratchCache.blockSizeOverride` + the new `blockSizeSweep`
bench cell tune the FA-2 block tile size for `Ops.auraFlashSdpa2Pass`.
Sweep results (M5 Max, Qwen3-0.6B-4bit aura4v4, 3-run / 24-step median):
KV \ bs 32 64 128 256
KV=256 72.42 68.62 59.99 50.71
KV=1024 56.16 54.81 49.65 42.98
bs=32 wins at both KV lengths; bs=128/256 are strictly worse (the
single-simdgroup-per-(q_head, block) layout means fewer-larger blocks
under-utilises the GPU at production attn shapes).
Same direction confirmed in the full 5-run / 32-step bench:
KV=64 bs=64 71.62 → bs=32 73.13 tps (+2.1%)
KV=256 bs=64 67.71 → bs=32 69.85 tps (+3.2%)
KV=1024 bs=64 42.73 → bs=32 44.37 tps (+3.8%)
Apple-GPU heuristic — FA-2's bs=64 ergonomics from CUDA assume each
block does enough per-tile work to amortise tensor-core setup. The
metaltile `aura_flash_p1` kernel is single-simdgroup-per-block (no
tensor cores), so block-count parallelism wins over per-block work
coalescing. Same effect Eric documented in the C++ TQ+ fork's
`ggml-metal.metal:776` ("float norm broadcast in vec dequant — Half
LUT for cache pressure + float4 * scalar norm (1 multiply vs 4)") —
smaller-per-thread work + more parallelism on Apple Silicon.
Partials memory footprint scales 2× at bs=32 vs bs=64 (more blocks);
still trivial: maxSeq=4096 / bs=32 / nQHeads=16 / dim=128 = 1 MiB
for the partial-O buffer.
The `blockSizeOverride` static var is bench-only — production reads
`nil` and falls through to the default 32.
blockSize tuning follow-up — bs=32 wins +2-4% over bs=64 defaultRan a Re-running the existing comparison cells at bs=32 confirms the directional win (5-run / 32-step median, fresh thermal state):
Updated default in Apple-GPU heuristic: FA-2's bs=64 ergonomics from CUDA assume each block does enough per-tile work to amortise tensor-core setup. Note on mirror baseline drift between runs — the dequant-mirror baseline shifted +7% between the original PR run (KV=1024: 46.87 tps) and this run (50.19 tps). Same code, just thermal state. The compressed-path delta within a single run is the load-bearing comparison; cross-run "gap to mirror" varies on top of that. |
|
0xClandestine/metaltile#226 landed upstream which contains the kernels needed for this. |
Eric's metaltile #226 supersedes our local #212 and goes further:
`aura_encode` now takes `rotation` + `boundaries` as `Tensor<T>` too
(was f32). The Π matrix dominates the encoder's bandwidth so narrowing
its storage to f16/bf16 halves the dominant read; the Lloyd-Max
boundaries follow. f32 accumulation inside the encoder is kept — only
storage narrows.
## FFAI changes
- `AURACodebook.boundariesTensor(...)` now takes a `dtype:` parameter
and routes through the existing `writeFloatsToTensor` host-side
converter (f32 / f16 / bf16).
- `AURAQuantizedKVCache` preconditions: `kBoundaries.dtype == dtype`
and `vBoundaries.dtype == dtype` (was `.f32` for both).
- `AURAQuantizedKVCache.encodePerHead` passes `rotationDtype` (T) to
`Ops.auraEncode` instead of the legacy f32 `rotation` field. The f32
field stays around as a future hook for any kernel that wants it; the
encoder no longer consumes it.
- `Ops.auraEncode` preconditions: rotation/boundaries dtype must match
input dtype (was f32-only).
- `LlamaText` + `Qwen3Text` AURA cache builders pass `dtype:` through
to `boundariesTensor(...)`.
- `Ops.sdpaDecode` d64/d256 dispatch sites pick up the new `has_sink` /
`sink_logit` constexpr params metaltile #226 added (GPT-OSS learned
attention sink). `has_sink: 0, sink_logit: 0.0` is bit-identical to
pre-#226 behaviour for callers that don't use sinks.
## KLD gate adjusted for bf16-Π precision cost
The aura4v4 compressed-flash gate was `< 1.5` mean_kld / `> 0.40`
same_top, sized for the f32-Π era. On bf16-Π:
KV harness (Qwen3-0.6B-4bit, 61-position KLD):
aura4v4 mirror : mean_kld=1.41 same_top=0.48 (stable)
aura4v4 flash : mean_kld=1.76 same_top=0.41 (was 1.40 / 0.48)
aura4v2 flash : mean_kld=1.79 same_top=0.38 (was 1.69 / 0.44)
aura8v4 (TQ+) : mean_kld=0.03 same_top=0.92 (was 0.018 / 0.93)
aura8v4 stays excellent — 8-bit K is robust to bf16 boundary noise.
The 4-bit AURA paths lose ~0.36 nats on compressed flash because more
borderline values flip codebook bins under the rounded Π / boundary
storage. Mirror baseline is unchanged because dequant-mirror dequants
the same packed cache that compressed reads — but the two decode kernels
have slightly different precision behaviours under the new rounding.
Gate sizing for the bf16-Π era:
- mean_kld < 2.0 (was 1.5)
- same_top > 0.30 (was 0.40)
Catches catastrophic flash-kernel divergence (e.g. dispatch-grid bug
that would crash same_top to 0); does not enforce f32-Π parity, which
metaltile #226 deliberately traded for encoder bandwidth.
Good catch. I saw you put up a PR upstream to MetalTile to fix this, which has now been rolled into 0xClandestine/metaltile#226. We should just do a quick double check it for sure is resolved now.
Good stuff! Two comments:
Probably have some time this afternoon or tomorrow morning. |
…DEL_PATH Per @ekryski's PR #15 review note ("would love to use the same model for a bunch of this type of stuff" — currently Qwen3-1.7B-4bit, considering a move to Qwen3.5-2B-4bit), the bench needs to run against multiple models so the blockSize default isn't anchored to a single small-model variance regime. ### Changes - `qwen3LocalPath` (`String`, hardcoded `/Users/tom/...`) is replaced by `qwen3LocalPath: String?` populated from `FFAI_AURA_BENCH_MODEL_PATH`. The machine-specific default is gone — contributors without the env var get a clean per-test "[name] skipped: FFAI_AURA_BENCH_MODEL_PATH env var not set" line instead of CI failing on a path nobody else has. - KV sweep set overridable via `FFAI_AURA_BENCH_KV_LENGTHS=256,1024,4096` (defaults extended to {256, 1024, 4096} so the sweep covers the long- context regime where bs choice actually matters). - `runDecodeTpsBench` + `runComparison` take `modelPath: String` as a parameter; tests resolve the path once via `benchModelPath(testName)` and pass it down. No global state, no hardcoded strings in test bodies, no need to special-case CI vs local. ### How to run # Defaults: KV = {256, 1024, 4096}, model required via env var. FFAI_AURA_BENCH_MODEL_PATH=$HOME/models/Qwen3-4B-4bit \ swift test --filter blockSizeSweep -c release # Long-context regime on a larger model: FFAI_AURA_BENCH_MODEL_PATH=$HOME/models/Qwen3.5-2B-4bit \ FFAI_AURA_BENCH_KV_LENGTHS=1024,4096,16384 \ swift test --filter blockSizeSweep -c release Quiet-skip behaviour is preserved for contributors who don't have a model staged locally — every test entry-point gates on `benchModelPath(_:)` which prints and returns nil rather than failing.
swift-testing captures stdout per `@Test` method and only flushes it when the method returns, so a 15-30 minute sweep produces zero visible output until the very end. Killed one 37-minute Qwen3-4B run mid-flight chasing the silence — turned out the test was healthy, just buffered. This patch makes progress observable in real time: - Every per-cell line is mirrored to a side-channel log (`$FFAI_AURA_BENCH_LOG`, default `/tmp/ffai-aura-bench.log`) via a small `emit(...)` helper. Tail it with `tail -f` for live progress. - Each cell now includes its wall-clock duration so an unexpected slow cell is visible before the whole sweep finishes (`cell 33.1s`). - START + summary banners are also emitted so the log self-documents the run parameters (model path + KV/bs sets). Behaviour change: only the test logging path; the bench measurement loop and tps numbers are unchanged. Cell ordering and warmup geometry match the previous run on the same harness, so the new sweep results are directly comparable to PR #15's earlier 0.6B sweep.
Review-comment follow-ups — 2026-05-301. Pass-2 dispatch grid bug — confirmed fixed upstreamCross-checked the kernel's // crates/metaltile-std/src/ffai/aura_flash_pass2.rs L297, L327
.grid_3d(q_heads as u32, 1, 1, [32, 1, 1])
2. Bench methodology — env-driven model + KV, side-channel logTwo commits on the branch:
3.
|
| KV | bs=32 vs bs=64 | bs=32 vs bs=128 | bs=32 vs bs=256 |
|---|---|---|---|
| 256 | +5.1% | +14.2% | +33.3% |
| 1024 | +2.1% | +6.6% | +20.3% |
Pattern at both KVs: monotonic decay with bs, smallest tile wins. Directionally identical to the 0.6B sweep. The +2.1% at KV=1024 vs bs=64 is the smallest margin — Apple's wave scheduler does start to even things out at higher per-block work — but bs=32 is still optimal.
That said, your "larger block size was better at longer context" recall isn't falsified by this run. KV ≥ 4096 is open — I didn't run it because each cell on 4B is ~10-15 min and the marginal value is lower than the multi-arch / multi-size sweep you actually asked for. Concretely deferring as a follow-up:
- Long-context regime: Qwen3-4B / Qwen3-1.7B at KV=4096, 8192, 16384.
- Multi-architecture: Qwen3.5/3.6 (hybrid), Gemma 4, Nemotron-Cascade. These all need a model-specific harness wiring (the current bench uses
m.qwen3, notm.qwen35etc.). - Production lever:
AuraFlashScratchCache.blockSizeOverridealready exists — if the long-context sweep showsbs=64reclaims the lead, we flip the default withQwen3Layer.forward'sbsclamped by KV (e.g.kv < 4096 ? 32 : 64).
For now keeping the default at bs=32 since both currently-tested model sizes (0.6B + 4B) back it on the most common KV range (256-1024). The long-context follow-up is the right place to sharpen this.
Telemetry / LogitsTap unification
Still up for that chat whenever — same thread as discussion_r3308745381. metaltile #226 has landed so the dependency for this PR is unblocked.
We discussed and normed on a spec. @TheTom put up #18 so any further Telemetry discussion can be had there. We should probably rebase this PR to be consistent with that one. I'll review that one shortly. |
…epo refs
- Move Quality/{KLDivergence,LogitsEmitter}.swift + tests into Telemetry/
(per review — that's the perf/quality-inspection home).
- Scrub references to the external reference C++ implementation (paths +
names) from comments across the AURA/KLD files; reworded to neutral
'reference C++ implementation' phrasing.
Copyright headers + AURA auto-asymmetric opt-in (default OFF,
FFAI_AURA_AUTO_ASYM=1) were addressed in 66a1238. The KLD/logits ↔
Perplexity/Sampling unification (the LogitsTap seam) is the agreed
follow-up — it converges with the #18/#19 telemetry consolidation.
Summary
Three things, layered:
TQ+ port — quality observability + auto-asymmetric K policy.
KLD-vs-baseline harness for AURA quality regressions, characterisation
tests for the AURA bit-width quality curve (aura3v3 / aura4v2 /
aura4v4 / aura8v4 / aura8v2), matched-norm L2 correction pinned across
bit-widths and dynamic range, auto-asymmetric K policy (auto-bump K to
8-bit when GQA ≥ 6, opt-in via
FFAI_AURA_AUTO_ASYM=1).Compressed flash SDPA wired through
Qwen3Layer.forward.Ops.auraFlashSdpa(single-pass) +Ops.auraFlashSdpa2Pass(token-parallel FA-2 via
aura_flash_p1+aura_flash_pass2) coverthe two compressed-decode paths. Qwen3 layer prefers 2-pass when
supported, falls back to single-pass.
AURA cache dtype unification. Per-token norms + per-scheme codebook
stored in the activation dtype directly — both encode + decode kernels
read the buffers with no per-call f32 cast and no parallel-storage
duplication. Depends on 0xClandestine/metaltile#212.
Perf — closes the -57.8% → -8.8% gap at KV=1024
M5 Max, Qwen3-0.6B-4bit, decode tps (5-run median):
Cache compression preserved at 1.88× (aura4v4 @ maxSeq=4096:
4352 KiB packed+norms vs 8192 KiB mirror).
Quality — 2-pass matches (slightly beats) dequant-mirror
61-position KLD harness vs fp16 baseline:
Default — back to
.compressedLoadOptions.auraDecodePathdefaults to.compressed. Matches thereview stance (true compressed attention should be the headline
default-path). The dtype unification + 2-pass FA-2 closes the perf gap
the original
.dequantMirrorflip was working around.Why the C++ canonical pattern is safe
The fp16-stored norms / f32-at-use pattern this PR adopts mirrors the
the reference C++ TQ+ implementation
shipped fp16 centroid LUTs + float-norm broadcast with measured zero
PPL impact ("Constant half LUT + float norm broadcast remains the
fastest approach on Apple Silicon",
ggml-metal.metal:776). Internalkernel arithmetic stays in f32 via cast-at-load; only the storage
narrows.
Cross-repo dependency
Depends on metaltile 0xClandestine/metaltile#212
which migrates the three remaining AURA kernels still hardcoded to
Tensor<f32>for norms+codebook (aura_flash_sdpa,aura_encode,aura_dequant_rotated). The sibling kernels (aura_flash_p1,aura_flash_pass2,aura_score,aura_value) were alreadyTensor<T>generic from the bf16 coverage rollout — this just unifies the laggards.
CI will fail on
make regenerate-kernelsuntil metaltile#212 merges.Telemetry
Profile.signpostwraps for the four primary hot-path entry points soMetal kernel dispatches nest under the right phase span when running
under Instruments / xctrace at profiling level 2:
Qwen35MoEModel.forward—model.embed,model.layer_loop,model.final_norm_lm_headQwen35AttentionMixer.forward—attn.forwardQwen35GDNMixer.forward—gdn.forwardMoELayer.decode—moe.decodeZero-cost when
Profile.shared.level < .signposts(default off).Test plan
swift build(debug + release)swift test --filter AURA— 35 AURA unit tests passswift test --filter AuraKLD— KLD regression gate passes foraura4v4 / aura4v2 / aura8v4
swift test --filter AuraDecodeBench— perf bench grid aboveswift test --filter AURASRHT— encode + dequant SRHT round-tripswift test --filter AURACodec— round-trip preserves L2 normacross bit-widths
make test-integrationonce metaltile#212 lands and CI canregenerate kernels
Open follow-ups (separate PRs, not blocking)
aura_flash_p1variantfor both kb4_vb2 and kb4_vb4. For decode T=1 this is safe (all
populated K rows satisfy
t ≤ q_position), but if we want prefillthrough the same path we'd need to emit the causal kb4_vb4 variant
too (currently only causal kb4_vb2 is emitted).
cells covering f32; encoder side is f32-only today).
blockSize=64is thecanonical FA-2 choice but a small sweep at long KV may close the
remaining single-digit gap further.