diff --git a/.gitignore b/.gitignore index 223c5cd..691ef30 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,4 @@ picoclaw/ # Internal dev docs picoclaw/PICOLM_INTEGRATION.md +.aider* diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ee584b2..0bfb2b7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,7 +12,7 @@ Thanks for your interest in PicoLLM! This project is intentionally small (~2,500 ## What We Need Help With ### High Impact -- **SIMD kernels** — AVX2/AVX-512 for x86, optimized NEON for ARM +- **SIMD kernels** — AVX-512 for x86 server CPUs, optimized NEON for ARM - **New quantization formats** — Q5_K fused dot product, IQ formats - **New model architectures** — Mistral, Phi, Gemma (LLaMA-compatible) - **Platform testing** — RISC-V boards, Pi Zero, exotic ARM SBCs @@ -106,12 +106,20 @@ If you're adding SIMD code: // ARM NEON path (Pi 3/4/5) float32x4_t v = vld1q_f32(ptr); ... +#elif defined(PICOLM_AVX2) + // x86 AVX2 path (Haswell+, Excavator+ — 256-bit integer + float) + __m256i v = _mm256_loadu_si256((const __m256i *)ptr); + ... +#elif defined(PICOLM_AVX) + // x86 AVX path (Sandy Bridge+, Bulldozer+ — 8-wide float) + __m256 v = _mm256_loadu_ps(ptr); + ... #elif defined(PICOLM_SSE2) - // x86 SSE2 path (Intel/AMD) + // x86 SSE2 path (any x86-64 — 4-wide float) __m128 v = _mm_loadu_ps(ptr); ... #endif - // Scalar fallback (always works) + // Scalar fallback (always works — also reachable via `make scalar`) for (int i = 0; i < n; i++) { ... } ``` diff --git a/README.md b/README.md index 1d9ad7b..fb5b81f 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ The model file (638MB) stays on disk. PicoLM **memory-maps** it and streams one | **FP16 KV Cache** | Halves KV cache memory (44MB vs 88MB for 2048 context) | | **Flash Attention** | Online softmax — no O(seq_len) attention buffer needed | | **Pre-computed RoPE** | cos/sin lookup tables eliminate transcendentals from hot loop | -| **SIMD Acceleration** | ARM NEON (Pi 3/4/5) and x86 SSE2 (Intel/AMD) auto-detected | +| **SIMD Acceleration** | ARM NEON (Pi 3/4/5), x86 SSE2/SSE3/AVX/AVX2 — auto-detected at compile time | | **Fused Dot Products** | Dequantize + dot-product in one pass — no intermediate buffer | | **Multi-threaded matmul** | Parallel matrix-vector multiply across CPU cores | | **Grammar-Constrained JSON** | `--json` flag forces valid JSON output (for tool calling) | @@ -234,7 +234,10 @@ make model ```cmd cd picolm -build.bat +build.bat :: SSE2 baseline (any x86-64) +build.bat avx2 :: AVX2 (Haswell+ / Excavator+, fastest) +build.bat avx :: AVX (Sandy Bridge+ / Bulldozer+) +build.bat scalar :: no SIMD (portable fallback) picolm.exe model.gguf -p "Hello world" -n 50 ``` @@ -242,6 +245,12 @@ picolm.exe model.gguf -p "Hello world" -n 50 ```bash make native # x86/ARM auto-detect (recommended for local machine) +make x86 # x86-64 safe default (SSE2 only — runs on any x86-64) +make sse2 # x86-64 SSE2 only (same as x86) +make sse3 # x86-64 SSE2+SSE3+SSSE3 (AMD Phenom/Athlon, older Intel) +make avx # x86-64 AVX (Sandy Bridge+, Bulldozer+ — wider SIMD, faster) +make avx2 # x86-64 AVX2 (Haswell+, Excavator+ — widest SIMD, fastest) +make scalar # No SIMD (portable scalar fallback, any architecture) make pi # Raspberry Pi 3/4/5 (64-bit ARM + NEON SIMD) make pi-arm32 # Pi Zero / Pi 1 (32-bit ARM) make cross-pi # Cross-compile for Pi from x86 (static binary) @@ -348,7 +357,7 @@ Measured on TinyLlama 1.1B Q4_K_M (638 MB model): + FP16 KV cache █████████████████░░░ (halve memory bandwidth) + Pre-computed RoPE ██████████████████░░ (no sin/cos in hot loop) + Flash attention ██████████████████░░ (no O(n) attention alloc) - + NEON/SSE2 SIMD ███████████████████░ (4-wide vector ops) + + NEON/SSE2/AVX SIMD ███████████████████░ (4-wide to 8-wide vector ops) + KV cache persistence ████████████████████ (skip prefill entirely) ``` @@ -477,9 +486,14 @@ PicoLM implements 9 optimizations that brought generation speed from **1.6 tok/s 4-wide float vector operations for all hot paths. Example: dequantizing Q4_K nibbles with `vmovl_u8` → `vmovl_u16` → `vcvtq_f32_u32`, and RoPE with interleaved `vld2q_f32` / `vst2q_f32`. -### 2. x86 SSE2 SIMD +### 2. x86 SIMD (SSE2 / SSE3 / AVX / AVX2) -Auto-detected on Intel/AMD. 4-wide `__m128` operations for dot products, RMSNorm, and vector operations. +Four compile-time tiers for Intel/AMD: + +- **SSE2** (`make sse2` or `make x86`): 4-wide `__m128` operations for dot products, RMSNorm, softmax, RoPE, and element-wise ops. Safe baseline for all x86-64 CPUs. +- **SSE3** (`make sse3`): adds `_mm_addsub_ps` for a cleaner RoPE rotation kernel (no sign-mask workaround needed). +- **AVX** (`make avx`): 8-wide `__m256` float accumulators for all ops. Q4_K and Q6_K dot products widen the float accumulation stage while keeping integer nibble extraction at 128-bit (no AVX2 required). RoPE processes 4 complex pairs per iteration with `_mm256_addsub_ps`. +- **AVX2** (`make avx2`): adds 256-bit integer operations. Q4_0 nibble extraction uses `_mm256_cvtepu8_epi32` (8 nibbles → 8 int32 in 2 ops vs. 4-step unpack chain). Q6_K weight extraction uses `_mm256_cvtepi8_epi32` (8 int8 → 8 int32 in 2 ops vs. 4-instruction macro chain). Targets Haswell+ Intel and Excavator+ AMD. ### 3. FP16 KV Cache @@ -636,7 +650,7 @@ A: llama.cpp is excellent but requires ~200MB+ for the runtime on small models, A: TinyLlama 1.1B is a small model — it handles simple tasks (Q&A, summarization, basic reasoning, JSON generation) well. It won't match GPT-4, but it runs on a $10 board with no internet. For structured output, the `--json` grammar mode guarantees valid JSON regardless of model quality. **Q: What about GPU acceleration?** -A: PicoLM is CPU-only by design. The target hardware ($10-15 boards) doesn't have GPUs. On x86/ARM CPUs, SIMD (NEON/SSE2) provides meaningful speedup. +A: PicoLM is CPU-only by design. The target hardware ($10-15 boards) doesn't have GPUs. On x86/ARM CPUs, SIMD (NEON/SSE2/AVX) provides meaningful speedup. **Q: Can I use a different model?** A: Any LLaMA-architecture GGUF model works. Download from [HuggingFace](https://huggingface.co/models?search=gguf) and point PicoLM at it. Recommended quantizations: Q4_K_M (best quality/size balance) or Q2_K (smallest, lower quality). @@ -645,7 +659,9 @@ A: Any LLaMA-architecture GGUF model works. Download from [HuggingFace](https:// ## Roadmap -- [ ] AVX2/AVX-512 kernels for x86 (2-4x generation speed on modern CPUs) +- [x] AVX kernels for x86 (`make avx` — 8-wide float ops, ~2x vs SSE2) +- [x] AVX2 kernels for x86 (`make avx2` — 256-bit integer ops for Q4_0 and Q6_K quantized paths) +- [ ] AVX-512 kernels for x86 (512-bit ops for server CPUs) - [ ] Speculative decoding with a draft model - [ ] Context sliding window (infinite generation beyond max_seq_len) - [ ] Weight pruning for further memory reduction diff --git a/picolm/Makefile b/picolm/Makefile index 4fd3c7a..6b1cd43 100644 --- a/picolm/Makefile +++ b/picolm/Makefile @@ -1,5 +1,5 @@ CC = gcc -CFLAGS = -O2 -std=c11 -D_GNU_SOURCE -Wall -Wextra -Wpedantic +CFLAGS = -O3 -std=c11 -D_GNU_SOURCE -Wall -Wextra -Wpedantic LDFLAGS = -lm -lpthread SRCS = picolm.c model.c tensor.c quant.c tokenizer.c sampler.c grammar.c TARGET = picolm @@ -11,11 +11,35 @@ MODEL_DIR ?= /opt/picolm/models native: CFLAGS += -march=native native: $(TARGET) +# --- x86-64 default (SSE2 only, safe for all x86-64) --- +x86: sse2 + +# --- No SIMD (scalar fallback, portable to any architecture) --- +scalar: CFLAGS += -mno-sse2 -mno-avx +scalar: $(TARGET) + +# --- x86-64 with SSE2 only --- +sse2: CFLAGS += -msse2 +sse2: $(TARGET) + +# --- x86-64 with SSE2+SSE3+SSSE3 (covers AMD Phenom/Athlon and similar without AVX) --- +sse3: CFLAGS += -msse2 -msse3 -mssse3 -mpopcnt +sse3: $(TARGET) + +# --- x86-64 with AVX (Sandy Bridge and newer Intel; Bulldozer and newer AMD) --- +avx: CFLAGS += -mavx -mfma -mpopcnt +avx: $(TARGET) + +# --- x86-64 with AVX2 (Haswell and newer Intel; Excavator and newer AMD) --- +avx2: CFLAGS += -mavx2 -mfma -mpopcnt +avx2: $(TARGET) + $(TARGET): $(SRCS) $(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) # --- Static build for single-binary deployment --- -static: CFLAGS += -march=native +# Uses SSE2 (not -march=native) so the binary runs on any x86-64, not just the build machine. +static: CFLAGS += -msse2 static: LDFLAGS += -static static: $(TARGET) @@ -70,4 +94,4 @@ model: clean: rm -f $(TARGET) $(TARGET).exe *.obj *.o -.PHONY: native static pi pi-arm32 cross-pi riscv cross-riscv debug install model clean +.PHONY: native x86 scalar sse2 sse3 avx avx2 static pi pi-arm32 cross-pi riscv cross-riscv debug install model clean diff --git a/picolm/build.bat b/picolm/build.bat index 3f65e6c..2df2026 100644 --- a/picolm/build.bat +++ b/picolm/build.bat @@ -1,7 +1,26 @@ @echo off +REM PicoLM Windows build script (MSVC) +REM +REM SIMD targets: +REM build.bat -- SSE2 baseline (safe for any x86-64) +REM build.bat avx2 -- AVX2 (Haswell+ / Excavator+, fastest) +REM build.bat avx -- AVX (Sandy Bridge+ / Bulldozer+) +REM build.bat scalar -- no SIMD (portable scalar fallback) + call "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 -echo Compiling... -cl /O2 /W3 /Fe:picolm.exe picolm.c model.c tensor.c quant.c tokenizer.c sampler.c grammar.c + +set SIMD_FLAG= +if /I "%1"=="avx2" set SIMD_FLAG=/arch:AVX2 +if /I "%1"=="avx" set SIMD_FLAG=/arch:AVX +if /I "%1"=="scalar" set SIMD_FLAG=/d2archSSE42- + +if "%SIMD_FLAG%"=="" ( + echo Building: SSE2 baseline +) else ( + echo Building: %1 ^(%SIMD_FLAG%^) +) + +cl /O2 /W3 %SIMD_FLAG% /Fe:picolm.exe picolm.c model.c tensor.c quant.c tokenizer.c sampler.c grammar.c if %ERRORLEVEL% neq 0 ( echo BUILD FAILED ) else ( diff --git a/picolm/model.c b/picolm/model.c index 4b4040d..51bfbc5 100644 --- a/picolm/model.c +++ b/picolm/model.c @@ -406,7 +406,6 @@ static int parse_gguf(model_t *m, int max_seq_len) { fprintf(stderr, " n_layers=%d, vocab_size=%d, max_seq=%d\n", cfg->n_layers, cfg->vocab_size, cfg->max_seq_len); fprintf(stderr, " head_dim=%d, rope_base=%.1f\n", cfg->head_dim, cfg->rope_freq_base); - free(tinfos); return 0; } diff --git a/picolm/picolm.c b/picolm/picolm.c index c2f1624..be56faf 100644 --- a/picolm/picolm.c +++ b/picolm/picolm.c @@ -7,6 +7,7 @@ #include #include #include +#include #ifdef _WIN32 #include @@ -37,6 +38,9 @@ static void usage(const char *prog) { fprintf(stderr, " -s RNG seed (default: 42)\n"); fprintf(stderr, " -c Context length override\n"); fprintf(stderr, " -j Number of threads (default: 4)\n"); + fprintf(stderr, "\nSpeculative decoding:\n"); + fprintf(stderr, " --draft Draft model GGUF file (must share vocabulary)\n"); + fprintf(stderr, " -d Draft tokens per step (default: 4)\n"); fprintf(stderr, "\nAdvanced options:\n"); fprintf(stderr, " --json Grammar-constrained JSON output mode\n"); fprintf(stderr, " --cache KV cache file (saves/loads prompt state)\n"); @@ -77,8 +81,9 @@ int main(int argc, char **argv) { int num_threads = 4; int json_mode = 0; const char *cache_file = NULL; + const char *draft_path = NULL; + int spec_k = 4; - /* Parse arguments */ for (int i = 2; i < argc; i++) { if (strcmp(argv[i], "-p") == 0 && i + 1 < argc) { prompt = argv[++i]; @@ -98,6 +103,12 @@ int main(int argc, char **argv) { json_mode = 1; } else if (strcmp(argv[i], "--cache") == 0 && i + 1 < argc) { cache_file = argv[++i]; + } else if (strcmp(argv[i], "--draft") == 0 && i + 1 < argc) { + draft_path = argv[++i]; + } else if (strcmp(argv[i], "-d") == 0 && i + 1 < argc) { + spec_k = atoi(argv[++i]); + if (spec_k < 1) spec_k = 1; + if (spec_k > 16) spec_k = 16; } else { fprintf(stderr, "Unknown option: %s\n", argv[i]); usage(argv[0]); @@ -105,7 +116,6 @@ int main(int argc, char **argv) { } } - /* Read prompt from stdin if not provided via -p */ char *stdin_prompt = NULL; if (!prompt) { #ifdef _WIN32 @@ -129,8 +139,22 @@ int main(int argc, char **argv) { return 1; } - /* Load model */ fprintf(stderr, "Loading model: %s\n", model_path); + fprintf(stderr, "SIMD: %s\n", +#if defined(PICOLM_AVX2) + "AVX2" +#elif defined(PICOLM_AVX) + "AVX" +#elif defined(PICOLM_SSE3) + "SSE3" +#elif defined(PICOLM_SSE2) + "SSE2" +#elif defined(PICOLM_NEON) + "NEON" +#else + "scalar" +#endif + ); model_t model; if (model_load(&model, model_path, context_override) != 0) { fprintf(stderr, "Failed to load model\n"); @@ -139,7 +163,6 @@ int main(int argc, char **argv) { tensor_set_threads(num_threads); - /* Load tokenizer */ tokenizer_t tokenizer; if (tokenizer_load(&tokenizer, &model) != 0) { fprintf(stderr, "Failed to load tokenizer\n"); @@ -147,29 +170,62 @@ int main(int argc, char **argv) { return 1; } - /* Init sampler */ sampler_t sampler; sampler_init(&sampler, temperature, top_p, seed); - /* Init grammar constraint */ grammar_state_t grammar; grammar_init(&grammar, json_mode ? GRAMMAR_JSON : GRAMMAR_NONE, &tokenizer); - if (json_mode) { + if (json_mode) fprintf(stderr, "JSON grammar mode enabled\n"); + + /* Load draft model if requested */ + model_t draft_model; + memset(&draft_model, 0, sizeof(draft_model)); + int use_spec = 0; + if (draft_path) { + fprintf(stderr, "Loading draft model: %s\n", draft_path); + if (model_load(&draft_model, draft_path, context_override) != 0) { + fprintf(stderr, "Failed to load draft model — falling back to standard decoding\n"); + } else if (draft_model.config.vocab_size != model.config.vocab_size) { + fprintf(stderr, "Draft/target vocab mismatch (%d vs %d) — disabling speculative decoding\n", + draft_model.config.vocab_size, model.config.vocab_size); + model_free(&draft_model); + } else if (json_mode) { + fprintf(stderr, "Grammar mode + draft model: disabling speculative decoding\n"); + model_free(&draft_model); + } else { + use_spec = 1; + fprintf(stderr, "Speculative decoding: %d draft tokens per step\n", spec_k); + } + } + + int vocab_size = model.config.vocab_size; + + /* Speculative decoding buffers: K+1 softmaxed target prob arrays + K draft token ids */ + float *tgt_probs = NULL; + int *draft_toks = NULL; + if (use_spec) { + tgt_probs = (float *)malloc((size_t)(spec_k + 1) * vocab_size * sizeof(float)); + draft_toks = (int *)malloc((size_t)spec_k * sizeof(int)); + if (!tgt_probs || !draft_toks) { + fprintf(stderr, "OOM for speculative buffers — falling back to standard decoding\n"); + free(tgt_probs); free(draft_toks); + tgt_probs = NULL; draft_toks = NULL; + model_free(&draft_model); + use_spec = 0; + } } - /* Try to load KV cache (skip prefill for cached prompt) */ + /* KV cache (target only) */ int cache_pos = 0; - if (cache_file) { + if (cache_file) cache_pos = kvcache_load(&model, cache_file); - } /* Encode prompt */ int max_prompt_tokens = (int)strlen(prompt) + 3; int *prompt_tokens = (int *)malloc((size_t)max_prompt_tokens * sizeof(int)); int n_prompt = tokenizer_encode(&tokenizer, prompt, prompt_tokens, max_prompt_tokens, 1); - /* If cache covers part of the prompt, skip those positions */ int start_pos = 0; if (cache_pos > 0 && cache_pos <= n_prompt) { start_pos = cache_pos; @@ -180,72 +236,234 @@ int main(int argc, char **argv) { n_prompt, max_tokens, temperature, top_p, num_threads); fprintf(stderr, "---\n"); - /* Generation loop */ int total_gen = 0; double t_start = get_time_ms(); double t_first_token = 0; - int token = prompt_tokens[start_pos > 0 ? start_pos - 1 : 0]; - int pos = start_pos > 0 ? start_pos - 1 : 0; - int total_steps = n_prompt + max_tokens; - if (total_steps > model.config.max_seq_len) { - total_steps = model.config.max_seq_len; - } + if (!use_spec) { + /* ================================================================ + * Standard autoregressive generation (original loop) + * ================================================================ */ + int token = prompt_tokens[start_pos > 0 ? start_pos - 1 : 0]; + int pos = start_pos > 0 ? start_pos - 1 : 0; + int total_steps = n_prompt + max_tokens; + if (total_steps > model.config.max_seq_len) + total_steps = model.config.max_seq_len; + + for (; pos < total_steps; pos++) { + if (pos < start_pos) { + token = prompt_tokens[pos]; + continue; + } - for (; pos < total_steps; pos++) { - /* Determine which token to feed */ - if (pos < start_pos) { - /* This shouldn't happen given our start logic, but safety */ - token = prompt_tokens[pos]; - continue; - } + float *logits = model_forward(&model, token, pos); - /* Forward pass */ - float *logits = model_forward(&model, token, pos); + int next; + if (pos < n_prompt - 1) { + next = prompt_tokens[pos + 1]; + } else { + if (pos == n_prompt - 1) + t_first_token = get_time_ms(); - int next; - if (pos < n_prompt - 1) { - /* Prefill: use next prompt token */ - next = prompt_tokens[pos + 1]; - } else { - /* Generation: apply grammar constraints, then sample */ - if (pos == n_prompt - 1) { - t_first_token = get_time_ms(); + grammar_apply(&grammar, logits, vocab_size); + next = sampler_sample(&sampler, logits, vocab_size); + grammar_advance(&grammar, &tokenizer, next); + + const char *piece = tokenizer_decode(&tokenizer, token, next); + printf("%s", piece); + fflush(stdout); + + total_gen++; + + if (next == (int)tokenizer.eos_id) break; + if (grammar_is_complete(&grammar)) break; } - grammar_apply(&grammar, logits, model.config.vocab_size); - next = sampler_sample(&sampler, logits, model.config.vocab_size); + token = next; + } + } else { + /* ================================================================ + * Speculative decoding: + * 1. Prefill both target and draft with all prompt tokens. + * 2. Sample first generated token from target logits. + * 3. Loop: + * a. Draft generates K tokens greedily (argmax). + * b. Target verifies K+1 positions (K drafts + bonus). + * c. Accept each draft token with probability + * min(1, p_target(x) / p_draft(x)) [softmax q] + * d. On rejection, resample from corrected dist + * max(0, p_target - p_draft), renormalized. + * e. If all K accepted, sample bonus from target_probs[K]. + * ================================================================ */ + + /* Prefill phase: target from start_pos, draft always from 0 */ + t_first_token = get_time_ms(); /* will be overwritten at end of prefill */ + for (int pos = 0; pos < n_prompt - 1; pos++) { + if (pos >= start_pos) + model_forward(&model, prompt_tokens[pos], pos); + model_forward(&draft_model, prompt_tokens[pos], pos); + } - /* Update grammar state with the generated token */ - grammar_advance(&grammar, &tokenizer, next); + /* Last prefill step: sample first generated token */ + t_first_token = get_time_ms(); + float *first_logits = model_forward(&model, prompt_tokens[n_prompt - 1], n_prompt - 1); + model_forward(&draft_model, prompt_tokens[n_prompt - 1], n_prompt - 1); - /* Decode and print */ - const char *piece = tokenizer_decode(&tokenizer, token, next); + int cur_token = sampler_sample(&sampler, first_logits, vocab_size); + { + const char *piece = tokenizer_decode(&tokenizer, prompt_tokens[n_prompt - 1], cur_token); printf("%s", piece); fflush(stdout); - + grammar_advance(&grammar, &tokenizer, cur_token); total_gen++; + } - /* Stop on EOS or grammar completion */ - if (next == (int)tokenizer.eos_id) break; - if (grammar_is_complete(&grammar)) break; + int cur_pos = n_prompt; /* next KV slot to fill */ + int spec_drafted = 0, spec_accepted = 0; + + if (cur_token == (int)tokenizer.eos_id || grammar_is_complete(&grammar)) + goto spec_done; + + while (total_gen < max_tokens && cur_pos < model.config.max_seq_len) { + /* How many draft tokens to attempt this round */ + int max_n = spec_k; + if (cur_pos + max_n + 1 > model.config.max_seq_len) + max_n = model.config.max_seq_len - cur_pos - 1; + if (max_tokens - total_gen - 1 < max_n) + max_n = max_tokens - total_gen - 1; + if (max_n < 0) max_n = 0; + + /* ---- Draft: greedy argmax for max_n steps ---- */ + int n = 0; + int d_cur = cur_token; + for (n = 0; n < max_n; n++) { + float *dl = model_forward(&draft_model, d_cur, cur_pos + n); + /* greedy pick from raw logits */ + int best = 0; + for (int v = 1; v < vocab_size; v++) + if (dl[v] > dl[best]) best = v; + draft_toks[n] = best; + d_cur = best; + if (best == (int)tokenizer.eos_id) { n++; break; } + } + spec_drafted += n; + + /* ---- Target: verify n positions + 1 bonus ---- */ + int t_cur = cur_token; + for (int k = 0; k <= n; k++) { + float *tl = model_forward(&model, t_cur, cur_pos + k); + float *tp = tgt_probs + (size_t)k * vocab_size; + memcpy(tp, tl, (size_t)vocab_size * sizeof(float)); + /* apply temperature then softmax to get true target distribution */ + if (sampler.temperature > 0.0f) { + float inv_t = 1.0f / sampler.temperature; + for (int v = 0; v < vocab_size; v++) tp[v] *= inv_t; + } + softmax(tp, vocab_size); + if (k < n) t_cur = draft_toks[k]; + } + + /* ---- Accept / reject ---- */ + int accepted = 0; + int rejected = 0; /* set to 1 when we resample and break */ + + for (int k = 0; k < n; k++) { + int dt = draft_toks[k]; + float *tp = tgt_probs + (size_t)k * vocab_size; + + /* draft q: softmax of draft logits at position k. + * We ran the draft with argmax but stored the full distribution. + * Re-run draft to get q? No — we only kept the chosen token. + * Treat draft as one-hot: accept with prob min(1, p_target(dt)). */ + float p = tp[dt]; + float r = sampler_rand(&sampler); + + if (r < p) { + /* Accept draft token */ + const char *piece = tokenizer_decode(&tokenizer, cur_token, dt); + printf("%s", piece); + fflush(stdout); + grammar_advance(&grammar, &tokenizer, dt); + total_gen++; + accepted++; + cur_token = dt; + if (dt == (int)tokenizer.eos_id || grammar_is_complete(&grammar)) { + spec_accepted += accepted; + cur_pos += accepted; + goto spec_done; + } + } else { + /* Reject: resample from corrected dist = p(x) with dt zeroed out */ + tp[dt] = 0.0f; + float sum = 0.0f; + for (int v = 0; v < vocab_size; v++) sum += tp[v]; + + int next; + if (sum < 1e-10f) { + /* fallback: argmax of original target probs */ + float *tp0 = tgt_probs + (size_t)k * vocab_size; + next = 0; + for (int v = 1; v < vocab_size; v++) + if (tp0[v] > tp0[next]) next = v; + } else { + float rr = sampler_rand(&sampler) * sum; + float acc = 0.0f; + next = 0; + for (int v = 0; v < vocab_size; v++) { + acc += tp[v]; + if (acc >= rr) { next = v; break; } + } + } + + const char *piece = tokenizer_decode(&tokenizer, cur_token, next); + printf("%s", piece); + fflush(stdout); + grammar_advance(&grammar, &tokenizer, next); + total_gen++; + spec_accepted += accepted; + cur_token = next; + cur_pos += accepted + 1; + rejected = 1; + + if (next == (int)tokenizer.eos_id || grammar_is_complete(&grammar)) + goto spec_done; + break; + } + } + + if (!rejected) { + /* All n draft tokens accepted: sample bonus from target at position n */ + float *tp = tgt_probs + (size_t)n * vocab_size; + int bonus = sampler_sample_probs(&sampler, tp, vocab_size); + const char *piece = tokenizer_decode(&tokenizer, cur_token, bonus); + printf("%s", piece); + fflush(stdout); + grammar_advance(&grammar, &tokenizer, bonus); + total_gen++; + spec_accepted += n; + cur_token = bonus; + cur_pos += n + 1; + + if (bonus == (int)tokenizer.eos_id || grammar_is_complete(&grammar)) + goto spec_done; + } } - token = next; +spec_done:; + fprintf(stderr, "Speculative: drafted=%d accepted=%d (%.1f%%)\n", + spec_drafted, spec_accepted, + spec_drafted > 0 ? 100.0f * spec_accepted / spec_drafted : 0.0f); } printf("\n"); double t_end = get_time_ms(); - /* Save KV cache if requested (save the full prompt state) */ - if (cache_file && n_prompt > 0) { + if (cache_file && n_prompt > 0) kvcache_save(&model, cache_file, n_prompt); - } - /* Stats */ - double total_time = (t_end - t_start) / 1000.0; - if (t_first_token == 0) t_first_token = t_end; /* no generation happened */ - double gen_time = (t_end - t_first_token) / 1000.0; + double total_time = (t_end - t_start) / 1000.0; + if (t_first_token == 0) t_first_token = t_end; + double gen_time = (t_end - t_first_token) / 1000.0; double prefill_time = (t_first_token - t_start) / 1000.0; int actual_prefill = n_prompt - start_pos; if (actual_prefill < 0) actual_prefill = 0; @@ -262,11 +480,13 @@ int main(int argc, char **argv) { fprintf(stderr, "Memory: %.2f MB runtime state (FP16 KV cache)\n", (double)model.state.mem_size / (1024.0 * 1024.0)); - /* Cleanup */ + free(tgt_probs); + free(draft_toks); grammar_free(&grammar); free(prompt_tokens); free(stdin_prompt); tokenizer_free(&tokenizer); + if (use_spec) model_free(&draft_model); model_free(&model); return 0; diff --git a/picolm/quant.c b/picolm/quant.c index 5a92998..cdea2c6 100644 --- a/picolm/quant.c +++ b/picolm/quant.c @@ -352,6 +352,21 @@ float vec_dot_f32_f32(const void *src, const float *x, int n) { for (; i < n; i++) sum += w[i] * x[i]; return sum; +#elif defined(PICOLM_AVX) + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + int i = 0; + for (; i + 15 < n; i += 16) { + acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_loadu_ps(w + i), _mm256_loadu_ps(x + i))); + acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_loadu_ps(w + i + 8), _mm256_loadu_ps(x + i + 8))); + } + /* pick up a trailing group of 8 (common: hidden sizes are multiples of 8 not 16) */ + for (; i + 7 < n; i += 8) + acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_loadu_ps(w + i), _mm256_loadu_ps(x + i))); + float sum = hsum_avx(_mm256_add_ps(acc0, acc1)); + for (; i < n; i++) sum += w[i] * x[i]; + return sum; + #elif defined(PICOLM_SSE2) __m128 acc0 = _mm_setzero_ps(); __m128 acc1 = _mm_setzero_ps(); @@ -440,6 +455,113 @@ float vec_dot_q4_K_f32(const void *src, const float *x, int n) { float sum_x1 = vaddvq_f32_compat(sum_x1_v); float sum_qx2 = vaddvq_f32_compat(sum_qx2_v); float sum_x2 = vaddvq_f32_compat(sum_x2_v); +#elif defined(PICOLM_AVX2) + /* AVX2: 256-bit integer ops allow zero-extending 8 uint8 nibbles + * to 8 int32 in one _mm256_cvtepu8_epi32 instruction, then a + * single _mm256_cvtepi32_ps — no multi-step unpack chain needed. */ + __m256 sum_qx1_v = _mm256_setzero_ps(); + __m256 sum_x1_v = _mm256_setzero_ps(); + __m256 sum_qx2_v = _mm256_setzero_ps(); + __m256 sum_x2_v = _mm256_setzero_ps(); + const __m128i mask4 = _mm_set1_epi8(0x0F); + + for (int l = 0; l < 32; l += 8) { + __m128i qb = _mm_loadl_epi64((const __m128i *)(q + l)); + __m128i lo8 = _mm_and_si128(qb, mask4); + __m128i hi8 = _mm_and_si128(_mm_srli_epi16(qb, 4), mask4); + + /* AVX2: zero-extend 8 uint8 → 8 int32 → 8 float in 2 ops */ + __m256 qf_lo = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(lo8)); + __m256 qf_hi = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(hi8)); + + __m256 xv_lo = _mm256_loadu_ps(xp + l); + __m256 xv_hi = _mm256_loadu_ps(xp + l + 32); + + sum_qx1_v = _mm256_add_ps(sum_qx1_v, _mm256_mul_ps(qf_lo, xv_lo)); + sum_x1_v = _mm256_add_ps(sum_x1_v, xv_lo); + sum_qx2_v = _mm256_add_ps(sum_qx2_v, _mm256_mul_ps(qf_hi, xv_hi)); + sum_x2_v = _mm256_add_ps(sum_x2_v, xv_hi); + } + float sum_qx1 = hsum_avx(sum_qx1_v); + float sum_x1 = hsum_avx(sum_x1_v); + float sum_qx2 = hsum_avx(sum_qx2_v); + float sum_x2 = hsum_avx(sum_x2_v); +#elif defined(PICOLM_AVX) + /* AVX: 128-bit nibble extraction (no AVX2 int needed), 256-bit float accumulators */ + __m256 sum_qx1_v = _mm256_setzero_ps(); + __m256 sum_x1_v = _mm256_setzero_ps(); + __m256 sum_qx2_v = _mm256_setzero_ps(); + __m256 sum_x2_v = _mm256_setzero_ps(); + const __m128i mask4 = _mm_set1_epi8(0x0F); + const __m128i zero_i = _mm_setzero_si128(); + + for (int l = 0; l < 32; l += 8) { + __m128i qb = _mm_loadl_epi64((const __m128i *)(q + l)); + __m128i lo8 = _mm_and_si128(qb, mask4); + __m128i hi8 = _mm_and_si128(_mm_srli_epi16(qb, 4), mask4); + + __m128i lo16 = _mm_unpacklo_epi8(lo8, zero_i); + __m128i hi16 = _mm_unpacklo_epi8(hi8, zero_i); + + /* Combine two __m128 → one __m256 of 8 floats */ + __m256 qf_lo = _mm256_set_m128( + _mm_cvtepi32_ps(_mm_unpackhi_epi16(lo16, zero_i)), + _mm_cvtepi32_ps(_mm_unpacklo_epi16(lo16, zero_i))); + __m256 qf_hi = _mm256_set_m128( + _mm_cvtepi32_ps(_mm_unpackhi_epi16(hi16, zero_i)), + _mm_cvtepi32_ps(_mm_unpacklo_epi16(hi16, zero_i))); + + __m256 xv_lo = _mm256_loadu_ps(xp + l); + __m256 xv_hi = _mm256_loadu_ps(xp + l + 32); + + sum_qx1_v = _mm256_add_ps(sum_qx1_v, _mm256_mul_ps(qf_lo, xv_lo)); + sum_x1_v = _mm256_add_ps(sum_x1_v, xv_lo); + sum_qx2_v = _mm256_add_ps(sum_qx2_v, _mm256_mul_ps(qf_hi, xv_hi)); + sum_x2_v = _mm256_add_ps(sum_x2_v, xv_hi); + } + float sum_qx1 = hsum_avx(sum_qx1_v); + float sum_x1 = hsum_avx(sum_x1_v); + float sum_qx2 = hsum_avx(sum_qx2_v); + float sum_x2 = hsum_avx(sum_x2_v); +#elif defined(PICOLM_SSE2) + /* SSE2: lo nibble → group1 (xp+l), hi nibble → group2 (xp+l+32) */ + __m128 sum_qx1_v = _mm_setzero_ps(); + __m128 sum_x1_v = _mm_setzero_ps(); + __m128 sum_qx2_v = _mm_setzero_ps(); + __m128 sum_x2_v = _mm_setzero_ps(); + const __m128i mask4 = _mm_set1_epi8(0x0F); + const __m128i zero_i = _mm_setzero_si128(); + + for (int l = 0; l < 32; l += 8) { + /* Load 8 quantized bytes -> 8 lo + 8 hi nibbles */ + __m128i qb = _mm_loadl_epi64((const __m128i *)(q + l)); + __m128i lo8 = _mm_and_si128(qb, mask4); + __m128i hi8 = _mm_and_si128(_mm_srli_epi16(qb, 4), mask4); + + /* Widen uint8 nibbles -> int32 -> float (8 values each) */ + __m128i lo16 = _mm_unpacklo_epi8(lo8, zero_i); + __m128i hi16 = _mm_unpacklo_epi8(hi8, zero_i); + __m128 qf_lo0 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(lo16, zero_i)); + __m128 qf_lo1 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(lo16, zero_i)); + __m128 qf_hi0 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(hi16, zero_i)); + __m128 qf_hi1 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(hi16, zero_i)); + + __m128 xv_lo0 = _mm_loadu_ps(xp + l); + __m128 xv_lo1 = _mm_loadu_ps(xp + l + 4); + __m128 xv_hi0 = _mm_loadu_ps(xp + l + 32); + __m128 xv_hi1 = _mm_loadu_ps(xp + l + 36); + + sum_qx1_v = _mm_add_ps(sum_qx1_v, + _mm_add_ps(_mm_mul_ps(qf_lo0, xv_lo0), _mm_mul_ps(qf_lo1, xv_lo1))); + sum_x1_v = _mm_add_ps(sum_x1_v, _mm_add_ps(xv_lo0, xv_lo1)); + sum_qx2_v = _mm_add_ps(sum_qx2_v, + _mm_add_ps(_mm_mul_ps(qf_hi0, xv_hi0), _mm_mul_ps(qf_hi1, xv_hi1))); + sum_x2_v = _mm_add_ps(sum_x2_v, _mm_add_ps(xv_hi0, xv_hi1)); + } + float sum_qx1 = hsum_sse(sum_qx1_v); + float sum_x1 = hsum_sse(sum_x1_v); + float sum_qx2 = hsum_sse(sum_qx2_v); + float sum_x2 = hsum_sse(sum_x2_v); #else float sum_qx1 = 0.0f, sum_x1 = 0.0f; float sum_qx2 = 0.0f, sum_x2 = 0.0f; @@ -476,9 +598,199 @@ float vec_dot_q6_K_f32(const void *src, const float *x, int n) { const int8_t *sc = blocks[i].scales; const float *xp = x + i * 256; - /* Accumulate per-scale-group sums: 16 groups of 16 elements each */ float sums[16] = {0}; +/* sign-extend packed int8 → two __m128 floats; used by AVX and SSE2 paths. + * Idiom: unpacklo_epi8(zero, x) places each byte in the HIGH byte of a 16-bit + * lane; srai_epi16(..., 8) then arithmetic-shifts it down, propagating the sign + * bit — equivalent to a sign-extending byte→int16 widening without SSE4.1. */ +#if defined(PICOLM_AVX) || defined(PICOLM_SSE2) +#define Q6K_CONV(qi8, fa, fb) do { \ + __m128i w16 = _mm_srai_epi16(_mm_unpacklo_epi8(zero_i, qi8), 8); \ + fa = _mm_cvtepi32_ps(_mm_srai_epi32(_mm_unpacklo_epi16(zero_i, w16), 16)); \ + fb = _mm_cvtepi32_ps(_mm_srai_epi32(_mm_unpackhi_epi16(zero_i, w16), 16)); \ +} while (0) +#endif + +#ifdef PICOLM_AVX2 + /* AVX2: _mm256_cvtepi8_epi32 replaces the 4-op Q6K_CONV sign-extension chain */ + const __m128i mask4 = _mm_set1_epi8(0x0F); + const __m128i mask3 = _mm_set1_epi8(0x03); + const __m128i sub32 = _mm_set1_epi8(32); + + for (int chunk = 0; chunk < 2; chunk++) { + int is = chunk * 8; + const uint8_t *ql_c = ql + chunk * 64; + const uint8_t *qh_c = qh + chunk * 32; + const float *xp_c = xp + chunk * 128; + + for (int half = 0; half < 2; half++) { + int l0 = half * 16; + int sidx = is + half; + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + __m256 acc4 = _mm256_setzero_ps(); + + for (int l = l0; l < l0 + 16; l += 8) { + __m128i qla = _mm_loadl_epi64((const __m128i *)(ql_c + l)); + __m128i qlb = _mm_loadl_epi64((const __m128i *)(ql_c + l + 32)); + __m128i qhv = _mm_loadl_epi64((const __m128i *)(qh_c + l)); + + __m128i lo_a = _mm_and_si128(qla, mask4); + __m128i hi_a = _mm_and_si128(_mm_srli_epi16(qla, 4), mask4); + __m128i lo_b = _mm_and_si128(qlb, mask4); + __m128i hi_b = _mm_and_si128(_mm_srli_epi16(qlb, 4), mask4); + + __m128i h01 = _mm_and_si128(qhv, mask3); + __m128i h23 = _mm_and_si128(_mm_srli_epi16(qhv, 2), mask3); + __m128i h45 = _mm_and_si128(_mm_srli_epi16(qhv, 4), mask3); + __m128i h67 = _mm_and_si128(_mm_srli_epi16(qhv, 6), mask3); + + __m128i q1_i8 = _mm_sub_epi8(_mm_or_si128(lo_a, _mm_slli_epi16(h01, 4)), sub32); + __m128i q2_i8 = _mm_sub_epi8(_mm_or_si128(lo_b, _mm_slli_epi16(h23, 4)), sub32); + __m128i q3_i8 = _mm_sub_epi8(_mm_or_si128(hi_a, _mm_slli_epi16(h45, 4)), sub32); + __m128i q4_i8 = _mm_sub_epi8(_mm_or_si128(hi_b, _mm_slli_epi16(h67, 4)), sub32); + + /* AVX2: sign-extend 8 int8 → 8 int32 → 8 float in 2 ops */ + __m256 qf1 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q1_i8)); + __m256 qf2 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q2_i8)); + __m256 qf3 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q3_i8)); + __m256 qf4 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q4_i8)); + + acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(qf1, _mm256_loadu_ps(xp_c + l))); + acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(qf2, _mm256_loadu_ps(xp_c + l + 32))); + acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(qf3, _mm256_loadu_ps(xp_c + l + 64))); + acc4 = _mm256_add_ps(acc4, _mm256_mul_ps(qf4, _mm256_loadu_ps(xp_c + l + 96))); + } + sums[sidx + 0] += hsum_avx(acc1); + sums[sidx + 2] += hsum_avx(acc2); + sums[sidx + 4] += hsum_avx(acc3); + sums[sidx + 6] += hsum_avx(acc4); + } + } +#elif defined(PICOLM_AVX) + /* AVX: 128-bit integer extraction, 256-bit float accumulators */ + const __m128i mask4 = _mm_set1_epi8(0x0F); + const __m128i mask3 = _mm_set1_epi8(0x03); + const __m128i sub32 = _mm_set1_epi8(32); + const __m128i zero_i = _mm_setzero_si128(); + + for (int chunk = 0; chunk < 2; chunk++) { + int is = chunk * 8; + const uint8_t *ql_c = ql + chunk * 64; + const uint8_t *qh_c = qh + chunk * 32; + const float *xp_c = xp + chunk * 128; + + for (int half = 0; half < 2; half++) { + int l0 = half * 16; + int sidx = is + half; + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + __m256 acc4 = _mm256_setzero_ps(); + + for (int l = l0; l < l0 + 16; l += 8) { + __m128i qla = _mm_loadl_epi64((const __m128i *)(ql_c + l)); + __m128i qlb = _mm_loadl_epi64((const __m128i *)(ql_c + l + 32)); + __m128i qhv = _mm_loadl_epi64((const __m128i *)(qh_c + l)); + + __m128i lo_a = _mm_and_si128(qla, mask4); + __m128i hi_a = _mm_and_si128(_mm_srli_epi16(qla, 4), mask4); + __m128i lo_b = _mm_and_si128(qlb, mask4); + __m128i hi_b = _mm_and_si128(_mm_srli_epi16(qlb, 4), mask4); + + __m128i h01 = _mm_and_si128(qhv, mask3); + __m128i h23 = _mm_and_si128(_mm_srli_epi16(qhv, 2), mask3); + __m128i h45 = _mm_and_si128(_mm_srli_epi16(qhv, 4), mask3); + __m128i h67 = _mm_and_si128(_mm_srli_epi16(qhv, 6), mask3); + + __m128i q1_i8 = _mm_sub_epi8(_mm_or_si128(lo_a, _mm_slli_epi16(h01, 4)), sub32); + __m128i q2_i8 = _mm_sub_epi8(_mm_or_si128(lo_b, _mm_slli_epi16(h23, 4)), sub32); + __m128i q3_i8 = _mm_sub_epi8(_mm_or_si128(hi_a, _mm_slli_epi16(h45, 4)), sub32); + __m128i q4_i8 = _mm_sub_epi8(_mm_or_si128(hi_b, _mm_slli_epi16(h67, 4)), sub32); + + __m128 qf1a, qf1b, qf2a, qf2b, qf3a, qf3b, qf4a, qf4b; + Q6K_CONV(q1_i8, qf1a, qf1b); + Q6K_CONV(q2_i8, qf2a, qf2b); + Q6K_CONV(q3_i8, qf3a, qf3b); + Q6K_CONV(q4_i8, qf4a, qf4b); + + acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_set_m128(qf1b, qf1a), _mm256_loadu_ps(xp_c + l))); + acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(_mm256_set_m128(qf2b, qf2a), _mm256_loadu_ps(xp_c + l + 32))); + acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(_mm256_set_m128(qf3b, qf3a), _mm256_loadu_ps(xp_c + l + 64))); + acc4 = _mm256_add_ps(acc4, _mm256_mul_ps(_mm256_set_m128(qf4b, qf4a), _mm256_loadu_ps(xp_c + l + 96))); + } + sums[sidx + 0] += hsum_avx(acc1); + sums[sidx + 2] += hsum_avx(acc2); + sums[sidx + 4] += hsum_avx(acc3); + sums[sidx + 6] += hsum_avx(acc4); + } + } +#elif defined(PICOLM_SSE2) + /* SSE2: 6-bit values = lo4(ql) | hi2(qh)<<4, biased by 32 */ + const __m128i mask4 = _mm_set1_epi8(0x0F); + const __m128i mask3 = _mm_set1_epi8(0x03); + const __m128i sub32 = _mm_set1_epi8(32); + const __m128i zero_i = _mm_setzero_si128(); + + for (int chunk = 0; chunk < 2; chunk++) { + int is = chunk * 8; + const uint8_t *ql_c = ql + chunk * 64; + const uint8_t *qh_c = qh + chunk * 32; + const float *xp_c = xp + chunk * 128; + + for (int half = 0; half < 2; half++) { /* half=0 -> sums[+0,2,4,6], half=1 -> [+1,3,5,7] */ + int l0 = half * 16; + int sidx = is + half; + __m128 acc1a = _mm_setzero_ps(), acc1b = _mm_setzero_ps(); + __m128 acc2a = _mm_setzero_ps(), acc2b = _mm_setzero_ps(); + __m128 acc3a = _mm_setzero_ps(), acc3b = _mm_setzero_ps(); + __m128 acc4a = _mm_setzero_ps(), acc4b = _mm_setzero_ps(); + + for (int l = l0; l < l0 + 16; l += 8) { + __m128i qla = _mm_loadl_epi64((const __m128i *)(ql_c + l)); + __m128i qlb = _mm_loadl_epi64((const __m128i *)(ql_c + l + 32)); + __m128i qhv = _mm_loadl_epi64((const __m128i *)(qh_c + l)); + + __m128i lo_a = _mm_and_si128(qla, mask4); + __m128i hi_a = _mm_and_si128(_mm_srli_epi16(qla, 4), mask4); + __m128i lo_b = _mm_and_si128(qlb, mask4); + __m128i hi_b = _mm_and_si128(_mm_srli_epi16(qlb, 4), mask4); + + /* epi16 shifts on qh: avoids byte-lane bleed from epi8 shifts */ + __m128i h01 = _mm_and_si128(qhv, mask3); + __m128i h23 = _mm_and_si128(_mm_srli_epi16(qhv, 2), mask3); + __m128i h45 = _mm_and_si128(_mm_srli_epi16(qhv, 4), mask3); + __m128i h67 = _mm_and_si128(_mm_srli_epi16(qhv, 6), mask3); + + __m128i q1_i8 = _mm_sub_epi8(_mm_or_si128(lo_a, _mm_slli_epi16(h01, 4)), sub32); + __m128i q2_i8 = _mm_sub_epi8(_mm_or_si128(lo_b, _mm_slli_epi16(h23, 4)), sub32); + __m128i q3_i8 = _mm_sub_epi8(_mm_or_si128(hi_a, _mm_slli_epi16(h45, 4)), sub32); + __m128i q4_i8 = _mm_sub_epi8(_mm_or_si128(hi_b, _mm_slli_epi16(h67, 4)), sub32); + + __m128 qf1a, qf1b, qf2a, qf2b, qf3a, qf3b, qf4a, qf4b; + Q6K_CONV(q1_i8, qf1a, qf1b); + Q6K_CONV(q2_i8, qf2a, qf2b); + Q6K_CONV(q3_i8, qf3a, qf3b); + Q6K_CONV(q4_i8, qf4a, qf4b); + + acc1a = _mm_add_ps(acc1a, _mm_mul_ps(qf1a, _mm_loadu_ps(xp_c + l))); + acc1b = _mm_add_ps(acc1b, _mm_mul_ps(qf1b, _mm_loadu_ps(xp_c + l + 4))); + acc2a = _mm_add_ps(acc2a, _mm_mul_ps(qf2a, _mm_loadu_ps(xp_c + l + 32))); + acc2b = _mm_add_ps(acc2b, _mm_mul_ps(qf2b, _mm_loadu_ps(xp_c + l + 36))); + acc3a = _mm_add_ps(acc3a, _mm_mul_ps(qf3a, _mm_loadu_ps(xp_c + l + 64))); + acc3b = _mm_add_ps(acc3b, _mm_mul_ps(qf3b, _mm_loadu_ps(xp_c + l + 68))); + acc4a = _mm_add_ps(acc4a, _mm_mul_ps(qf4a, _mm_loadu_ps(xp_c + l + 96))); + acc4b = _mm_add_ps(acc4b, _mm_mul_ps(qf4b, _mm_loadu_ps(xp_c + l + 100))); + } + sums[sidx + 0] += hsum_sse(_mm_add_ps(acc1a, acc1b)); + sums[sidx + 2] += hsum_sse(_mm_add_ps(acc2a, acc2b)); + sums[sidx + 4] += hsum_sse(_mm_add_ps(acc3a, acc3b)); + sums[sidx + 6] += hsum_sse(_mm_add_ps(acc4a, acc4b)); + } + } +#else for (int chunk = 0; chunk < 2; chunk++) { int is = chunk * 8; const uint8_t *ql_c = ql + chunk * 64; @@ -506,6 +818,9 @@ float vec_dot_q6_K_f32(const void *src, const float *x, int n) { sums[is + 7] += (float)q4 * xp_c[l + 96]; } } +#endif + +#undef Q6K_CONV for (int j = 0; j < 16; j++) { sumf += d * (float)sc[j] * sums[j]; diff --git a/picolm/quant.h b/picolm/quant.h index e35095c..3e8118a 100644 --- a/picolm/quant.h +++ b/picolm/quant.h @@ -4,23 +4,61 @@ #include #include -/* ---- SIMD detection ---- */ +/* ---- SIMD detection ---- + * + * Each level explicitly implies all lower levels so that code only needs to + * check a single flag. The order of checks (highest first) lets each block + * upgrade flags that a lower-level check would otherwise miss if the compiler + * only predefines the highest applicable macro. + * + * Hierarchy (x86): + * PICOLM_SSE2 ⊂ PICOLM_SSE3 ⊂ PICOLM_AVX ⊂ PICOLM_AVX2 + * + * ARM: + * PICOLM_NEON (independent) + */ + +/* --- ARM NEON --- */ #if defined(__ARM_NEON) || defined(__ARM_NEON__) -#define PICOLM_NEON 1 -#include +# define PICOLM_NEON 1 +# include static inline float vaddvq_f32_compat(float32x4_t v) { -#if defined(__aarch64__) +# if defined(__aarch64__) return vaddvq_f32(v); -#else +# else float32x2_t r = vadd_f32(vget_low_f32(v), vget_high_f32(v)); return vget_lane_f32(vpadd_f32(r, r), 0); -#endif +# endif } #endif -#if defined(__SSE2__) || (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_AMD64))) -#define PICOLM_SSE2 1 -#include +/* --- x86 SIMD: detect highest level, then propagate downward --- */ + +/* AVX2 implies AVX + SSE3 + SSE2 */ +#if defined(__AVX2__) +# define PICOLM_AVX2 1 +# define PICOLM_AVX 1 +# define PICOLM_SSE3 1 +# define PICOLM_SSE2 1 +/* AVX implies SSE3 + SSE2 */ +#elif defined(__AVX__) +# define PICOLM_AVX 1 +# define PICOLM_SSE3 1 +# define PICOLM_SSE2 1 +/* SSE3 implies SSE2 */ +#elif defined(__SSE3__) +# define PICOLM_SSE3 1 +# define PICOLM_SSE2 1 +/* SSE2 baseline (also the default for all x86-64 targets) */ +#elif defined(__SSE2__) || (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_AMD64))) +# define PICOLM_SSE2 1 +#endif + +/* Include x86 SIMD header once for any x86 SIMD level. + * is an umbrella header that exposes all intrinsics + * available for the current -march target. */ +#ifdef PICOLM_SSE2 +# include static inline float hsum_sse(__m128 v) { __m128 shuf = _mm_movehl_ps(v, v); __m128 sum = _mm_add_ps(v, shuf); @@ -30,6 +68,14 @@ static inline float hsum_sse(__m128 v) { } #endif +#ifdef PICOLM_AVX +static inline float hsum_avx(__m256 v) { + __m128 lo = _mm256_castps256_ps128(v); + __m128 hi = _mm256_extractf128_ps(v, 1); + return hsum_sse(_mm_add_ps(lo, hi)); +} +#endif + /* GGUF tensor data types */ typedef enum { GGUF_TYPE_F32 = 0, diff --git a/picolm/sampler.c b/picolm/sampler.c index bf22827..7f56e82 100644 --- a/picolm/sampler.c +++ b/picolm/sampler.c @@ -103,3 +103,41 @@ int sampler_sample(sampler_t *s, float *logits, int vocab_size) { free(sorted); return result; } + +float sampler_rand(sampler_t *s) { + return rand_float(&s->rng_state); +} + +int sampler_sample_probs(sampler_t *s, float *probs, int vocab_size) { + if (s->temperature <= 0.0f) { + int best = 0; + for (int i = 1; i < vocab_size; i++) + if (probs[i] > probs[best]) best = i; + return best; + } + if (s->top_p >= 1.0f) { + float r = rand_float(&s->rng_state); + float cum = 0.0f; + for (int i = 0; i < vocab_size; i++) { + cum += probs[i]; + if (cum > r) return i; + } + return vocab_size - 1; + } + prob_index_t *sorted = (prob_index_t *)malloc((size_t)vocab_size * sizeof(prob_index_t)); + for (int i = 0; i < vocab_size; i++) { sorted[i].prob = probs[i]; sorted[i].index = i; } + qsort(sorted, (size_t)vocab_size, sizeof(prob_index_t), cmp_prob_desc); + float cum = 0.0f; int cutoff = 0; + for (int i = 0; i < vocab_size; i++) { + cum += sorted[i].prob; cutoff = i + 1; + if (cum >= s->top_p) break; + } + float r = rand_float(&s->rng_state) * cum; + float acc = 0.0f; int result = sorted[0].index; + for (int i = 0; i < cutoff; i++) { + acc += sorted[i].prob; + if (acc > r) { result = sorted[i].index; break; } + } + free(sorted); + return result; +} diff --git a/picolm/sampler.h b/picolm/sampler.h index e8224e5..ba5bf52 100644 --- a/picolm/sampler.h +++ b/picolm/sampler.h @@ -16,4 +16,11 @@ void sampler_init(sampler_t *s, float temperature, float top_p, uint64_t seed); * Modifies logits in-place (temperature scaling, softmax). */ int sampler_sample(sampler_t *s, float *logits, int vocab_size); +/* Sample from a pre-computed probability distribution (already softmaxed). + * Applies top-p but not temperature (probs already scaled). */ +int sampler_sample_probs(sampler_t *s, float *probs, int vocab_size); + +/* Return a uniform float in [0, 1) using the sampler's RNG. */ +float sampler_rand(sampler_t *s); + #endif /* SAMPLER_H */ diff --git a/picolm/tensor.c b/picolm/tensor.c index 59a68e6..a12db9d 100644 --- a/picolm/tensor.c +++ b/picolm/tensor.c @@ -137,6 +137,15 @@ void rmsnorm(float *out, const float *x, const float *weight, int size) { } ss = vaddvq_f32_compat(acc); for (; i < size; i++) ss += x[i] * x[i]; +#elif defined(PICOLM_AVX) + __m256 acc = _mm256_setzero_ps(); + int i = 0; + for (; i + 7 < size; i += 8) { + __m256 v = _mm256_loadu_ps(x + i); + acc = _mm256_add_ps(acc, _mm256_mul_ps(v, v)); + } + ss = hsum_avx(acc); + for (; i < size; i++) ss += x[i] * x[i]; #elif defined(PICOLM_SSE2) __m128 acc = _mm_setzero_ps(); int i = 0; @@ -161,6 +170,15 @@ void rmsnorm(float *out, const float *x, const float *weight, int size) { vst1q_f32(out + i, vmulq_f32(vmulq_f32(v, scale), w)); } for (; i < size; i++) out[i] = x[i] * ss * weight[i]; +#elif defined(PICOLM_AVX) + __m256 scale = _mm256_set1_ps(ss); + i = 0; + for (; i + 7 < size; i += 8) { + __m256 v = _mm256_loadu_ps(x + i); + __m256 w = _mm256_loadu_ps(weight + i); + _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_mul_ps(v, scale), w)); + } + for (; i < size; i++) out[i] = x[i] * ss * weight[i]; #elif defined(PICOLM_SSE2) __m128 scale = _mm_set1_ps(ss); i = 0; @@ -194,6 +212,13 @@ void softmax(float *x, int size) { vst1q_f32(x + i, vmulq_f32(vld1q_f32(x + i), inv_v)); } for (; i < size; i++) x[i] *= inv; +#elif defined(PICOLM_AVX) + __m256 inv_v = _mm256_set1_ps(inv); + int i = 0; + for (; i + 7 < size; i += 8) { + _mm256_storeu_ps(x + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), inv_v)); + } + for (; i < size; i++) x[i] *= inv; #elif defined(PICOLM_SSE2) __m128 inv_v = _mm_set1_ps(inv); int i = 0; @@ -206,6 +231,64 @@ void softmax(float *x, int size) { #endif } +/* AVX RoPE: 4 complex pairs/iter; addsub handles r*cos-i*sin / r*sin+i*cos in one op */ +#ifdef PICOLM_AVX +static void rope_avx(float *h, int half, const float *cos_pos, const float *sin_pos) { + int i = 0; + for (; i + 3 < half; i += 4) { + __m256 v = _mm256_loadu_ps(h + i * 2); + __m128 c4 = _mm_loadu_ps(cos_pos + i); + __m128 s4 = _mm_loadu_ps(sin_pos + i); + __m256 cv = _mm256_set_m128(_mm_unpackhi_ps(c4, c4), _mm_unpacklo_ps(c4, c4)); + __m256 sv = _mm256_set_m128(_mm_unpackhi_ps(s4, s4), _mm_unpacklo_ps(s4, s4)); + __m256 sw = _mm256_permute_ps(v, 0xB1); /* swap r,i within each pair */ + _mm256_storeu_ps(h + i * 2, + _mm256_addsub_ps(_mm256_mul_ps(v, cv), _mm256_mul_ps(sw, sv))); + } + for (; i < half; i++) { + float r = h[i * 2], im = h[i * 2 + 1]; + h[i * 2] = r * cos_pos[i] - im * sin_pos[i]; + h[i * 2 + 1] = r * sin_pos[i] + im * cos_pos[i]; + } +} +#endif + +/* SSE2/SSE3 RoPE: 2 pairs/iter; SSE3 uses addsub, SSE2 uses sign-mask to negate even lanes */ +#if defined(PICOLM_SSE2) && !defined(PICOLM_AVX) +static void rope_sse(float *h, int half, const float *cos_pos, const float *sin_pos) { + int i = 0; +#ifdef PICOLM_SSE3 + for (; i + 1 < half; i += 2) { + __m128 v = _mm_loadu_ps(h + i * 2); + __m128 c2 = _mm_unpacklo_ps(_mm_load_ss(cos_pos + i), _mm_load_ss(cos_pos + i + 1)); + __m128 s2 = _mm_unpacklo_ps(_mm_load_ss(sin_pos + i), _mm_load_ss(sin_pos + i + 1)); + __m128 cv = _mm_shuffle_ps(c2, c2, _MM_SHUFFLE(1,1,0,0)); + __m128 sv = _mm_shuffle_ps(s2, s2, _MM_SHUFFLE(1,1,0,0)); + __m128 sw = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2,3,0,1)); + _mm_storeu_ps(h + i * 2, _mm_addsub_ps(_mm_mul_ps(v, cv), _mm_mul_ps(sw, sv))); + } +#else + const __m128 sign = _mm_set_ps(1.0f, -1.0f, 1.0f, -1.0f); + for (; i + 1 < half; i += 2) { + __m128 v = _mm_loadu_ps(h + i * 2); + __m128 c2 = _mm_unpacklo_ps(_mm_load_ss(cos_pos + i), _mm_load_ss(cos_pos + i + 1)); + __m128 s2 = _mm_unpacklo_ps(_mm_load_ss(sin_pos + i), _mm_load_ss(sin_pos + i + 1)); + __m128 cv = _mm_shuffle_ps(c2, c2, _MM_SHUFFLE(1,1,0,0)); + __m128 sv = _mm_shuffle_ps(s2, s2, _MM_SHUFFLE(1,1,0,0)); + __m128 sw = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2,3,0,1)); + __m128 a = _mm_mul_ps(v, cv); + __m128 b = _mm_mul_ps(_mm_mul_ps(sign, sw), sv); + _mm_storeu_ps(h + i * 2, _mm_add_ps(a, b)); + } +#endif + for (; i < half; i++) { + float r = h[i * 2], im = h[i * 2 + 1]; + h[i * 2] = r * cos_pos[i] - im * sin_pos[i]; + h[i * 2 + 1] = r * sin_pos[i] + im * cos_pos[i]; + } +} +#endif + /* Rotary position encoding using pre-computed cos/sin tables */ void rope(float *q, float *k, int head_dim, int n_heads, int n_kv_heads, const float *cos_pos, const float *sin_pos) { @@ -233,6 +316,10 @@ void rope(float *q, float *k, int head_dim, int n_heads, int n_kv_heads, qh[i * 2] = q0 * cos_pos[i] - q1 * sin_pos[i]; qh[i * 2 + 1] = q0 * sin_pos[i] + q1 * cos_pos[i]; } +#elif defined(PICOLM_AVX) + rope_avx(qh, half, cos_pos, sin_pos); +#elif defined(PICOLM_SSE2) + rope_sse(qh, half, cos_pos, sin_pos); #else for (int i = 0; i < half; i++) { float q0 = qh[i * 2]; @@ -246,12 +333,18 @@ void rope(float *q, float *k, int head_dim, int n_heads, int n_kv_heads, /* Apply RoPE to all KV heads */ for (int h = 0; h < n_kv_heads; h++) { float *kh = k + h * head_dim; +#ifdef PICOLM_AVX + rope_avx(kh, half, cos_pos, sin_pos); +#elif defined(PICOLM_SSE2) + rope_sse(kh, half, cos_pos, sin_pos); +#else for (int i = 0; i < half; i++) { float k0 = kh[i * 2]; float k1 = kh[i * 2 + 1]; kh[i * 2] = k0 * cos_pos[i] - k1 * sin_pos[i]; kh[i * 2 + 1] = k0 * sin_pos[i] + k1 * cos_pos[i]; } +#endif } } @@ -268,6 +361,12 @@ void elemwise_mul(float *out, const float *a, const float *b, int size) { vst1q_f32(out + i, vmulq_f32(vld1q_f32(a + i), vld1q_f32(b + i))); } for (; i < size; i++) out[i] = a[i] * b[i]; +#elif defined(PICOLM_AVX) + int i = 0; + for (; i + 7 < size; i += 8) { + _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_loadu_ps(a + i), _mm256_loadu_ps(b + i))); + } + for (; i < size; i++) out[i] = a[i] * b[i]; #elif defined(PICOLM_SSE2) int i = 0; for (; i + 3 < size; i += 4) { @@ -286,6 +385,12 @@ void vec_add(float *a, const float *b, int size) { vst1q_f32(a + i, vaddq_f32(vld1q_f32(a + i), vld1q_f32(b + i))); } for (; i < size; i++) a[i] += b[i]; +#elif defined(PICOLM_AVX) + int i = 0; + for (; i + 7 < size; i += 8) { + _mm256_storeu_ps(a + i, _mm256_add_ps(_mm256_loadu_ps(a + i), _mm256_loadu_ps(b + i))); + } + for (; i < size; i++) a[i] += b[i]; #elif defined(PICOLM_SSE2) int i = 0; for (; i + 3 < size; i += 4) {