-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathturboquant.cpp
More file actions
646 lines (588 loc) · 28.9 KB
/
turboquant.cpp
File metadata and controls
646 lines (588 loc) · 28.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
// turboquant.cpp — CPU reference implementation of TurboQuant
// Based on: Zandieh et al., "TurboQuant: Online Vector Quantization with
// Near-optimal Distortion Rate", arXiv:2504.19874 (ICLR 2026).
//
// Single-file, C++17, no deps. Build:
// g++ -O2 -std=c++17 turboquant.cpp -o turboquant && ./turboquant
//
// Implements Algorithm 1 (TurboQuant_mse) and Algorithm 2 (TurboQuant_prod).
// Fast rotation via randomized Hadamard (H·D where D is ±1 random diagonal).
// Codebook learned offline with Lloyd's algorithm on the rotated-component
// distribution (approximately N(0, 1/d) for unit-norm inputs at large d).
//
// This is a correctness reference — not tuned for throughput. Use fused
// SIMD / BLAS kernels for production; this just validates the algorithm.
#include <cstdint>
#include <cmath>
#include <cstdio>
#include <chrono>
#include <random>
#include <vector>
#include <cassert>
#include <algorithm>
#include <bitset>
#include <stdexcept>
namespace turboquant {
// ───────────────────────── Fast Walsh-Hadamard Transform ──────────────────
// In-place, d must be a power of 2. Normalization: applies 1/sqrt(d) so
// the transform is orthonormal (preserves L2 norm).
inline void fwht(std::vector<double>& a) {
const std::size_t d = a.size();
if (d & (d - 1)) throw std::runtime_error("fwht: d must be power of 2");
for (std::size_t h = 1; h < d; h <<= 1) {
for (std::size_t i = 0; i < d; i += h << 1) {
for (std::size_t j = i; j < i + h; ++j) {
double x = a[j], y = a[j + h];
a[j] = x + y;
a[j + h] = x - y;
}
}
}
const double s = 1.0 / std::sqrt(static_cast<double>(d));
for (auto& v : a) v *= s;
}
// ───────────────────────── Randomized Hadamard rotation ───────────────────
// Π = H · D where D = diag(±1). Applying Π·x = FWHT(D·x).
struct Rotation {
std::vector<int8_t> D; // ±1 per dimension
static Rotation make(int d, uint64_t seed) {
Rotation r;
r.D.resize(d);
std::mt19937_64 rng(seed);
std::uniform_int_distribution<int> coin(0, 1);
for (auto& s : r.D) s = coin(rng) ? 1 : -1;
return r;
}
void apply(std::vector<double>& x) const {
for (std::size_t i = 0; i < x.size(); ++i) x[i] *= D[i];
fwht(x); // forward transform (Π·x)
}
void apply_transpose(std::vector<double>& x) const {
fwht(x); // FWHT is self-inverse (modulo the normalization we already did)
for (std::size_t i = 0; i < x.size(); ++i) x[i] *= D[i];
}
};
// ───────────────────────── Scalar codebook (Lloyd-Max on Gaussian) ────────
// For a unit-norm vector rotated by a random orthogonal Π, each component y_j
// is approximately N(0, 1/d). We train 2^b centroids that minimize MSE on
// that marginal distribution.
struct Codebook {
int bits = 0;
std::vector<double> centroids; // size 2^bits, sorted ascending
int nearest(double y) const {
// Linear scan — fine for ≤16 centroids (b ≤ 4). For larger b, use
// std::lower_bound on the sorted centroid vector.
int best = 0;
double bd = std::abs(y - centroids[0]);
const int K = static_cast<int>(centroids.size());
for (int k = 1; k < K; ++k) {
double d = std::abs(y - centroids[k]);
if (d < bd) { bd = d; best = k; }
}
return best;
}
};
Codebook train_codebook(int bits, int dim, int n_samples = 50000,
uint64_t seed = 0xC0FFEE, int iters = 50) {
Codebook cb;
cb.bits = bits;
const int K = 1 << bits;
cb.centroids.resize(K);
// 1. Draw n_samples from N(0, 1/d) — the per-component distribution of
// Π · x when x is unit-norm and Π is random orthogonal.
std::mt19937_64 rng(seed);
std::normal_distribution<double> N01(0.0, 1.0 / std::sqrt(static_cast<double>(dim)));
std::vector<double> samples(n_samples);
for (auto& s : samples) s = N01(rng);
std::sort(samples.begin(), samples.end());
// 2. Init centroids at quantiles.
for (int k = 0; k < K; ++k) {
double q = (k + 0.5) / K;
cb.centroids[k] = samples[static_cast<std::size_t>(q * n_samples)];
}
// 3. Lloyd iterations.
for (int it = 0; it < iters; ++it) {
std::vector<double> sum(K, 0.0);
std::vector<int> cnt(K, 0);
for (double s : samples) {
int k = cb.nearest(s);
sum[k] += s;
cnt[k] += 1;
}
for (int k = 0; k < K; ++k) {
if (cnt[k] > 0) cb.centroids[k] = sum[k] / cnt[k];
}
std::sort(cb.centroids.begin(), cb.centroids.end());
}
return cb;
}
// ───────────────────────── Algorithm 1: TurboQuant_mse ────────────────────
struct MSEState {
int d;
int bits;
Rotation rot;
Codebook cb;
static MSEState make(int d, int bits, uint64_t rot_seed = 0xDEADBEEF) {
MSEState s;
s.d = d;
s.bits = bits;
s.rot = Rotation::make(d, rot_seed);
s.cb = train_codebook(bits, d);
return s;
}
};
// Output: one uint8_t per dimension (low `bits` valid). Bit-packing is a
// trivial post-step that we skip here to keep the reference readable.
std::vector<uint8_t> quant_mse(const MSEState& s, std::vector<double> x) {
assert(static_cast<int>(x.size()) == s.d);
s.rot.apply(x); // y = Π · x
std::vector<uint8_t> idx(s.d);
for (int j = 0; j < s.d; ++j) idx[j] = static_cast<uint8_t>(s.cb.nearest(x[j]));
return idx;
}
std::vector<double> dequant_mse(const MSEState& s, const std::vector<uint8_t>& idx) {
assert(static_cast<int>(idx.size()) == s.d);
std::vector<double> y(s.d);
for (int j = 0; j < s.d; ++j) y[j] = s.cb.centroids[idx[j]];
s.rot.apply_transpose(y); // x̃ = Π^⊤ · ỹ
return y;
}
// ───────────────────────── Algorithm 2: TurboQuant_prod ───────────────────
struct ProdState {
MSEState mse; // uses (b - 1) bits
int d;
int bits; // total bits per dim (= mse.bits + 1)
std::vector<double> S; // d × d, i.i.d. N(0, 1), row-major
static ProdState make(int d, int bits, uint64_t seed = 0xFEEDFACE) {
if (bits < 1) throw std::runtime_error("prod needs bits >= 1");
ProdState p;
p.d = d;
p.bits = bits;
p.mse = MSEState::make(d, bits - 1, seed ^ 0x11);
p.S.resize(static_cast<std::size_t>(d) * d);
std::mt19937_64 rng(seed ^ 0x22);
std::normal_distribution<double> N01(0.0, 1.0);
for (auto& v : p.S) v = N01(rng);
return p;
}
};
struct ProdCompressed {
std::vector<uint8_t> idx; // (b - 1) bits per dim (stored one byte each)
std::vector<uint8_t> qjl; // 1 bit per dim (packed 8 per byte)
double gamma; // ‖r‖_2
};
ProdCompressed quant_prod(const ProdState& p, const std::vector<double>& x) {
assert(static_cast<int>(x.size()) == p.d);
ProdCompressed c;
c.idx = quant_mse(p.mse, x); // idx ← Quant_mse(x)
std::vector<double> xmse = dequant_mse(p.mse, c.idx);
std::vector<double> r(p.d);
double rn2 = 0.0;
for (int j = 0; j < p.d; ++j) { r[j] = x[j] - xmse[j]; rn2 += r[j] * r[j]; }
c.gamma = std::sqrt(rn2);
// qjl = sign(S · r), 1 bit per output dim. Pack 8 bits/byte.
const std::size_t nbytes = (p.d + 7) / 8;
c.qjl.assign(nbytes, 0);
for (int i = 0; i < p.d; ++i) {
double acc = 0.0;
const double* row = &p.S[static_cast<std::size_t>(i) * p.d];
for (int j = 0; j < p.d; ++j) acc += row[j] * r[j];
if (acc >= 0.0) c.qjl[i >> 3] |= uint8_t(1u << (i & 7));
}
return c;
}
std::vector<double> dequant_prod(const ProdState& p, const ProdCompressed& c) {
// x̃_mse + (√(π/2) / d) · γ · S^⊤ · qjl_signs
std::vector<double> xmse = dequant_mse(p.mse, c.idx);
// Reconstitute sign vector into ±1 doubles (in-memory; bit-unpack).
std::vector<double> s(p.d);
for (int i = 0; i < p.d; ++i) {
bool bit = (c.qjl[i >> 3] >> (i & 7)) & 1u;
s[i] = bit ? 1.0 : -1.0;
}
// xqjl = (√(π/2) / d) · γ · S^⊤ · s
const double scale = std::sqrt(M_PI / 2.0) / static_cast<double>(p.d) * c.gamma;
std::vector<double> xqjl(p.d, 0.0);
for (int j = 0; j < p.d; ++j) {
double acc = 0.0;
for (int i = 0; i < p.d; ++i) acc += p.S[static_cast<std::size_t>(i) * p.d + j] * s[i];
xqjl[j] = scale * acc;
}
std::vector<double> out(p.d);
for (int j = 0; j < p.d; ++j) out[j] = xmse[j] + xqjl[j];
return out;
}
// ───────────────────────── Standalone QJL key compressor ──────────────────
// QJL = Query-Key JL transform. Introduced in the same paper (§3.2 / Thm 3).
// The idea: compress *only the key* into d binary bits + one fp32 scale γ = ‖k‖.
// Inner products with any query q are then estimated without decompressing the key.
//
// Storage: d bits + 1 fp32 per key → 1 bit / dim + O(1) scalar.
//
// Compression: k → (sign(S·k), ‖k‖)
// IP estimator: ⟨q, k⟩ ≈ (√π/2 · ‖k‖ / d) · qjl^⊤ · (S·q)
//
// Where qjl_i = +1 if i-th bit is set, −1 otherwise.
//
// The key insight for attention:
// • Precompute Sq = S·q ONCE per query position (cost: O(d²), shared over all T keys).
// • Then each key's contribution is just: γ · (√π/2/d) · sum_i(±Sq_i) — O(d) w/ binary arith.
//
// Distortion: variance(estimator) = (π/2)/d (Theorem 3, Zandieh et al. 2026).
// At unit norm: variance·d ≈ π/2 ≈ 1.57 (vs TurboQuant_prod b=2: D_prod·d ≤ 0.56).
// QJL is 3× worse at 2× lower bit cost — pure-binary is the trade-off point.
//
// Production note: replace the dense S (O(d²) space) with a Randomized Hadamard
// projection using the same Rotation struct above — apply D·(H·D·x) with two
// random diagonals. Cost drops to O(d log d) per key, O(d log d) per query.
//
// This is the kernel OpenVINO PR #35092 implements as a separate SDPA attention
// path (TurboQuant for values, QJL for keys).
struct QJLState {
int d;
std::vector<double> S; // d×d, i.i.d. N(0,1), row-major
static QJLState make(int d, uint64_t seed = 0xABCDEF01) {
QJLState q;
q.d = d;
q.S.resize(static_cast<std::size_t>(d) * d);
std::mt19937_64 rng(seed);
std::normal_distribution<double> N01(0.0, 1.0);
for (auto& v : q.S) v = N01(rng);
return q;
}
};
struct QJLCompressed {
std::vector<uint8_t> bits; // d bits packed 8/byte (1 = positive projection, 0 = negative)
double gamma; // ‖k‖₂
};
// Compress key k to binary code + norm.
QJLCompressed quant_qjl(const QJLState& st, const std::vector<double>& k) {
assert(static_cast<int>(k.size()) == st.d);
QJLCompressed c;
// γ = ‖k‖₂
double n2 = 0.0;
for (double x : k) n2 += x * x;
c.gamma = std::sqrt(n2);
// bits_i = (S_i · k ≥ 0)
const std::size_t nbytes = (static_cast<std::size_t>(st.d) + 7) / 8;
c.bits.assign(nbytes, 0);
for (int i = 0; i < st.d; ++i) {
double acc = 0.0;
const double* row = &st.S[static_cast<std::size_t>(i) * st.d];
for (int j = 0; j < st.d; ++j) acc += row[j] * k[j];
if (acc >= 0.0) c.bits[i >> 3] |= uint8_t(1u << (i & 7));
}
return c;
}
// Precompute Sq = S·q ONCE per query (shared over all T keys in the sequence).
std::vector<double> precompute_Sq(const QJLState& st, const std::vector<double>& q) {
assert(static_cast<int>(q.size()) == st.d);
std::vector<double> Sq(static_cast<std::size_t>(st.d), 0.0);
for (int i = 0; i < st.d; ++i) {
const double* row = &st.S[static_cast<std::size_t>(i) * st.d];
for (int j = 0; j < st.d; ++j) Sq[i] += row[j] * q[j];
}
return Sq;
}
// Estimate ⟨q, k⟩ given the precomputed Sq and the compressed key c.
// Total per-key cost: d multiply-adds (with binary ±1 multiplier).
double estimate_qjl_ip(const QJLCompressed& c, const std::vector<double>& Sq) {
const int d = static_cast<int>(Sq.size());
double acc = 0.0;
for (int i = 0; i < d; ++i) {
double sign = ((c.bits[i >> 3] >> (i & 7)) & 1u) ? 1.0 : -1.0;
acc += sign * Sq[i];
}
// Scale: (√π/2 · γ / d)
return (std::sqrt(M_PI / 2.0) / d) * c.gamma * acc;
}
// ───────────────────────── Fast QJL via Randomized Hadamard ──────────────
// Production replacement for QJLState. Instead of storing a dense d×d
// Gaussian matrix S (O(d²) space), we reuse the Rotation struct (H/√d · D)
// already defined above.
//
// Dense QJL: S is d×d Gaussian → O(d²) space, O(d²) time per key/query.
// Fast QJL: S = Π = H·D → O(d) space, O(d log d) time per key/query.
//
// Validity: Π is orthogonal (H/√d is orthonormal, D is unitary), so Π·x is a
// random rotation of x. For a uniformly random D and large d, each coordinate
// (Π·x)_i ≈ N(0, ‖x‖²/d) marginally (Ailon & Chazelle, "FJLT", 2006).
// Taking sign(·) of each coordinate gives a binary JL projection with the
// same unbiased inner-product estimator property as the dense version.
//
// Empirical distortion (verified hb #379, 2026-05-20):
// Dense QJL (iid Gaussian S): variance·d ≈ π/2 ≈ 1.5708 (matches theory, n=5000)
// Fast QJL (Hadamard Π = H·D): variance·d ≈ 0.58 ≈ 1/√π ≈ 0.564 (2.75× lower!)
//
// Mechanism: dense S has independent rows (zero cross-terms in variance sum).
// Hadamard Π has ORTHOGONAL rows — H-orthogonality forces negative cross-covariances
// between sign(Πk)_i·(Πq)_i terms for i≠j, reducing total variance by ~2.75×.
// The constant ≈ 1/√π may follow from the Rademacher-Hadamard 4th moment interaction.
//
// Practical implication: at d=128, Fast QJL achieves var·d ≈ 0.57 which matches
// TurboQuant_prod b=2 (D_prod·d ≤ 0.56) at HALF the bit budget (1 vs 2 bits/dim).
// Speedup vs dense at d=512: ≈57× theoretical (log₂(512)/512 vs 1).
//
// This is the kernel used in high-performance attention implementations
// (e.g. OpenVINO PR #35092, Flash-Decode with QJL key cache).
struct QJLFastState {
int d;
Rotation rot; // Π = (H/√d)·D — the shared Randomized Hadamard rotation
static QJLFastState make(int d, uint64_t seed = 0xFABBEEF1) {
if (d & (d - 1)) throw std::runtime_error("QJLFast: d must be power of 2");
QJLFastState s;
s.d = d;
s.rot = Rotation::make(d, seed);
return s;
}
};
// Compress key k → (bits, γ). O(d log d): one FWHT + one pass for sign bits.
// Note: takes k by value — rot.apply() modifies it in place.
QJLCompressed quant_qjl_fast(const QJLFastState& st, std::vector<double> k) {
assert(static_cast<int>(k.size()) == st.d);
QJLCompressed c;
double n2 = 0.0; for (double x : k) n2 += x * x; c.gamma = std::sqrt(n2);
st.rot.apply(k); // k ← Π·k (in-place, O(d log d))
const std::size_t nbytes = (static_cast<std::size_t>(st.d) + 7) / 8;
c.bits.assign(nbytes, 0);
for (int i = 0; i < st.d; ++i)
if (k[i] >= 0.0) c.bits[i >> 3] |= uint8_t(1u << (i & 7));
return c;
}
// Precompute Sq = √d · Π·q. O(d log d). Call once per query; reuse over all T keys.
//
// The √d rescaling bridges the magnitude gap between the dense and Hadamard variants:
// Dense S: each row is a Gaussian vector in R^d with per-component variance 1.
// (S·q)_i ~ N(0, ‖q‖²). For unit q: E[(S·q)_i²] = 1.
// Hadamard Π: orthonormal, so ‖Π·q‖₂ = ‖q‖₂ = 1.
// Per-component variance = 1/d. For unit q: E[(Π·q)_i²] = 1/d.
//
// The inner product estimator formula (derived for dense S) expects Sq with unit-scale
// components. Multiplying Π·q by √d restores that convention so estimate_qjl_ip()
// can be reused verbatim.
//
// Proof of unbiasedness after rescaling (unit k, q):
// E[sign(Πk)_i · (√d · Πq)_i] = √d · E[sign(Πk)_i · (Πq)_i]
// = √d · (k·q / d) · √(2/π) (from CLT + bivariate Gaussian)
// = k·q · √(2/π) / √d
// Σ_i over d dims: = d · k·q · √(2/π) / √d = √d · k·q · √(2/π)
// After multiplying by √d (the rescaling): = d · k·q · √(2/π) ← same as dense
// Final scale (√π/2 / d): → k·q ✓
//
// Note: takes q by value — rot.apply() modifies it in place; returns √d · Π·q.
std::vector<double> precompute_Sq_fast(const QJLFastState& st, std::vector<double> q) {
assert(static_cast<int>(q.size()) == st.d);
st.rot.apply(q); // q ← Π·q (each component O(1/√d))
// Rescale to match dense-QJL magnitude: multiply by √d so components are O(1).
const double scale = std::sqrt(static_cast<double>(st.d));
for (auto& x : q) x *= scale;
return q; // = √d · Π·q
}
// Estimator is identical to the dense version — call estimate_qjl_ip(c, Sq) directly.
} // namespace turboquant
// ───────────────────────── Test driver ─────────────────────────────────────
// Validates:
// 1. MSE of TurboQuant_mse matches the paper's Theorem-1 bounds.
// 2. Inner-product estimator from TurboQuant_prod is approximately unbiased
// and its distortion matches Theorem-2 bounds.
namespace {
std::vector<double> random_unit_vec(int d, std::mt19937_64& rng) {
std::normal_distribution<double> N01(0.0, 1.0);
std::vector<double> v(d);
double n2 = 0.0;
for (auto& x : v) { x = N01(rng); n2 += x * x; }
double s = 1.0 / std::sqrt(n2);
for (auto& x : v) x *= s;
return v;
}
double dot(const std::vector<double>& a, const std::vector<double>& b) {
double acc = 0.0;
for (std::size_t i = 0; i < a.size(); ++i) acc += a[i] * b[i];
return acc;
}
} // namespace
int main() {
using namespace turboquant;
const int d = 128; // power of 2 for FWHT; 128 or 256 typical for LLM KV head-dim
const int n_trial = 200; // test vectors per bit-width
std::mt19937_64 rng(0xCAFEBABE);
std::printf("=== TurboQuant CPU reference (d = %d) ===\n\n", d);
// ── Algorithm 1: MSE ──────────────────────────────────────────────────
std::printf("Algorithm 1 — TurboQuant_mse\n");
std::printf(" %-4s %-10s %-10s %-10s\n", "b", "empir_MSE", "theory≤", "rel");
const double paper_mse[5] = {0.0, 0.36, 0.117, 0.03, 0.009};
for (int b = 1; b <= 4; ++b) {
MSEState st = MSEState::make(d, b, 0x1234 + b);
double sum_mse = 0.0;
for (int t = 0; t < n_trial; ++t) {
auto x = random_unit_vec(d, rng);
auto idx = quant_mse(st, x);
auto xtilde = dequant_mse(st, idx);
double e = 0.0;
for (int j = 0; j < d; ++j) { double df = x[j] - xtilde[j]; e += df * df; }
sum_mse += e;
}
double emp = sum_mse / n_trial;
std::printf(" %-4d %-10.4f %-10.4f %-10.2f\n",
b, emp, paper_mse[b], emp / paper_mse[b]);
}
// ── Algorithm 2: product ──────────────────────────────────────────────
std::printf("\nAlgorithm 2 — TurboQuant_prod (unbiased <y, x̃> estimator)\n");
std::printf(" %-4s %-14s %-14s %-14s %-14s\n",
"b", "⟨y,x⟩ (truth)", "⟨y,x̃⟩ (est)", "bias", "D_prod·d");
const double paper_prod_d[5] = {0.0, 1.57, 0.56, 0.18, 0.047}; // ·1/d
for (int b = 2; b <= 4; ++b) {
ProdState ps = ProdState::make(d, b, 0xBEEF + b);
double sum_truth = 0, sum_est = 0, sum_sqerr = 0;
int n = 0;
for (int t = 0; t < n_trial; ++t) {
auto x = random_unit_vec(d, rng);
auto y = random_unit_vec(d, rng); // ‖y‖ = 1
auto c = quant_prod(ps, x);
auto xhat = dequant_prod(ps, c);
double truth = dot(y, x);
double est = dot(y, xhat);
sum_truth += truth;
sum_est += est;
sum_sqerr += (truth - est) * (truth - est);
n++;
}
double bias = (sum_est - sum_truth) / n;
double dprod = sum_sqerr / n; // MSE-style distortion
std::printf(" %-4d %-14.4f %-14.4f %-14.5f %-14.5f\n",
b, sum_truth / n, sum_est / n, bias, dprod * d);
std::printf(" (paper upper bound at ‖y‖=1: D_prod · d ≤ %.3f)\n",
paper_prod_d[b]);
}
// ── Standalone QJL ────────────────────────────────────────────────────
std::printf("\nStandalone QJL (1 bit/dim — pure binary key compression)\n");
std::printf(" Theoretical variance·d = π/2 ≈ %.4f\n", M_PI / 2.0);
std::printf(" %-14s %-14s %-14s %-14s\n",
"⟨q,k⟩ truth", "QJL est", "bias", "var·d (emp)");
{
QJLState qst = QJLState::make(d, 0xDEADF00D);
double sum_truth = 0, sum_est = 0, sum_sqerr = 0;
for (int t = 0; t < n_trial; ++t) {
auto k = random_unit_vec(d, rng); // unit-norm key
auto q = random_unit_vec(d, rng); // unit-norm query
auto c = quant_qjl(qst, k);
auto Sq = precompute_Sq(qst, q);
double truth = dot(q, k);
double est = estimate_qjl_ip(c, Sq);
sum_truth += truth;
sum_est += est;
sum_sqerr += (truth - est) * (truth - est);
}
double bias = (sum_est - sum_truth) / n_trial;
double var_d = (sum_sqerr / n_trial) * d;
std::printf(" %-14.4f %-14.4f %-14.5f %-14.4f\n",
sum_truth / n_trial, sum_est / n_trial, bias, var_d);
}
// ── Fast QJL (Hadamard, O(d log d)) — distortion check ────────────────
std::printf("\nFast QJL (Hadamard, O(d log d) per key/query) [d = %d]\n", d);
std::printf(" Uses Π = H·D instead of dense d×d Gaussian S.\n");
std::printf(" Expected variance·d = π/2 ≈ %.4f (same as dense QJL)\n", M_PI / 2.0);
std::printf(" %-14s %-14s %-14s %-14s\n",
"⟨q,k⟩ truth", "FastQJL est", "bias", "var·d (emp)");
{
QJLFastState fst = QJLFastState::make(d, 0xFABBEEF1);
double sum_truth = 0, sum_est = 0, sum_sqerr = 0;
for (int t = 0; t < n_trial; ++t) {
auto k = random_unit_vec(d, rng);
auto q = random_unit_vec(d, rng);
auto c = quant_qjl_fast(fst, k);
auto Sq = precompute_Sq_fast(fst, q);
double truth = dot(q, k);
double est = estimate_qjl_ip(c, Sq);
sum_truth += truth; sum_est += est;
sum_sqerr += (truth - est) * (truth - est);
}
double bias = (sum_est - sum_truth) / n_trial;
double var_d = (sum_sqerr / n_trial) * d;
std::printf(" %-14.4f %-14.4f %-14.5f %-14.4f\n",
sum_truth / n_trial, sum_est / n_trial, bias, var_d);
}
// ── Timing comparison: dense O(d²) vs Hadamard O(d log d) ─────────────
{
const int d2 = 512; // larger d to make the speedup visible
const int n_time = 2000;
std::printf("\nTiming at d=%d (%d key compressions + query precomputes)\n", d2, n_time);
// Pre-generate keys and queries (exclude from timing).
std::vector<std::vector<double>> keys(n_time), queries(n_time);
std::mt19937_64 rng2(0x1234ABCD);
for (int i = 0; i < n_time; ++i) {
keys[i] = random_unit_vec(d2, rng2);
queries[i] = random_unit_vec(d2, rng2);
}
// Dense QJL O(d²)
{
QJLState qst = QJLState::make(d2, 0xDEADF00D);
volatile double sink = 0.0; // prevent dead-code elimination
auto t0 = std::chrono::high_resolution_clock::now();
for (int i = 0; i < n_time; ++i) {
auto c = quant_qjl(qst, keys[i]);
auto Sq = precompute_Sq(qst, queries[i]);
sink += estimate_qjl_ip(c, Sq);
}
auto t1 = std::chrono::high_resolution_clock::now();
double ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
std::printf(" Dense O(d²): %6.1f ms total, %.4f ms/pair (sink=%.3g)\n",
ms, ms / n_time, static_cast<double>(sink));
}
// Fast QJL O(d log d)
{
QJLFastState fst = QJLFastState::make(d2, 0xFABBEEF1);
volatile double sink = 0.0;
auto t0 = std::chrono::high_resolution_clock::now();
for (int i = 0; i < n_time; ++i) {
auto c = quant_qjl_fast(fst, keys[i]);
auto Sq = precompute_Sq_fast(fst, queries[i]);
sink += estimate_qjl_ip(c, Sq);
}
auto t1 = std::chrono::high_resolution_clock::now();
double ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
std::printf(" Hadamard O(d log d): %6.1f ms total, %.4f ms/pair (sink=%.3g)\n",
ms, ms / n_time, static_cast<double>(sink));
}
double theory_ratio = std::log2(static_cast<double>(d2)) / static_cast<double>(d2);
std::printf(" Asymptotic ops ratio log₂(%d)/%d = %.4f (Hadamard %dx cheaper)\n",
d2, d2, theory_ratio,
static_cast<int>(std::round(1.0 / theory_ratio)));
}
// ── Comparison table at equal bit-budgets ──────────────────────────────
std::printf("\nBit-budget comparison (d = %d, unit-norm k, q)\n", d);
std::printf(" %-28s %-8s %-14s %-14s\n",
"method", "bits/dim", "space/state", "distortion·d");
std::printf(" %-28s %-8s %-14s %-14.4f\n",
"QJL dense (standalone)", "1", "O(d²)", M_PI / 2.0);
// BUG FIX (hb #380, 2026-05-20): Fast QJL distortion is NOT π/2.
// Empirical study (qjl_variance_study.cpp, n=100k per d, d=32–1024) shows:
// Fast QJL variance·d ≈ 0.565–0.571 across d, NOT π/2 ≈ 1.571.
// The asymptotic constant appears to be π/2 - 1 ≈ 0.5708 (consistent with d→∞ trend).
// Mechanism: Hadamard orthogonality forces negative cross-covariances between
// sign(Πk)_i·(Πq)_i terms for i≠j, reducing variance by ≈ 2.75× vs dense Gaussian.
// Exact constant: open theoretical question (between 1/√π ≈ 0.5642 and π/2-1 ≈ 0.5708).
const double fast_qjl_empirical = M_PI / 2.0 - 1.0; // ≈ 0.5708, best large-d estimate
std::printf(" %-28s %-8s %-14s %-14.4f %s\n",
"QJL fast (Hadamard)", "1", "O(d)", fast_qjl_empirical,
"(empirical ~0.57; theory: open)");
std::printf(" %-28s %-8s %-14s %-14.4f\n",
"TurboQuant_prod b=2", "2", "O(d)", paper_prod_d[2]);
std::printf(" %-28s %-8s %-14s %-14.4f\n",
"TurboQuant_prod b=3", "3", "O(d)", paper_prod_d[3]);
std::printf(" %-28s %-8s %-14s %-14.4f\n",
"TurboQuant_prod b=4", "4", "O(d)", paper_prod_d[4]);
std::printf(" (QJL is also residual compressor inside TurboQuant_prod)\n");
std::printf(" NOTE: Fast QJL distortion corrected from π/2 to empirical ~0.57 (hb #380).\n");
std::printf(" This changes the key finding: Fast QJL 1-bit ≈ TurboQuant_prod 2-bit!\n");
std::printf("\nNotes:\n"
" • Π is H·D (randomized Hadamard). Paper allows any rotation; Hadamard is fast + fine at large d.\n"
" • Codebook trained by Lloyd's algorithm on N(0, 1/d) samples.\n"
" • bit-packing: idx stored 1 byte/dim here for clarity; qjl is 1 bit/dim packed 8/byte.\n"
" • Production path: fuse dequant+matmul; keep γ as fp16 scalar per vector.\n"
" • Fast QJL: Π = H·D reuses the same Rotation struct as TurboQuant_mse — O(d) state, O(d log d) ops.\n"
" • OpenVINO PR #35092: TurboQuant for values + QJL for keys in custom SDPA kernel.\n"
" • Attention asymmetry: QJL for keys (need ranking, not reconstruction), TurboQuant for values.\n"
" • Fast QJL variance study: qjl_variance_study.cpp — empirically resolves the constant.\n");
return 0;
}