Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ regenerate-kernels: ## run `tile build --emit` to regenerate metallib + Swift wr
{ printf '// swift-format-ignore-file\n//\n'; cat "$$gen"; } > "$$tmp" && mv "$$tmp" "$$gen"; \
fi

# ─── Bench ────────────────────────────────────────────────────────────
# Quality benches (wikitext2 perplexity / KLD, NIAH retrieval) forced-decode
# an entire corpus. In debug that's far too slow for the numbers to mean
# anything (design §7), and the CLI hard-refuses a debug quality bench. Always
# go through release. Pass-through args via ARGS=:
#
# make bench ARGS="--method wikitext2 --model <repo> --wikitext2-corpus wiki.test.raw"
# make bench ARGS="--method niah --model <long-ctx-repo>"
.PHONY: bench
bench: regenerate-kernels ## run `ffai bench` in release (quality benches require it)
swift run -c release ffai bench $(ARGS)

# ─── Test ─────────────────────────────────────────────────────────────
#
# Production-parity defaults. The 2026-05-19 GPU-pin root cause —
Expand Down
28 changes: 21 additions & 7 deletions Sources/FFAI/Benchmark/BenchMethod.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
// knows about. Mirrors mlx-swift-lm's `MLX_BENCH_METHOD` set so the
// reports are cross-comparable.
//
// Implemented today: simple, summarization, wikitext2
// Plumbed but not implemented yet: niah, multiTurn, toolCalling,
// Implemented today: simple, summarization, wikitext2, niah
// Plumbed but not implemented yet: multiTurn, toolCalling,
// ngramSpot, ngramSweep,
// ngramSweepSummary, vision
//
Expand Down Expand Up @@ -45,9 +45,25 @@ public enum BenchMethod: String, Sendable, CaseIterable {
/// non-zero rather than silently producing garbage.
public var isImplemented: Bool {
switch self {
case .simple, .summarization, .wikitext2:
case .simple, .summarization, .wikitext2, .niah:
return true
case .niah, .multiTurn, .toolCalling,
case .multiTurn, .toolCalling,
.ngramSpot, .ngramSweep, .ngramSweepSummary, .vision:
return false
}
}

/// `true` for methods that forced-decode (or long-context retrieval-
/// score) an entire corpus and so publish *quality* numbers. These
/// are the ones that must run on a release build — in debug the
/// per-token decode loop is slow enough that the numbers are
/// meaningless as a throughput reference and painfully slow to
/// produce. See `planning/telemetry-quality-metrics-design.md` §7.
public var isQualityMetric: Bool {
switch self {
case .wikitext2, .niah:
return true
case .simple, .summarization, .multiTurn, .toolCalling,
.ngramSpot, .ngramSweep, .ngramSweepSummary, .vision:
return false
}
Expand All @@ -74,10 +90,8 @@ public enum BenchMethod: String, Sendable, CaseIterable {
/// today problem or a future-phase problem.
public var dependency: String? {
switch self {
case .simple, .summarization, .wikitext2:
case .simple, .summarization, .wikitext2, .niah:
return nil
case .niah:
return "sliding-window attention mask + needle-position bookkeeping"
case .multiTurn:
return "ChatSession-style multi-turn cache reuse helper"
case .toolCalling:
Expand Down
135 changes: 127 additions & 8 deletions Sources/FFAI/Benchmark/BenchRunner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,20 @@ public struct BenchOptions: Sendable {
public var quantization: String?
public var wikitext2Corpus: URL?
public var wikitext2MaxTokens: Int
/// Sliding `n_ctx` window for PPL/KLD scoring (design §6.1). `0` keeps
/// the legacy single-pass scoring (every token from position 0 of one
/// growing cache); `> 0` strides by `window/2` and scores only each
/// window's second half so scored tokens carry real left-context.
public var wikitext2ContextWindow: Int
public var referenceModel: Model?
/// Phase B of two-phase KLD: score the candidate against a reference
/// distribution cached on disk by a prior `--save-ref-logits` run,
/// instead of co-loading a reference model (design §6.3).
public var referenceLogits: URL?
/// Phase A of two-phase KLD: dump this model's per-position log-probs
/// to disk so future candidates can KLD against it without re-running
/// the reference. Use with the bf16 reference as `--model`.
public var saveReferenceLogits: URL?

public init(
prompt: String? = nil,
Expand All @@ -60,15 +73,21 @@ public struct BenchOptions: Sendable {
quantization: String? = nil,
wikitext2Corpus: URL? = nil,
wikitext2MaxTokens: Int = 2048,
referenceModel: Model? = nil
wikitext2ContextWindow: Int = 0,
referenceModel: Model? = nil,
referenceLogits: URL? = nil,
saveReferenceLogits: URL? = nil
) {
self.prompt = prompt
self.maxTokens = maxTokens
self.contextSize = contextSize
self.quantization = quantization
self.wikitext2Corpus = wikitext2Corpus
self.wikitext2MaxTokens = wikitext2MaxTokens
self.wikitext2ContextWindow = wikitext2ContextWindow
self.referenceModel = referenceModel
self.referenceLogits = referenceLogits
self.saveReferenceLogits = saveReferenceLogits
}
}

Expand All @@ -95,6 +114,8 @@ public struct BenchRunner {
return try await runSummarization(options: options)
case .wikitext2:
return try await runWikiText2(options: options)
case .niah:
return try await runNIAH(options: options)
default:
throw BenchRunnerError.notImplemented(
method: method,
Expand Down Expand Up @@ -167,22 +188,56 @@ public struct BenchRunner {
if tokens.count > options.wikitext2MaxTokens {
tokens = Array(tokens.prefix(options.wikitext2MaxTokens))
}
Debug.log(.bench, "wikitext2: scoring \(tokens.count) tokens")
// BOS handling (design §6.2): prepend the model's <bos> once so the
// first scored token has a real left-context anchor. Idempotent —
// skip when the tokenizer already emitted one or the family is not
// BOS-critical.
if model.engine.requiresLeadingBOS,
let bos = model.tokenizer.bosTokenId ?? model.config.bosTokenId,
tokens.first != bos
{
tokens.insert(bos, at: 0)
}
let window = options.wikitext2ContextWindow
Debug.log(
.bench,
"wikitext2: scoring \(tokens.count) tokens, contextWindow=\(window == 0 ? "off" : String(window))"
)

// Capture a pre-PPL memory snapshot so the report still
// surfaces realistic numbers (no decode loop populates them
// for this method otherwise).
let memTracker = PhaseMemoryTracker()
let pplResult = Perplexity.compute(model: model, tokens: tokens)
let pplResult = Perplexity.compute(
model: model, tokens: tokens, contextWindow: window)
memTracker.endPrefill()
memTracker.endDecode()

// Phase A: dump this model's reference distribution and stop — no
// candidate KLD here (this run's --model *is* the reference).
if let saveURL = options.saveReferenceLogits {
let n = try ReferenceLogitCache.write(
reference: model, tokens: tokens,
contextWindow: window, to: saveURL)
print("[BENCH] saved \(n) reference log-prob rows → \(saveURL.path)")
}

var kld: Double?
if let ref = options.referenceModel {
let kldResult = Perplexity.klDivergence(
reference: ref, candidate: model, tokens: tokens
)
kld = kldResult.meanKLDivergence
var kldResult: Perplexity.KLDResult?
if let refLogits = options.referenceLogits {
// Phase B: cached reference, no second model in memory.
kldResult = try ReferenceLogitCache.klDivergence(
candidate: model, tokens: tokens,
referenceFile: refLogits, contextWindow: window)
} else if let ref = options.referenceModel {
// Live paired KLD (both models resident).
kldResult = Perplexity.klDivergence(
reference: ref, candidate: model,
tokens: tokens, contextWindow: window)
}
if let r = kldResult {
kld = r.meanKLDivergence
printKLDDistribution(r)
}

let weightsBytes = model.engine.parameters().reduce(0) { $0 + $1.1.byteCount }
Expand Down Expand Up @@ -211,4 +266,68 @@ public struct BenchRunner {
genKLDivergence: kld
)
}

// MARK: - NIAH — needle-in-a-haystack retrieval

private func runNIAH(options: BenchOptions) async throws -> BenchRow {
// A single --ctx pins one context length; otherwise sweep the
// default grid (clamped to the model's max sequence by NIAH.run).
let lengths = options.contextSize.map { [$0] } ?? [1024, 2048, 4096]
let memTracker = PhaseMemoryTracker()
let result = NIAH.run(model: model, contextLengths: lengths)
memTracker.endPrefill()
memTracker.endDecode()

print("[BENCH] niah \(modelLabel)")
print(" ctx-tokens depth result answer")
for t in result.trials {
let ctxStr = String(t.contextTokens)
let ctx = String(repeating: " ", count: max(0, 9 - ctxStr.count)) + ctxStr
let depth = String(format: "%.2f", t.depthFraction)
let verdict = t.passed ? "PASS" : "miss"
print(" \(ctx) \(depth) \(verdict) \(t.answerPreview)")
}
print(" \(result.summaryLine)")

let weightsBytes = model.engine.parameters().reduce(0) { $0 + $1.1.byteCount }
let stats = GenerationStats(
promptTokens: result.trials.map(\.contextTokens).max() ?? 0,
generatedTokens: 0,
contextSize: model.engine.maxSeq,
prefillTimeS: 0, decodeTimeS: 0, timeToFirstTokenMs: 0,
steadyTokensPerSecond: nil,
baselineGPUBytes: memTracker.baseline.gpuBytes,
postPrefillGPUBytes: memTracker.postPrefill?.gpuBytes ?? 0,
postDecodeGPUBytes: memTracker.postDecode?.gpuBytes ?? 0,
prefillPeakGPUBytes: memTracker.prefillPeakBytes,
decodePeakGPUBytes: memTracker.decodePeakBytes,
wiredTicketBytes: memTracker.baseline.wiredTicketBytes,
weightsBytes: weightsBytes,
kvCacheAllocatedBytes: 0, kvCacheUsedBytes: 0,
thinkPerplexity: nil, genPerplexity: nil,
thinkKLDivergence: nil, genKLDivergence: nil,
thinkTokenCount: nil, genTokenCount: nil
)
return BenchRow(
model: modelLabel, method: BenchMethod.niah.rawValue,
quantization: options.quantization,
stats: stats, outputPreview: result.summaryLine,
genPerplexity: nil, genKLDivergence: nil
)
}

/// Print the full per-position KLD distribution + top-1 agreement.
/// Mean alone hides tail distortion — the p99 / max / top-1 flip rate
/// are the legible "is this quant broken" signals (design §6.4). The
/// report row still stores the mean for cross-comparability.
private func printKLDDistribution(_ r: Perplexity.KLDResult) {
guard let d = r.distribution else { return }
print(" KLD distribution over \(r.scoredTokens) positions (nats):")
print(String(format: " mean %.4f", d.mean))
print(String(format: " median %.4f", d.median))
print(String(format: " p90 %.4f", d.p90))
print(String(format: " p99 %.4f", d.p99))
print(String(format: " max %.4f", d.max))
print(String(format: " top-1 agreement %.2f%%", d.top1Agreement * 100))
}
}
43 changes: 40 additions & 3 deletions Sources/FFAI/Generation/Generate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,39 @@ extension Model {
var rng = params.makeRNG()
var tokenHistory = promptTokens // grows as decode produces tokens; used by cpu-sample

// ─── Live quality-metric capture (design §5) ─────────────────────
// The single InspectTap drives this. `FFAI_TELEMETRY=ppl` opts into
// accumulating the model's self-perplexity over its own stream — the
// NLL of each token it actually picks. Off by default: when the flag
// is unset `capturePPL` is false and the hot path is byte-for-byte
// the existing fused-kernel sampler (no logit readback, no softmax).
//
// `.kld` / `.niah` are not actionable in the live decode loop — KLD
// needs a paired reference and NIAH is its own retrieval harness; both
// run through `ffai bench`. The tap flags still gate them there.
let tap = InspectTap.fromEnvironment
let capturePPL = tap.captures(.perplexity)
var liveNLLSum = 0.0
var liveNLLCount = 0

func sampleNext(tokenId t: Int, position i: Int) -> Int {
// Capturing path: route through full logits so we can read the
// chosen token's probability. Only taken when the flag is on.
if capturePPL {
let logits = engine.forward(tokenId: t, position: i, caches: caches)
let chosen: Int
switch path {
case .greedyGPU:
chosen = Perplexity.argmaxLogits(logits)
case .gpuCategorical, .cpuSample:
chosen = Sampling.sample(
logits, parameters: params,
rng: &rng, tokenHistory: tokenHistory)
}
liveNLLSum += Perplexity.negLogSoftmaxAt(logits: logits, index: chosen)
liveNLLCount += 1
return chosen
}
switch path {
case .greedyGPU:
return engine.forwardSample(tokenId: t, position: i, caches: caches)
Expand Down Expand Up @@ -409,13 +441,17 @@ extension Model {
Double(generated.count) / max(decodeTime, 1e-9)))

let split = ThinkingSplit.split(tokens: generated, model: self)
// Live self-perplexity over the generated stream, when captured.
let livePerplexity: Double? =
(capturePPL && liveNLLCount > 0) ? exp(liveNLLSum / Double(liveNLLCount)) : nil
let stats = makeStats(
promptTokens: promptTokens, generatedCount: generated.count,
contextSize: engine.maxSeq, prefillTime: prefillTime,
decodeTime: decodeTime, ttftMs: ttftMs,
perTokenWallclock: perTokenWallclock,
memTracker: memTracker, caches: caches,
weightsBytes: weightsBytes, splitTokens: split
weightsBytes: weightsBytes, splitTokens: split,
genPerplexity: livePerplexity
)

continuation.yield(
Expand All @@ -431,7 +467,8 @@ extension Model {
prefillTime: Double, decodeTime: Double, ttftMs: Double,
perTokenWallclock: [Double],
memTracker: PhaseMemoryTracker, caches: [any LayerCacheProtocol],
weightsBytes: Int, splitTokens: ThinkingSplit.Split?
weightsBytes: Int, splitTokens: ThinkingSplit.Split?,
genPerplexity: Double? = nil
) -> GenerationStats {
let steady: Double? = {
guard perTokenWallclock.count > 10 else { return nil }
Expand All @@ -458,7 +495,7 @@ extension Model {
kvCacheAllocatedBytes: caches.totalBytesAllocated,
kvCacheUsedBytes: caches.totalBytesInUse,
thinkPerplexity: nil,
genPerplexity: nil,
genPerplexity: genPerplexity,
thinkKLDivergence: nil,
genKLDivergence: nil,
thinkTokenCount: splitTokens.map { $0.thinkTokens.count },
Expand Down
Loading
Loading