diff --git a/Makefile b/Makefile index 27283ba0f..aa7075c2d 100644 --- a/Makefile +++ b/Makefile @@ -16,8 +16,8 @@ METAL_SRCS := $(wildcard metal/*.metal) ifeq ($(UNAME_S),Darwin) METAL_LDLIBS := $(LDLIBS) -framework Foundation -framework Metal -CORE_OBJS = ds4.o ds4_metal.o -CPU_CORE_OBJS = ds4_cpu.o +CORE_OBJS = ds4.o ds4_metal.o ds4_planar_quant.o +CPU_CORE_OBJS = ds4_cpu.o ds4_planar_quant.o else CFLAGS += -D_GNU_SOURCE -fno-finite-math-only CUDA_HOME ?= /usr/local/cuda @@ -28,12 +28,12 @@ NVCC_ARCH_FLAGS := -arch=$(CUDA_ARCH) endif NVCCFLAGS ?= -O3 -g -lineinfo --use_fast_math $(NVCC_ARCH_FLAGS) -Xcompiler $(NATIVE_CPU_FLAG) -Xcompiler -pthread CUDA_LDLIBS ?= -lm -Xcompiler -pthread -L$(CUDA_HOME)/targets/sbsa-linux/lib -L$(CUDA_HOME)/lib64 -lcudart -lcublas -CORE_OBJS = ds4.o ds4_cuda.o -CPU_CORE_OBJS = ds4_cpu.o +CORE_OBJS = ds4.o ds4_cuda.o ds4_planar_quant.o +CPU_CORE_OBJS = ds4_cpu.o ds4_planar_quant.o METAL_LDLIBS := $(LDLIBS) endif -.PHONY: all help clean test cpu cuda cuda-spark cuda-generic cuda-regression +.PHONY: all help clean test planar-quant-test planar-eval cpu cuda cuda-spark cuda-generic cuda-regression ifeq ($(UNAME_S),Darwin) all: ds4 ds4-server ds4-bench ds4-eval ds4-agent @@ -154,6 +154,21 @@ tests/cuda_long_context_smoke.o: tests/cuda_long_context_smoke.c ds4_gpu.h rax.o: rax.c rax.h rax_malloc.h $(CC) $(CFLAGS) -c -o $@ rax.c +ds4_planar_quant.o: ds4_planar_quant.c ds4_planar_quant.h + $(CC) $(CFLAGS) -c -o $@ ds4_planar_quant.c + +tests/planar_quant_test: tests/planar_quant_test.c ds4_planar_quant.c ds4_planar_quant.h + $(CC) $(CFLAGS) -I. -o $@ tests/planar_quant_test.c ds4_planar_quant.c $(LDLIBS) + +planar-quant-test: tests/planar_quant_test + ./tests/planar_quant_test + +tools/planar_eval: tools/planar_eval.c ds4_planar_quant.c ds4_planar_quant.h + $(CC) $(CFLAGS) -I. -o $@ tools/planar_eval.c ds4_planar_quant.c $(LDLIBS) + +planar-eval: tools/planar_eval + ./tools/planar_eval --mode ds4_like --rows 10000 --queries 8 + linenoise.o: linenoise.c linenoise.h $(CC) $(CFLAGS) -c -o $@ linenoise.c @@ -191,9 +206,9 @@ else $(NVCC) $(NVCCFLAGS) -o $@ ds4_test.o ds4_kvstore.o rax.o $(CORE_OBJS) $(CUDA_LDLIBS) endif -test: ds4_test ds4-eval +test: ds4_test ds4-eval planar-quant-test ./ds4-eval --self-test-extractors ./ds4_test clean: - rm -f ds4 ds4-server ds4-bench ds4-eval ds4-agent ds4_cpu ds4_native ds4_server_test ds4_test *.o tests/cuda_long_context_smoke tests/cuda_long_context_smoke.o + rm -f ds4 ds4-server ds4-bench ds4-eval ds4-agent ds4_cpu ds4_native ds4_server_test ds4_test *.o tests/cuda_long_context_smoke tests/cuda_long_context_smoke.o tests/planar_quant_test tools/planar_eval diff --git a/ds4.c b/ds4.c index ecbcec3f5..42337a142 100644 --- a/ds4.c +++ b/ds4.c @@ -36,6 +36,7 @@ #include #include "ds4.h" +#include "ds4_planar_quant.h" #ifndef DS4_NO_GPU #include "ds4_gpu.h" @@ -6520,6 +6521,8 @@ typedef struct { uint32_t comp_cap; uint32_t n_comp; float *attn_comp_kv; + ds4_row_planar3 *attn_comp_planar; + float *planar_staging; /* dequant staging, same size as attn_comp_kv */ float *attn_state_kv; float *attn_state_score; @@ -6690,7 +6693,8 @@ static void cpu_decode_scratch_free(ds4_cpu_decode_scratch *scratch) { /* Allocate per-layer KV state: a raw sliding window for all layers, plus * compressed attention/indexer caches for layers whose ratio is nonzero. */ -static void kv_cache_init(ds4_kv_cache *cache, uint32_t ctx_size, uint32_t raw_cap) { +static void kv_cache_init(ds4_kv_cache *cache, uint32_t ctx_size, uint32_t raw_cap, + bool planar_kv_cache, bool planar_kv_cache_only) { memset(cache, 0, sizeof(*cache)); if (raw_cap == 0) raw_cap = ds4_default_raw_cap(ctx_size); if (raw_cap > ctx_size) raw_cap = ctx_size; @@ -6711,7 +6715,17 @@ static void kv_cache_init(ds4_kv_cache *cache, uint32_t ctx_size, uint32_t raw_c const uint32_t attn_rows = coff * ratio; cache->layer[il].comp_cap = comp_cap; - cache->layer[il].attn_comp_kv = xmalloc_zeroed((size_t)comp_cap * DS4_N_HEAD_DIM, sizeof(float)); + cache->layer[il].attn_comp_kv = planar_kv_cache_only ? NULL : + xmalloc_zeroed((size_t)comp_cap * DS4_N_HEAD_DIM, sizeof(float)); + cache->layer[il].attn_comp_planar = planar_kv_cache ? + xmalloc_zeroed(comp_cap, sizeof(ds4_row_planar3)) : NULL; + cache->layer[il].planar_staging = planar_kv_cache ? + xmalloc_zeroed((size_t)comp_cap * DS4_N_HEAD_DIM, sizeof(float)) : NULL; + /* planar-only: ensure staging is always available for dequant */ + if (planar_kv_cache_only && !cache->layer[il].planar_staging) { + cache->layer[il].planar_staging = + xmalloc_zeroed((size_t)comp_cap * DS4_N_HEAD_DIM, sizeof(float)); + } cache->layer[il].attn_state_kv = xmalloc_zeroed((size_t)attn_width * attn_rows, sizeof(float)); cache->layer[il].attn_state_score = xmalloc((size_t)attn_width * attn_rows * sizeof(float)); for (uint64_t i = 0; i < (uint64_t)attn_width * attn_rows; i++) { @@ -6737,6 +6751,8 @@ static void kv_cache_free(ds4_kv_cache *cache) { for (uint32_t il = 0; il < DS4_N_LAYER; il++) { free(cache->layer[il].raw_kv); free(cache->layer[il].attn_comp_kv); + free(cache->layer[il].attn_comp_planar); + free(cache->layer[il].planar_staging); free(cache->layer[il].attn_state_kv); free(cache->layer[il].attn_state_score); free(cache->layer[il].index_comp_kv); @@ -6762,10 +6778,28 @@ static void kv_cache_push_raw(ds4_layer_cache *cache, const float *kv) { for (uint32_t i = 0; i < DS4_N_HEAD_DIM; i++) dst[i] = f16_to_f32(f32_to_f16(kv[i])); } -static void kv_cache_push_comp(float *rows, uint32_t *n_rows, uint32_t cap_rows, uint32_t row_dim, const float *kv) { +/* Return comp_kv pointer for attention: dequantize from Planar3 if present. */ +static const float *comp_kv_for_attn(const ds4_layer_cache *layer) { + if (layer->attn_comp_planar && layer->n_comp > 0) { + if (!layer->planar_staging) ds4_die("Planar3 compressed KV staging buffer is missing"); + ds4_planar3_dequantize(layer->attn_comp_planar, layer->planar_staging, + layer->n_comp, DS4_N_HEAD_DIM); + return layer->planar_staging; + } + return layer->attn_comp_kv; +} + +static void kv_cache_push_comp(float *rows, ds4_row_planar3 *planar_rows, + uint32_t *n_rows, uint32_t cap_rows, + uint32_t row_dim, const float *kv) { if (*n_rows >= cap_rows) ds4_die("compressed KV cache capacity exceeded"); - float *dst = rows + (uint64_t)(*n_rows) * row_dim; - for (uint32_t i = 0; i < row_dim; i++) dst[i] = f16_to_f32(f32_to_f16(kv[i])); + float f16_rounded[DS4_N_HEAD_DIM]; + for (uint32_t i = 0; i < row_dim; i++) f16_rounded[i] = f16_to_f32(f32_to_f16(kv[i])); + if (rows) { + float *dst = rows + (uint64_t)(*n_rows) * row_dim; + memcpy(dst, f16_rounded, (size_t)row_dim * sizeof(float)); + } + if (planar_rows) ds4_planar3_quantize_row(f16_rounded, &planar_rows[*n_rows]); (*n_rows)++; } @@ -7483,7 +7517,7 @@ static void layer_attention_raw_swa_one( ratio, il, pos)) { - kv_cache_push_comp(cache->attn_comp_kv, &cache->n_comp, cache->comp_cap, DS4_N_HEAD_DIM, comp); + kv_cache_push_comp(cache->attn_comp_kv, cache->attn_comp_planar, &cache->n_comp, cache->comp_cap, DS4_N_HEAD_DIM, comp); } free(comp); @@ -7501,7 +7535,7 @@ static void layer_attention_raw_swa_one( ratio, il, pos)) { - kv_cache_push_comp(cache->index_comp_kv, &cache->n_index_comp, cache->comp_cap, DS4_N_INDEXER_HEAD_DIM, index_comp); + kv_cache_push_comp(cache->index_comp_kv, NULL, &cache->n_index_comp, cache->comp_cap, DS4_N_INDEXER_HEAD_DIM, index_comp); } free(index_comp); @@ -7514,7 +7548,7 @@ static void layer_attention_raw_swa_one( layer_attention_mixed_one(heads, model, layer, q, cache->raw_kv, cache->n_raw, - cache->attn_comp_kv, cache->n_comp, + comp_kv_for_attn(cache), cache->n_comp, comp_allowed); } else { layer_attention_rows_one(heads, model, layer, q, cache->raw_kv, cache->n_raw); @@ -7719,7 +7753,7 @@ static void layer_attention_raw_swa_batch( il, pos); if (have_comp) { - kv_cache_push_comp(cache->attn_comp_kv, &cache->n_comp, cache->comp_cap, DS4_N_HEAD_DIM, comp); + kv_cache_push_comp(cache->attn_comp_kv, cache->attn_comp_planar, &cache->n_comp, cache->comp_cap, DS4_N_HEAD_DIM, comp); } if (ratio == 4) { @@ -7737,7 +7771,7 @@ static void layer_attention_raw_swa_batch( il, pos); if (have_index_comp) { - kv_cache_push_comp(cache->index_comp_kv, &cache->n_index_comp, cache->comp_cap, DS4_N_INDEXER_HEAD_DIM, index_comp); + kv_cache_push_comp(cache->index_comp_kv, NULL, &cache->n_index_comp, cache->comp_cap, DS4_N_INDEXER_HEAD_DIM, index_comp); } if (profile) t_tl_compress += now_sec() - tx; @@ -7769,7 +7803,7 @@ static void layer_attention_raw_swa_batch( tx = profile ? now_sec() : 0.0; layer_attention_mixed_one(heads + (uint64_t)t * q_dim, model, layer, q_t, cache->raw_kv, cache->n_raw, - cache->attn_comp_kv, cache->n_comp, + comp_kv_for_attn(cache), cache->n_comp, comp_allowed); if (profile) t_tl_attn_rows += now_sec() - tx; } @@ -7798,7 +7832,8 @@ static void layer_attention_raw_swa_batch( if (prefix_batch_attn) { double tx = profile ? now_sec() : 0.0; - const float *comp_kv_for_prefix = cache->attn_comp_kv ? cache->attn_comp_kv : kv; + const float *comp_kv_for_prefix = comp_kv_for_attn(cache); + if (!comp_kv_for_prefix) comp_kv_for_prefix = kv; if (!heads) { heads = xmalloc((size_t)n_tok * q_dim * sizeof(heads[0])); } @@ -7967,7 +8002,7 @@ static void layer_forward_raw_swa_one( il, pos, scratch)) { - kv_cache_push_comp(cache->attn_comp_kv, &cache->n_comp, cache->comp_cap, DS4_N_HEAD_DIM, scratch->comp); + kv_cache_push_comp(cache->attn_comp_kv, cache->attn_comp_planar, &cache->n_comp, cache->comp_cap, DS4_N_HEAD_DIM, scratch->comp); } if (ratio == 4) { @@ -7984,7 +8019,7 @@ static void layer_forward_raw_swa_one( il, pos, scratch)) { - kv_cache_push_comp(cache->index_comp_kv, &cache->n_index_comp, cache->comp_cap, + kv_cache_push_comp(cache->index_comp_kv, NULL, &cache->n_index_comp, cache->comp_cap, DS4_N_INDEXER_HEAD_DIM, scratch->index_comp); } if (profile) t_compress = now_sec() - t0; @@ -8008,7 +8043,7 @@ static void layer_forward_raw_swa_one( if (ratio != 0) { layer_attention_mixed_one_decode_scratch(scratch->heads, model, layer, scratch->q, cache->raw_kv, cache->n_raw, - cache->attn_comp_kv, cache->n_comp, + comp_kv_for_attn(cache), cache->n_comp, comp_allowed, scratch); } else { @@ -8542,6 +8577,7 @@ typedef struct { * the row counters whenever a checkpoint is saved or partially rewound. */ ds4_gpu_tensor *layer_raw_cache[DS4_MAX_LAYER]; ds4_gpu_tensor *layer_attn_comp_cache[DS4_MAX_LAYER]; + ds4_gpu_tensor *layer_attn_comp_planar[DS4_MAX_LAYER]; ds4_gpu_tensor *layer_attn_state_kv[DS4_MAX_LAYER]; ds4_gpu_tensor *layer_attn_state_score[DS4_MAX_LAYER]; ds4_gpu_tensor *layer_index_comp_cache[DS4_MAX_LAYER]; @@ -8816,6 +8852,9 @@ static void metal_graph_free(ds4_gpu_graph *g) { for (uint32_t il = 0; il < DS4_N_LAYER; il++) { ds4_gpu_tensor_free(g->layer_attn_comp_cache[il]); } + for (uint32_t il = 0; il < DS4_N_LAYER; il++) { + ds4_gpu_tensor_free(g->layer_attn_comp_planar[il]); + } for (uint32_t il = 0; il < DS4_N_LAYER; il++) { ds4_gpu_tensor_free(g->layer_attn_state_kv[il]); } @@ -9150,7 +9189,9 @@ static bool metal_graph_alloc_raw_cap( uint32_t raw_cap, uint32_t ctx_size, uint32_t prefill_cap, - bool enable_mtp) { + bool enable_mtp, + bool planar_kv_cache, + bool planar_kv_cache_only) { memset(g, 0, sizeof(*g)); g->mtp_enabled = enable_mtp; if (raw_cap == 0) raw_cap = 1; @@ -9252,10 +9293,17 @@ static bool metal_graph_alloc_raw_cap( const uint32_t coff = ratio == 4 ? 2u : 1u; const uint64_t attn_width = (uint64_t)coff * DS4_N_HEAD_DIM; const uint64_t attn_rows = (uint64_t)coff * ratio; - g->layer_attn_comp_cache[il] = metal_graph_alloc_kv_cache_tensor( - managed_kv_cache, - (uint64_t)g->layer_comp_cap[il] * DS4_N_HEAD_DIM * - (DS4_GPU_ATTN_COMP_CACHE_F16 ? sizeof(uint16_t) : sizeof(float))); + if (!planar_kv_cache_only) { + g->layer_attn_comp_cache[il] = metal_graph_alloc_kv_cache_tensor( + managed_kv_cache, + (uint64_t)g->layer_comp_cap[il] * DS4_N_HEAD_DIM * + (DS4_GPU_ATTN_COMP_CACHE_F16 ? sizeof(uint16_t) : sizeof(float))); + } + if (planar_kv_cache) { + g->layer_attn_comp_planar[il] = metal_graph_alloc_kv_cache_tensor( + managed_kv_cache, + (uint64_t)g->layer_comp_cap[il] * 4 * 50); + } g->layer_attn_state_kv[il] = ds4_gpu_tensor_alloc(attn_width * attn_rows * sizeof(float)); g->layer_attn_state_score[il] = ds4_gpu_tensor_alloc(attn_width * attn_rows * sizeof(float)); if (enable_mtp) { @@ -9402,9 +9450,10 @@ static bool metal_graph_alloc_raw_cap( layer_cache_ok = g->layer_raw_cache[il] != NULL; const uint32_t ratio = ds4_layer_compress_ratio(il); if (layer_cache_ok && ratio != 0) { - layer_cache_ok = g->layer_attn_comp_cache[il] != NULL && + layer_cache_ok = (planar_kv_cache_only || g->layer_attn_comp_cache[il] != NULL) && g->layer_attn_state_kv[il] != NULL && g->layer_attn_state_score[il] != NULL && + (!planar_kv_cache || g->layer_attn_comp_planar[il] != NULL) && (!enable_mtp || (g->spec_attn_state_kv[il] != NULL && g->spec_attn_state_score[il] != NULL && @@ -9473,7 +9522,7 @@ static bool metal_graph_alloc( ds4_gpu_graph *g, const ds4_weights *weights, const ds4_layer_weights *layer) { - return metal_graph_alloc_raw_cap(g, weights, layer, DS4_N_SWA, DS4_N_SWA, 1, false); + return metal_graph_alloc_raw_cap(g, weights, layer, DS4_N_SWA, DS4_N_SWA, 1, false, false, false); } static uint32_t metal_graph_raw_span_for_batch( @@ -9690,6 +9739,20 @@ static uint32_t metal_graph_attn_comp_cache_is_f16(void) { return DS4_GPU_ATTN_COMP_CACHE_F16 ? 1u : 0u; } +static ds4_gpu_tensor *metal_graph_attn_comp_for_attention( + ds4_gpu_graph *g, + uint32_t il) { + return g->layer_attn_comp_planar[il] + ? g->layer_attn_comp_planar[il] + : g->layer_attn_comp_cache[il]; +} + +static uint32_t metal_graph_attn_comp_is_planar( + ds4_gpu_graph *g, + uint32_t il) { + return g->layer_attn_comp_planar[il] ? 1u : 0u; +} + static bool metal_graph_store_attn_comp_stage( ds4_gpu_graph *g, uint32_t il, @@ -9697,7 +9760,8 @@ static bool metal_graph_store_attn_comp_stage( uint32_t rows) { if (!g || il >= DS4_N_LAYER) return false; if (rows == 0) return true; - if (!g->layer_attn_comp_cache[il] || !g->attn_comp_stage) return false; + if (!g->layer_attn_comp_cache[il]) return true; + if (!g->attn_comp_stage) return false; if (rows > g->attn_comp_stage_cap || first_row > g->layer_comp_cap[il] || rows > g->layer_comp_cap[il] - first_row) { return false; @@ -9742,6 +9806,31 @@ static bool metal_graph_commit_attn_comp_stage( return metal_graph_store_attn_comp_stage(g, il, first_row, rows); } +static bool metal_graph_quantize_attn_comp_planar( + ds4_gpu_graph *g, + uint32_t il, + uint32_t first_row, + uint32_t rows) { + if (!g->layer_attn_comp_planar[il] || rows == 0) return true; + if (DS4_GPU_ATTN_COMP_CACHE_F16 && rows > g->attn_comp_stage_cap) return false; + const uint64_t fp32_row = (uint64_t)DS4_N_HEAD_DIM * sizeof(float); + const uint64_t planar_row = (uint64_t)4 * 50; + ds4_gpu_tensor *src = DS4_GPU_ATTN_COMP_CACHE_F16 + ? ds4_gpu_tensor_view(g->attn_comp_stage, 0, (uint64_t)rows * fp32_row) + : ds4_gpu_tensor_view(g->layer_attn_comp_cache[il], + (uint64_t)first_row * fp32_row, + (uint64_t)rows * fp32_row); + ds4_gpu_tensor *dst = ds4_gpu_tensor_view( + g->layer_attn_comp_planar[il], + (uint64_t)first_row * planar_row, + (uint64_t)rows * planar_row); + bool ok = src && dst && + ds4_gpu_planar3_quantize_tensor(src, dst, rows, DS4_N_HEAD_DIM) != 0; + ds4_gpu_tensor_free(src); + ds4_gpu_tensor_free(dst); + return ok; +} + static ds4_gpu_tensor *metal_graph_attn_comp_row_view( ds4_gpu_graph *g, uint32_t il, @@ -10066,6 +10155,7 @@ static bool metal_graph_encode_decode_layer( ds4_gpu_tensor_free(comp_row_view); } if (ok) ok = metal_graph_commit_attn_comp_stage(g, il, comp_row, 1); + if (ok) ok = metal_graph_quantize_attn_comp_planar(g, il, comp_row, 1); } if (ok && emit) g->layer_n_comp[il]++; @@ -10286,8 +10376,9 @@ static bool metal_graph_encode_decode_layer( layer->attn_sinks->abs_offset, g->q, raw_cache, - g->layer_attn_comp_cache[il], + metal_graph_attn_comp_for_attention(g, il), metal_graph_attn_comp_cache_is_f16(), + metal_graph_attn_comp_is_planar(g, il), comp_selected, 1, pos, @@ -10315,12 +10406,13 @@ static bool metal_graph_encode_decode_layer( g->q, raw_cache, n_raw, raw_cap, raw_start, - n_comp ? comp_cache : NULL, + n_comp ? metal_graph_attn_comp_for_attention(g, il) : NULL, metal_graph_attn_comp_cache_is_f16(), n_comp, NULL, 0, - DS4_N_HEAD, DS4_N_HEAD_DIM) != 0; + DS4_N_HEAD, DS4_N_HEAD_DIM, + n_comp ? metal_graph_attn_comp_is_planar(g, il) : 0) != 0; } } DS4_METAL_PROFILE_DECODE_STAGE("attention"); @@ -12226,6 +12318,7 @@ static bool metal_graph_encode_layer_attention_batch( DS4_RMS_EPS) != 0; if (ok && n_comp != 0) { ok = metal_graph_commit_attn_comp_stage(g, il, 0, n_comp); + if (ok) ok = metal_graph_quantize_attn_comp_planar(g, il, 0, n_comp); } if (ok && ratio == 4) { ok = metal_graph_refresh_ratio4_compressor_state(g, @@ -12337,6 +12430,7 @@ static bool metal_graph_encode_layer_attention_batch( } if (ok && comp_chunk != 0) { ok = metal_graph_commit_attn_comp_stage(g, il, comp_before, comp_chunk); + if (ok) ok = metal_graph_quantize_attn_comp_planar(g, il, comp_before, comp_chunk); } if (ok && ratio == 4) { ok = metal_graph_refresh_ratio4_compressor_state(g, @@ -12428,6 +12522,7 @@ static bool metal_graph_encode_layer_attention_batch( } ds4_gpu_tensor_free(comp_row_view); if (ok) ok = metal_graph_commit_attn_comp_stage(g, il, comp_row, 1); + if (ok) ok = metal_graph_quantize_attn_comp_planar(g, il, comp_row, 1); } if (ok && emit) g->layer_n_comp[il]++; if (comp_counts) comp_counts[t] = g->layer_n_comp[il]; @@ -12818,8 +12913,9 @@ static bool metal_graph_encode_layer_attention_batch( layer->attn_sinks->abs_offset, g->batch_q, g->layer_raw_cache[il], - g->layer_attn_comp_cache[il], + metal_graph_attn_comp_for_attention(g, il), metal_graph_attn_comp_cache_is_f16(), + metal_graph_attn_comp_is_planar(g, il), g->comp_selected, n_tokens, pos0, @@ -12847,7 +12943,7 @@ static bool metal_graph_encode_layer_attention_batch( layer->attn_sinks->abs_offset, g->batch_q, g->layer_raw_cache[il], - g->layer_attn_comp_cache[il], + metal_graph_attn_comp_for_attention(g, il), metal_graph_attn_comp_cache_is_f16(), use_comp_mask ? g->comp_mask : NULL, use_comp_mask, @@ -12860,7 +12956,8 @@ static bool metal_graph_encode_layer_attention_batch( g->raw_window, ratio, DS4_N_HEAD, - DS4_N_HEAD_DIM) != 0; + DS4_N_HEAD_DIM, + metal_graph_attn_comp_is_planar(g, il)) != 0; } } if (ok) batch_attention_done = true; @@ -12932,8 +13029,9 @@ static bool metal_graph_encode_layer_attention_batch( layer->attn_sinks->abs_offset, g->batch_q, g->layer_raw_cache[il], - g->layer_attn_comp_cache[il], + metal_graph_attn_comp_for_attention(g, il), metal_graph_attn_comp_cache_is_f16(), + metal_graph_attn_comp_is_planar(g, il), g->comp_selected, n_tokens, pos0, @@ -12964,8 +13062,9 @@ static bool metal_graph_encode_layer_attention_batch( layer->attn_sinks->abs_offset, g->batch_q, g->batch_kv, - g->layer_attn_comp_cache[il], + metal_graph_attn_comp_for_attention(g, il), metal_graph_attn_comp_cache_is_f16(), + metal_graph_attn_comp_is_planar(g, il), n_tokens, n_comp, g->raw_window, @@ -13059,8 +13158,9 @@ static bool metal_graph_encode_layer_attention_batch( layer->attn_sinks->abs_offset, q_view, g->layer_raw_cache[il], - g->layer_attn_comp_cache[il], + metal_graph_attn_comp_for_attention(g, il), metal_graph_attn_comp_cache_is_f16(), + metal_graph_attn_comp_is_planar(g, il), g->comp_selected, 1, pos, @@ -13083,13 +13183,14 @@ static bool metal_graph_encode_layer_attention_batch( n_raw, g->raw_cap, raw_start, - cur_comp ? g->layer_attn_comp_cache[il] : NULL, + cur_comp ? metal_graph_attn_comp_for_attention(g, il) : NULL, metal_graph_attn_comp_cache_is_f16(), cur_comp, comp_mask, n_selected, DS4_N_HEAD, - DS4_N_HEAD_DIM) != 0; + DS4_N_HEAD_DIM, + cur_comp ? metal_graph_attn_comp_is_planar(g, il) : 0) != 0; } ds4_gpu_tensor_free(heads_view); ds4_gpu_tensor_free(kv_cache_view); @@ -14759,7 +14860,7 @@ static int metal_graph_prompt_logits_test( ds4_gpu_graph g; bool ok = metal_graph_alloc_raw_cap(&g, weights, &weights->layer[0], - raw_cap, (uint32_t)ctx_size, (uint32_t)n_test, false); + raw_cap, (uint32_t)ctx_size, (uint32_t)n_test, false, false, false); if (!ok) { metal_graph_free(&g); fprintf(stderr, "ds4: failed to initialize Metal graph prompt test runtime\n"); @@ -14769,7 +14870,7 @@ static int metal_graph_prompt_logits_test( if (memory_report) ds4_gpu_print_memory_report("after graph alloc"); ds4_kv_cache cpu_cache; - kv_cache_init(&cpu_cache, (uint32_t)ctx_size, raw_cap); + kv_cache_init(&cpu_cache, (uint32_t)ctx_size, raw_cap, false, false); float *cpu_logits = xmalloc((size_t)DS4_N_VOCAB * sizeof(float)); float *gpu_logits = xmalloc((size_t)DS4_N_VOCAB * sizeof(float)); float *oracle_logits = NULL; @@ -15068,6 +15169,9 @@ struct ds4_engine { float directional_steering_ffn_scale; int power_percent; bool quality; + bool planar_kv_cache; + bool planar_kv_cache_only; + const char *dump_comp_kv; bool metal_ready; bool mtp_ready; }; @@ -16035,7 +16139,7 @@ static int generate_raw_swa_cpu( fprintf(stderr, "ds4: using CPU generation with layer-major prefill\n"); ds4_kv_cache cache; - kv_cache_init(&cache, (uint32_t)ctx_size, 0); + kv_cache_init(&cache, (uint32_t)ctx_size, 0, false, false); ds4_cpu_decode_scratch decode_scratch; cpu_decode_scratch_init(&decode_scratch, (uint32_t)ctx_size); @@ -16166,7 +16270,7 @@ static int generate_metal_graph_raw_swa( } ds4_gpu_graph g; bool ok = metal_graph_alloc_raw_cap(&g, weights, &weights->layer[0], - raw_cap, (uint32_t)ctx_size, prefill_cap, false); + raw_cap, (uint32_t)ctx_size, prefill_cap, false, false, false); if (!ok) { fprintf(stderr, "ds4: failed to allocate GPU graph runtime\n"); return 1; @@ -16746,7 +16850,8 @@ static uint64_t session_cpu_payload_live_tensor_bytes(const ds4_session *s) { static void session_cpu_reset_cache(ds4_session *s) { kv_cache_free(&s->cpu_cache); - kv_cache_init(&s->cpu_cache, (uint32_t)s->ctx_size, 0); + kv_cache_init(&s->cpu_cache, (uint32_t)s->ctx_size, 0, + s->engine->planar_kv_cache, s->engine->planar_kv_cache_only); } int ds4_engine_routed_quant_bits(ds4_engine *e) { @@ -17040,20 +17145,32 @@ int ds4_session_save_payload(ds4_session *s, FILE *fp, char *err, size_t errlen) /* Compressed rows are append-only from row zero, so the live prefix is * contiguous. The two compressor state tensors hold the partial window * that will become the next compressed row. */ - if (DS4_GPU_ATTN_COMP_CACHE_F16) { - rc = payload_write_tensor_span_f16_as_f32(fp, - g->layer_attn_comp_cache[il], - 0, - (uint64_t)g->layer_n_comp[il] * DS4_N_HEAD_DIM, - buf, - DS4_SESSION_IO_CHUNK, - err, - errlen); - } else { + if (g->layer_attn_comp_cache[il]) { + if (DS4_GPU_ATTN_COMP_CACHE_F16) { + rc = payload_write_tensor_span_f16_as_f32(fp, + g->layer_attn_comp_cache[il], + 0, + (uint64_t)g->layer_n_comp[il] * DS4_N_HEAD_DIM, + buf, + DS4_SESSION_IO_CHUNK, + err, + errlen); + } else { + rc = payload_write_tensor_span(fp, + g->layer_attn_comp_cache[il], + 0, + (uint64_t)g->layer_n_comp[il] * DS4_N_HEAD_DIM * sizeof(float), + buf, + DS4_SESSION_IO_CHUNK, + err, + errlen); + } + } else if (g->layer_attn_comp_planar[il]) { + /* planar-only: write Planar3 bytes directly */ rc = payload_write_tensor_span(fp, - g->layer_attn_comp_cache[il], + g->layer_attn_comp_planar[il], 0, - (uint64_t)g->layer_n_comp[il] * DS4_N_HEAD_DIM * sizeof(float), + (uint64_t)g->layer_n_comp[il] * 200, buf, DS4_SESSION_IO_CHUNK, err, @@ -17232,6 +17349,10 @@ int ds4_session_load_payload(ds4_session *s, FILE *fp, uint64_t payload_bytes, c token_vec_free(&new_checkpoint); return 1; } + if (layer->attn_comp_planar && n_comp[il] > 0) { + ds4_planar3_quantize(layer->attn_comp_kv, layer->attn_comp_planar, + n_comp[il], DS4_N_HEAD_DIM); + } if (ratio == 4) { if (payload_read_bytes(fp, layer->index_comp_kv, @@ -17373,21 +17494,64 @@ int ds4_session_load_payload(ds4_session *s, FILE *fp, uint64_t payload_bytes, c } const uint32_t ratio = ds4_layer_compress_ratio(il); if (rc != 0 || ratio == 0) continue; - if (DS4_GPU_ATTN_COMP_CACHE_F16) { - rc = payload_read_tensor_span_f32_as_f16(fp, - g->layer_attn_comp_cache[il], - 0, - (uint64_t)n_comp[il] * DS4_N_HEAD_DIM, - buf, - DS4_SESSION_IO_CHUNK, - &remaining, - err, - errlen); - } else { + if (g->layer_attn_comp_cache[il]) { + if (DS4_GPU_ATTN_COMP_CACHE_F16) { + rc = payload_read_tensor_span_f32_as_f16(fp, + g->layer_attn_comp_cache[il], + 0, + (uint64_t)n_comp[il] * DS4_N_HEAD_DIM, + buf, + DS4_SESSION_IO_CHUNK, + &remaining, + err, + errlen); + } else { + rc = payload_read_tensor_span(fp, + g->layer_attn_comp_cache[il], + 0, + (uint64_t)n_comp[il] * DS4_N_HEAD_DIM * sizeof(float), + buf, + DS4_SESSION_IO_CHUNK, + &remaining, + err, + errlen); + } + /* Rebuild Planar3 GPU cache from the restored FP32/F16 comp rows. */ + if (rc == 0 && g->layer_attn_comp_planar[il] && n_comp[il] > 0) { + const uint64_t fp32_bytes = (uint64_t)n_comp[il] * DS4_N_HEAD_DIM * sizeof(float); + float *stage = xmalloc(fp32_bytes); + if (DS4_GPU_ATTN_COMP_CACHE_F16) { + uint64_t f16_bytes = (uint64_t)n_comp[il] * DS4_N_HEAD_DIM * sizeof(uint16_t); + void *f16_buf = xmalloc(f16_bytes); + if (ds4_gpu_tensor_read(g->layer_attn_comp_cache[il], 0, f16_buf, f16_bytes)) { + for (uint64_t i = 0; i < (uint64_t)n_comp[il] * DS4_N_HEAD_DIM; i++) { + stage[i] = f16_to_f32(((uint16_t *)f16_buf)[i]); + } + } else { rc = 1; } + free(f16_buf); + } else { + if (!ds4_gpu_tensor_read(g->layer_attn_comp_cache[il], 0, stage, fp32_bytes)) { + rc = 1; + } + } + if (rc == 0) { + const uint64_t planar_bytes = (uint64_t)n_comp[il] * 200; + void *planar_buf = xmalloc(planar_bytes); + ds4_planar3_quantize(stage, planar_buf, n_comp[il], DS4_N_HEAD_DIM); + if (!ds4_gpu_tensor_write(g->layer_attn_comp_planar[il], 0, planar_buf, planar_bytes)) { + rc = 1; + } + free(planar_buf); + } + free(stage); + } + } else if (g->layer_attn_comp_planar[il]) { + /* planar-only: read Planar3 bytes directly, skip FP16/FP32 cache */ + const uint64_t planar_bytes = (uint64_t)n_comp[il] * 200; rc = payload_read_tensor_span(fp, - g->layer_attn_comp_cache[il], + g->layer_attn_comp_planar[il], 0, - (uint64_t)n_comp[il] * DS4_N_HEAD_DIM * sizeof(float), + planar_bytes, buf, DS4_SESSION_IO_CHUNK, &remaining, @@ -17639,7 +17803,7 @@ int ds4_engine_collect_imatrix(ds4_engine *e, ds4_gpu_graph g; bool ok = metal_graph_alloc_raw_cap(&g, weights, &weights->layer[0], - raw_cap, (uint32_t)ctx_size, prefill_cap, false); + raw_cap, (uint32_t)ctx_size, prefill_cap, false, false, false); if (!ok) { fprintf(stderr, "ds4: failed to allocate imatrix Metal graph runtime\n"); free(dataset); @@ -17988,6 +18152,9 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { e->mtp_model.fd = -1; e->backend = opt->backend; e->quality = opt->quality; + e->planar_kv_cache = opt->planar_kv_cache; + e->planar_kv_cache_only = opt->planar_kv_cache_only; + e->dump_comp_kv = opt->dump_comp_kv; e->power_percent = opt->power_percent > 0 ? opt->power_percent : 100; if (e->power_percent > 100) e->power_percent = 100; e->mtp_draft_tokens = opt->mtp_draft_tokens > 0 ? opt->mtp_draft_tokens : 1; @@ -18165,7 +18332,8 @@ int ds4_session_create(ds4_session **out, ds4_engine *e, int ctx_size) { s->engine = e; s->ctx_size = ctx_size; s->prefill_cap = ds4_default_prefill_cap_for_prompt(ctx_size); - kv_cache_init(&s->cpu_cache, (uint32_t)ctx_size, 0); + kv_cache_init(&s->cpu_cache, (uint32_t)ctx_size, 0, + e->planar_kv_cache, e->planar_kv_cache_only); cpu_decode_scratch_init(&s->cpu_scratch, (uint32_t)ctx_size); s->logits = xmalloc((size_t)DS4_N_VOCAB * sizeof(s->logits[0])); *out = s; @@ -18182,7 +18350,8 @@ int ds4_session_create(ds4_session **out, ds4_engine *e, int ctx_size) { s->prefill_cap = metal_graph_prefill_cap_for_prompt(ctx_size); const uint32_t raw_cap = metal_graph_raw_cap_for_context(ctx_size, s->prefill_cap); if (!metal_graph_alloc_raw_cap(&s->graph, &e->weights, &e->weights.layer[0], - raw_cap, (uint32_t)ctx_size, s->prefill_cap, e->mtp_ready)) + raw_cap, (uint32_t)ctx_size, s->prefill_cap, e->mtp_ready, + e->planar_kv_cache, e->planar_kv_cache_only)) { free(s); return 1; @@ -19293,3 +19462,59 @@ int ds4_session_pos(ds4_session *s) { int ds4_session_ctx(ds4_session *s) { return s->ctx_size; } + +/* Dump the first compressed layer's FP32 KV rows to a binary file. + * Layout: uint32_t n_rows, uint32_t dim, then n_rows*dim float32 values. + * Only dumps the first layer with compression; different layers may have + * different counts. Requires --cpu backend. */ +int ds4_session_dump_comp_kv(ds4_session *s, const char *path, char *err, size_t errlen) { + if (!s || !path) { + payload_set_err(err, errlen, "invalid compressed KV dump request"); + return 1; + } + if (!s->engine || s->engine->backend != DS4_BACKEND_CPU) { + snprintf(err, errlen, + "--dump-comp-kv requires --cpu backend (CPU cache is not populated under %s)", + s->engine ? ds4_backend_name(s->engine->backend) : "unknown"); + return 1; + } + + /* Find the first layer with compressed rows and use its actual n_comp. */ + const ds4_layer_cache *dump_layer = NULL; + for (uint32_t il = 0; il < DS4_N_LAYER; il++) { + const ds4_layer_cache *layer = &s->cpu_cache.layer[il]; + if (layer->compress_ratio && layer->n_comp > 0) { + dump_layer = layer; + break; + } + } + if (!dump_layer) { + payload_set_err(err, errlen, "no compressed KV rows to dump"); + return 1; + } + + FILE *f = fopen(path, "wb"); + if (!f) { + snprintf(err, errlen, "cannot open dump path %s: %s", path, strerror(errno)); + return 1; + } + + const uint32_t n_rows = dump_layer->n_comp; + const uint32_t dim = DS4_N_HEAD_DIM; + int rc = 0; + if (fwrite(&n_rows, sizeof(uint32_t), 1, f) != 1 || + fwrite(&dim, sizeof(uint32_t), 1, f) != 1 || + fwrite(dump_layer->attn_comp_kv, sizeof(float), (size_t)n_rows * dim, f) != (size_t)n_rows * dim) + { + snprintf(err, errlen, "failed to write compressed KV dump %s: %s", path, strerror(errno)); + rc = 1; + } + if (fclose(f) != 0 && rc == 0) { + snprintf(err, errlen, "failed to close compressed KV dump %s: %s", path, strerror(errno)); + rc = 1; + } + if (rc != 0) return 1; + fprintf(stderr, "ds4: dumped %u compressed KV rows x %u dims to %s\n", + n_rows, dim, path); + return 0; +} diff --git a/ds4.h b/ds4.h index f1a8e9e4b..6c37a7e6c 100644 --- a/ds4.h +++ b/ds4.h @@ -73,6 +73,9 @@ typedef struct { bool warm_weights; bool quality; bool inspect_only; + const char *dump_comp_kv; + bool planar_kv_cache; + bool planar_kv_cache_only; } ds4_engine_options; typedef void (*ds4_token_emit_fn)(void *ud, int token); @@ -201,6 +204,7 @@ void ds4_session_invalidate(ds4_session *s); void ds4_session_rewind(ds4_session *s, int pos); int ds4_session_pos(ds4_session *s); int ds4_session_ctx(ds4_session *s); +int ds4_session_dump_comp_kv(ds4_session *s, const char *path, char *err, size_t errlen); int ds4_engine_routed_quant_bits(ds4_engine *e); bool ds4_engine_has_mtp(ds4_engine *e); int ds4_engine_mtp_draft_tokens(ds4_engine *e); diff --git a/ds4_cli.c b/ds4_cli.c index dfac149b3..4fd86cd3d 100644 --- a/ds4_cli.c +++ b/ds4_cli.c @@ -160,6 +160,13 @@ static void usage(FILE *fp) { "Diagnostics:\n" " --inspect\n" " Load the model and print a summary only.\n" + " --dump-comp-kv FILE\n" + " Dump compressed KV rows (binary) after prefill for offline Planar3 eval.\n" + " --planar-kv-cache\n" + " Enable Planar3 quantization for compressed attention KV cache.\n" + " --planar-kv-cache-only\n" + " Only keep Planar3 compressed KV cache, skip FP16/FP32 allocation.\n" + " Implies --planar-kv-cache. Saves memory at the cost of dequant overhead.\n" " --dump-tokens\n" " Tokenize -p/--prompt-file exactly as written, then exit without inference.\n" " --dump-logits FILE\n" @@ -509,6 +516,15 @@ static int run_sampled_generation(ds4_engine *engine, const cli_config *cfg, con } ds4_session_set_progress(session, NULL, NULL); ds4_session_set_display_progress(session, NULL, NULL); + + if (cfg->engine.dump_comp_kv) { + if (ds4_session_dump_comp_kv(session, cfg->engine.dump_comp_kv, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4: compressed KV dump failed: %s\n", err); + ds4_session_free(session); + return 1; + } + } + const double t_prefill1 = cli_now_sec(); int max_tokens = cfg->gen.n_predict; @@ -1537,6 +1553,13 @@ static cli_config parse_options(int argc, char **argv) { exit(2); } else if (!strcmp(arg, "--inspect")) { c.inspect = true; + } else if (!strcmp(arg, "--dump-comp-kv")) { + c.engine.dump_comp_kv = need_arg(&i, argc, argv, arg); + } else if (!strcmp(arg, "--planar-kv-cache")) { + c.engine.planar_kv_cache = true; + } else if (!strcmp(arg, "--planar-kv-cache-only")) { + c.engine.planar_kv_cache = true; + c.engine.planar_kv_cache_only = true; } else if (!strcmp(arg, "--warm-weights")) { c.engine.warm_weights = true; } else if (!strcmp(arg, "--server")) { diff --git a/ds4_gpu.h b/ds4_gpu.h index 2872b46a4..dc01552e9 100644 --- a/ds4_gpu.h +++ b/ds4_gpu.h @@ -434,7 +434,8 @@ int ds4_gpu_attention_decode_heads_tensor( const ds4_gpu_tensor *comp_mask, uint32_t use_mask, uint32_t n_head, - uint32_t head_dim); + uint32_t head_dim, + uint32_t comp_kv_planar); int ds4_gpu_attention_prefill_raw_heads_tensor( ds4_gpu_tensor *heads, @@ -484,7 +485,8 @@ int ds4_gpu_attention_decode_mixed_batch_heads_tensor( uint32_t window, uint32_t ratio, uint32_t n_head, - uint32_t head_dim); + uint32_t head_dim, + uint32_t comp_kv_planar); int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( ds4_gpu_tensor *heads, @@ -495,6 +497,7 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( const ds4_gpu_tensor *raw_kv, const ds4_gpu_tensor *comp_kv, uint32_t comp_kv_f16, + uint32_t comp_kv_planar, const ds4_gpu_tensor *topk, uint32_t n_tokens, uint32_t pos0, @@ -517,6 +520,7 @@ int ds4_gpu_attention_prefill_static_mixed_heads_tensor( const ds4_gpu_tensor *raw_kv, const ds4_gpu_tensor *comp_kv, uint32_t comp_kv_f16, + uint32_t comp_kv_planar, uint32_t n_tokens, uint32_t n_comp, uint32_t window, @@ -575,6 +579,12 @@ int ds4_gpu_attention_output_low_q8_tensor( * routing, shared SwiGLU, and the IQ2_XXS/Q2_K/Q4_K routed experts. */ +int ds4_gpu_planar3_quantize_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *dst, + uint32_t n_rows, + uint32_t head_dim); + int ds4_gpu_swiglu_tensor( ds4_gpu_tensor *out, const ds4_gpu_tensor *gate, diff --git a/ds4_metal.m b/ds4_metal.m index 2c84b453d..a525f4417 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -94,6 +94,8 @@ static id g_dsv4_indexer_weighted_sum_pipeline; static id g_dsv4_indexer_score_one_direct_pipeline; static id g_dsv4_compressor_store_one_pipeline; +static id g_planar3_quantize_pipeline; +static id g_planar3_dequant_to_f16_pipeline; static id g_dsv4_sort_i32_rows_asc_pipeline; static id g_dsv4_indexed_attention_heads8_pipeline; static id g_dsv4_indexed_attention_heads8_rb16_pipeline; @@ -2945,7 +2947,7 @@ static int ds4_gpu_encode_rope_tail_inplace( uint32_t window; uint32_t ratio; uint32_t comp_kv_f16; - uint32_t pad0; + uint32_t comp_kv_planar; uint64_t q_token_stride; uint64_t q_head_stride; uint64_t raw_row_stride; @@ -4145,6 +4147,10 @@ int ds4_gpu_init(void) { ds4_gpu_get_pipeline("kernel_dsv4_indexer_score_one_direct"); g_dsv4_compressor_store_one_pipeline = ds4_gpu_get_pipeline("kernel_dsv4_compressor_store_one"); + g_planar3_quantize_pipeline = + ds4_gpu_get_pipeline("kernel_planar3_quantize_row"); + g_planar3_dequant_to_f16_pipeline = + ds4_gpu_get_pipeline("kernel_planar3_dequant_to_f16_rows"); g_dsv4_sort_i32_rows_asc_pipeline = ds4_gpu_get_pipeline("kernel_dsv4_sort_i32_rows_asc"); g_dsv4_indexed_attention_heads8_pipeline = @@ -4529,6 +4535,8 @@ void ds4_gpu_cleanup(void) { g_dsv4_indexer_weighted_sum_pipeline = nil; g_dsv4_indexer_score_one_direct_pipeline = nil; g_dsv4_compressor_store_one_pipeline = nil; + g_planar3_quantize_pipeline = nil; + g_planar3_dequant_to_f16_pipeline = nil; g_dsv4_sort_i32_rows_asc_pipeline = nil; g_dsv4_indexed_attention_heads8_pipeline = nil; g_dsv4_indexed_attention_heads8_rb16_pipeline = nil; @@ -9733,6 +9741,14 @@ static void ds4_gpu_fill_static_mixed_prefill_mask( } } +static int ds4_gpu_encode_planar3_dequant_to_f16( + id cb, + id src, + NSUInteger src_off, + id dst, + NSUInteger dst_off, + uint32_t n_rows); + static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_nonvec_long( id __strong *cbp, ds4_gpu_tensor *heads, @@ -9749,7 +9765,10 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_nonvec_long uint32_t window, uint32_t ratio, uint32_t n_head, - uint32_t head_dim) { + uint32_t head_dim, + bool comp_kv_planar, + id comp_kv_planar_buf, + NSUInteger comp_kv_planar_off) { if (!cbp || !*cbp) return 0; id cb = *cbp; if (head_dim != 512 || n_head == 0 || n_tokens == 0 || ratio == 0) { @@ -9843,17 +9862,24 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_nonvec_long return 0; } DS4_METAL_PROFILE_FLASH_ATTN_STAGE("copy_raw"); - if (n_comp && - !ds4_gpu_encode_copy_to_f16_1d(cb, - compbuf, - ds4_gpu_tensor_offset(comp_kv), - comp_kv_f16 != 0, - g_flash_attn_kv_buffer, - (NSUInteger)n_tokens * row_bytes_f16, - n_comp * head_dim)) { - return 0; - } if (n_comp) { + if (comp_kv_planar) { + if (!ds4_gpu_encode_planar3_dequant_to_f16(cb, + comp_kv_planar_buf, comp_kv_planar_off, + g_flash_attn_kv_buffer, + (NSUInteger)n_tokens * row_bytes_f16, + n_comp)) { + return 0; + } + } else if (!ds4_gpu_encode_copy_to_f16_1d(cb, + compbuf, + ds4_gpu_tensor_offset(comp_kv), + comp_kv_f16 != 0, + g_flash_attn_kv_buffer, + (NSUInteger)n_tokens * row_bytes_f16, + n_comp * head_dim)) { + return 0; + } DS4_METAL_PROFILE_FLASH_ATTN_STAGE("copy_comp"); } @@ -10024,7 +10050,10 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( uint32_t window, uint32_t ratio, uint32_t n_head, - uint32_t head_dim) { + uint32_t head_dim, + bool comp_kv_planar, + id comp_kv_planar_buf, + NSUInteger comp_kv_planar_off) { if (!cbp || !*cbp) return 0; id cb = *cbp; if (head_dim != 512 || n_head == 0 || n_tokens == 0 || ratio == 0) { @@ -10119,13 +10148,21 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( } DS4_METAL_PROFILE_FLASH_ATTN_STAGE("copy_raw"); if (n_comp) { - if (!ds4_gpu_encode_copy_to_f16_1d(cb, - compbuf, - ds4_gpu_tensor_offset(comp_kv), - comp_kv_f16 != 0, - g_flash_attn_kv_buffer, - (NSUInteger)n_tokens * row_bytes_f16, - n_comp * head_dim)) { + if (comp_kv_planar) { + if (!ds4_gpu_encode_planar3_dequant_to_f16(cb, + comp_kv_planar_buf, comp_kv_planar_off, + g_flash_attn_kv_buffer, + (NSUInteger)n_tokens * row_bytes_f16, + n_comp)) { + return 0; + } + } else if (!ds4_gpu_encode_copy_to_f16_1d(cb, + compbuf, + ds4_gpu_tensor_offset(comp_kv), + comp_kv_f16 != 0, + g_flash_attn_kv_buffer, + (NSUInteger)n_tokens * row_bytes_f16, + n_comp * head_dim)) { return 0; } DS4_METAL_PROFILE_FLASH_ATTN_STAGE("copy_comp"); @@ -10283,6 +10320,9 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_nonvec( const ds4_gpu_tensor *raw_kv, const ds4_gpu_tensor *comp_kv, uint32_t comp_kv_f16, + bool comp_kv_planar, + id comp_kv_planar_buf, + NSUInteger comp_kv_planar_off, const ds4_gpu_tensor *comp_mask, uint32_t use_comp_mask, uint32_t n_tokens, @@ -10307,7 +10347,10 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_nonvec( window, ratio, n_head, - head_dim); + head_dim, + comp_kv_planar, + comp_kv_planar_buf, + comp_kv_planar_off); } return ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec(cbp, heads, @@ -10324,7 +10367,10 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_nonvec( window, ratio, n_head, - head_dim); + head_dim, + comp_kv_planar, + comp_kv_planar_buf, + comp_kv_planar_off); } static int ds4_gpu_encode_flash_attention_prefill_raw_heads_nonvec( @@ -10805,7 +10851,10 @@ static int ds4_gpu_encode_flash_attention_gathered_heads( const ds4_gpu_tensor *comp_mask, uint32_t use_mask, uint32_t n_head, - uint32_t head_dim) { + uint32_t head_dim, + bool comp_kv_planar, + id comp_kv_planar_buf, + NSUInteger comp_kv_planar_off) { const uint32_t n_keys = n_raw + n_comp; if (head_dim != 512 || n_head == 0 || n_raw == 0 || n_keys == 0 || raw_cap < n_raw || n_keys < n_raw) { @@ -10925,13 +10974,21 @@ static int ds4_gpu_encode_flash_attention_gathered_heads( return 0; } if (n_comp) { - if (!ds4_gpu_encode_copy_to_f16_1d(cb, - compbuf, - ds4_gpu_tensor_offset(comp_kv), - comp_kv_f16 != 0, - g_flash_attn_kv_buffer, - (NSUInteger)n_raw * row_bytes_f16, - n_comp * head_dim)) { + if (comp_kv_planar) { + if (!ds4_gpu_encode_planar3_dequant_to_f16(cb, + comp_kv_planar_buf, comp_kv_planar_off, + g_flash_attn_kv_buffer, + (NSUInteger)n_raw * row_bytes_f16, + n_comp)) { + return 0; + } + } else if (!ds4_gpu_encode_copy_to_f16_1d(cb, + compbuf, + ds4_gpu_tensor_offset(comp_kv), + comp_kv_f16 != 0, + g_flash_attn_kv_buffer, + (NSUInteger)n_raw * row_bytes_f16, + n_comp * head_dim)) { return 0; } } @@ -11310,7 +11367,10 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( uint32_t window, uint32_t ratio, uint32_t n_head, - uint32_t head_dim) { + uint32_t head_dim, + bool comp_kv_planar, + id comp_kv_planar_buf, + NSUInteger comp_kv_planar_off) { if (n_comp == 0) { return ds4_gpu_encode_flash_attention_decode_raw_batch_heads(cb, heads, @@ -11426,16 +11486,28 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( kvoff, g_flash_attn_kv_buffer, 0, - n_raw * head_dim) || - !ds4_gpu_encode_copy_to_f16_1d(cb, - compbuf, - ds4_gpu_tensor_offset(comp_kv), - comp_kv_f16 != 0, - g_flash_attn_kv_buffer, - (NSUInteger)n_raw * row_bytes_f16, - n_comp * head_dim)) { + n_raw * head_dim)) { return 0; } + if (n_comp) { + if (comp_kv_planar) { + if (!ds4_gpu_encode_planar3_dequant_to_f16(cb, + comp_kv_planar_buf, comp_kv_planar_off, + g_flash_attn_kv_buffer, + (NSUInteger)n_raw * row_bytes_f16, + n_comp)) { + return 0; + } + } else if (!ds4_gpu_encode_copy_to_f16_1d(cb, + compbuf, + ds4_gpu_tensor_offset(comp_kv), + comp_kv_f16 != 0, + g_flash_attn_kv_buffer, + (NSUInteger)n_raw * row_bytes_f16, + n_comp * head_dim)) { + return 0; + } + } ds4_gpu_fill_mixed_decode_batch_mask((uint16_t *)[mask_buffer contents], n_tokens, @@ -11714,7 +11786,8 @@ int ds4_gpu_attention_decode_mixed_batch_heads_tensor( uint32_t window, uint32_t ratio, uint32_t n_head, - uint32_t head_dim) { + uint32_t head_dim, + uint32_t comp_kv_planar) { if (!g_initialized && !ds4_gpu_init()) return 0; if (!heads || !q || !raw_kv || !model_map || n_tokens == 0 || n_raw == 0 || raw_cap < n_raw || raw_start >= raw_cap || @@ -11759,7 +11832,14 @@ int ds4_gpu_attention_decode_mixed_batch_heads_tensor( window, ratio, n_head, - head_dim)) { + head_dim, + comp_kv_planar != 0, + comp_kv_planar + ? ds4_gpu_tensor_buffer(comp_kv) + : nil, + comp_kv_planar + ? ds4_gpu_tensor_offset(comp_kv) + : 0)) { return 0; } @@ -11778,6 +11858,7 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( const ds4_gpu_tensor *raw_kv, const ds4_gpu_tensor *comp_kv, uint32_t comp_kv_f16, + uint32_t comp_kv_planar, const ds4_gpu_tensor *topk, uint32_t n_tokens, uint32_t pos0, @@ -11806,9 +11887,12 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( const uint64_t row_bytes = (uint64_t)head_dim * sizeof(float); const uint64_t row_bytes_f16 = (uint64_t)head_dim * sizeof(uint16_t); + const uint64_t row_bytes_planar = (uint64_t)4 * 50; const uint64_t q_bytes = (uint64_t)n_tokens * n_head * row_bytes; const uint64_t raw_bytes = (uint64_t)raw_cap * row_bytes; - const uint64_t comp_bytes = (uint64_t)n_comp * (comp_kv_f16 ? row_bytes_f16 : row_bytes); + const uint64_t comp_row_bytes = comp_kv_planar ? row_bytes_planar : + (comp_kv_f16 ? row_bytes_f16 : row_bytes); + const uint64_t comp_bytes = (uint64_t)n_comp * comp_row_bytes; const uint64_t topk_bytes = (uint64_t)top_k * n_tokens * sizeof(int32_t); id qbuf = ds4_gpu_tensor_buffer(q); id rawbuf = ds4_gpu_tensor_buffer(raw_kv); @@ -11883,11 +11967,11 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( .window = window, .ratio = ratio, .comp_kv_f16 = comp_kv_f16 ? 1u : 0u, - .pad0 = 0, + .comp_kv_planar = comp_kv_planar ? 1u : 0u, .q_token_stride = (uint64_t)n_head * row_bytes, .q_head_stride = row_bytes, .raw_row_stride = row_bytes, - .comp_row_stride = comp_kv_f16 ? row_bytes_f16 : row_bytes, + .comp_row_stride = comp_row_bytes, .topk_token_stride = (uint64_t)top_k * sizeof(int32_t), .dst_token_stride = (uint64_t)n_head * row_bytes, .dst_head_stride = row_bytes, @@ -11944,6 +12028,7 @@ int ds4_gpu_attention_prefill_static_mixed_heads_tensor( const ds4_gpu_tensor *raw_kv, const ds4_gpu_tensor *comp_kv, uint32_t comp_kv_f16, + uint32_t comp_kv_planar, uint32_t n_tokens, uint32_t n_comp, uint32_t window, @@ -11981,6 +12066,13 @@ int ds4_gpu_attention_prefill_static_mixed_heads_tensor( raw_kv, comp_kv, comp_kv_f16, + comp_kv_planar != 0, + comp_kv_planar + ? ds4_gpu_tensor_buffer(comp_kv) + : nil, + comp_kv_planar + ? ds4_gpu_tensor_offset(comp_kv) + : 0, NULL, 0, n_tokens, @@ -12045,6 +12137,9 @@ int ds4_gpu_attention_prefill_masked_mixed_heads_tensor( raw_kv, comp_kv, comp_kv_f16, + false, + nil, + 0, comp_mask, 1, n_tokens, @@ -12078,7 +12173,8 @@ int ds4_gpu_attention_decode_heads_tensor( const ds4_gpu_tensor *comp_mask, uint32_t use_mask, uint32_t n_head, - uint32_t head_dim) { + uint32_t head_dim, + uint32_t comp_kv_planar) { if (!g_initialized && !ds4_gpu_init()) return 0; if (!heads || !model_map || !q || !raw_kv || n_raw == 0 || n_head == 0 || head_dim == 0 || @@ -12162,7 +12258,14 @@ int ds4_gpu_attention_decode_heads_tensor( comp_mask, use_mask, n_head, - head_dim)) { + head_dim, + comp_kv_planar != 0, + comp_kv_planar + ? ds4_gpu_tensor_buffer(comp_kv) + : nil, + comp_kv_planar + ? ds4_gpu_tensor_offset(comp_kv) + : 0)) { return 0; } @@ -12172,6 +12275,76 @@ int ds4_gpu_attention_decode_heads_tensor( return 1; } +int ds4_gpu_planar3_quantize_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *dst, + uint32_t n_rows, + uint32_t head_dim) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!src || !dst || n_rows == 0 || head_dim != 512) return 0; + + @autoreleasepool { + id pipeline = + ds4_gpu_hot_pipeline(g_planar3_quantize_pipeline, + "kernel_planar3_quantize_row"); + if (!pipeline) return 0; + + const uint64_t src_bytes = (uint64_t)n_rows * head_dim * sizeof(float); + const uint64_t dst_bytes = (uint64_t)n_rows * 4 * 50; + if (ds4_gpu_tensor_bytes(src) < src_bytes || + ds4_gpu_tensor_bytes(dst) < dst_bytes) return 0; + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBuffer:ds4_gpu_tensor_buffer(src) + offset:ds4_gpu_tensor_offset(src) + atIndex:0]; + [enc setBuffer:ds4_gpu_tensor_buffer(dst) + offset:ds4_gpu_tensor_offset(dst) + atIndex:1]; + [enc setBytes:&n_rows length:sizeof(n_rows) atIndex:2]; + + const NSUInteger threads = (NSUInteger)n_rows; + [enc dispatchThreadgroups:MTLSizeMake((threads + 255u) / 256u, 1, 1) + threadsPerThreadgroup:MTLSizeMake(256u, 1, 1)]; + + return ds4_gpu_finish_command_buffer(cb, owned, "planar3 quantize"); + } +} + +static int ds4_gpu_encode_planar3_dequant_to_f16( + id cb, + id src, + NSUInteger src_off, + id dst, + NSUInteger dst_off, + uint32_t n_rows) { + if (!cb || !src || !dst || n_rows == 0) return 0; + + id pipeline = + ds4_gpu_hot_pipeline(g_planar3_dequant_to_f16_pipeline, + "kernel_planar3_dequant_to_f16_rows"); + if (!pipeline) return 0; + + const uint64_t dst_row_stride = (uint64_t)512 * sizeof(uint16_t); + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:pipeline]; + [enc setBuffer:src offset:src_off atIndex:0]; + [enc setBuffer:dst offset:dst_off atIndex:1]; + [enc setBytes:&n_rows length:sizeof(n_rows) atIndex:2]; + [enc setBytes:&dst_row_stride length:sizeof(dst_row_stride) atIndex:3]; + + const NSUInteger threads = (NSUInteger)n_rows; + [enc dispatchThreadgroups:MTLSizeMake((threads + 255u) / 256u, 1, 1) + threadsPerThreadgroup:MTLSizeMake(256u, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; +} + int ds4_gpu_swiglu_tensor( ds4_gpu_tensor *out, const ds4_gpu_tensor *gate, diff --git a/ds4_planar_quant.c b/ds4_planar_quant.c new file mode 100644 index 000000000..89fc05aed --- /dev/null +++ b/ds4_planar_quant.c @@ -0,0 +1,285 @@ +#define _USE_MATH_DEFINES +#include "ds4_planar_quant.h" + +#include +#include +#include + + +/* ---- FP16 conversion (standalone, no external lib) ---- */ + +static inline uint16_t fp32_to_fp16(float f) { + uint32_t x; + memcpy(&x, &f, 4); + uint32_t sign = (x >> 16) & 0x8000u; + int32_t exp = (int32_t)((x >> 23) & 0xFFu) - 127 + 15; + uint32_t mant = (x >> 13) & 0x3FFu; + if (exp <= 0) return (uint16_t)sign; + if (exp >= 31) return (uint16_t)(sign | 0x7C00u); + return (uint16_t)(sign | ((uint32_t)exp << 10) | mant); +} + +static inline float fp16_to_fp32(uint16_t h) { + uint32_t sign = ((uint32_t)(h & 0x8000u)) << 16; + uint32_t exp = (h >> 10) & 0x1Fu; + uint32_t mant = h & 0x3FFu; + if (exp == 0) { + if (mant == 0) { + uint32_t f = sign; + float r; + memcpy(&r, &f, 4); + return r; + } + /* Subnormal FP16: (-1)^sign * 2^-14 * (mant/1024) */ + float f = (float)mant / 1024.0f; + uint32_t bits; + memcpy(&bits, &f, 4); + bits = sign | (bits & 0x7FFFFFFFu); + float r; + memcpy(&r, &bits, 4); + return r * 6.103515625e-5f; /* 2^-14 */ + } + if (exp == 31) { + uint32_t f = sign | 0x7F800000u | (mant << 13); + float r; + memcpy(&r, &f, 4); + return r; + } + uint32_t f = sign | ((exp + 127u - 15u) << 23) | (mant << 13); + float r; + memcpy(&r, &f, 4); + return r; +} + +/* ---- Lloyd-Max 3-bit centroids for N(0, 1/128) ---- */ + +static const float DS4_PLANAR3_CENTROIDS[8] = { + -0.190685f, -0.117832f, -0.065717f, -0.021460f, + 0.021460f, 0.065717f, 0.117832f, 0.190685f, +}; + +/* ---- Givens rotation parameters (64 pairs per 128-dim block, reused across 4 blocks) ---- */ + +static const float DS4_PLANAR3_COS[256] = { + 0.7386546135f, 0.8607548475f,-0.7411674857f, 0.9674890637f,-0.7723053098f,-0.8056974411f,-0.0412844308f, 0.2707833052f, + 0.9315500855f, 0.6698185802f, 0.9167487621f,-0.8320636749f, 0.6818146110f,-0.9108457565f,-0.0559285842f,-0.9032276273f, + 0.7519487143f,-0.8941103816f,-0.1039871648f,-0.6961420774f,-0.1230370328f,-0.9328963161f,-0.2905603051f, 0.4910068214f, + 0.7889407277f,-0.1221836656f,-0.6316579580f, 0.3128163815f,-0.9563610554f, 0.9992509484f, 0.9540294409f, 0.8902468085f, + 0.7543080449f,-0.8664138913f,-0.5232898593f, 0.3621287644f,-0.8825117350f, 0.8234673142f,-0.9416025877f,-0.5480425358f, +-0.6644080281f,-0.6585279703f,-0.2460795939f, 0.9438471198f, 0.2427810431f,-0.1960992366f, 0.2403578013f,-0.8461306095f, + 0.0246123374f, 0.3372744620f, 0.9994974732f,-0.3494733870f, 0.7438930869f, 0.8452339768f,-0.6177822948f,-0.2662552595f, +-0.5457068086f,-0.9985070229f, 0.7757105827f, 0.6141811609f,-0.9805000424f, 0.5425475240f,-0.5663578510f,-0.4696439803f, + 0.7386546135f, 0.8607548475f,-0.7411674857f, 0.9674890637f,-0.7723053098f,-0.8056974411f,-0.0412844308f, 0.2707833052f, + 0.9315500855f, 0.6698185802f, 0.9167487621f,-0.8320636749f, 0.6818146110f,-0.9108457565f,-0.0559285842f,-0.9032276273f, + 0.7519487143f,-0.8941103816f,-0.1039871648f,-0.6961420774f,-0.1230370328f,-0.9328963161f,-0.2905603051f, 0.4910068214f, + 0.7889407277f,-0.1221836656f,-0.6316579580f, 0.3128163815f,-0.9563610554f, 0.9992509484f, 0.9540294409f, 0.8902468085f, + 0.7543080449f,-0.8664138913f,-0.5232898593f, 0.3621287644f,-0.8825117350f, 0.8234673142f,-0.9416025877f,-0.5480425358f, +-0.6644080281f,-0.6585279703f,-0.2460795939f, 0.9438471198f, 0.2427810431f,-0.1960992366f, 0.2403578013f,-0.8461306095f, + 0.0246123374f, 0.3372744620f, 0.9994974732f,-0.3494733870f, 0.7438930869f, 0.8452339768f,-0.6177822948f,-0.2662552595f, +-0.5457068086f,-0.9985070229f, 0.7757105827f, 0.6141811609f,-0.9805000424f, 0.5425475240f,-0.5663578510f,-0.4696439803f, + 0.7386546135f, 0.8607548475f,-0.7411674857f, 0.9674890637f,-0.7723053098f,-0.8056974411f,-0.0412844308f, 0.2707833052f, + 0.9315500855f, 0.6698185802f, 0.9167487621f,-0.8320636749f, 0.6818146110f,-0.9108457565f,-0.0559285842f,-0.9032276273f, + 0.7519487143f,-0.8941103816f,-0.1039871648f,-0.6961420774f,-0.1230370328f,-0.9328963161f,-0.2905603051f, 0.4910068214f, + 0.7889407277f,-0.1221836656f,-0.6316579580f, 0.3128163815f,-0.9563610554f, 0.9992509484f, 0.9540294409f, 0.8902468085f, + 0.7543080449f,-0.8664138913f,-0.5232898593f, 0.3621287644f,-0.8825117350f, 0.8234673142f,-0.9416025877f,-0.5480425358f, +-0.6644080281f,-0.6585279703f,-0.2460795939f, 0.9438471198f, 0.2427810431f,-0.1960992366f, 0.2403578013f,-0.8461306095f, + 0.0246123374f, 0.3372744620f, 0.9994974732f,-0.3494733870f, 0.7438930869f, 0.8452339768f,-0.6177822948f,-0.2662552595f, +-0.5457068086f,-0.9985070229f, 0.7757105827f, 0.6141811609f,-0.9805000424f, 0.5425475240f,-0.5663578510f,-0.4696439803f, + 0.7386546135f, 0.8607548475f,-0.7411674857f, 0.9674890637f,-0.7723053098f,-0.8056974411f,-0.0412844308f, 0.2707833052f, + 0.9315500855f, 0.6698185802f, 0.9167487621f,-0.8320636749f, 0.6818146110f,-0.9108457565f,-0.0559285842f,-0.9032276273f, + 0.7519487143f,-0.8941103816f,-0.1039871648f,-0.6961420774f,-0.1230370328f,-0.9328963161f,-0.2905603051f, 0.4910068214f, + 0.7889407277f,-0.1221836656f,-0.6316579580f, 0.3128163815f,-0.9563610554f, 0.9992509484f, 0.9540294409f, 0.8902468085f, + 0.7543080449f,-0.8664138913f,-0.5232898593f, 0.3621287644f,-0.8825117350f, 0.8234673142f,-0.9416025877f,-0.5480425358f, +-0.6644080281f,-0.6585279703f,-0.2460795939f, 0.9438471198f, 0.2427810431f,-0.1960992366f, 0.2403578013f,-0.8461306095f, + 0.0246123374f, 0.3372744620f, 0.9994974732f,-0.3494733870f, 0.7438930869f, 0.8452339768f,-0.6177822948f,-0.2662552595f, +-0.5457068086f,-0.9985070229f, 0.7757105827f, 0.6141811609f,-0.9805000424f, 0.5425475240f,-0.5663578510f,-0.4696439803f, +}; + +static const float DS4_PLANAR3_SIN[256] = { +-0.6740840673f,-0.5090196729f, 0.6713201404f,-0.2529129684f, 0.6352515221f,-0.5923272967f, 0.9991474152f,-0.9626403451f, +-0.3636130989f, 0.7425247431f,-0.3994642496f,-0.5546801090f,-0.7315250039f,-0.4127469361f,-0.9984347820f, 0.4291617870f, +-0.6592215896f,-0.4478466809f, 0.9945786595f,-0.7179040313f, 0.9924020767f, 0.3601450622f, 0.9568566680f,-0.8711557388f, + 0.6144692898f, 0.9925075173f, 0.7752471566f, 0.9498136044f,-0.2921875417f, 0.0386975110f,-0.2997128963f, 0.4554784000f, +-0.6565206647f,-0.4993265271f, 0.8521547318f,-0.9321280718f,-0.4702904224f,-0.5673637390f,-0.3367263079f, 0.8364504576f, +-0.7473700047f, 0.7525562644f,-0.9692496061f,-0.3303825557f,-0.9700810909f, 0.9805840850f,-0.9706843495f,-0.5329755545f, +-0.9996970892f, 0.9414063692f, 0.0316982083f, 0.9369462729f, 0.6682986617f,-0.5343964100f,-0.7863491774f,-0.9639025331f, +-0.8379761577f, 0.0546237342f,-0.6310887933f, 0.7891650796f,-0.1965190321f, 0.8400250673f,-0.8241594434f, 0.8828558922f, +-0.6740840673f,-0.5090196729f, 0.6713201404f,-0.2529129684f, 0.6352515221f,-0.5923272967f, 0.9991474152f,-0.9626403451f, +-0.3636130989f, 0.7425247431f,-0.3994642496f,-0.5546801090f,-0.7315250039f,-0.4127469361f,-0.9984347820f, 0.4291617870f, +-0.6592215896f,-0.4478466809f, 0.9945786595f,-0.7179040313f, 0.9924020767f, 0.3601450622f, 0.9568566680f,-0.8711557388f, + 0.6144692898f, 0.9925075173f, 0.7752471566f, 0.9498136044f,-0.2921875417f, 0.0386975110f,-0.2997128963f, 0.4554784000f, +-0.6565206647f,-0.4993265271f, 0.8521547318f,-0.9321280718f,-0.4702904224f,-0.5673637390f,-0.3367263079f, 0.8364504576f, +-0.7473700047f, 0.7525562644f,-0.9692496061f,-0.3303825557f,-0.9700810909f, 0.9805840850f,-0.9706843495f,-0.5329755545f, +-0.9996970892f, 0.9414063692f, 0.0316982083f, 0.9369462729f, 0.6682986617f,-0.5343964100f,-0.7863491774f,-0.9639025331f, +-0.8379761577f, 0.0546237342f,-0.6310887933f, 0.7891650796f,-0.1965190321f, 0.8400250673f,-0.8241594434f, 0.8828558922f, +-0.6740840673f,-0.5090196729f, 0.6713201404f,-0.2529129684f, 0.6352515221f,-0.5923272967f, 0.9991474152f,-0.9626403451f, +-0.3636130989f, 0.7425247431f,-0.3994642496f,-0.5546801090f,-0.7315250039f,-0.4127469361f,-0.9984347820f, 0.4291617870f, +-0.6592215896f,-0.4478466809f, 0.9945786595f,-0.7179040313f, 0.9924020767f, 0.3601450622f, 0.9568566680f,-0.8711557388f, + 0.6144692898f, 0.9925075173f, 0.7752471566f, 0.9498136044f,-0.2921875417f, 0.0386975110f,-0.2997128963f, 0.4554784000f, +-0.6565206647f,-0.4993265271f, 0.8521547318f,-0.9321280718f,-0.4702904224f,-0.5673637390f,-0.3367263079f, 0.8364504576f, +-0.7473700047f, 0.7525562644f,-0.9692496061f,-0.3303825557f,-0.9700810909f, 0.9805840850f,-0.9706843495f,-0.5329755545f, +-0.9996970892f, 0.9414063692f, 0.0316982083f, 0.9369462729f, 0.6682986617f,-0.5343964100f,-0.7863491774f,-0.9639025331f, +-0.8379761577f, 0.0546237342f,-0.6310887933f, 0.7891650796f,-0.1965190321f, 0.8400250673f,-0.8241594434f, 0.8828558922f, +-0.6740840673f,-0.5090196729f, 0.6713201404f,-0.2529129684f, 0.6352515221f,-0.5923272967f, 0.9991474152f,-0.9626403451f, +-0.3636130989f, 0.7425247431f,-0.3994642496f,-0.5546801090f,-0.7315250039f,-0.4127469361f,-0.9984347820f, 0.4291617870f, +-0.6592215896f,-0.4478466809f, 0.9945786595f,-0.7179040313f, 0.9924020767f, 0.3601450622f, 0.9568566680f,-0.8711557388f, + 0.6144692898f, 0.9925075173f, 0.7752471566f, 0.9498136044f,-0.2921875417f, 0.0386975110f,-0.2997128963f, 0.4554784000f, +-0.6565206647f,-0.4993265271f, 0.8521547318f,-0.9321280718f,-0.4702904224f,-0.5673637390f,-0.3367263079f, 0.8364504576f, +-0.7473700047f, 0.7525562644f,-0.9692496061f,-0.3303825557f,-0.9700810909f, 0.9805840850f,-0.9706843495f,-0.5329755545f, +-0.9996970892f, 0.9414063692f, 0.0316982083f, 0.9369462729f, 0.6682986617f,-0.5343964100f,-0.7863491774f,-0.9639025331f, +-0.8379761577f, 0.0546237342f,-0.6310887933f, 0.7891650796f,-0.1965190321f, 0.8400250673f,-0.8241594434f, 0.8828558922f, +}; + +/* ---- helpers ---- */ + +static int nearest_centroid(float val) { + int best = 0; + float best_d = fabsf(val - DS4_PLANAR3_CENTROIDS[0]); + for (int i = 1; i < 8; i++) { + float d = fabsf(val - DS4_PLANAR3_CENTROIDS[i]); + if (d < best_d) { best_d = d; best = i; } + } + return best; +} + +/* ---- quantize one 128-dim sub-block ---- */ + +static void quantize_block_128(const float *src, ds4_block_planar3 *blk, int rot_offset) { + float norm_sq = 0.0f; + for (int j = 0; j < DS4_PLANAR3_BLOCK_DIM; j++) + norm_sq += src[j] * src[j]; + float grp_norm = sqrtf(norm_sq); + float inv_norm = (grp_norm > 1e-10f) ? 1.0f / grp_norm : 0.0f; + + memset(blk->qs, 0, DS4_PLANAR3_BLOCK_DIM / 4); + memset(blk->signs, 0, DS4_PLANAR3_BLOCK_DIM / 8); + + float recon_sq = 0.0f; + const int n_pairs = DS4_PLANAR3_BLOCK_DIM / 2; + + for (int p = 0; p < n_pairs; p++) { + float v0 = src[p * 2] * inv_norm; + float v1 = src[p * 2 + 1] * inv_norm; + + float c = DS4_PLANAR3_COS[rot_offset + p]; + float s = DS4_PLANAR3_SIN[rot_offset + p]; + float r0 = c * v0 - s * v1; + float r1 = s * v0 + c * v1; + + int idx0 = nearest_centroid(r0); + int idx1 = nearest_centroid(r1); + + int j0 = p * 2; + int j1 = p * 2 + 1; + + blk->qs[j0 / 4] |= (uint8_t)((idx0 & 0x3) << ((j0 % 4) * 2)); + if (idx0 & 0x4) blk->signs[j0 / 8] |= (uint8_t)(1 << (j0 % 8)); + + blk->qs[j1 / 4] |= (uint8_t)((idx1 & 0x3) << ((j1 % 4) * 2)); + if (idx1 & 0x4) blk->signs[j1 / 8] |= (uint8_t)(1 << (j1 % 8)); + + recon_sq += DS4_PLANAR3_CENTROIDS[idx0] * DS4_PLANAR3_CENTROIDS[idx0] + + DS4_PLANAR3_CENTROIDS[idx1] * DS4_PLANAR3_CENTROIDS[idx1]; + } + + float recon_norm = sqrtf(recon_sq); + float corrected = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + blk->norm = fp32_to_fp16(corrected); +} + +/* ---- dequantize one 128-dim sub-block ---- */ + +static void dequantize_block_128(const ds4_block_planar3 *blk, float *dst, int rot_offset) { + float norm = fp16_to_fp32(blk->norm); + const int n_pairs = DS4_PLANAR3_BLOCK_DIM / 2; + + for (int p = 0; p < n_pairs; p++) { + int j0 = p * 2; + int j1 = p * 2 + 1; + + uint8_t low0 = (blk->qs[j0 / 4] >> ((j0 % 4) * 2)) & 0x3; + uint8_t hi0 = (blk->signs[j0 / 8] >> (j0 % 8)) & 0x1; + uint8_t idx0 = low0 | (hi0 << 2); + + uint8_t low1 = (blk->qs[j1 / 4] >> ((j1 % 4) * 2)) & 0x3; + uint8_t hi1 = (blk->signs[j1 / 8] >> (j1 % 8)) & 0x1; + uint8_t idx1 = low1 | (hi1 << 2); + + float q0 = DS4_PLANAR3_CENTROIDS[idx0]; + float q1 = DS4_PLANAR3_CENTROIDS[idx1]; + + float c = DS4_PLANAR3_COS[rot_offset + p]; + float s = DS4_PLANAR3_SIN[rot_offset + p]; + float f0 = c * q0 + s * q1; + float f1 = -s * q0 + c * q1; + + dst[j0] = f0 * norm; + dst[j1] = f1 * norm; + } +} + +/* ---- public API ---- */ + +size_t ds4_planar3_quantize_row(const float *src, ds4_row_planar3 *dst) { + for (int b = 0; b < DS4_PLANAR3_BLOCKS_512; b++) { + quantize_block_128( + src + b * DS4_PLANAR3_BLOCK_DIM, + &dst->blocks[b], + b * (DS4_PLANAR3_BLOCK_DIM / 2) + ); + } + return sizeof(ds4_row_planar3); +} + +void ds4_planar3_dequantize_row(const ds4_row_planar3 *src, float *dst) { + for (int b = 0; b < DS4_PLANAR3_BLOCKS_512; b++) { + dequantize_block_128( + &src->blocks[b], + dst + b * DS4_PLANAR3_BLOCK_DIM, + b * (DS4_PLANAR3_BLOCK_DIM / 2) + ); + } +} + +size_t ds4_planar3_quantize(const float *src, void *dst, + size_t nrows, size_t n_per_row) { + if (n_per_row != DS4_PLANAR3_BLOCK_DIM * DS4_PLANAR3_BLOCKS_512) return 0; + size_t total = 0; + for (size_t row = 0; row < nrows; row++) { + ds4_row_planar3 *out = (ds4_row_planar3 *)((char *)dst + row * sizeof(ds4_row_planar3)); + total += ds4_planar3_quantize_row(src + row * n_per_row, out); + } + return total; +} + +void ds4_planar3_dequantize(const void *src, float *dst, + size_t nrows, size_t n_per_row) { + if (n_per_row != DS4_PLANAR3_BLOCK_DIM * DS4_PLANAR3_BLOCKS_512) return; + for (size_t row = 0; row < nrows; row++) { + const ds4_row_planar3 *in = (const ds4_row_planar3 *) + ((const char *)src + row * sizeof(ds4_row_planar3)); + ds4_planar3_dequantize_row(in, dst + row * n_per_row); + } +} + +float ds4_planar3_roundtrip_cosine(const float *original, float *reconstructed, + size_t dim) { + float dot = 0.0f, n1 = 0.0f, n2 = 0.0f; + for (size_t i = 0; i < dim; i++) { + dot += original[i] * reconstructed[i]; + n1 += original[i] * original[i]; + n2 += reconstructed[i] * reconstructed[i]; + } + float denom = sqrtf(n1) * sqrtf(n2); + return (denom > 1e-10f) ? dot / denom : 0.0f; +} + +float ds4_planar3_roundtrip_mse(const float *original, float *reconstructed, + size_t dim) { + float mse = 0.0f; + for (size_t i = 0; i < dim; i++) { + float d = original[i] - reconstructed[i]; + mse += d * d; + } + return mse / (float)dim; +} diff --git a/ds4_planar_quant.h b/ds4_planar_quant.h new file mode 100644 index 000000000..fe00b8a08 --- /dev/null +++ b/ds4_planar_quant.h @@ -0,0 +1,70 @@ +#ifndef DS4_PLANAR_QUANT_H +#define DS4_PLANAR_QUANT_H + +#include +#include + +/* + * PlanarQuant: KV cache compression via 2D Givens rotation + Lloyd-Max. + * + * Adapted from planar-llama (experolk/planar-llama) ggml-planar-quant.c, + * modified for ds4's head_dim=512 and standalone use (no ggml dependency). + * + * Block layout (per 128-dim block): + * norm: uint16_t (FP16) — 2 bytes + * qs: uint8_t[32] — 2-bit quantized indices per element, 32 bytes + * signs: uint8_t[16] — high bit of 3-bit centroid index per element, 16 bytes + * Total: 50 bytes per 128-dim block + * + * For ds4 head_dim=512: 4 blocks per row = 200 bytes. + * Compared to FP16 (1024 bytes): 5.12x compression. + * + * Adapted from: experolk/planar-llama (MIT License) + * Method: PlanarQuant -- 2D Givens rotation + Lloyd-Max 3-bit centroids, + * block layout compatible with ggml block_planar3_0. + */ + +#define DS4_PLANAR3_BLOCK_DIM 128 +#define DS4_PLANAR3_BLOCKS_512 4 /* 512 / 128 */ + +/* One 128-dim block = 50 bytes. */ +typedef struct { + uint16_t norm; /* FP16 group norm */ + uint8_t qs[DS4_PLANAR3_BLOCK_DIM / 4]; /* 2-bit indices: 32 bytes */ + uint8_t signs[DS4_PLANAR3_BLOCK_DIM / 8]; /* high bit of 3-bit index: 16 bytes */ +} ds4_block_planar3; + +/* One 512-dim row = 4 blocks = 200 bytes. */ +typedef struct { + ds4_block_planar3 blocks[DS4_PLANAR3_BLOCKS_512]; +} ds4_row_planar3; + +/* Quantize a 512-dim FP32 row into Planar3 format. + * src must have 512 floats, dst must point to sizeof(ds4_row_planar3) bytes. + * Returns the compressed size in bytes. */ +size_t ds4_planar3_quantize_row(const float *src, ds4_row_planar3 *dst); + +/* Dequantize a Planar3 512-dim row back to FP32. + * dst must have space for 512 floats. */ +void ds4_planar3_dequantize_row(const ds4_row_planar3 *src, float *dst); + +/* Quantize nrows rows of n_per_row dims each. + * n_per_row must be 512. Returns total compressed bytes, or 0 if n_per_row != 512. */ +size_t ds4_planar3_quantize(const float *src, void *dst, + size_t nrows, size_t n_per_row); + +/* Dequantize nrows rows of n_per_row dims each. + * n_per_row must be 512. No-op if n_per_row != 512. */ +void ds4_planar3_dequantize(const void *src, float *dst, + size_t nrows, size_t n_per_row); + +/* Compute roundtrip quality metrics for one row. + * Returns cosine similarity between original and reconstructed vectors. */ +float ds4_planar3_roundtrip_cosine(const float *original, float *reconstructed, + size_t dim); + +/* Compute MSE between original and reconstructed. */ +float ds4_planar3_roundtrip_mse(const float *original, float *reconstructed, + size_t dim); + +#endif /* DS4_PLANAR_QUANT_H */ diff --git a/metal/dsv4_misc.metal b/metal/dsv4_misc.metal index 7e8cbd8a7..82d5396e0 100644 --- a/metal/dsv4_misc.metal +++ b/metal/dsv4_misc.metal @@ -55,7 +55,7 @@ struct ds4_metal_args_dsv4_indexed_attention { uint32_t window; uint32_t ratio; uint32_t comp_kv_f16; - uint32_t pad0; + uint32_t comp_kv_planar; uint64_t q_token_stride; uint64_t q_head_stride; uint64_t raw_row_stride; @@ -535,6 +535,188 @@ static inline void dsv4_attend_shared_h4_row_at( o0, o1, o2, o3); } +/* ---- Planar3 dequantize constants (matches ds4_planar_quant.c) ---- */ + +constant static const float planar3_centroids[8] = { + -0.190685f, -0.117832f, -0.065717f, -0.021460f, + 0.021460f, 0.065717f, 0.117832f, 0.190685f, +}; + +constant static const float planar3_cos[64] = { + 0.7386546135f, 0.8607548475f,-0.7411674857f, 0.9674890637f,-0.7723053098f,-0.8056974411f,-0.0412844308f, 0.2707833052f, + 0.9315500855f, 0.6698185802f, 0.9167487621f,-0.8320636749f, 0.6818146110f,-0.9108457565f,-0.0559285842f,-0.9032276273f, + 0.7519487143f,-0.8941103816f,-0.1039871648f,-0.6961420774f,-0.1230370328f,-0.9328963161f,-0.2905603051f, 0.4910068214f, + 0.7889407277f,-0.1221836656f,-0.6316579580f, 0.3128163815f,-0.9563610554f, 0.9992509484f, 0.9540294409f, 0.8902468085f, + 0.7543080449f,-0.8664138913f,-0.5232898593f, 0.3621287644f,-0.8825117350f, 0.8234673142f,-0.9416025877f,-0.5480425358f, +-0.6644080281f,-0.6585279703f,-0.2460795939f, 0.9438471198f, 0.2427810431f,-0.1960992366f, 0.2403578013f,-0.8461306095f, + 0.0246123374f, 0.3372744620f, 0.9994974732f,-0.3494733870f, 0.7438930869f, 0.8452339768f,-0.6177822948f,-0.2662552595f, +-0.5457068086f,-0.9985070229f, 0.7757105827f, 0.6141811609f,-0.9805000424f, 0.5425475240f,-0.5663578510f,-0.4696439803f, +}; + +constant static const float planar3_sin[64] = { +-0.6740840673f,-0.5090196729f, 0.6713201404f,-0.2529129684f, 0.6352515221f,-0.5923272967f, 0.9991474152f,-0.9626403451f, +-0.3636130989f, 0.7425247431f,-0.3994642496f,-0.5546801090f,-0.7315250039f,-0.4127469361f,-0.9984347820f, 0.4291617870f, +-0.6592215896f,-0.4478466809f, 0.9945786595f,-0.7179040313f, 0.9924020767f, 0.3601450622f, 0.9568566680f,-0.8711557388f, + 0.6144692898f, 0.9925075173f, 0.7752471566f, 0.9498136044f,-0.2921875417f, 0.0386975110f,-0.2997128963f, 0.4554784000f, +-0.6565206647f,-0.4993265271f, 0.8521547318f,-0.9321280718f,-0.4702904224f,-0.5673637390f,-0.3367263079f, 0.8364504576f, +-0.7473700047f, 0.7525562644f,-0.9692496061f,-0.3303825557f,-0.9700810909f, 0.9805840850f,-0.9706843495f,-0.5329755545f, +-0.9996970892f, 0.9414063692f, 0.0316982083f, 0.9369462729f, 0.6682986617f,-0.5343964100f,-0.7863491774f,-0.9639025331f, +-0.8379761577f, 0.0546237342f,-0.6310887933f, 0.7891650796f,-0.1965190321f, 0.8400250673f,-0.8241594434f, 0.8828558922f, +}; + +/* Planar3 block layout: 2 bytes norm + 32 bytes qs + 16 bytes signs = 50 bytes per 128-dim. + * 4 blocks per 512-dim row = 200 bytes. Cos/sin tables have 64 pairs, reused across 4 blocks. */ + +struct ds4_planar3_block { half norm; uint8_t qs[32]; uint8_t signs[16]; }; + +static inline void dsv4_dequant_planar3_row( + device const char *planar_base, + uint64_t row_stride, + uint row, + uint tid, + threadgroup half4 *kv_shared) { + /* Each thread dequantizes one half4 (4 elements) from the packed row. + * tid 0..127 covers all 512 dims. */ + device const ds4_planar3_block *blocks = (device const ds4_planar3_block *) + (planar_base + (uint64_t)row * row_stride); + + const uint blk_idx = tid / 32; /* 0..3 */ + const uint sub = tid % 32; /* 0..31 -> half4 offset within block */ + + device const ds4_planar3_block *blk = &blocks[blk_idx]; + const float norm = float(blk->norm); + + half4 result; + for (int g = 0; g < 4; g++) { + const uint j = sub * 4 + g; + const uint pair = j / 2; + const uint8_t qb = blk->qs[j / 4]; + const uint shift2 = (j % 4) * 2; + const uint8_t sb = blk->signs[j / 8]; + const uint shift_s = j % 8; + const uint idx = ((qb >> shift2) & 0x3) | (((sb >> shift_s) & 1) << 2); + const float raw = planar3_centroids[idx]; + + const uint j_other = pair * 2 + (1 - (j & 1)); + const uint8_t qb2 = blk->qs[j_other / 4]; + const uint shift2b = (j_other % 4) * 2; + const uint8_t sb2 = blk->signs[j_other / 8]; + const uint shift_sb = j_other % 8; + const uint idx2 = ((qb2 >> shift2b) & 0x3) | (((sb2 >> shift_sb) & 1) << 2); + const float raw2 = planar3_centroids[idx2]; + + const float c = planar3_cos[pair]; + const float s = planar3_sin[pair]; + float f; + if ((j & 1) == 0) { + f = c * raw + s * raw2; + } else { + f = -s * raw + c * raw2; + } + result[g] = (half)(f * norm); + } + kv_shared[tid] = result; +} + +/* ---- Planar3 GPU quantize: FP32 comp KV row → Planar3 blocks ---- */ + +constant static const float planar3_mid[7] = { + -0.154259f, -0.091775f, -0.043589f, 0.0f, 0.043589f, 0.091775f, 0.154259f, +}; + +kernel void kernel_planar3_quantize_row( + device const float *src, + device ds4_planar3_block *dst, + constant uint32_t &n_rows, + uint gid [[thread_position_in_grid]]) { + if (gid >= n_rows) return; + device const float *row = src + (uint64_t)gid * 512; + device ds4_planar3_block *blocks = dst + (uint64_t)gid * 4; + + for (uint blk = 0; blk < 4; blk++) { + device const float *in = row + blk * 128; + device ds4_planar3_block *b = &blocks[blk]; + + float norm_sq = 0.0f; + for (uint j = 0; j < 128; j++) norm_sq += in[j] * in[j]; + float grp_norm = sqrt(norm_sq); + float inv_norm = grp_norm > 1e-10f ? 1.0f / grp_norm : 0.0f; + + float rotated[128]; + for (uint p = 0; p < 64; p++) { + float x0 = in[p*2] * inv_norm; + float x1 = in[p*2+1] * inv_norm; + float c = planar3_cos[p], s = planar3_sin[p]; + rotated[p*2] = c * x0 - s * x1; + rotated[p*2+1] = s * x0 + c * x1; + } + + for (uint j = 0; j < 32; j++) b->qs[j] = 0; + for (uint j = 0; j < 16; j++) b->signs[j] = 0; + + float recon_sq = 0.0f; + for (uint j = 0; j < 128; j++) { + uint8_t idx; + float v = rotated[j]; + if (v < planar3_mid[0]) idx = 0; + else if (v < planar3_mid[1]) idx = 1; + else if (v < planar3_mid[2]) idx = 2; + else if (v < planar3_mid[3]) idx = 3; + else if (v < planar3_mid[4]) idx = 4; + else if (v < planar3_mid[5]) idx = 5; + else if (v < planar3_mid[6]) idx = 6; + else idx = 7; + b->qs[j/4] |= (idx & 0x3) << ((j % 4) * 2); + if (idx & 0x4) b->signs[j/8] |= (1 << (j % 8)); + recon_sq += planar3_centroids[idx] * planar3_centroids[idx]; + } + float recon_norm = sqrt(recon_sq); + b->norm = (half)(recon_norm > 1e-10f ? grp_norm / recon_norm : grp_norm); + } +} + +/* Dequantize Planar3 rows to F16 output buffer. One thread per row. + * Each thread dequantizes 4 blocks (512 dims) and writes 256 half values. */ +kernel void kernel_planar3_dequant_to_f16_rows( + device const ds4_planar3_block *src, + device half *dst, + constant uint32_t &n_rows, + constant uint64_t &dst_row_stride, + uint gid [[thread_position_in_grid]]) { + if (gid >= n_rows) return; + device const ds4_planar3_block *blocks = src + (uint64_t)gid * 4; + device half *out = dst + (uint64_t)gid * (dst_row_stride / sizeof(half)); + + for (uint blk = 0; blk < 4; blk++) { + device const ds4_planar3_block *b = &blocks[blk]; + float norm = float(b->norm); + + for (uint p = 0; p < 64; p++) { + uint j0 = p * 2; + uint j1 = p * 2 + 1; + + uint8_t low0 = (b->qs[j0 / 4] >> ((j0 % 4) * 2)) & 0x3; + uint8_t hi0 = (b->signs[j0 / 8] >> (j0 % 8)) & 0x1; + uint8_t idx0 = low0 | (hi0 << 2); + + uint8_t low1 = (b->qs[j1 / 4] >> ((j1 % 4) * 2)) & 0x3; + uint8_t hi1 = (b->signs[j1 / 8] >> (j1 % 8)) & 0x1; + uint8_t idx1 = low1 | (hi1 << 2); + + float q0 = planar3_centroids[idx0]; + float q1 = planar3_centroids[idx1]; + float c = planar3_cos[p]; + float s = planar3_sin[p]; + float f0 = c * q0 + s * q1; + float f1 = -s * q0 + c * q1; + + uint out_base = blk * 128 + p * 2; + out[out_base] = (half)(f0 * norm); + out[out_base + 1] = (half)(f1 * norm); + } + } +} + static inline half4 dsv4_load_cache_h4( device const char *kv, uint64_t row_stride, @@ -647,7 +829,13 @@ kernel void kernel_dsv4_indexed_mixed_attention_heads8( if ((uint)idx >= visible) { break; } - if (tid < 128) { + if (args.comp_kv_planar != 0u) { + if (tid < 128) dsv4_dequant_planar3_row(comp_kv, + args.comp_row_stride, + (uint)idx, + tid, + kv_shared); + } else if (tid < 128) { kv_shared[tid] = dsv4_load_cache_h4(comp_kv, args.comp_row_stride, (uint)idx, @@ -773,14 +961,35 @@ kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb16( if (n_rows == 0) { continue; } - for (uint off = (uint)tid; off < n_rows * 128u; off += 256u) { - const uint r = off >> 7; - const uint c = off & 127u; - kv_shared[off] = dsv4_load_cache_h4(comp_kv, - args.comp_row_stride, - rows[r], - c, - args.comp_kv_f16 != 0u); + if (args.comp_kv_planar != 0u) { + /* Dequant two rows per iteration: tid 0-127 handle even row, + * tid 128-255 handle odd row. All 256 threads participate. */ + for (uint r = 0; r < n_rows; r += 2) { + if (tid < 128) { + dsv4_dequant_planar3_row(comp_kv, + args.comp_row_stride, + rows[r], + tid, + &kv_shared[r * 128]); + } + if (r + 1 < n_rows && tid >= 128) { + dsv4_dequant_planar3_row(comp_kv, + args.comp_row_stride, + rows[r + 1], + tid - 128, + &kv_shared[(r + 1) * 128]); + } + } + } else { + for (uint off = (uint)tid; off < n_rows * 128u; off += 256u) { + const uint r = off >> 7; + const uint c = off & 127u; + kv_shared[off] = dsv4_load_cache_h4(comp_kv, + args.comp_row_stride, + rows[r], + c, + args.comp_kv_f16 != 0u); + } } threadgroup_barrier(mem_flags::mem_threadgroup); for (uint r = 0; r < n_rows; r++) { diff --git a/tests/planar_quant_test.c b/tests/planar_quant_test.c new file mode 100644 index 000000000..3ea3b6f4a --- /dev/null +++ b/tests/planar_quant_test.c @@ -0,0 +1,320 @@ +/* + * planar_quant_test.c -- unit tests for Planar3 512-dim quantization + * + * Build: + * cc -O2 -Wall -Wextra -std=c99 -I. -o tests/planar_quant_test \ + * tests/planar_quant_test.c ds4_planar_quant.c -lm + */ + +#include +#include +#include +#include +#include "ds4_planar_quant.h" + +#define DIM 512 +#define NROWS_BATCH 100 + +/* ---- test harness (matches ds4 test style) ---- */ + +static int expect(int cond, const char *msg) { + if (!cond) { + fprintf(stderr, "planar_quant_test: FAIL: %s\n", msg); + return 1; + } + return 0; +} + +/* ---- simple deterministic PRNG (xorshift32) ---- */ + +static uint32_t prng_state; + +static void prng_seed(uint32_t s) { + prng_state = s ? s : 1u; +} + +static float prng_float(void) { + /* xorshift32 */ + prng_state ^= prng_state << 13; + prng_state ^= prng_state >> 17; + prng_state ^= prng_state << 5; + /* map to [-1, 1) */ + return (float)((int32_t)prng_state) / (float)0x7FFFFFFF; +} + +/* ---- helpers ---- */ + +static float vec_norm(const float *v, int n) { + float s = 0.0f; + for (int i = 0; i < n; i++) s += v[i] * v[i]; + return sqrtf(s); +} + +/* ---- tests ---- */ + +static int test_block_size(void) { + int fail = 0; + fail += expect(sizeof(ds4_block_planar3) == 50, + "sizeof(ds4_block_planar3) should be 50"); + fail += expect(sizeof(ds4_row_planar3) == 200, + "sizeof(ds4_row_planar3) should be 200"); + + printf("test_block_size: block=%zu row=%zu -- %s\n", + sizeof(ds4_block_planar3), sizeof(ds4_row_planar3), + fail ? "FAIL" : "OK"); + return fail; +} + +static int test_roundtrip_basis_vector(void) { + float src[DIM] = {0}; + float dst[DIM]; + ds4_row_planar3 compressed; + + src[0] = 1.0f; /* e0 */ + + ds4_planar3_quantize_row(src, &compressed); + ds4_planar3_dequantize_row(&compressed, dst); + + float cos = ds4_planar3_roundtrip_cosine(src, dst, DIM); + float mse = ds4_planar3_roundtrip_mse(src, dst, DIM); + + /* Spike vectors (all energy in one dimension) are pathological for + 2-bit rotated quantization: Givens rotation spreads the energy + across many pairs, and coarse 3-bit centroids lose detail. + Cosine ~0.74 is expected; we verify roundtrip is non-degenerate. */ + int fail = expect(cos > 0.70f, + "basis vector cosine similarity should be > 0.70"); + + printf("test_roundtrip_basis_vector: cosine=%.6f mse=%.6f -- %s\n", + cos, mse, fail ? "FAIL" : "OK"); + return fail; +} + +static int test_roundtrip_random(void) { + const int N = 10; + float src[DIM], dst[DIM]; + ds4_row_planar3 compressed; + + prng_seed(123); + + float cos_min = 1.0f, cos_max = 0.0f, cos_sum = 0.0f; + float mse_sum = 0.0f; + int fail = 0; + + for (int v = 0; v < N; v++) { + for (int i = 0; i < DIM; i++) + src[i] = prng_float(); + + ds4_planar3_quantize_row(src, &compressed); + ds4_planar3_dequantize_row(&compressed, dst); + + float cos = ds4_planar3_roundtrip_cosine(src, dst, DIM); + float mse = ds4_planar3_roundtrip_mse(src, dst, DIM); + + if (cos < cos_min) cos_min = cos; + if (cos > cos_max) cos_max = cos; + cos_sum += cos; + mse_sum += mse; + } + + float cos_avg = cos_sum / N; + float mse_avg = mse_sum / N; + + fail += expect(cos_avg > 0.97f, + "random vector avg cosine similarity should be > 0.97"); + + printf("test_roundtrip_random: %d vectors, cosine min=%.3f avg=%.3f max=%.3f, MSE avg=%.6f -- %s\n", + N, cos_min, cos_avg, cos_max, mse_avg, fail ? "FAIL" : "OK"); + return fail; +} + +static int test_roundtrip_large_norm(void) { + float src[DIM], dst[DIM]; + ds4_row_planar3 compressed; + + for (int i = 0; i < DIM; i++) + src[i] = sinf((float)i * 0.1f + 0.5f) * 10.0f; + + float orig_norm = vec_norm(src, DIM); + + ds4_planar3_quantize_row(src, &compressed); + ds4_planar3_dequantize_row(&compressed, dst); + + float cos = ds4_planar3_roundtrip_cosine(src, dst, DIM); + float mse = ds4_planar3_roundtrip_mse(src, dst, DIM); + float recon_norm = vec_norm(dst, DIM); + + float norm_ratio = recon_norm / orig_norm; + int fail = 0; + fail += expect(cos > 0.95f, + "large-norm cosine similarity should be > 0.95"); + fail += expect(norm_ratio > 0.90f && norm_ratio < 1.10f, + "reconstructed norm should be within 10% of original"); + + printf("test_roundtrip_large_norm: cosine=%.6f mse=%.6f " + "norm_orig=%.3f norm_recon=%.3f ratio=%.3f -- %s\n", + cos, mse, orig_norm, recon_norm, norm_ratio, fail ? "FAIL" : "OK"); + return fail; +} + +static int test_batch_quantize(void) { + float src[NROWS_BATCH * DIM]; + float dst[NROWS_BATCH * DIM]; + ds4_row_planar3 compressed[NROWS_BATCH]; + + prng_seed(456); + for (int i = 0; i < NROWS_BATCH * DIM; i++) + src[i] = prng_float(); + + size_t total = ds4_planar3_quantize(src, compressed, NROWS_BATCH, DIM); + ds4_planar3_dequantize(compressed, dst, NROWS_BATCH, DIM); + + float cos_sum = 0.0f; + int fail = 0; + + for (int r = 0; r < NROWS_BATCH; r++) { + float cos = ds4_planar3_roundtrip_cosine( + src + r * DIM, dst + r * DIM, DIM); + cos_sum += cos; + } + + float cos_avg = cos_sum / NROWS_BATCH; + fail += expect(cos_avg > 0.97f, + "batch avg cosine similarity should be > 0.97"); + fail += expect(total == NROWS_BATCH * sizeof(ds4_row_planar3), + "batch total compressed size mismatch"); + + printf("test_batch_quantize: %d rows, avg cosine=%.4f, total_bytes=%zu -- %s\n", + NROWS_BATCH, cos_avg, total, fail ? "FAIL" : "OK"); + return fail; +} + +static int test_compression_ratio(void) { + size_t fp16_bytes = DIM * 2; /* 2 bytes per FP16 */ + size_t planar_bytes = sizeof(ds4_row_planar3); + float ratio = (float)fp16_bytes / (float)planar_bytes; + + int fail = expect(planar_bytes == 200, + "compressed row should be exactly 200 bytes"); + + printf("test_compression_ratio: fp16=%zu planar=%zu ratio=%.2fx -- %s\n", + fp16_bytes, planar_bytes, ratio, fail ? "FAIL" : "OK"); + return fail; +} + +static int test_block_independence(void) { + float src[DIM] = {0}; + float dst[DIM]; + ds4_row_planar3 compressed; + + /* Block 0: e0 = [1,0,0,...,0], blocks 1-3: all zeros */ + src[0] = 1.0f; + /* all other dims already 0 */ + + ds4_planar3_quantize_row(src, &compressed); + ds4_planar3_dequantize_row(&compressed, dst); + + /* Check block 3 (indices 384..511) is near-zero */ + float block3_sq = 0.0f; + for (int i = 384; i < 512; i++) + block3_sq += dst[i] * dst[i]; + float block3_norm = sqrtf(block3_sq); + + int fail = expect(block3_norm < 1e-3f, + "block 3 should be near-zero when only block 0 has data"); + + printf("test_block_independence: block3_norm=%.6e -- %s\n", + block3_norm, fail ? "FAIL" : "OK"); + return fail; +} + +static int test_batch_dim_mismatch(void) { + float src[DIM]; + float dst[DIM]; + ds4_row_planar3 compressed[1]; + + for (int i = 0; i < DIM; i++) src[i] = 1.0f; + + /* quantize with wrong dim should return 0 */ + size_t ret = ds4_planar3_quantize(src, compressed, 1, 511); + int fail = expect(ret == 0, + "quantize with n_per_row=511 should return 0"); + + /* dequantize with wrong dim should be no-op: fill dst with sentinel */ + memset(compressed, 0, sizeof(ds4_row_planar3)); + for (int i = 0; i < DIM; i++) dst[i] = 42.0f; + ds4_planar3_dequantize(compressed, dst, 1, 511); + for (int i = 0; i < DIM; i++) { + if (dst[i] != 42.0f) { + fail += expect(0, "dequantize with n_per_row=511 should not write dst"); + break; + } + } + + printf("test_batch_dim_mismatch: quantize_ret=%zu dst_unchanged=%s -- %s\n", + ret, fail ? "no" : "yes", fail ? "FAIL" : "OK"); + return fail; +} + +static int test_zero_norm(void) { + float src[DIM] = {0}; + float dst[DIM]; + ds4_row_planar3 comp; + memset(&comp, 0, sizeof(comp)); + + ds4_planar3_quantize_row(src, &comp); + ds4_planar3_dequantize_row(&comp, dst); + + int fail = 0; + for (int i = 0; i < DIM; i++) { + if (dst[i] != 0.0f) { + fail += expect(0, "zero-norm dequant should produce zeros"); + break; + } + } + printf("test_zero_norm: all_zeros=%s -- %s\n", fail ? "no" : "yes", fail ? "FAIL" : "OK"); + return fail; +} + +static int test_single_element(void) { + float src[DIM] = {0}; + src[42] = 3.7f; + float dst[DIM]; + ds4_row_planar3 comp; + + ds4_planar3_quantize_row(src, &comp); + ds4_planar3_dequantize_row(&comp, dst); + + float cos = ds4_planar3_roundtrip_cosine(src, dst, DIM); + int fail = expect(cos > 0.5f, + "single-element vector should preserve direction"); + if (!fail) { + printf("test_single_element: cos=%f val[42]=%f -- OK\n", cos, dst[42]); + } + return fail; +} + +/* ---- main ---- */ + +int main(void) { + int fail = 0; + + fail += test_block_size(); + fail += test_roundtrip_basis_vector(); + fail += test_roundtrip_random(); + fail += test_roundtrip_large_norm(); + fail += test_batch_quantize(); + fail += test_compression_ratio(); + fail += test_block_independence(); + fail += test_batch_dim_mismatch(); + fail += test_zero_norm(); + fail += test_single_element(); + + if (fail) { + fprintf(stderr, "\nplanar_quant_test: %d test(s) FAILED\n", fail); + return 1; + } + + printf("\nplanar_quant_test: all tests passed\n"); + return 0; +} diff --git a/tools/planar_eval.c b/tools/planar_eval.c new file mode 100644 index 000000000..0496e86b7 --- /dev/null +++ b/tools/planar_eval.c @@ -0,0 +1,597 @@ +/* + * planar_eval.c - offline quality evaluator for Planar3 KV-cache rows. + * + * The tool either generates synthetic 512-dim rows or reads dumped rows from a + * binary file, applies Planar3 quantize/dequantize, and reports row-level, + * attention-score, and softmax-output drift metrics. + * + * Binary input format: + * uint32_t nrows + * uint32_t n_per_row, must be 512 + * float32 rows[nrows][512] + */ + +#include "ds4_planar_quant.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#define DIM 512 +#define DEFAULT_ROWS 10000 +#define DEFAULT_QUERIES 8 +#define TOPK_MAX 10 + +static uint64_t pcg_state_; + +static inline uint32_t pcg32(void) { + uint64_t s = pcg_state_; + pcg_state_ = s * 6364136223846793005ULL + 1442695040888963407ULL; + uint32_t x = (uint32_t)(((s >> 18) ^ s) >> 27); + int r = (int)(s >> 59); + return (x >> r) | (x << ((-r) & 31)); +} + +static inline float randf01(void) { + return (float)((double)pcg32() / 4294967296.0); +} + +static inline float randf_uniform(void) { + return randf01() * 2.0f - 1.0f; +} + +static inline float randf_normal(void) { + const float u1 = (float)(((double)pcg32() + 1.0) / 4294967297.0); + const float u2 = randf01(); + return sqrtf(-2.0f * logf(u1)) * cosf(6.283185307179586f * u2); +} + +static bool checked_mul_size(size_t a, size_t b, size_t *out) { + if (a != 0 && b > SIZE_MAX / a) return false; + *out = a * b; + return true; +} + +static bool parse_int_arg(const char *s, const char *opt, int minv, int maxv, int *out) { + char *end = NULL; + errno = 0; + long v = strtol(s, &end, 10); + if (errno || !end || *end || v < minv || v > maxv) { + fprintf(stderr, "planar_eval: invalid value for %s: %s\n", opt, s); + return false; + } + *out = (int)v; + return true; +} + +static bool parse_u64_arg(const char *s, const char *opt, uint64_t *out) { + char *end = NULL; + errno = 0; + unsigned long long v = strtoull(s, &end, 10); + if (errno || !end || *end) { + fprintf(stderr, "planar_eval: invalid value for %s: %s\n", opt, s); + return false; + } + *out = (uint64_t)v; + return true; +} + +static float *alloc_rows(int nrows) { + size_t count, bytes; + if (!checked_mul_size((size_t)nrows, DIM, &count) || + !checked_mul_size(count, sizeof(float), &bytes)) { + return NULL; + } + return (float *)malloc(bytes); +} + +static int load_input_file(const char *path, float **data_out, int *nrows_out) { + FILE *f = fopen(path, "rb"); + if (!f) { + fprintf(stderr, "planar_eval: cannot open input file %s: %s\n", + path, strerror(errno)); + return 0; + } + + uint32_t hdr_nrows = 0; + uint32_t hdr_dim = 0; + if (fread(&hdr_nrows, sizeof(hdr_nrows), 1, f) != 1 || + fread(&hdr_dim, sizeof(hdr_dim), 1, f) != 1) { + fprintf(stderr, "planar_eval: failed to read header from %s\n", path); + fclose(f); + return 0; + } + if (hdr_dim != DIM) { + fprintf(stderr, "planar_eval: input file has n_per_row=%u, expected %d\n", + hdr_dim, DIM); + fclose(f); + return 0; + } + if (hdr_nrows == 0 || hdr_nrows > (uint32_t)INT_MAX) { + fprintf(stderr, "planar_eval: input file has invalid nrows=%u\n", hdr_nrows); + fclose(f); + return 0; + } + + int nrows = (int)hdr_nrows; + float *data = alloc_rows(nrows); + if (!data) { + fprintf(stderr, "planar_eval: failed to allocate %d rows x %d dims\n", + nrows, DIM); + fclose(f); + return 0; + } + + const size_t total = (size_t)nrows * DIM; + if (fread(data, sizeof(float), total, f) != total) { + fprintf(stderr, "planar_eval: failed to read %zu float32 values from %s\n", + total, path); + free(data); + fclose(f); + return 0; + } + + fclose(f); + *data_out = data; + *nrows_out = nrows; + return 1; +} + +static void gen_random_normal(float *buf, int nrows, uint64_t seed) { + pcg_state_ = seed; + for (size_t i = 0, n = (size_t)nrows * DIM; i < n; i++) + buf[i] = randf_normal(); +} + +static void gen_random_uniform(float *buf, int nrows, uint64_t seed) { + pcg_state_ = seed; + for (size_t i = 0, n = (size_t)nrows * DIM; i < n; i++) + buf[i] = randf_uniform(); +} + +static void gen_sparse(float *buf, int nrows, uint64_t seed) { + pcg_state_ = seed; + for (size_t i = 0, n = (size_t)nrows * DIM; i < n; i++) + buf[i] = (pcg32() % 10 == 0) ? randf_normal() : 0.0f; +} + +static void gen_ds4_like(float *buf, int nrows, uint64_t seed) { + pcg_state_ = seed; + for (int r = 0; r < nrows; r++) { + float *row = buf + (size_t)r * DIM; + float base[DIM]; + memset(base, 0, sizeof(base)); + + const int nwaves = 3 + (int)(pcg32() % 3); + for (int w = 0; w < nwaves; w++) { + const float freq = 0.01f + randf01() * 0.1f; + const float amp = 0.5f + randf01() * 0.5f; + const float phase = randf01() * 6.283185307179586f; + for (int j = 0; j < DIM; j++) + base[j] += amp * sinf(6.283185307179586f * freq * (float)j + phase); + } + + float ms = 0.0f; + for (int j = 0; j < DIM; j++) ms += base[j] * base[j]; + const float rms = sqrtf(ms / (float)DIM + 1e-6f); + const float learned_norm = 0.9f + randf01() * 0.2f; + for (int j = 0; j < DIM; j++) + row[j] = base[j] / rms * learned_norm; + + for (int j = 0; j < DIM; j += 2) { + if ((pcg32() >> 16) & 1u) { + row[j] = -row[j]; + row[j + 1] = -row[j + 1]; + } + } + + for (int j = 0; j < DIM; j++) + row[j] += randf_normal() * 0.01f; + } +} + +static int cmp_float(const void *a, const void *b) { + const float fa = *(const float *)a; + const float fb = *(const float *)b; + return (fa > fb) - (fa < fb); +} + +static float percentile_sorted(const float *arr, int n, float pct) { + const float idx = (float)(n - 1) * pct; + const int lo = (int)idx; + const int hi = lo + 1; + if (hi >= n) return arr[n - 1]; + const float frac = idx - (float)lo; + return arr[lo] * (1.0f - frac) + arr[hi] * frac; +} + +static float vec_cosine(const float *a, const float *b, int n) { + double dot = 0.0; + double na = 0.0; + double nb = 0.0; + for (int i = 0; i < n; i++) { + dot += (double)a[i] * b[i]; + na += (double)a[i] * a[i]; + nb += (double)b[i] * b[i]; + } + const double denom = sqrt(na) * sqrt(nb); + return denom > 1e-20 ? (float)(dot / denom) : 0.0f; +} + +static float rel_l2_diff(const float *ref, const float *got, int n) { + double num = 0.0; + double den = 0.0; + for (int i = 0; i < n; i++) { + const double d = (double)ref[i] - got[i]; + num += d * d; + den += (double)ref[i] * ref[i]; + } + return den > 1e-20 ? (float)(sqrt(num) / sqrt(den)) : 0.0f; +} + +static void fill_query(float *query) { + double ss = 0.0; + for (int j = 0; j < DIM; j++) { + query[j] = randf_normal(); + ss += (double)query[j] * query[j]; + } + const float scale = ss > 1e-20 ? (float)(sqrt((double)DIM / ss)) : 1.0f; + for (int j = 0; j < DIM; j++) query[j] *= scale; +} + +static void dot_scores(const float *rows, int nrows, const float *query, float *scores) { + const float attn_scale = 1.0f / sqrtf((float)DIM); + for (int r = 0; r < nrows; r++) { + const float *row = rows + (size_t)r * DIM; + float s = 0.0f; + for (int j = 0; j < DIM; j++) s += query[j] * row[j]; + scores[r] = s * attn_scale; + } +} + +static double pearson_corr(const float *a, const float *b, int n) { + double ma = 0.0; + double mb = 0.0; + for (int i = 0; i < n; i++) { + ma += a[i]; + mb += b[i]; + } + ma /= n; + mb /= n; + + double cov = 0.0; + double va = 0.0; + double vb = 0.0; + for (int i = 0; i < n; i++) { + const double da = a[i] - ma; + const double db = b[i] - mb; + cov += da * db; + va += da * da; + vb += db * db; + } + const double denom = sqrt(va * vb); + return denom > 1e-20 ? cov / denom : 0.0; +} + +static void topk_indices(const float *scores, int n, int k, int *idx) { + for (int i = 0; i < k; i++) idx[i] = -1; + for (int i = 0; i < n; i++) { + for (int j = 0; j < k; j++) { + if (idx[j] < 0 || scores[i] > scores[idx[j]]) { + for (int m = k - 1; m > j; m--) idx[m] = idx[m - 1]; + idx[j] = i; + break; + } + } + } +} + +static int topk_overlap(const int *a, const int *b, int k) { + int overlap = 0; + for (int i = 0; i < k; i++) { + for (int j = 0; j < k; j++) { + if (a[i] == b[j]) { + overlap++; + break; + } + } + } + return overlap; +} + +static void softmax_output(const float *scores, const float *rows, int nrows, float *out) { + float max_score = scores[0]; + for (int r = 1; r < nrows; r++) { + if (scores[r] > max_score) max_score = scores[r]; + } + + memset(out, 0, DIM * sizeof(float)); + double denom = 0.0; + for (int r = 0; r < nrows; r++) { + const float w = expf(scores[r] - max_score); + const float *row = rows + (size_t)r * DIM; + denom += w; + for (int j = 0; j < DIM; j++) out[j] += w * row[j]; + } + + if (denom > 0.0) { + const float inv = (float)(1.0 / denom); + for (int j = 0; j < DIM; j++) out[j] *= inv; + } +} + +static void print_usage(const char *prog) { + printf("Usage: %s [options]\n", prog); + printf("Options:\n"); + printf(" --rows N Number of rows to generate (default: %d)\n", DEFAULT_ROWS); + printf(" --queries N Number of random attention probes (default: %d)\n", DEFAULT_QUERIES); + printf(" --mode MODE random_normal, random_uniform, sparse, ds4_like, ds4_realistic\n"); + printf(" default: random_normal; ds4_realistic is an alias for ds4_like\n"); + printf(" --input FILE Load rows from binary file instead of generating\n"); + printf(" format: uint32 nrows, uint32 n_per_row, float32 data\n"); + printf(" --seed N PRNG seed (default: 42)\n"); + printf(" --help Show this help\n"); +} + +int main(int argc, char **argv) { + int nrows = DEFAULT_ROWS; + int nqueries = DEFAULT_QUERIES; + const char *mode_str = "random_normal"; + const char *input_file = NULL; + uint64_t seed = 42; + + for (int i = 1; i < argc; i++) { + if (!strcmp(argv[i], "--rows")) { + if (++i >= argc || !parse_int_arg(argv[i], "--rows", 1, INT_MAX / DIM, &nrows)) + return 1; + } else if (!strcmp(argv[i], "--queries")) { + if (++i >= argc || !parse_int_arg(argv[i], "--queries", 1, 1000000, &nqueries)) + return 1; + } else if (!strcmp(argv[i], "--mode")) { + if (++i >= argc) { + fprintf(stderr, "planar_eval: --mode requires an argument\n"); + return 1; + } + mode_str = argv[i]; + } else if (!strcmp(argv[i], "--input")) { + if (++i >= argc) { + fprintf(stderr, "planar_eval: --input requires an argument\n"); + return 1; + } + input_file = argv[i]; + } else if (!strcmp(argv[i], "--seed")) { + if (++i >= argc || !parse_u64_arg(argv[i], "--seed", &seed)) + return 1; + } else if (!strcmp(argv[i], "--help")) { + print_usage(argv[0]); + return 0; + } else { + fprintf(stderr, "planar_eval: unknown option: %s\n", argv[i]); + print_usage(argv[0]); + return 1; + } + } + + float *data_orig = NULL; + if (input_file) { + if (!load_input_file(input_file, &data_orig, &nrows)) return 1; + } else { + data_orig = alloc_rows(nrows); + if (!data_orig) { + fprintf(stderr, "planar_eval: failed to allocate %d rows x %d dims\n", + nrows, DIM); + return 1; + } + if (!strcmp(mode_str, "random_normal")) { + gen_random_normal(data_orig, nrows, seed); + } else if (!strcmp(mode_str, "random_uniform")) { + gen_random_uniform(data_orig, nrows, seed); + } else if (!strcmp(mode_str, "sparse")) { + gen_sparse(data_orig, nrows, seed); + } else if (!strcmp(mode_str, "ds4_like") || !strcmp(mode_str, "ds4_realistic")) { + gen_ds4_like(data_orig, nrows, seed); + } else { + fprintf(stderr, "planar_eval: unknown mode: %s\n", mode_str); + print_usage(argv[0]); + free(data_orig); + return 1; + } + } + + float *data_recon = alloc_rows(nrows); + ds4_row_planar3 *compressed = (ds4_row_planar3 *)malloc((size_t)nrows * sizeof(ds4_row_planar3)); + float *cosine_arr = (float *)malloc((size_t)nrows * sizeof(float)); + float *mse_arr = (float *)malloc((size_t)nrows * sizeof(float)); + float *maxerr_arr = (float *)malloc((size_t)nrows * sizeof(float)); + float *relnorm_arr = (float *)malloc((size_t)nrows * sizeof(float)); + float *score_orig = (float *)malloc((size_t)nrows * sizeof(float)); + float *score_recon = (float *)malloc((size_t)nrows * sizeof(float)); + if (!data_recon || !compressed || !cosine_arr || !mse_arr || + !maxerr_arr || !relnorm_arr || !score_orig || !score_recon) { + fprintf(stderr, "planar_eval: allocation failure for %d rows\n", nrows); + free(data_orig); + free(data_recon); + free(compressed); + free(cosine_arr); + free(mse_arr); + free(maxerr_arr); + free(relnorm_arr); + free(score_orig); + free(score_recon); + return 1; + } + + const size_t bytes_written = + ds4_planar3_quantize(data_orig, compressed, (size_t)nrows, DIM); + ds4_planar3_dequantize(compressed, data_recon, (size_t)nrows, DIM); + + double cos_mean = 0.0; + double mse_mean = 0.0; + double maxerr_mean = 0.0; + double relnorm_mean = 0.0; + for (int r = 0; r < nrows; r++) { + const float *orig = data_orig + (size_t)r * DIM; + const float *rec = data_recon + (size_t)r * DIM; + + cosine_arr[r] = ds4_planar3_roundtrip_cosine(orig, (float *)rec, DIM); + mse_arr[r] = ds4_planar3_roundtrip_mse(orig, (float *)rec, DIM); + + float maxerr = 0.0f; + float n1 = 0.0f; + float n2 = 0.0f; + for (int j = 0; j < DIM; j++) { + const float d = fabsf(orig[j] - rec[j]); + if (d > maxerr) maxerr = d; + n1 += orig[j] * orig[j]; + n2 += rec[j] * rec[j]; + } + maxerr_arr[r] = maxerr; + n1 = sqrtf(n1); + n2 = sqrtf(n2); + relnorm_arr[r] = n1 > 1e-10f ? fabsf(n1 - n2) / n1 : 0.0f; + + cos_mean += cosine_arr[r]; + mse_mean += mse_arr[r]; + maxerr_mean += maxerr_arr[r]; + relnorm_mean += relnorm_arr[r]; + } + cos_mean /= nrows; + mse_mean /= nrows; + maxerr_mean /= nrows; + relnorm_mean /= nrows; + + qsort(cosine_arr, (size_t)nrows, sizeof(float), cmp_float); + qsort(mse_arr, (size_t)nrows, sizeof(float), cmp_float); + qsort(maxerr_arr, (size_t)nrows, sizeof(float), cmp_float); + qsort(relnorm_arr, (size_t)nrows, sizeof(float), cmp_float); + + double corr_sum = 0.0; + double score_abs_sum = 0.0; + double score_sq_sum = 0.0; + float score_max_diff = 0.0f; + int top1_agree = 0; + int topk_overlap_sum = 0; + double vonly_cos_sum = 0.0; + double vonly_rel_l2_sum = 0.0; + double full_cos_sum = 0.0; + double full_rel_l2_sum = 0.0; + const int topk = nrows < TOPK_MAX ? nrows : TOPK_MAX; + + float query[DIM]; + float out_orig[DIM]; + float out_vonly[DIM]; + float out_full[DIM]; + int top_orig[TOPK_MAX]; + int top_recon[TOPK_MAX]; + + pcg_state_ = seed ^ 0x9E3779B97F4A7C15ULL; + for (int q = 0; q < nqueries; q++) { + fill_query(query); + dot_scores(data_orig, nrows, query, score_orig); + dot_scores(data_recon, nrows, query, score_recon); + + corr_sum += pearson_corr(score_orig, score_recon, nrows); + for (int r = 0; r < nrows; r++) { + const float d = fabsf(score_orig[r] - score_recon[r]); + score_abs_sum += d; + score_sq_sum += (double)d * d; + if (d > score_max_diff) score_max_diff = d; + } + + topk_indices(score_orig, nrows, topk, top_orig); + topk_indices(score_recon, nrows, topk, top_recon); + if (top_orig[0] == top_recon[0]) top1_agree++; + topk_overlap_sum += topk_overlap(top_orig, top_recon, topk); + + softmax_output(score_orig, data_orig, nrows, out_orig); + softmax_output(score_orig, data_recon, nrows, out_vonly); + softmax_output(score_recon, data_recon, nrows, out_full); + vonly_cos_sum += vec_cosine(out_orig, out_vonly, DIM); + vonly_rel_l2_sum += rel_l2_diff(out_orig, out_vonly, DIM); + full_cos_sum += vec_cosine(out_orig, out_full, DIM); + full_rel_l2_sum += rel_l2_diff(out_orig, out_full, DIM); + } + + const double score_count = (double)nrows * (double)nqueries; + const double score_mae = score_abs_sum / score_count; + const double score_rmse = sqrt(score_sq_sum / score_count); + + printf("\n=== Planar3 Quality Evaluation ===\n"); + printf("Source: %s | Mode: %s | Rows: %d | Dim: %d | Queries: %d | Seed: %llu\n", + input_file ? input_file : "synthetic", + input_file ? "input" : mode_str, + nrows, DIM, nqueries, (unsigned long long)seed); + if (!input_file && (!strcmp(mode_str, "ds4_like") || !strcmp(mode_str, "ds4_realistic"))) { + printf("Note: ds4_like is synthetic. Use --input with dumped compressed-KV rows for real evidence.\n"); + } + + printf("\nRow roundtrip cosine:\n"); + printf(" min=%.4f mean=%.4f median=%.4f p99=%.4f max=%.4f\n", + cosine_arr[0], (float)cos_mean, + percentile_sorted(cosine_arr, nrows, 0.5f), + percentile_sorted(cosine_arr, nrows, 0.99f), + cosine_arr[nrows - 1]); + + printf("\nRow roundtrip MSE per element:\n"); + printf(" min=%.3e mean=%.3e median=%.3e p99=%.3e max=%.3e\n", + mse_arr[0], (float)mse_mean, + percentile_sorted(mse_arr, nrows, 0.5f), + percentile_sorted(mse_arr, nrows, 0.99f), + mse_arr[nrows - 1]); + + printf("\nMax element error:\n"); + printf(" min=%.4f mean=%.4f median=%.4f p99=%.4f max=%.4f\n", + maxerr_arr[0], (float)maxerr_mean, + percentile_sorted(maxerr_arr, nrows, 0.5f), + percentile_sorted(maxerr_arr, nrows, 0.99f), + maxerr_arr[nrows - 1]); + + printf("\nRelative norm error:\n"); + printf(" min=%.3e mean=%.3e median=%.3e p99=%.3e max=%.3e\n", + relnorm_arr[0], (float)relnorm_mean, + percentile_sorted(relnorm_arr, nrows, 0.5f), + percentile_sorted(relnorm_arr, nrows, 0.99f), + relnorm_arr[nrows - 1]); + + printf("\nAttention score drift (%d random queries):\n", nqueries); + printf(" corr_mean=%.4f mae=%.4f rmse=%.4f max_diff=%.4f\n", + (float)(corr_sum / nqueries), + (float)score_mae, + (float)score_rmse, + score_max_diff); + printf(" top1_agree=%d/%d top%d_overlap=%.2f/%d\n", + top1_agree, nqueries, topk, + (float)topk_overlap_sum / (float)nqueries, topk); + + printf("\nSoftmax output drift:\n"); + printf(" V-only: cos_mean=%.4f rel_l2_mean=%.4f\n", + (float)(vonly_cos_sum / nqueries), + (float)(vonly_rel_l2_sum / nqueries)); + printf(" K+V path: cos_mean=%.4f rel_l2_mean=%.4f\n", + (float)(full_cos_sum / nqueries), + (float)(full_rel_l2_sum / nqueries)); + + const size_t compressed_bytes = sizeof(ds4_row_planar3); + const size_t fp16_bytes = (size_t)DIM * sizeof(uint16_t); + printf("\nCompression: %zu bytes/row (%.2fx vs FP16 %zu bytes), total=%zu bytes\n", + compressed_bytes, + (double)fp16_bytes / (double)compressed_bytes, + fp16_bytes, + bytes_written); + + free(data_orig); + free(data_recon); + free(compressed); + free(cosine_arr); + free(mse_arr); + free(maxerr_arr); + free(relnorm_arr); + free(score_orig); + free(score_recon); + return 0; +}