From 5845414199e2d564b7b1474ebc777b97e86e9ef8 Mon Sep 17 00:00:00 2001 From: Pasquale Minervini Date: Fri, 29 May 2026 18:43:53 +0100 Subject: [PATCH 1/9] Add llguidance structured outputs --- Makefile | 66 ++- README.md | 20 +- ds4.c | 239 +++++++++++ ds4.h | 7 + ds4_llguidance.c | 468 ++++++++++++++++++++++ ds4_llguidance.h | 37 ++ ds4_server.c | 623 ++++++++++++++++++++++++++++- tests/structured_outputs_stress.py | 424 ++++++++++++++++++++ 8 files changed, 1866 insertions(+), 18 deletions(-) create mode 100644 ds4_llguidance.c create mode 100644 ds4_llguidance.h create mode 100755 tests/structured_outputs_stress.py diff --git a/Makefile b/Makefile index 694faf955..9cad30654 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,29 @@ OBJCFLAGS ?= -O3 -ffast-math $(DEBUG_FLAGS) $(NATIVE_CPU_FLAG) -Wall -Wextra -fo LDLIBS ?= -lm -pthread METAL_SRCS := $(wildcard metal/*.metal) +LLGUIDANCE ?= 0 +LLGUIDANCE_REPO ?= https://github.com/guidance-ai/llguidance +LLGUIDANCE_TAG ?= v1.7.5 +SERVER_EXTRA_OBJS := ds4_llguidance.o + +ifeq ($(LLGUIDANCE),1) +ifeq ($(strip $(LLGUIDANCE_DIR)),) +ifneq ($(wildcard ../../llguidance/parser/llguidance.h),) +LLGUIDANCE_DIR := ../../llguidance +else +LLGUIDANCE_DIR := .deps/llguidance +LLGUIDANCE_NEEDS_CLONE := 1 +endif +endif +LLGUIDANCE_LIB := $(LLGUIDANCE_DIR)/target/release/libllguidance.a +LLGUIDANCE_LDLIBS := $(LLGUIDANCE_LIB) +ifneq ($(UNAME_S),Darwin) +LLGUIDANCE_LDLIBS += -ldl +endif +CFLAGS += -DDS4_USE_LLGUIDANCE -I$(LLGUIDANCE_DIR)/parser +LDLIBS += $(LLGUIDANCE_LDLIBS) +DS4_LLGUIDANCE_DEPS := $(LLGUIDANCE_LIB) +endif ifeq ($(UNAME_S),Darwin) METAL_LDLIBS := $(LDLIBS) -framework Foundation -framework Metal @@ -31,6 +54,7 @@ CUDA_SPARK_FLAGS := -DDS4_CUDA_SPARK_HBM_CACHE=1 CORE_OBJS = ds4.o ds4_distributed.o ds4_cuda.o CPU_CORE_OBJS = ds4_cpu.o ds4_distributed.o CUDA_LDLIBS ?= -lm -Xcompiler -pthread -L$(CUDA_HOME)/targets/sbsa-linux/lib -L$(CUDA_HOME)/lib64 -lcudart -lcublas +CUDA_LDLIBS += $(LLGUIDANCE_LDLIBS) METAL_LDLIBS := $(LDLIBS) endif @@ -42,6 +66,7 @@ all: ds4 ds4-server ds4-bench ds4-eval ds4-agent help: @echo "DS4 build targets:" @echo " make Build Metal ./ds4, ./ds4-server, ./ds4-bench, ./ds4-eval, and ./ds4-agent" + @echo " make LLGUIDANCE=1 Build with structured-output constrained decoding" @echo " make cpu Build CPU-only ./ds4, ./ds4-server, ./ds4-bench, ./ds4-eval, and ./ds4-agent" @echo " make test Build and run tests" @echo " make clean Remove build outputs" @@ -49,8 +74,8 @@ help: ds4: ds4_cli.o linenoise.o $(CORE_OBJS) $(CC) $(CFLAGS) -o $@ ds4_cli.o linenoise.o $(CORE_OBJS) $(METAL_LDLIBS) -ds4-server: ds4_server.o ds4_kvstore.o rax.o $(CORE_OBJS) - $(CC) $(CFLAGS) -o $@ ds4_server.o ds4_kvstore.o rax.o $(CORE_OBJS) $(METAL_LDLIBS) +ds4-server: ds4_server.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) + $(CC) $(CFLAGS) -o $@ ds4_server.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) $(METAL_LDLIBS) ds4-bench: ds4_bench.o $(CORE_OBJS) $(CC) $(CFLAGS) -o $@ ds4_bench.o $(CORE_OBJS) $(METAL_LDLIBS) @@ -61,9 +86,9 @@ ds4-eval: ds4_eval.o $(CORE_OBJS) ds4-agent: ds4_agent.o ds4_web.o ds4_kvstore.o linenoise.o $(CORE_OBJS) $(CC) $(CFLAGS) -o $@ ds4_agent.o ds4_web.o ds4_kvstore.o linenoise.o $(CORE_OBJS) $(METAL_LDLIBS) -cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o ds4_eval_cpu.o ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o rax.o $(CPU_CORE_OBJS) +cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o ds4_eval_cpu.o ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) $(CC) $(CFLAGS) -o ds4 ds4_cli_cpu.o linenoise.o $(CPU_CORE_OBJS) $(LDLIBS) - $(CC) $(CFLAGS) -o ds4-server ds4_server_cpu.o ds4_kvstore.o rax.o $(CPU_CORE_OBJS) $(LDLIBS) + $(CC) $(CFLAGS) -o ds4-server ds4_server_cpu.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-bench ds4_bench_cpu.o $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-eval ds4_eval_cpu.o $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-agent ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o $(CPU_CORE_OBJS) $(LDLIBS) @@ -76,6 +101,7 @@ all: help help: @echo "DS4 build targets:" @echo " make cuda-spark Build CUDA for DGX Spark / GB10 with Spark HBM weight cache" + @echo " make LLGUIDANCE=1 ... Build with structured-output constrained decoding" @echo " make cuda-generic Build CUDA for a generic local CUDA GPU" @echo " make cuda CUDA_ARCH=sm_N Build CUDA with an explicit nvcc -arch value" @echo " make cpu Build CPU-only ./ds4, ./ds4-server, ./ds4-bench, ./ds4-eval, and ./ds4-agent" @@ -99,7 +125,7 @@ cuda: ds4: ds4_cli.o linenoise.o $(CORE_OBJS) $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) -ds4-server: ds4_server.o ds4_kvstore.o rax.o $(CORE_OBJS) +ds4-server: ds4_server.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) ds4-bench: ds4_bench.o $(CORE_OBJS) @@ -111,9 +137,9 @@ ds4-eval: ds4_eval.o $(CORE_OBJS) ds4-agent: ds4_agent.o ds4_web.o ds4_kvstore.o linenoise.o $(CORE_OBJS) $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) -cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o ds4_eval_cpu.o ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o rax.o $(CPU_CORE_OBJS) +cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o ds4_eval_cpu.o ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) $(CC) $(CFLAGS) -o ds4 ds4_cli_cpu.o linenoise.o $(CPU_CORE_OBJS) $(LDLIBS) - $(CC) $(CFLAGS) -o ds4-server ds4_server_cpu.o ds4_kvstore.o rax.o $(CPU_CORE_OBJS) $(LDLIBS) + $(CC) $(CFLAGS) -o ds4-server ds4_server_cpu.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-bench ds4_bench_cpu.o $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-eval ds4_eval_cpu.o $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-agent ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o $(CPU_CORE_OBJS) $(LDLIBS) @@ -131,9 +157,12 @@ ds4_cli.o: ds4_cli.c ds4.h ds4_distributed.h linenoise.h ds4_distributed.o: ds4_distributed.c ds4_distributed.h ds4.h $(CC) $(CFLAGS) -c -o $@ ds4_distributed.c -ds4_server.o: ds4_server.c ds4.h ds4_distributed.h ds4_kvstore.h rax.h +ds4_server.o: ds4_server.c ds4.h ds4_distributed.h ds4_kvstore.h ds4_llguidance.h rax.h $(CC) $(CFLAGS) -c -o $@ ds4_server.c +ds4_llguidance.o: ds4_llguidance.c ds4_llguidance.h ds4.h $(DS4_LLGUIDANCE_DEPS) + $(CC) $(CFLAGS) -c -o $@ ds4_llguidance.c + ds4_bench.o: ds4_bench.c ds4.h $(CC) $(CFLAGS) -c -o $@ ds4_bench.c @@ -149,7 +178,7 @@ ds4_web.o: ds4_web.c ds4_web.h ds4_kvstore.o: ds4_kvstore.c ds4_kvstore.h ds4.h $(CC) $(CFLAGS) -c -o $@ ds4_kvstore.c -ds4_test.o: tests/ds4_test.c ds4_server.c ds4.h ds4_distributed.h ds4_kvstore.h rax.h +ds4_test.o: tests/ds4_test.c ds4_server.c ds4.h ds4_distributed.h ds4_kvstore.h ds4_llguidance.h rax.h $(CC) $(CFLAGS) -Wno-unused-function -c -o $@ tests/ds4_test.c tests/cuda_long_context_smoke.o: tests/cuda_long_context_smoke.c ds4_gpu.h @@ -167,7 +196,7 @@ ds4_cpu.o: ds4.c ds4.h ds4_distributed.h ds4_gpu.h ds4_cli_cpu.o: ds4_cli.c ds4.h ds4_distributed.h linenoise.h $(CC) $(CFLAGS) -DDS4_NO_GPU -c -o $@ ds4_cli.c -ds4_server_cpu.o: ds4_server.c ds4.h ds4_distributed.h ds4_kvstore.h rax.h +ds4_server_cpu.o: ds4_server.c ds4.h ds4_distributed.h ds4_kvstore.h ds4_llguidance.h rax.h $(CC) $(CFLAGS) -DDS4_NO_GPU -c -o $@ ds4_server.c ds4_bench_cpu.o: ds4_bench.c ds4.h @@ -188,11 +217,22 @@ ds4_cuda.o: ds4_cuda.cu ds4_gpu.h ds4_iq2_tables_cuda.inc tests/cuda_long_context_smoke: tests/cuda_long_context_smoke.o ds4_cuda.o $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) -ds4_test: ds4_test.o ds4_kvstore.o rax.o $(CORE_OBJS) +ifeq ($(LLGUIDANCE),1) +ifeq ($(LLGUIDANCE_NEEDS_CLONE),1) +$(LLGUIDANCE_DIR): + mkdir -p .deps + git clone --depth 1 --branch $(LLGUIDANCE_TAG) $(LLGUIDANCE_REPO) $(LLGUIDANCE_DIR) +endif + +$(LLGUIDANCE_LIB): | $(LLGUIDANCE_DIR) + cargo build --release --package llguidance --manifest-path $(LLGUIDANCE_DIR)/Cargo.toml +endif + +ds4_test: ds4_test.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) ifeq ($(UNAME_S),Darwin) - $(CC) $(CFLAGS) -o $@ ds4_test.o ds4_kvstore.o rax.o $(CORE_OBJS) $(METAL_LDLIBS) + $(CC) $(CFLAGS) -o $@ ds4_test.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) $(METAL_LDLIBS) else - $(NVCC) $(NVCCFLAGS) -o $@ ds4_test.o ds4_kvstore.o rax.o $(CORE_OBJS) $(CUDA_LDLIBS) + $(NVCC) $(NVCCFLAGS) -o $@ ds4_test.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) $(CUDA_LDLIBS) endif test: ds4_test ds4-eval diff --git a/README.md b/README.md index bbc0e76da..a4aace8b5 100644 --- a/README.md +++ b/README.md @@ -635,9 +635,22 @@ tool calls are mapped back to OpenAI tool calls. `/v1/responses` accepts OpenAI Responses-style `input`, `instructions`, `tools`, `tool_choice`, `max_output_tokens`, `temperature`, `top_p`, `stream`, -and `reasoning`. It is the preferred endpoint for Codex CLI. The server keeps -Responses continuations bound to live state when possible, and can fall back to -the same DSML rendering and KV prefix reuse used by chat completions. +`text.format`, and `reasoning`. It is the preferred endpoint for Codex CLI. +The server keeps Responses continuations bound to live state when possible, and +can fall back to the same DSML rendering and KV prefix reuse used by chat +completions. + +Structured outputs are available when the server is built with llguidance: + +```sh +make LLGUIDANCE=1 +``` + +With that build, `/v1/chat/completions` supports +`response_format.type=json_schema` and `response_format.type=json_object`; +`/v1/responses` supports the same modes through `text.format`. Structured +outputs use constrained decoding, disable thinking for that turn, and currently +cannot be combined with tools. `/v1/messages` is the Anthropic-compatible endpoint used by Claude Code style clients. It accepts `system`, `messages`, `tools`, `tool_choice`, `max_tokens`, @@ -1133,6 +1146,7 @@ extractor self-test run first: make test # ./ds4-eval --self-test-extractors && ./ds4_test --all ./ds4_test --logprob-vectors ./ds4_test --server +python3 tests/structured_outputs_stress.py --base-url http://127.0.0.1:8000/v1 --model ds4 --apis chat,responses ``` ## Debugging Notes diff --git a/ds4.c b/ds4.c index 0953864ae..2913779a0 100644 --- a/ds4.c +++ b/ds4.c @@ -16143,6 +16143,231 @@ static int sample_top_p_min_p( return ids[filtered - 1]; } +static bool sample_mask_allows(const uint32_t *mask, size_t words, uint32_t id) { + if (!mask) return true; + const size_t word = id / 32u; + if (word >= words) return false; + return (mask[word] & (UINT32_C(1) << (id & 31u))) != 0; +} + +static bool sample_filtered_allows( + const uint32_t *allow_mask, + size_t allow_words, + const uint32_t *deny_mask, + size_t deny_words, + uint32_t id) { + return sample_mask_allows(allow_mask, allow_words, id) && + !(deny_mask && sample_mask_allows(deny_mask, deny_words, id)); +} + +static int sample_argmax_filtered( + const float *logits, + uint32_t n_vocab, + const uint32_t *allow_mask, + size_t allow_words, + const uint32_t *deny_mask, + size_t deny_words) { + int best = -1; + float best_v = DS4_NEG_INF; + for (uint32_t i = 0; i < n_vocab; i++) { + if (!sample_filtered_allows(allow_mask, allow_words, deny_mask, deny_words, i)) { + continue; + } + const float v = logits[i]; + if (best < 0 || v > best_v) { + best_v = v; + best = (int)i; + } + } + return best; +} + +static int sample_full_vocab_filtered( + const float *logits, + uint32_t n_vocab, + float temperature, + float top_p, + float min_p, + const uint32_t *allow_mask, + size_t allow_words, + const uint32_t *deny_mask, + size_t deny_words, + uint64_t *rng) { + float max_logit = DS4_NEG_INF; + int best = -1; + uint32_t finite = 0; + for (uint32_t i = 0; i < n_vocab; i++) { + if (!sample_filtered_allows(allow_mask, allow_words, deny_mask, deny_words, i)) { + continue; + } + const float v = logits[i]; + if (!isfinite(v)) continue; + finite++; + if (best < 0 || v > max_logit) { + max_logit = v; + best = (int)i; + } + } + if (finite == 0) return sample_argmax_filtered(logits, n_vocab, allow_mask, + allow_words, deny_mask, + deny_words); + + if (top_p >= 1.0f) { + float sum = 0.0f; + const float min_rel = min_p > 0.0f ? min_p : 0.0f; + for (uint32_t i = 0; i < n_vocab; i++) { + if (!sample_filtered_allows(allow_mask, allow_words, deny_mask, deny_words, i)) { + continue; + } + const float v = logits[i]; + if (!isfinite(v)) continue; + const float p = expf((v - max_logit) / temperature); + if (p < min_rel) continue; + sum += p; + } + if (sum <= 0.0f || !isfinite(sum)) return best; + float r = sample_rng_f32(rng) * sum; + for (uint32_t i = 0; i < n_vocab; i++) { + if (!sample_filtered_allows(allow_mask, allow_words, deny_mask, deny_words, i)) { + continue; + } + const float v = logits[i]; + if (!isfinite(v)) continue; + const float p = expf((v - max_logit) / temperature); + if (p < min_rel) continue; + r -= p; + if (r <= 0.0f) return (int)i; + } + return best; + } + + sample_candidate *cand = xmalloc((size_t)finite * sizeof(cand[0])); + uint32_t n = 0; + float sum = 0.0f; + for (uint32_t i = 0; i < n_vocab; i++) { + if (!sample_filtered_allows(allow_mask, allow_words, deny_mask, deny_words, i)) { + continue; + } + const float v = logits[i]; + if (!isfinite(v)) continue; + const float p = expf((v - max_logit) / temperature); + cand[n++] = (sample_candidate){.id = (int)i, .logit = v, .prob = p}; + sum += p; + } + if (sum <= 0.0f || !isfinite(sum)) { + free(cand); + return best; + } + + qsort(cand, n, sizeof(cand[0]), sample_candidate_cmp_desc); + const float min_prob = (cand[0].prob / sum) * (min_p > 0.0f ? min_p : 0.0f); + float filtered_sum = 0.0f; + uint32_t filtered = 0; + for (uint32_t i = 0; i < n; i++) { + const float p = cand[i].prob / sum; + if (i > 0 && p < min_prob) break; + filtered_sum += cand[i].prob; + filtered++; + if (filtered_sum / sum >= top_p) break; + } + if (filtered == 0) { + free(cand); + return best; + } + + float r = sample_rng_f32(rng) * filtered_sum; + for (uint32_t i = 0; i < filtered; i++) { + r -= cand[i].prob; + if (r <= 0.0f) { + const int id = cand[i].id; + free(cand); + return id; + } + } + const int id = cand[filtered - 1].id; + free(cand); + return id; +} + +static int sample_top_p_min_p_filtered( + const float *logits, + uint32_t n_vocab, + float temperature, + int top_k, + float top_p, + float min_p, + const uint32_t *allow_mask, + size_t allow_words, + const uint32_t *deny_mask, + size_t deny_words, + uint64_t *rng) { + if (temperature <= 0.0f) { + return sample_argmax_filtered(logits, n_vocab, allow_mask, allow_words, + deny_mask, deny_words); + } + if (top_p <= 0.0f || top_p > 1.0f) top_p = 1.0f; + if (min_p < 0.0f) min_p = 0.0f; + if (top_k <= 0) { + return sample_full_vocab_filtered(logits, n_vocab, temperature, top_p, + min_p, allow_mask, allow_words, + deny_mask, deny_words, rng); + } + if (top_k > 1024) top_k = 1024; + if ((uint32_t)top_k > n_vocab) top_k = (int)n_vocab; + + int ids[1024]; + float vals[1024]; + int n = 0; + for (uint32_t i = 0; i < n_vocab; i++) { + if (!sample_filtered_allows(allow_mask, allow_words, deny_mask, deny_words, i)) { + continue; + } + float v = logits[i]; + if (!isfinite(v)) continue; + if (n == top_k && v <= vals[n - 1]) continue; + int j = n < top_k ? n++ : n - 1; + while (j > 0 && vals[j - 1] < v) { + vals[j] = vals[j - 1]; + ids[j] = ids[j - 1]; + j--; + } + vals[j] = v; + ids[j] = (int)i; + } + if (n == 0) { + return sample_argmax_filtered(logits, n_vocab, allow_mask, allow_words, + deny_mask, deny_words); + } + + float probs[1024]; + const float max_logit = vals[0]; + float sum = 0.0f; + for (int i = 0; i < n; i++) { + probs[i] = expf((vals[i] - max_logit) / temperature); + sum += probs[i]; + } + if (sum <= 0.0f || !isfinite(sum)) return ids[0]; + + const float min_prob = (probs[0] / sum) * min_p; + float filtered_sum = 0.0f; + int filtered = 0; + for (int i = 0; i < n; i++) { + float p = probs[i] / sum; + if (i > 0 && p < min_prob) break; + filtered_sum += probs[i]; + filtered++; + if (filtered_sum / sum >= top_p) break; + } + if (filtered <= 0) return ids[0]; + + float r = sample_rng_f32(rng) * filtered_sum; + for (int i = 0; i < filtered; i++) { + r -= probs[i]; + if (r <= 0.0f) return ids[i]; + } + return ids[filtered - 1]; +} + static void print_top_logits( FILE * fp, const char * label, @@ -19738,6 +19963,20 @@ int ds4_session_sample(ds4_session *s, float temperature, int top_k, float top_p return sample_top_p_min_p(s->logits, DS4_N_VOCAB, temperature, top_k, top_p, min_p, rng); } +int ds4_session_sample_masked(ds4_session *s, float temperature, int top_k, + float top_p, float min_p, + const uint32_t *allow_mask, + size_t allow_mask_words, + const uint32_t *deny_mask, + size_t deny_mask_words, + uint64_t *rng) { + if (!s || !s->logits || !allow_mask) return -1; + return sample_top_p_min_p_filtered(s->logits, DS4_N_VOCAB, temperature, + top_k, top_p, min_p, allow_mask, + allow_mask_words, deny_mask, + deny_mask_words, rng); +} + int ds4_session_top_logprobs(ds4_session *s, ds4_token_score *out, int k) { if (!s || !out || k <= 0) return 0; if (k > (int)DS4_N_VOCAB) k = (int)DS4_N_VOCAB; diff --git a/ds4.h b/ds4.h index 7b7233c36..4fed595ae 100644 --- a/ds4.h +++ b/ds4.h @@ -236,6 +236,13 @@ int ds4_session_argmax_excluding(ds4_session *s, int excluded_id); int ds4_sample_logits(const float *logits, int n_vocab, float temperature, int top_k, float top_p, float min_p, uint64_t *rng); int ds4_session_sample(ds4_session *s, float temperature, int top_k, float top_p, float min_p, uint64_t *rng); +int ds4_session_sample_masked(ds4_session *s, float temperature, int top_k, + float top_p, float min_p, + const uint32_t *allow_mask, + size_t allow_mask_words, + const uint32_t *deny_mask, + size_t deny_mask_words, + uint64_t *rng); int ds4_session_top_logprobs(ds4_session *s, ds4_token_score *out, int k); int ds4_session_token_logprob(ds4_session *s, int token, ds4_token_score *out); int ds4_session_copy_logits(ds4_session *s, float *out, int cap); diff --git a/ds4_llguidance.c b/ds4_llguidance.c new file mode 100644 index 000000000..1a39ea3e8 --- /dev/null +++ b/ds4_llguidance.c @@ -0,0 +1,468 @@ +#include "ds4_llguidance.h" + +#include +#include +#include +#include +#include + +#ifdef DS4_USE_LLGUIDANCE +#include +#include "llguidance.h" +#endif + +#ifndef UINT32_C +#include +#endif + +struct ds4_llguidance { +#ifdef DS4_USE_LLGUIDANCE + LlgTokenizer *tokenizer; + LlgMatcher *matcher; + const uint32_t *leading_ws_mask; + size_t leading_ws_words; + size_t mask_words; + int n_vocab; + int eos_token; + bool started; +#else + int unused; +#endif +}; + +bool ds4_llguidance_available(void) { +#ifdef DS4_USE_LLGUIDANCE + return true; +#else + return false; +#endif +} + +const char *ds4_llguidance_build_info(void) { +#ifdef DS4_USE_LLGUIDANCE + return "llguidance enabled"; +#else + return "llguidance disabled"; +#endif +} + +#ifdef DS4_USE_LLGUIDANCE + +typedef struct { + ds4_engine *engine; + LlgTokenizer *tokenizer; + uint32_t *leading_ws_mask; + size_t leading_ws_words; + int n_vocab; +} ds4_llg_cache; + +static pthread_mutex_t g_llg_cache_mu = PTHREAD_MUTEX_INITIALIZER; +static ds4_llg_cache g_llg_cache = {0}; + +static void set_err(char *err, size_t errlen, const char *fmt, ...) { + if (!err || errlen == 0) return; + va_list ap; + va_start(ap, fmt); + vsnprintf(err, errlen, fmt, ap); + va_end(ap); +} + +static bool json_ws_byte(unsigned char c) { + return c == ' ' || c == '\n' || c == '\r' || c == '\t'; +} + +static bool bytes_all_json_ws(const char *p, size_t len) { + if (!p || len == 0) return false; + for (size_t i = 0; i < len; i++) { + if (!json_ws_byte((unsigned char)p[i])) return false; + } + return true; +} + +static bool bytes_have_non_json_ws(const char *p, size_t len) { + if (!p) return false; + for (size_t i = 0; i < len; i++) { + if (!json_ws_byte((unsigned char)p[i])) return true; + } + return false; +} + +static bool token_text_is_special(const char *p, size_t len) { + static const char *specials[] = { + "<|begin▁of▁sentence|>", + "<|end▁of▁sentence|>", + "<|User|>", + "<|Assistant|>", + "", + "", + "|DSML|", + }; + for (size_t i = 0; i < sizeof(specials) / sizeof(specials[0]); i++) { + size_t n = strlen(specials[i]); + if (len == n && memcmp(p, specials[i], n) == 0) return true; + } + + const unsigned char bar[] = {0xef, 0xbd, 0x9c}; + for (size_t i = 0; i + sizeof(bar) <= len; i++) { + if (!memcmp(p + i, bar, sizeof(bar))) return true; + } + return false; +} + +static void bitset_set(uint32_t *mask, int token) { + mask[(uint32_t)token / 32u] |= UINT32_C(1) << ((uint32_t)token & 31u); +} + +static bool bitset_get(const uint32_t *mask, size_t words, uint32_t token) { + const size_t word = token / 32u; + if (!mask || word >= words) return false; + return (mask[word] & (UINT32_C(1) << (token & 31u))) != 0; +} + +static bool mask_has_non_denied_token(const uint32_t *allow, + size_t allow_words, + const uint32_t *deny, + size_t deny_words, + int n_vocab) { + if (!allow) return false; + for (int i = 0; i < n_vocab; i++) { + if (bitset_get(allow, allow_words, (uint32_t)i) && + !bitset_get(deny, deny_words, (uint32_t)i)) + { + return true; + } + } + return false; +} + +static size_t ds4_llg_tokenize_fn(const void *user_data, + const uint8_t *bytes, + size_t bytes_len, + uint32_t *output_tokens, + size_t output_tokens_len) { + ds4_engine *e = (ds4_engine *)user_data; + char *text = malloc(bytes_len + 1); + if (!text) return 0; + memcpy(text, bytes, bytes_len); + text[bytes_len] = '\0'; + + ds4_tokens toks = {0}; + ds4_tokenize_text(e, text, &toks); + free(text); + + const size_t n = toks.len < 0 ? 0 : (size_t)toks.len; + const size_t copy = n < output_tokens_len ? n : output_tokens_len; + for (size_t i = 0; i < copy; i++) output_tokens[i] = (uint32_t)toks.v[i]; + ds4_tokens_free(&toks); + return n; +} + +static LlgTokenizer *build_tokenizer(ds4_engine *e, + uint32_t **leading_ws_mask_out, + size_t *leading_ws_words_out, + int *n_vocab_out, + char *err, + size_t errlen) { + const int n_vocab = ds4_engine_vocab_size(e); + if (n_vocab <= 0) { + set_err(err, errlen, "llguidance tokenizer cannot use an empty vocabulary"); + return NULL; + } + + size_t total = 0; + uint32_t *token_lens = calloc((size_t)n_vocab, sizeof(token_lens[0])); + if (!token_lens) { + set_err(err, errlen, "out of memory"); + return NULL; + } + + const size_t mask_words = ((size_t)n_vocab + 31u) / 32u; + uint32_t *leading_ws = calloc(mask_words, sizeof(leading_ws[0])); + if (!leading_ws) { + free(token_lens); + set_err(err, errlen, "out of memory"); + return NULL; + } + + for (int i = 0; i < n_vocab; i++) { + size_t len = 0; + char *piece = ds4_token_text(e, i, &len); + const bool special = token_text_is_special(piece, len); + token_lens[i] = (uint32_t)(len + (special ? 1u : 0u)); + total += token_lens[i]; + if (!special && bytes_all_json_ws(piece, len)) bitset_set(leading_ws, i); + free(piece); + } + + uint8_t *token_bytes = malloc(total ? total : 1); + if (!token_bytes) { + free(leading_ws); + free(token_lens); + set_err(err, errlen, "out of memory"); + return NULL; + } + + size_t off = 0; + for (int i = 0; i < n_vocab; i++) { + size_t len = 0; + char *piece = ds4_token_text(e, i, &len); + if (token_text_is_special(piece, len)) token_bytes[off++] = 0xffu; + memcpy(token_bytes + off, piece, len); + off += len; + free(piece); + } + + LlgTokenizerInit init = {0}; + init.vocab_size = (uint32_t)n_vocab; + init.tok_eos = (uint32_t)ds4_token_eos(e); + init.token_lens = token_lens; + init.token_bytes = token_bytes; + init.tokenize_assumes_string = true; + init.tokenize_fn = ds4_llg_tokenize_fn; + init.use_approximate_greedy_tokenize_fn = false; + init.tokenize_user_data = e; + init.slices = NULL; + + char llg_err[1024] = {0}; + LlgTokenizer *tok = llg_new_tokenizer(&init, llg_err, sizeof(llg_err)); + free(token_bytes); + free(token_lens); + if (!tok) { + free(leading_ws); + set_err(err, errlen, "llguidance tokenizer error: %s", llg_err); + return NULL; + } + + *leading_ws_mask_out = leading_ws; + *leading_ws_words_out = mask_words; + *n_vocab_out = n_vocab; + return tok; +} + +static LlgTokenizer *cached_tokenizer_clone(ds4_engine *e, + const uint32_t **leading_ws_mask_out, + size_t *leading_ws_words_out, + int *n_vocab_out, + char *err, + size_t errlen) { + LlgTokenizer *clone = NULL; + pthread_mutex_lock(&g_llg_cache_mu); + if (g_llg_cache.engine != e || !g_llg_cache.tokenizer) { + if (g_llg_cache.tokenizer) llg_free_tokenizer(g_llg_cache.tokenizer); + free(g_llg_cache.leading_ws_mask); + memset(&g_llg_cache, 0, sizeof(g_llg_cache)); + + uint32_t *leading_ws = NULL; + size_t leading_ws_words = 0; + int n_vocab = 0; + LlgTokenizer *tok = build_tokenizer(e, &leading_ws, &leading_ws_words, + &n_vocab, err, errlen); + if (!tok) { + pthread_mutex_unlock(&g_llg_cache_mu); + return NULL; + } + g_llg_cache.engine = e; + g_llg_cache.tokenizer = tok; + g_llg_cache.leading_ws_mask = leading_ws; + g_llg_cache.leading_ws_words = leading_ws_words; + g_llg_cache.n_vocab = n_vocab; + } + + clone = llg_clone_tokenizer(g_llg_cache.tokenizer); + if (leading_ws_mask_out) *leading_ws_mask_out = g_llg_cache.leading_ws_mask; + if (leading_ws_words_out) *leading_ws_words_out = g_llg_cache.leading_ws_words; + if (n_vocab_out) *n_vocab_out = g_llg_cache.n_vocab; + pthread_mutex_unlock(&g_llg_cache_mu); + if (!clone) set_err(err, errlen, "llguidance tokenizer clone failed"); + return clone; +} + +ds4_llguidance *ds4_llguidance_create(ds4_engine *e, + const char *constraint_type, + const char *constraint_data, + char *err, + size_t errlen) { + if (!e || !constraint_type || !constraint_type[0]) { + set_err(err, errlen, "invalid structured output constraint"); + return NULL; + } + + const uint32_t *leading_ws_mask = NULL; + size_t leading_ws_words = 0; + int n_vocab = 0; + LlgTokenizer *tok = cached_tokenizer_clone(e, &leading_ws_mask, + &leading_ws_words, + &n_vocab, err, errlen); + if (!tok) return NULL; + + LlgConstraintInit init; + llg_constraint_init_set_defaults(&init, tok); + const char *log_level = getenv("LLGUIDANCE_LOG_LEVEL"); + if (!log_level || !log_level[0]) log_level = getenv("DS4_LLGUIDANCE_LOG_LEVEL"); + if (log_level && log_level[0]) init.log_stderr_level = (uint32_t)atoi(log_level); + + LlgMatcher *matcher = llg_new_matcher(&init, constraint_type, + constraint_data ? constraint_data : ""); + const char *llg_err = matcher ? llg_matcher_get_error(matcher) : "allocation failed"; + if (llg_err) { + set_err(err, errlen, "llguidance grammar error: %s", llg_err); + if (matcher) llg_free_matcher(matcher); + llg_free_tokenizer(tok); + return NULL; + } + + const size_t mask_bytes = llg_matcher_get_mask_byte_size(matcher); + const size_t expected = ((size_t)n_vocab + 31u) / 32u * sizeof(uint32_t); + if (mask_bytes != expected) { + set_err(err, errlen, "llguidance mask size mismatch"); + llg_free_matcher(matcher); + llg_free_tokenizer(tok); + return NULL; + } + + ds4_llguidance *g = calloc(1, sizeof(*g)); + if (!g) { + set_err(err, errlen, "out of memory"); + llg_free_matcher(matcher); + llg_free_tokenizer(tok); + return NULL; + } + g->tokenizer = tok; + g->matcher = matcher; + g->leading_ws_mask = leading_ws_mask; + g->leading_ws_words = leading_ws_words; + g->mask_words = mask_bytes / sizeof(uint32_t); + g->n_vocab = n_vocab; + g->eos_token = ds4_token_eos(e); + g->started = false; + return g; +} + +void ds4_llguidance_free(ds4_llguidance *g) { + if (!g) return; + if (g->matcher) llg_free_matcher(g->matcher); + if (g->tokenizer) llg_free_tokenizer(g->tokenizer); + free(g); +} + +int ds4_llguidance_sample(ds4_llguidance *g, + ds4_session *s, + float temperature, + int top_k, + float top_p, + float min_p, + uint64_t *rng, + char *err, + size_t errlen) { + if (!g || !g->matcher || !s) { + set_err(err, errlen, "structured output decoder is not active"); + return -1; + } + if (llg_matcher_is_stopped(g->matcher)) return g->eos_token; + if (llg_matcher_compute_mask(g->matcher) != 0) { + set_err(err, errlen, "llguidance mask error: %s", + llg_matcher_get_error(g->matcher)); + return -1; + } + const uint32_t *allow = llg_matcher_get_mask(g->matcher); + if (!allow) { + set_err(err, errlen, "llguidance did not return a token mask"); + return -1; + } + + const uint32_t *deny = NULL; + size_t deny_words = 0; + if (!g->started && + mask_has_non_denied_token(allow, g->mask_words, g->leading_ws_mask, + g->leading_ws_words, g->n_vocab)) + { + deny = g->leading_ws_mask; + deny_words = g->leading_ws_words; + } + + int token = ds4_session_sample_masked(s, temperature, top_k, top_p, min_p, + allow, g->mask_words, deny, + deny_words, rng); + if (token < 0) set_err(err, errlen, "llguidance mask allowed no sampleable token"); + return token; +} + +bool ds4_llguidance_accept(ds4_llguidance *g, + ds4_engine *e, + int token, + char *err, + size_t errlen) { + if (!g || !g->matcher) return true; + if (token < 0) return true; + if (llg_matcher_consume_token(g->matcher, (uint32_t)token) != 0) { + set_err(err, errlen, "llguidance consume error: %s", + llg_matcher_get_error(g->matcher)); + return false; + } + if (!g->started && e) { + size_t len = 0; + char *piece = ds4_token_text(e, token, &len); + if (bytes_have_non_json_ws(piece, len)) g->started = true; + free(piece); + } + return true; +} + +#else + +ds4_llguidance *ds4_llguidance_create(ds4_engine *e, + const char *constraint_type, + const char *constraint_data, + char *err, + size_t errlen) { + (void)e; + (void)constraint_type; + (void)constraint_data; + if (err && errlen) { + snprintf(err, errlen, + "structured outputs require building ds4 with LLGUIDANCE=1"); + } + return NULL; +} + +void ds4_llguidance_free(ds4_llguidance *g) { + (void)g; +} + +int ds4_llguidance_sample(ds4_llguidance *g, + ds4_session *s, + float temperature, + int top_k, + float top_p, + float min_p, + uint64_t *rng, + char *err, + size_t errlen) { + (void)g; + (void)s; + (void)temperature; + (void)top_k; + (void)top_p; + (void)min_p; + (void)rng; + if (err && errlen) { + snprintf(err, errlen, + "structured outputs require building ds4 with LLGUIDANCE=1"); + } + return -1; +} + +bool ds4_llguidance_accept(ds4_llguidance *g, + ds4_engine *e, + int token, + char *err, + size_t errlen) { + (void)g; + (void)e; + (void)token; + (void)err; + (void)errlen; + return true; +} + +#endif diff --git a/ds4_llguidance.h b/ds4_llguidance.h new file mode 100644 index 000000000..f677f3b13 --- /dev/null +++ b/ds4_llguidance.h @@ -0,0 +1,37 @@ +#ifndef DS4_LLGUIDANCE_H +#define DS4_LLGUIDANCE_H + +#include +#include +#include + +#include "ds4.h" + +typedef struct ds4_llguidance ds4_llguidance; + +bool ds4_llguidance_available(void); +const char *ds4_llguidance_build_info(void); + +ds4_llguidance *ds4_llguidance_create(ds4_engine *e, + const char *constraint_type, + const char *constraint_data, + char *err, + size_t errlen); +void ds4_llguidance_free(ds4_llguidance *g); + +int ds4_llguidance_sample(ds4_llguidance *g, + ds4_session *s, + float temperature, + int top_k, + float top_p, + float min_p, + uint64_t *rng, + char *err, + size_t errlen); +bool ds4_llguidance_accept(ds4_llguidance *g, + ds4_engine *e, + int token, + char *err, + size_t errlen); + +#endif diff --git a/ds4_server.c b/ds4_server.c index 2cd33b18a..add158ca4 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -1,6 +1,7 @@ #include "ds4.h" #include "ds4_distributed.h" #include "ds4_kvstore.h" +#include "ds4_llguidance.h" #include "rax.h" /* OpenAI/Anthropic compatible local server. @@ -402,6 +403,392 @@ static char *json_minify_raw_value(const char *json) { return buf_take(&b); } +typedef enum { + DS4_TEXT_FORMAT_TEXT, + DS4_TEXT_FORMAT_JSON_OBJECT, + DS4_TEXT_FORMAT_JSON_SCHEMA, +} ds4_text_format_type; + +typedef struct { + ds4_text_format_type type; + char *name; + char *schema_json; + bool strict; +} ds4_text_format; + +static void ds4_text_format_clear(ds4_text_format *f) { + if (!f) return; + free(f->name); + free(f->schema_json); + memset(f, 0, sizeof(*f)); +} + +static bool ds4_text_format_is_json(const ds4_text_format *f) { + return f && (f->type == DS4_TEXT_FORMAT_JSON_OBJECT || + f->type == DS4_TEXT_FORMAT_JSON_SCHEMA); +} + +static void ds4_text_format_set_schema(ds4_text_format *f, + ds4_text_format_type type, + char *name, + char *schema_json, + bool strict) { + ds4_text_format_clear(f); + f->type = type; + f->name = name; + f->schema_json = schema_json; + f->strict = strict; +} + +static const char *ds4_text_format_constraint_type(const ds4_text_format *f) { + if (!f) return "text"; + if (f->type == DS4_TEXT_FORMAT_JSON_SCHEMA) return "json_schema"; + if (f->type == DS4_TEXT_FORMAT_JSON_OBJECT) { + return f->schema_json ? "json_schema" : "json_object"; + } + return "text"; +} + +static const char *ds4_text_format_constraint_data(const ds4_text_format *f) { + return f && f->schema_json ? f->schema_json : ""; +} + +static bool ds4_text_format_validate_with_llguidance(ds4_engine *e, + const ds4_text_format *f, + char *err, + size_t errlen) { + if (!ds4_text_format_is_json(f)) return true; + if (!ds4_llguidance_available()) { + snprintf(err, errlen, + "structured outputs require building ds4 with LLGUIDANCE=1"); + return false; + } + + char llg_err[160] = {0}; + ds4_llguidance *g = ds4_llguidance_create( + e, + ds4_text_format_constraint_type(f), + ds4_text_format_constraint_data(f), + llg_err, + sizeof(llg_err)); + if (!g) { + snprintf(err, errlen, "invalid structured output schema: %s", + llg_err[0] ? llg_err : "llguidance rejected constraint"); + return false; + } + ds4_llguidance_free(g); + return true; +} + +static bool parse_json_schema_wrapper(const char **p, + ds4_text_format *format, + char *err, + size_t errlen) { + json_ws(p); + if (**p != '{') return false; + (*p)++; + char *name = NULL; + char *schema = NULL; + bool strict = false; + json_ws(p); + while (**p && **p != '}') { + char *key = NULL; + if (!json_string(p, &key)) goto bad; + json_ws(p); + if (**p != ':') { + free(key); + goto bad; + } + (*p)++; + if (!strcmp(key, "name")) { + free(name); + if (!json_string(p, &name)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "schema")) { + free(schema); + if (!json_raw_value(p, &schema)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "strict")) { + if (!json_bool(p, &strict)) { + free(key); + goto bad; + } + } else if (!json_skip_value(p)) { + free(key); + goto bad; + } + free(key); + json_ws(p); + if (**p == ',') (*p)++; + json_ws(p); + } + if (**p != '}') goto bad; + (*p)++; + if (!schema) { + snprintf(err, errlen, "json_schema.schema is required"); + free(name); + return false; + } + ds4_text_format_set_schema(format, DS4_TEXT_FORMAT_JSON_SCHEMA, + name, schema, strict); + return true; +bad: + free(name); + free(schema); + return false; +} + +static bool parse_chat_response_format(const char **p, + ds4_text_format *format, + char *err, + size_t errlen) { + json_ws(p); + if (json_lit(p, "null")) { + ds4_text_format_clear(format); + return true; + } + if (**p != '{') return false; + (*p)++; + + char *type = NULL; + char *schema = NULL; + char *name = NULL; + bool strict = false; + bool saw_json_schema = false; + json_ws(p); + while (**p && **p != '}') { + char *key = NULL; + if (!json_string(p, &key)) goto bad; + json_ws(p); + if (**p != ':') { + free(key); + goto bad; + } + (*p)++; + if (!strcmp(key, "type")) { + free(type); + if (!json_string(p, &type)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "json_schema")) { + saw_json_schema = true; + if (!parse_json_schema_wrapper(p, format, err, errlen)) { + free(key); + goto bad_keep_err; + } + } else if (!strcmp(key, "schema")) { + free(schema); + if (!json_raw_value(p, &schema)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "name")) { + free(name); + if (!json_string(p, &name)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "strict")) { + if (!json_bool(p, &strict)) { + free(key); + goto bad; + } + } else if (!json_skip_value(p)) { + free(key); + goto bad; + } + free(key); + json_ws(p); + if (**p == ',') (*p)++; + json_ws(p); + } + if (**p != '}') goto bad; + (*p)++; + + if (!type || !strcmp(type, "text")) { + ds4_text_format_clear(format); + } else if (!strcmp(type, "json_object")) { + if (schema) { + ds4_text_format_set_schema(format, DS4_TEXT_FORMAT_JSON_SCHEMA, + name, schema, strict); + name = NULL; + schema = NULL; + } else { + ds4_text_format_set_schema(format, DS4_TEXT_FORMAT_JSON_OBJECT, + NULL, NULL, false); + } + } else if (!strcmp(type, "json_schema")) { + if (!saw_json_schema && schema) { + ds4_text_format_set_schema(format, DS4_TEXT_FORMAT_JSON_SCHEMA, + name, schema, strict); + name = NULL; + schema = NULL; + } else if (!format->schema_json) { + snprintf(err, errlen, "response_format json_schema.schema is required"); + goto bad_keep_err; + } + } else { + snprintf(err, errlen, "response_format.type=%s not supported", type); + goto bad_keep_err; + } + + free(type); + free(name); + free(schema); + return true; +bad: + snprintf(err, errlen, "invalid response_format"); +bad_keep_err: + free(type); + free(name); + free(schema); + return false; +} + +static bool parse_responses_text_format_object(const char **p, + ds4_text_format *format, + char *err, + size_t errlen) { + json_ws(p); + if (json_lit(p, "null")) { + ds4_text_format_clear(format); + return true; + } + if (**p != '{') return false; + (*p)++; + char *type = NULL; + char *name = NULL; + char *schema = NULL; + bool strict = false; + json_ws(p); + while (**p && **p != '}') { + char *key = NULL; + if (!json_string(p, &key)) goto bad; + json_ws(p); + if (**p != ':') { + free(key); + goto bad; + } + (*p)++; + if (!strcmp(key, "type")) { + free(type); + if (!json_string(p, &type)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "name")) { + free(name); + if (!json_string(p, &name)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "schema")) { + free(schema); + if (!json_raw_value(p, &schema)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "strict")) { + if (!json_bool(p, &strict)) { + free(key); + goto bad; + } + } else if (!json_skip_value(p)) { + free(key); + goto bad; + } + free(key); + json_ws(p); + if (**p == ',') (*p)++; + json_ws(p); + } + if (**p != '}') goto bad; + (*p)++; + + if (!type || !strcmp(type, "text")) { + ds4_text_format_clear(format); + } else if (!strcmp(type, "json_object")) { + if (schema) { + ds4_text_format_set_schema(format, DS4_TEXT_FORMAT_JSON_SCHEMA, + name, schema, strict); + name = NULL; + schema = NULL; + } else { + ds4_text_format_set_schema(format, DS4_TEXT_FORMAT_JSON_OBJECT, + NULL, NULL, false); + } + } else if (!strcmp(type, "json_schema")) { + if (!schema) { + snprintf(err, errlen, "text.format.schema is required"); + goto bad_keep_err; + } + ds4_text_format_set_schema(format, DS4_TEXT_FORMAT_JSON_SCHEMA, + name, schema, strict); + name = NULL; + schema = NULL; + } else { + snprintf(err, errlen, "text.format.type=%s not supported", type); + goto bad_keep_err; + } + + free(type); + free(name); + free(schema); + return true; +bad: + snprintf(err, errlen, "invalid text.format"); +bad_keep_err: + free(type); + free(name); + free(schema); + return false; +} + +static bool parse_responses_text_value(const char **p, + ds4_text_format *format, + char *err, + size_t errlen) { + json_ws(p); + if (json_lit(p, "null")) { + ds4_text_format_clear(format); + return true; + } + if (**p != '{') return false; + (*p)++; + json_ws(p); + while (**p && **p != '}') { + char *key = NULL; + if (!json_string(p, &key)) return false; + json_ws(p); + if (**p != ':') { + free(key); + return false; + } + (*p)++; + if (!strcmp(key, "format")) { + if (!parse_responses_text_format_object(p, format, err, errlen)) { + free(key); + return false; + } + } else if (!json_skip_value(p)) { + free(key); + return false; + } + free(key); + json_ws(p); + if (**p == ',') (*p)++; + json_ws(p); + } + if (**p != '}') return false; + (*p)++; + return true; +} + static bool json_content(const char **p, char **out) { json_ws(p); if (**p == '"') return json_string(p, out); @@ -601,6 +988,7 @@ typedef struct { int cache_read_tokens; int cache_write_tokens; ds4_think_mode think_mode; + ds4_text_format text_format; bool has_tools; bool prompt_preserves_reasoning; /* For /v1/responses: emit reasoning_summary_* events / fields only when the @@ -763,6 +1151,7 @@ static void request_free(request *r) { free(r->stops.v); free(r->raw_body); free(r->prompt_text); + ds4_text_format_clear(&r->text_format); stop_list_clear(&r->responses_live_call_ids); free(r->responses_live_call_ids.v); free(r->responses_live_suffix_text); @@ -2726,6 +3115,15 @@ static bool parse_chat_request(ds4_engine *e, server *s, const char *body, int d free(key); goto bad; } + } else if (!strcmp(key, "response_format")) { + if (!parse_chat_response_format(&p, &r->text_format, err, errlen)) { + free(key); + chat_msgs_free(&msgs); + free(tool_schemas); + if (!err[0]) snprintf(err, errlen, "invalid response_format"); + request_free(r); + return false; + } } else if (!strcmp(key, "thinking")) { if (!parse_thinking_control_value(&p, &thinking_enabled)) { free(key); @@ -2766,6 +3164,25 @@ static bool parse_chat_request(ds4_engine *e, server *s, const char *body, int d return false; } r->has_tools = tool_schemas && tool_schemas[0] && !tool_choice_none; + if (ds4_text_format_is_json(&r->text_format)) { + if (r->has_tools) { + snprintf(err, errlen, + "structured outputs with tools are not supported"); + chat_msgs_free(&msgs); + free(tool_schemas); + request_free(r); + return false; + } + if (!ds4_text_format_validate_with_llguidance(e, &r->text_format, + err, errlen)) { + chat_msgs_free(&msgs); + free(tool_schemas); + request_free(r); + return false; + } + thinking_enabled = false; + got_thinking = true; + } if (!got_thinking && model_alias_disables_thinking(r->model)) thinking_enabled = false; if (!got_thinking && model_alias_enables_thinking(r->model)) thinking_enabled = true; r->think_mode = ds4_think_mode_for_context( @@ -3815,6 +4232,17 @@ static bool parse_responses_request(ds4_engine *e, server *s, const char *body, free(key); goto bad; } + } else if (!strcmp(key, "text")) { + if (!parse_responses_text_value(&p, &r->text_format, err, errlen)) { + free(key); + chat_msgs_free(&msgs); + buf_free(&loaded_tool_schemas); + free(instructions); + free(tool_schemas); + if (!err[0]) snprintf(err, errlen, "invalid text"); + request_free(r); + return false; + } } else if (!strcmp(key, "reasoning")) { bool effort_seen = false; if (!parse_responses_reasoning(&p, &reasoning_effort, @@ -3904,6 +4332,32 @@ static bool parse_responses_request(ds4_engine *e, server *s, const char *body, (!tool_choice_none && combined_tool_schemas.len) ? combined_tool_schemas.ptr : NULL; r->has_tools = active_tool_schemas && active_tool_schemas[0]; + if (ds4_text_format_is_json(&r->text_format)) { + if (r->has_tools) { + snprintf(err, errlen, + "structured outputs with tools are not supported"); + chat_msgs_free(&msgs); + buf_free(&combined_tool_schemas); + buf_free(&loaded_tool_schemas); + free(instructions); + free(tool_schemas); + request_free(r); + return false; + } + if (!ds4_text_format_validate_with_llguidance(e, &r->text_format, + err, errlen)) { + chat_msgs_free(&msgs); + buf_free(&combined_tool_schemas); + buf_free(&loaded_tool_schemas); + free(instructions); + free(tool_schemas); + request_free(r); + return false; + } + thinking_enabled = false; + got_thinking = true; + r->reasoning_summary_emit = false; + } if (!got_thinking && model_alias_disables_thinking(r->model)) thinking_enabled = false; if (!got_thinking && model_alias_enables_thinking(r->model)) thinking_enabled = true; r->think_mode = ds4_think_mode_for_context( @@ -5966,6 +6420,10 @@ static bool request_uses_structured_stream(const request *r) { request_uses_openai_live_stream(r)); } +static bool request_uses_structured_decoder(const request *r) { + return r && r->kind == REQ_CHAT && ds4_text_format_is_json(&r->text_format); +} + /* Codex' Responses API uses 24-hex suffixes for response/item ids. Prefix * controls the variant (resp_, rs_, msg_, fc_) so each event references a * stable identifier across output_item.added / .done. */ @@ -9907,6 +10365,7 @@ static bool should_canonicalize_tool_checkpoint(const server *s, const tool_call static void generate_job(server *s, job *j) { char err[160]; err[0] = '\0'; + ds4_llguidance *structured = NULL; const int old_pos = ds4_session_pos(s->session); const int common = ds4_session_common_prefix(s->session, &j->req.prompt); trace_cache_diag cache_diag = {0}; @@ -10064,6 +10523,25 @@ static void generate_job(server *s, job *j) { char req_flags[64]; log_flags(req_flags, sizeof(req_flags), responses_protocol, j->req.has_tools, false, false, false); + if (request_uses_structured_decoder(&j->req)) { + structured = ds4_llguidance_create( + s->engine, + ds4_text_format_constraint_type(&j->req.text_format), + ds4_text_format_constraint_data(&j->req.text_format), + err, + sizeof(err)); + if (!structured) { + ds4_tokens_free(&effective_prompt); + free(disk_cache_path); + trace_event(s, trace_id, "structured output init failed: %s", + err[0] ? err : "unknown error"); + http_error(j->fd, s->enable_cors, 400, + err[0] ? err : "structured output init failed"); + return; + } + trace_event(s, trace_id, "structured output constraint=%s", + ds4_text_format_constraint_type(&j->req.text_format)); + } if (responses_live_continuation) { server_log(DS4_LOG_PREFILL, "ds4-server: responses live continuation RESPPROTO match=%s ids=%d cached=%d prompt=%d", @@ -10149,6 +10627,7 @@ static void generate_job(server *s, job *j) { cold_store_len); kv_cache_discard_failed_disk_entry(s, disk_cache_path); free(disk_cache_path); + ds4_llguidance_free(structured); trace_event(s, trace_id, "prefill failed: %s", err); send_prefill_failure_response(s, j, &progress, ctx_span, req_flags, err); return; @@ -10172,6 +10651,7 @@ static void generate_job(server *s, job *j) { cold_store_len); kv_cache_discard_failed_disk_entry(s, disk_cache_path); free(disk_cache_path); + ds4_llguidance_free(structured); trace_event(s, trace_id, "prefill failed: %s", err); send_prefill_failure_response(s, j, &progress, ctx_span, req_flags, err); return; @@ -10222,6 +10702,7 @@ static void generate_job(server *s, job *j) { req_flags[0] ? " " : "", req_flags); ds4_tokens_free(&effective_prompt); + ds4_llguidance_free(structured); return; } /* The prefill progress callback may have already sent the SSE headers @@ -10235,6 +10716,7 @@ static void generate_job(server *s, job *j) { req_flags[0] ? " " : "", req_flags); ds4_tokens_free(&effective_prompt); + ds4_llguidance_free(structured); return; } progress.headers_sent = true; @@ -10243,12 +10725,14 @@ static void generate_job(server *s, job *j) { prompt_tokens, &anthropic_live)) { server_log(DS4_LOG_GENERATION, "ds4-server: chat ctx=%s anthropic stream start failed", ctx_span); ds4_tokens_free(&effective_prompt); + ds4_llguidance_free(structured); return; } if (j->req.api == API_OPENAI && j->req.kind == REQ_CHAT && !sse_chunk(j->fd, &j->req, id, NULL, NULL)) { server_log(DS4_LOG_GENERATION, "ds4-server: chat ctx=%s openai role chunk failed", ctx_span); ds4_tokens_free(&effective_prompt); + ds4_llguidance_free(structured); return; } if (openai_live_chat) openai_stream_start(&j->req, &openai_live); @@ -10263,6 +10747,7 @@ static void generate_job(server *s, job *j) { req_flags); responses_stream_free(&responses_live); ds4_tokens_free(&effective_prompt); + ds4_llguidance_free(structured); return; } } @@ -10320,7 +10805,15 @@ static void generate_job(server *s, job *j) { if (in_tool_call && !dsml_decode_state_uses_payload_sampling(dsml_state)) { temperature = 0.0f; } - int token = ds4_session_sample(s->session, temperature, top_k, top_p, min_p, &rng); + int token = structured ? + ds4_llguidance_sample(structured, s->session, + temperature, top_k, top_p, min_p, + &rng, err, sizeof(err)) : + ds4_session_sample(s->session, temperature, top_k, top_p, min_p, &rng); + if (token < 0) { + finish = "error"; + break; + } if (token == ds4_token_eos(s->engine)) { finish = "stop"; break; @@ -10328,7 +10821,8 @@ static void generate_job(server *s, job *j) { int toks[17]; int ntok = 0; - if (temperature <= 0.0f && + if (!structured && + temperature <= 0.0f && ds4_engine_mtp_draft_tokens(s->engine) > 1 && getenv("DS4_MTP_SPEC_DISABLE") == NULL) { @@ -10361,6 +10855,13 @@ static void generate_job(server *s, job *j) { stop_decode = true; break; } + if (structured && + !ds4_llguidance_accept(structured, s->engine, token, + err, sizeof(err))) { + finish = "error"; + stop_decode = true; + break; + } size_t piece_len = 0; char *piece = ds4_token_text(s->engine, token, &piece_len); @@ -10916,6 +11417,7 @@ static void generate_job(server *s, job *j) { anthropic_stream_free(&anthropic_live); openai_stream_free(&openai_live); responses_stream_free(&responses_live); + ds4_llguidance_free(structured); buf_free(&text); ds4_tokens_free(&effective_prompt); } @@ -11767,6 +12269,116 @@ static void test_assert(bool cond, const char *file, int line, const char *expr) #define TEST_ASSERT(expr) test_assert((expr), __FILE__, __LINE__, #expr) +static void test_parse_chat_response_format_json_schema(void) { + const char *json = + "{\"type\":\"json_schema\",\"json_schema\":{" + "\"name\":\"calendar_event\",\"strict\":true," + "\"schema\":{\"type\":\"object\",\"properties\":{" + "\"name\":{\"type\":\"string\"}},\"required\":[\"name\"]," + "\"additionalProperties\":false}}}"; + const char *p = json; + ds4_text_format fmt = {0}; + char err[160] = {0}; + + TEST_ASSERT(parse_chat_response_format(&p, &fmt, err, sizeof(err))); + TEST_ASSERT(fmt.type == DS4_TEXT_FORMAT_JSON_SCHEMA); + TEST_ASSERT(fmt.name && !strcmp(fmt.name, "calendar_event")); + TEST_ASSERT(fmt.strict); + TEST_ASSERT(fmt.schema_json && strstr(fmt.schema_json, "\"additionalProperties\"")); + json_ws(&p); + TEST_ASSERT(*p == '\0'); + + ds4_text_format_clear(&fmt); +} + +static void test_parse_chat_response_format_json_object(void) { + const char *json = "{\"type\":\"json_object\"}"; + const char *p = json; + ds4_text_format fmt = {0}; + char err[160] = {0}; + + TEST_ASSERT(parse_chat_response_format(&p, &fmt, err, sizeof(err))); + TEST_ASSERT(fmt.type == DS4_TEXT_FORMAT_JSON_OBJECT); + TEST_ASSERT(fmt.schema_json == NULL); + TEST_ASSERT(!strcmp(ds4_text_format_constraint_type(&fmt), "json_object")); + + ds4_text_format_clear(&fmt); +} + +static void test_parse_chat_response_format_rejects_missing_schema(void) { + const char *json = "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"bad\"}}"; + const char *p = json; + ds4_text_format fmt = {0}; + char err[160] = {0}; + + TEST_ASSERT(!parse_chat_response_format(&p, &fmt, err, sizeof(err))); + TEST_ASSERT(strstr(err, "schema is required") != NULL); + + ds4_text_format_clear(&fmt); +} + +static void test_parse_responses_text_format_json_schema(void) { + const char *json = + "{\"format\":{\"type\":\"json_schema\"," + "\"name\":\"calendar_event\",\"strict\":true," + "\"schema\":{\"type\":\"object\",\"properties\":{" + "\"date\":{\"type\":\"string\"}},\"required\":[\"date\"]," + "\"additionalProperties\":false}}}"; + const char *p = json; + ds4_text_format fmt = {0}; + char err[160] = {0}; + + TEST_ASSERT(parse_responses_text_value(&p, &fmt, err, sizeof(err))); + TEST_ASSERT(fmt.type == DS4_TEXT_FORMAT_JSON_SCHEMA); + TEST_ASSERT(fmt.name && !strcmp(fmt.name, "calendar_event")); + TEST_ASSERT(fmt.strict); + TEST_ASSERT(fmt.schema_json && strstr(fmt.schema_json, "\"required\"")); + TEST_ASSERT(!strcmp(ds4_text_format_constraint_type(&fmt), "json_schema")); + json_ws(&p); + TEST_ASSERT(*p == '\0'); + + ds4_text_format_clear(&fmt); +} + +static void test_parse_responses_text_format_json_object(void) { + const char *json = "{\"format\":{\"type\":\"json_object\"}}"; + const char *p = json; + ds4_text_format fmt = {0}; + char err[160] = {0}; + + TEST_ASSERT(parse_responses_text_value(&p, &fmt, err, sizeof(err))); + TEST_ASSERT(fmt.type == DS4_TEXT_FORMAT_JSON_OBJECT); + TEST_ASSERT(fmt.schema_json == NULL); + TEST_ASSERT(!strcmp(ds4_text_format_constraint_type(&fmt), "json_object")); + + ds4_text_format_clear(&fmt); +} + +static void test_parse_responses_text_format_rejects_unknown_type(void) { + const char *json = "{\"format\":{\"type\":\"xml\"}}"; + const char *p = json; + ds4_text_format fmt = {0}; + char err[160] = {0}; + + TEST_ASSERT(!parse_responses_text_value(&p, &fmt, err, sizeof(err))); + TEST_ASSERT(strstr(err, "not supported") != NULL); + + ds4_text_format_clear(&fmt); +} + +static void test_parse_responses_text_format_text_is_noop(void) { + const char *json = "{\"format\":{\"type\":\"text\"}}"; + const char *p = json; + ds4_text_format fmt = {0}; + char err[160] = {0}; + + TEST_ASSERT(parse_responses_text_value(&p, &fmt, err, sizeof(err))); + TEST_ASSERT(fmt.type == DS4_TEXT_FORMAT_TEXT); + TEST_ASSERT(fmt.schema_json == NULL); + + ds4_text_format_clear(&fmt); +} + static void test_tool_schema_order_from_anthropic_schema(void) { tool_schema_orders orders = {0}; tool_schema_orders_add_json(&orders, @@ -15554,6 +16166,13 @@ static void ds4_server_unit_tests_run(void) { test_render_drops_old_reasoning_without_tools(); test_render_preserves_reasoning_with_tools(); test_render_chat_prompt_text_renders_tools_before_system(); + test_parse_chat_response_format_json_schema(); + test_parse_chat_response_format_json_object(); + test_parse_chat_response_format_rejects_missing_schema(); + test_parse_responses_text_format_json_schema(); + test_parse_responses_text_format_json_object(); + test_parse_responses_text_format_rejects_unknown_type(); + test_parse_responses_text_format_text_is_noop(); test_tool_schema_order_from_anthropic_schema(); test_tool_schema_order_from_openai_tools(); test_tool_schema_order_from_responses_tool_search(); diff --git a/tests/structured_outputs_stress.py b/tests/structured_outputs_stress.py new file mode 100755 index 000000000..9bc7610fc --- /dev/null +++ b/tests/structured_outputs_stress.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +"""Stress JSON structured outputs on OpenAI-compatible chat/responses APIs. + +Examples: + python3 tests/structured_outputs_stress.py \ + --base-url http://127.0.0.1:8000/v1 --model ds4 --apis chat,responses + + python3 tests/structured_outputs_stress.py \ + --base-url http://127.0.0.1:8080/v1 --model qwen --apis chat +""" + +from __future__ import annotations + +import argparse +import json +import re +import sys +import time +import urllib.error +import urllib.request +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class Case: + name: str + prompt: str + schema: dict[str, Any] | None + json_object: bool = False + + +CASES: list[Case] = [ + Case( + name="calendar_event", + prompt=( + "Create one calendar event for Alice and Bob having lunch on " + "2026-06-01 at noon. Return only the requested JSON object." + ), + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "date": {"type": "string"}, + "participants": { + "type": "array", + "items": {"type": "string"}, + "minItems": 1, + "maxItems": 5, + }, + }, + "required": ["name", "date", "participants"], + "additionalProperties": False, + }, + ), + Case( + name="enum_const_integer_boolean", + prompt=( + "Return a compact health-check result. Use status ok, one priority, " + "a retry count, and whether the system is active." + ), + schema={ + "type": "object", + "properties": { + "status": {"const": "ok"}, + "priority": {"type": "string", "enum": ["low", "medium", "high"]}, + "retry_count": {"type": "integer", "minimum": 0, "maximum": 5}, + "active": {"type": "boolean"}, + }, + "required": ["status", "priority", "retry_count", "active"], + "additionalProperties": False, + }, + ), + Case( + name="nested_arrays", + prompt=( + "Return a 2 by 2 integer matrix and two short labels. Keep values " + "small and return only JSON." + ), + schema={ + "type": "object", + "properties": { + "matrix": { + "type": "array", + "minItems": 2, + "maxItems": 2, + "items": { + "type": "array", + "minItems": 2, + "maxItems": 2, + "items": {"type": "integer", "minimum": -9, "maximum": 9}, + }, + }, + "labels": { + "type": "array", + "minItems": 2, + "maxItems": 2, + "items": {"type": "string"}, + }, + }, + "required": ["matrix", "labels"], + "additionalProperties": False, + }, + ), + Case( + name="nullable_anyof_number_bounds", + prompt=( + "Return a score between zero and one, and use either an owner name " + "or null if unknown." + ), + schema={ + "type": "object", + "properties": { + "owner": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "score": {"type": "number", "minimum": 0, "maximum": 1}, + }, + "required": ["owner", "score"], + "additionalProperties": False, + }, + ), + Case( + name="pattern_string", + prompt="Return an inventory code in the form two uppercase letters, dash, three digits.", + schema={ + "type": "object", + "properties": { + "code": {"type": "string", "pattern": "^[A-Z]{2}-[0-9]{3}$"} + }, + "required": ["code"], + "additionalProperties": False, + }, + ), + Case( + name="json_object_mode", + prompt="Return a JSON object with two fields describing a tiny task list.", + schema=None, + json_object=True, + ), +] + + +class ValidationError(Exception): + pass + + +def type_matches(value: Any, typ: str) -> bool: + if typ == "object": + return isinstance(value, dict) + if typ == "array": + return isinstance(value, list) + if typ == "string": + return isinstance(value, str) + if typ == "integer": + return isinstance(value, int) and not isinstance(value, bool) + if typ == "number": + return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool) + if typ == "boolean": + return isinstance(value, bool) + if typ == "null": + return value is None + return True + + +def validate_schema(value: Any, schema: dict[str, Any], path: str = "$") -> None: + if "anyOf" in schema: + errors: list[str] = [] + for option in schema["anyOf"]: + try: + validate_schema(value, option, path) + return + except ValidationError as exc: + errors.append(str(exc)) + raise ValidationError(f"{path}: did not match anyOf: {'; '.join(errors)}") + + if "const" in schema and value != schema["const"]: + raise ValidationError(f"{path}: expected const {schema['const']!r}, got {value!r}") + if "enum" in schema and value not in schema["enum"]: + raise ValidationError(f"{path}: expected one of {schema['enum']!r}, got {value!r}") + + typ = schema.get("type") + if isinstance(typ, list): + if not any(type_matches(value, t) for t in typ): + raise ValidationError(f"{path}: wrong type {type(value).__name__}, expected {typ}") + elif isinstance(typ, str) and not type_matches(value, typ): + raise ValidationError(f"{path}: wrong type {type(value).__name__}, expected {typ}") + + if typ == "object" or "properties" in schema: + if not isinstance(value, dict): + raise ValidationError(f"{path}: expected object") + props = schema.get("properties", {}) + for key in schema.get("required", []): + if key not in value: + raise ValidationError(f"{path}: missing required property {key!r}") + if schema.get("additionalProperties") is False: + extra = sorted(set(value) - set(props)) + if extra: + raise ValidationError(f"{path}: extra properties {extra!r}") + for key, sub in props.items(): + if key in value: + validate_schema(value[key], sub, f"{path}.{key}") + + if typ == "array" or "items" in schema: + if not isinstance(value, list): + raise ValidationError(f"{path}: expected array") + min_items = schema.get("minItems") + max_items = schema.get("maxItems") + if min_items is not None and len(value) < min_items: + raise ValidationError(f"{path}: expected at least {min_items} items") + if max_items is not None and len(value) > max_items: + raise ValidationError(f"{path}: expected at most {max_items} items") + items = schema.get("items") + if isinstance(items, dict): + for i, item in enumerate(value): + validate_schema(item, items, f"{path}[{i}]") + + if isinstance(value, str) and "pattern" in schema: + if re.fullmatch(schema["pattern"], value) is None: + raise ValidationError(f"{path}: {value!r} does not match {schema['pattern']!r}") + + if isinstance(value, (int, float)) and not isinstance(value, bool): + if "minimum" in schema and value < schema["minimum"]: + raise ValidationError(f"{path}: {value!r} is below minimum {schema['minimum']!r}") + if "maximum" in schema and value > schema["maximum"]: + raise ValidationError(f"{path}: {value!r} is above maximum {schema['maximum']!r}") + + +def post_json(url: str, payload: dict[str, Any], timeout: float) -> dict[str, Any]: + data = json.dumps(payload, separators=(",", ":")).encode("utf-8") + req = urllib.request.Request( + url, + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + raw = resp.read().decode("utf-8", errors="replace") + except urllib.error.HTTPError as exc: + raw = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"HTTP {exc.code}: {raw[:1000]}") from exc + except urllib.error.URLError as exc: + raise RuntimeError(str(exc)) from exc + try: + body = json.loads(raw) + except json.JSONDecodeError as exc: + raise RuntimeError(f"invalid JSON response: {raw[:1000]}") from exc + if isinstance(body, dict) and body.get("error"): + raise RuntimeError(f"API error: {body['error']!r}") + return body + + +def chat_payload(model: str, case: Case, json_object_schema: bool) -> dict[str, Any]: + response_format: dict[str, Any] + if case.json_object: + response_format = {"type": "json_object"} + if json_object_schema: + response_format["schema"] = {"type": "object"} + else: + response_format = { + "type": "json_schema", + "json_schema": { + "name": case.name, + "strict": True, + "schema": case.schema, + }, + } + return { + "model": model, + "messages": [{"role": "user", "content": case.prompt}], + "max_tokens": 256, + "temperature": 0, + "response_format": response_format, + } + + +def responses_payload(model: str, case: Case, json_object_schema: bool) -> dict[str, Any]: + fmt: dict[str, Any] + if case.json_object: + fmt = {"type": "json_object"} + if json_object_schema: + fmt["schema"] = {"type": "object"} + else: + fmt = { + "type": "json_schema", + "name": case.name, + "strict": True, + "schema": case.schema, + } + return { + "model": model, + "input": case.prompt, + "max_output_tokens": 256, + "temperature": 0, + "text": {"format": fmt}, + } + + +def extract_chat_text(body: dict[str, Any]) -> str: + choices = body.get("choices") + if not isinstance(choices, list) or not choices: + raise RuntimeError(f"missing choices in chat response: {body!r}") + message = choices[0].get("message", {}) + content = message.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + out: list[str] = [] + for part in content: + if isinstance(part, dict) and isinstance(part.get("text"), str): + out.append(part["text"]) + return "".join(out) + raise RuntimeError(f"missing text content in chat response: {body!r}") + + +def extract_responses_text(body: dict[str, Any]) -> str: + if isinstance(body.get("output_text"), str): + return body["output_text"] + out: list[str] = [] + for item in body.get("output", []): + if not isinstance(item, dict): + continue + if item.get("type") == "message": + for part in item.get("content", []): + if isinstance(part, dict) and isinstance(part.get("text"), str): + out.append(part["text"]) + if out: + return "".join(out) + raise RuntimeError(f"missing output text in responses response: {body!r}") + + +def check_case( + api: str, + base_url: str, + model: str, + case: Case, + timeout: float, + json_object_schema: bool, +) -> str: + if api == "chat": + body = post_json( + f"{base_url}/chat/completions", + chat_payload(model, case, json_object_schema), + timeout, + ) + text = extract_chat_text(body) + elif api == "responses": + body = post_json( + f"{base_url}/responses", + responses_payload(model, case, json_object_schema), + timeout, + ) + text = extract_responses_text(body) + else: + raise RuntimeError(f"unknown api {api!r}") + + try: + value = json.loads(text.strip()) + except json.JSONDecodeError as exc: + raise RuntimeError(f"{api}/{case.name}: output is not JSON: {text!r}") from exc + if not isinstance(value, dict): + raise RuntimeError(f"{api}/{case.name}: output is not a JSON object: {value!r}") + if case.schema is not None: + validate_schema(value, case.schema) + return json.dumps(value, ensure_ascii=False, sort_keys=True) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--base-url", required=True, help="Base URL, usually http://host:port/v1") + p.add_argument("--model", required=True) + p.add_argument("--apis", default="chat,responses", help="Comma-separated: chat,responses") + p.add_argument("--case", action="append", help="Run only this case name; may repeat") + p.add_argument("--repeat", type=int, default=1) + p.add_argument("--timeout", type=float, default=120.0) + p.add_argument( + "--json-object-schema", + action="store_true", + help="Send {'type':'object'} with json_object mode for servers that require a concrete schema.", + ) + p.add_argument("--verbose", action="store_true") + return p.parse_args() + + +def main() -> int: + args = parse_args() + base_url = args.base_url.rstrip("/") + apis = [x.strip() for x in args.apis.split(",") if x.strip()] + selected = set(args.case or []) + cases = [c for c in CASES if not selected or c.name in selected] + missing = selected - {c.name for c in CASES} + if missing: + print(f"unknown case(s): {', '.join(sorted(missing))}", file=sys.stderr) + return 2 + + failures = 0 + for repeat in range(args.repeat): + for api in apis: + for case in cases: + label = f"{api}/{case.name}" + if args.repeat > 1: + label = f"{label}#{repeat + 1}" + t0 = time.time() + try: + value = check_case( + api, + base_url, + args.model, + case, + args.timeout, + args.json_object_schema, + ) + elapsed = time.time() - t0 + if args.verbose: + print(f"PASS {label} {elapsed:.2f}s {value}") + else: + print(f"PASS {label} {elapsed:.2f}s") + except Exception as exc: + failures += 1 + print(f"FAIL {label}: {exc}", file=sys.stderr) + return 1 if failures else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 08e053d31d93d46eae67fa26478e9435fba858b5 Mon Sep 17 00:00:00 2001 From: fry69 <142489379+fry69@users.noreply.github.com> Date: Fri, 29 May 2026 20:43:36 +0200 Subject: [PATCH 2/9] build: add $(DS4_LLGUIDANCE_DEPS) prerequisite to all binary targets Ensure ds4, ds4-server, ds4-bench, ds4-eval, ds4-agent, and cpu targets depend on libllguidance.a when LLGUIDANCE=1, so that `cargo build` runs before linking. Previously only ds4-server triggered the build via ds4_llguidance.o, causing other binaries to fail linking against a nonexistent library. --- Makefile | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index 9cad30654..41631d33f 100644 --- a/Makefile +++ b/Makefile @@ -71,22 +71,22 @@ help: @echo " make test Build and run tests" @echo " make clean Remove build outputs" -ds4: ds4_cli.o linenoise.o $(CORE_OBJS) +ds4: ds4_cli.o linenoise.o $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(CC) $(CFLAGS) -o $@ ds4_cli.o linenoise.o $(CORE_OBJS) $(METAL_LDLIBS) -ds4-server: ds4_server.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) +ds4-server: ds4_server.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(CC) $(CFLAGS) -o $@ ds4_server.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) $(METAL_LDLIBS) -ds4-bench: ds4_bench.o $(CORE_OBJS) +ds4-bench: ds4_bench.o $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(CC) $(CFLAGS) -o $@ ds4_bench.o $(CORE_OBJS) $(METAL_LDLIBS) -ds4-eval: ds4_eval.o $(CORE_OBJS) +ds4-eval: ds4_eval.o $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(CC) $(CFLAGS) -o $@ ds4_eval.o $(CORE_OBJS) $(METAL_LDLIBS) -ds4-agent: ds4_agent.o ds4_web.o ds4_kvstore.o linenoise.o $(CORE_OBJS) +ds4-agent: ds4_agent.o ds4_web.o ds4_kvstore.o linenoise.o $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(CC) $(CFLAGS) -o $@ ds4_agent.o ds4_web.o ds4_kvstore.o linenoise.o $(CORE_OBJS) $(METAL_LDLIBS) -cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o ds4_eval_cpu.o ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) +cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o ds4_eval_cpu.o ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(CC) $(CFLAGS) -o ds4 ds4_cli_cpu.o linenoise.o $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-server ds4_server_cpu.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-bench ds4_bench_cpu.o $(CPU_CORE_OBJS) $(LDLIBS) @@ -122,22 +122,22 @@ cuda: fi $(MAKE) -B ds4 ds4-server ds4-bench ds4-eval ds4-agent CUDA_ARCH="$(CUDA_ARCH)" -ds4: ds4_cli.o linenoise.o $(CORE_OBJS) +ds4: ds4_cli.o linenoise.o $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) -ds4-server: ds4_server.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) +ds4-server: ds4_server.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) -ds4-bench: ds4_bench.o $(CORE_OBJS) +ds4-bench: ds4_bench.o $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) -ds4-eval: ds4_eval.o $(CORE_OBJS) +ds4-eval: ds4_eval.o $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) -ds4-agent: ds4_agent.o ds4_web.o ds4_kvstore.o linenoise.o $(CORE_OBJS) +ds4-agent: ds4_agent.o ds4_web.o ds4_kvstore.o linenoise.o $(CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(NVCC) $(NVCCFLAGS) -o $@ $^ $(CUDA_LDLIBS) -cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o ds4_eval_cpu.o ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) +cpu: ds4_cli_cpu.o ds4_server_cpu.o ds4_bench_cpu.o ds4_eval_cpu.o ds4_agent_cpu.o ds4_web.o ds4_kvstore.o linenoise.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) $(DS4_LLGUIDANCE_DEPS) $(CC) $(CFLAGS) -o ds4 ds4_cli_cpu.o linenoise.o $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-server ds4_server_cpu.o ds4_kvstore.o rax.o $(SERVER_EXTRA_OBJS) $(CPU_CORE_OBJS) $(LDLIBS) $(CC) $(CFLAGS) -o ds4-bench ds4_bench_cpu.o $(CPU_CORE_OBJS) $(LDLIBS) From 96c24df2662158bc4a51a09ef003b90cf855d56e Mon Sep 17 00:00:00 2001 From: fry69 <142489379+fry69@users.noreply.github.com> Date: Fri, 29 May 2026 20:51:18 +0200 Subject: [PATCH 3/9] build: add distclean target to remove .deps Introduce a distclean target that runs clean and then removes the .deps directory (cloned llguidance source + Rust build artifacts). This avoids forcing a re-clone on every `make clean` while giving users an explicit way to fully reset when needed. --- Makefile | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 41631d33f..d14d1009c 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ CUDA_LDLIBS += $(LLGUIDANCE_LDLIBS) METAL_LDLIBS := $(LDLIBS) endif -.PHONY: all help clean test cpu cuda cuda-spark cuda-generic cuda-regression +.PHONY: all help clean distclean test cpu cuda cuda-spark cuda-generic cuda-regression ifeq ($(UNAME_S),Darwin) all: ds4 ds4-server ds4-bench ds4-eval ds4-agent @@ -241,3 +241,6 @@ test: ds4_test ds4-eval clean: rm -f ds4 ds4-server ds4-bench ds4-eval ds4-agent ds4_cpu ds4_native ds4_server_test ds4_test *.o tests/cuda_long_context_smoke tests/cuda_long_context_smoke.o + +distclean: clean + rm -rf .deps From afab720363c295fb147333a5b38e6145e291ed64 Mon Sep 17 00:00:00 2001 From: Pasquale Minervini Date: Fri, 29 May 2026 22:52:25 +0100 Subject: [PATCH 4/9] build: make llguidance path explicit --- Makefile | 8 ++------ README.md | 4 ++++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index d14d1009c..76b492b3a 100644 --- a/Makefile +++ b/Makefile @@ -14,19 +14,15 @@ OBJCFLAGS ?= -O3 -ffast-math $(DEBUG_FLAGS) $(NATIVE_CPU_FLAG) -Wall -Wextra -fo LDLIBS ?= -lm -pthread METAL_SRCS := $(wildcard metal/*.metal) LLGUIDANCE ?= 0 +LLGUIDANCE_DIR ?= .deps/llguidance LLGUIDANCE_REPO ?= https://github.com/guidance-ai/llguidance LLGUIDANCE_TAG ?= v1.7.5 SERVER_EXTRA_OBJS := ds4_llguidance.o ifeq ($(LLGUIDANCE),1) -ifeq ($(strip $(LLGUIDANCE_DIR)),) -ifneq ($(wildcard ../../llguidance/parser/llguidance.h),) -LLGUIDANCE_DIR := ../../llguidance -else -LLGUIDANCE_DIR := .deps/llguidance +ifeq ($(LLGUIDANCE_DIR),.deps/llguidance) LLGUIDANCE_NEEDS_CLONE := 1 endif -endif LLGUIDANCE_LIB := $(LLGUIDANCE_DIR)/target/release/libllguidance.a LLGUIDANCE_LDLIBS := $(LLGUIDANCE_LIB) ifneq ($(UNAME_S),Darwin) diff --git a/README.md b/README.md index a4aace8b5..26c8a7b8e 100644 --- a/README.md +++ b/README.md @@ -646,6 +646,10 @@ Structured outputs are available when the server is built with llguidance: make LLGUIDANCE=1 ``` +By default, this clones llguidance into `.deps/llguidance` and builds the +static library there. To use an existing checkout instead, pass +`LLGUIDANCE_DIR=/path/to/llguidance`. + With that build, `/v1/chat/completions` supports `response_format.type=json_schema` and `response_format.type=json_object`; `/v1/responses` supports the same modes through `text.format`. Structured From 50b8035e21793b6ae3aee11d990854db186fd385 Mon Sep 17 00:00:00 2001 From: Pasquale Minervini Date: Sat, 30 May 2026 12:04:17 +0100 Subject: [PATCH 5/9] structured outputs: expose llguidance format types Expose regex, Lark, and llguidance structured-output formats through the existing Chat Completions and Responses structured-output surfaces, reusing the current llguidance constrained decoder. --- README.md | 8 +- ds4_llguidance.c | 15 +- ds4_server.c | 198 +++++++- stress-test-cli.py | 1086 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1294 insertions(+), 13 deletions(-) create mode 100755 stress-test-cli.py diff --git a/README.md b/README.md index 26c8a7b8e..6d55447ce 100644 --- a/README.md +++ b/README.md @@ -651,10 +651,10 @@ static library there. To use an existing checkout instead, pass `LLGUIDANCE_DIR=/path/to/llguidance`. With that build, `/v1/chat/completions` supports -`response_format.type=json_schema` and `response_format.type=json_object`; -`/v1/responses` supports the same modes through `text.format`. Structured -outputs use constrained decoding, disable thinking for that turn, and currently -cannot be combined with tools. +`response_format.type=json_schema`, `json_object`, `regex`, `lark`, and +`llguidance`; `/v1/responses` supports the same modes through `text.format`. +Structured outputs use constrained decoding, disable thinking for that turn, +and currently cannot be combined with tools. `/v1/messages` is the Anthropic-compatible endpoint used by Claude Code style clients. It accepts `system`, `messages`, `tools`, `tool_choice`, `max_tokens`, diff --git a/ds4_llguidance.c b/ds4_llguidance.c index 1a39ea3e8..5c6854d2f 100644 --- a/ds4_llguidance.c +++ b/ds4_llguidance.c @@ -24,6 +24,7 @@ struct ds4_llguidance { size_t mask_words; int n_vocab; int eos_token; + bool deny_leading_ws; bool started; #else int unused; @@ -109,6 +110,13 @@ static bool token_text_is_special(const char *p, size_t len) { return false; } +static bool constraint_uses_json_leading_ws_rule(const char *constraint_type) { + return constraint_type && + (!strcmp(constraint_type, "json") || + !strcmp(constraint_type, "json_schema") || + !strcmp(constraint_type, "json_object")); +} + static void bitset_set(uint32_t *mask, int token) { mask[(uint32_t)token / 32u] |= UINT32_C(1) << ((uint32_t)token & 31u); } @@ -334,6 +342,8 @@ ds4_llguidance *ds4_llguidance_create(ds4_engine *e, g->mask_words = mask_bytes / sizeof(uint32_t); g->n_vocab = n_vocab; g->eos_token = ds4_token_eos(e); + g->deny_leading_ws = + constraint_uses_json_leading_ws_rule(constraint_type); g->started = false; return g; } @@ -372,7 +382,8 @@ int ds4_llguidance_sample(ds4_llguidance *g, const uint32_t *deny = NULL; size_t deny_words = 0; - if (!g->started && + if (g->deny_leading_ws && + !g->started && mask_has_non_denied_token(allow, g->mask_words, g->leading_ws_mask, g->leading_ws_words, g->n_vocab)) { @@ -399,7 +410,7 @@ bool ds4_llguidance_accept(ds4_llguidance *g, llg_matcher_get_error(g->matcher)); return false; } - if (!g->started && e) { + if (g->deny_leading_ws && !g->started && e) { size_t len = 0; char *piece = ds4_token_text(e, token, &len); if (bytes_have_non_json_ws(piece, len)) g->started = true; diff --git a/ds4_server.c b/ds4_server.c index add158ca4..de9acbc03 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -407,6 +407,9 @@ typedef enum { DS4_TEXT_FORMAT_TEXT, DS4_TEXT_FORMAT_JSON_OBJECT, DS4_TEXT_FORMAT_JSON_SCHEMA, + DS4_TEXT_FORMAT_REGEX, + DS4_TEXT_FORMAT_LARK, + DS4_TEXT_FORMAT_LLGUIDANCE, } ds4_text_format_type; typedef struct { @@ -423,9 +426,12 @@ static void ds4_text_format_clear(ds4_text_format *f) { memset(f, 0, sizeof(*f)); } -static bool ds4_text_format_is_json(const ds4_text_format *f) { +static bool ds4_text_format_is_structured(const ds4_text_format *f) { return f && (f->type == DS4_TEXT_FORMAT_JSON_OBJECT || - f->type == DS4_TEXT_FORMAT_JSON_SCHEMA); + f->type == DS4_TEXT_FORMAT_JSON_SCHEMA || + f->type == DS4_TEXT_FORMAT_REGEX || + f->type == DS4_TEXT_FORMAT_LARK || + f->type == DS4_TEXT_FORMAT_LLGUIDANCE); } static void ds4_text_format_set_schema(ds4_text_format *f, @@ -440,12 +446,21 @@ static void ds4_text_format_set_schema(ds4_text_format *f, f->strict = strict; } +static void ds4_text_format_set_constraint(ds4_text_format *f, + ds4_text_format_type type, + char *constraint_data) { + ds4_text_format_set_schema(f, type, NULL, constraint_data, false); +} + static const char *ds4_text_format_constraint_type(const ds4_text_format *f) { if (!f) return "text"; if (f->type == DS4_TEXT_FORMAT_JSON_SCHEMA) return "json_schema"; if (f->type == DS4_TEXT_FORMAT_JSON_OBJECT) { return f->schema_json ? "json_schema" : "json_object"; } + if (f->type == DS4_TEXT_FORMAT_REGEX) return "regex"; + if (f->type == DS4_TEXT_FORMAT_LARK) return "lark"; + if (f->type == DS4_TEXT_FORMAT_LLGUIDANCE) return "llguidance"; return "text"; } @@ -457,7 +472,7 @@ static bool ds4_text_format_validate_with_llguidance(ds4_engine *e, const ds4_text_format *f, char *err, size_t errlen) { - if (!ds4_text_format_is_json(f)) return true; + if (!ds4_text_format_is_structured(f)) return true; if (!ds4_llguidance_available()) { snprintf(err, errlen, "structured outputs require building ds4 with LLGUIDANCE=1"); @@ -472,7 +487,7 @@ static bool ds4_text_format_validate_with_llguidance(ds4_engine *e, llg_err, sizeof(llg_err)); if (!g) { - snprintf(err, errlen, "invalid structured output schema: %s", + snprintf(err, errlen, "invalid structured output constraint: %s", llg_err[0] ? llg_err : "llguidance rejected constraint"); return false; } @@ -556,6 +571,8 @@ static bool parse_chat_response_format(const char **p, char *type = NULL; char *schema = NULL; + char *regex = NULL; + char *grammar = NULL; char *name = NULL; bool strict = false; bool saw_json_schema = false; @@ -587,6 +604,18 @@ static bool parse_chat_response_format(const char **p, free(key); goto bad; } + } else if (!strcmp(key, "regex")) { + free(regex); + if (!json_string(p, ®ex)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "grammar")) { + free(grammar); + if (!json_string(p, &grammar)) { + free(key); + goto bad; + } } else if (!strcmp(key, "name")) { free(name); if (!json_string(p, &name)) { @@ -632,6 +661,27 @@ static bool parse_chat_response_format(const char **p, snprintf(err, errlen, "response_format json_schema.schema is required"); goto bad_keep_err; } + } else if (!strcmp(type, "regex")) { + if (!regex) { + snprintf(err, errlen, "response_format.regex is required"); + goto bad_keep_err; + } + ds4_text_format_set_constraint(format, DS4_TEXT_FORMAT_REGEX, regex); + regex = NULL; + } else if (!strcmp(type, "lark")) { + if (!grammar) { + snprintf(err, errlen, "response_format.grammar is required"); + goto bad_keep_err; + } + ds4_text_format_set_constraint(format, DS4_TEXT_FORMAT_LARK, grammar); + grammar = NULL; + } else if (!strcmp(type, "llguidance")) { + if (!grammar) { + snprintf(err, errlen, "response_format.grammar is required"); + goto bad_keep_err; + } + ds4_text_format_set_constraint(format, DS4_TEXT_FORMAT_LLGUIDANCE, grammar); + grammar = NULL; } else { snprintf(err, errlen, "response_format.type=%s not supported", type); goto bad_keep_err; @@ -640,6 +690,8 @@ static bool parse_chat_response_format(const char **p, free(type); free(name); free(schema); + free(regex); + free(grammar); return true; bad: snprintf(err, errlen, "invalid response_format"); @@ -647,6 +699,8 @@ static bool parse_chat_response_format(const char **p, free(type); free(name); free(schema); + free(regex); + free(grammar); return false; } @@ -664,6 +718,8 @@ static bool parse_responses_text_format_object(const char **p, char *type = NULL; char *name = NULL; char *schema = NULL; + char *regex = NULL; + char *grammar = NULL; bool strict = false; json_ws(p); while (**p && **p != '}') { @@ -693,6 +749,18 @@ static bool parse_responses_text_format_object(const char **p, free(key); goto bad; } + } else if (!strcmp(key, "regex")) { + free(regex); + if (!json_string(p, ®ex)) { + free(key); + goto bad; + } + } else if (!strcmp(key, "grammar")) { + free(grammar); + if (!json_string(p, &grammar)) { + free(key); + goto bad; + } } else if (!strcmp(key, "strict")) { if (!json_bool(p, &strict)) { free(key); @@ -731,6 +799,27 @@ static bool parse_responses_text_format_object(const char **p, name, schema, strict); name = NULL; schema = NULL; + } else if (!strcmp(type, "regex")) { + if (!regex) { + snprintf(err, errlen, "text.format.regex is required"); + goto bad_keep_err; + } + ds4_text_format_set_constraint(format, DS4_TEXT_FORMAT_REGEX, regex); + regex = NULL; + } else if (!strcmp(type, "lark")) { + if (!grammar) { + snprintf(err, errlen, "text.format.grammar is required"); + goto bad_keep_err; + } + ds4_text_format_set_constraint(format, DS4_TEXT_FORMAT_LARK, grammar); + grammar = NULL; + } else if (!strcmp(type, "llguidance")) { + if (!grammar) { + snprintf(err, errlen, "text.format.grammar is required"); + goto bad_keep_err; + } + ds4_text_format_set_constraint(format, DS4_TEXT_FORMAT_LLGUIDANCE, grammar); + grammar = NULL; } else { snprintf(err, errlen, "text.format.type=%s not supported", type); goto bad_keep_err; @@ -739,6 +828,8 @@ static bool parse_responses_text_format_object(const char **p, free(type); free(name); free(schema); + free(regex); + free(grammar); return true; bad: snprintf(err, errlen, "invalid text.format"); @@ -746,6 +837,8 @@ static bool parse_responses_text_format_object(const char **p, free(type); free(name); free(schema); + free(regex); + free(grammar); return false; } @@ -3164,7 +3257,7 @@ static bool parse_chat_request(ds4_engine *e, server *s, const char *body, int d return false; } r->has_tools = tool_schemas && tool_schemas[0] && !tool_choice_none; - if (ds4_text_format_is_json(&r->text_format)) { + if (ds4_text_format_is_structured(&r->text_format)) { if (r->has_tools) { snprintf(err, errlen, "structured outputs with tools are not supported"); @@ -4332,7 +4425,7 @@ static bool parse_responses_request(ds4_engine *e, server *s, const char *body, (!tool_choice_none && combined_tool_schemas.len) ? combined_tool_schemas.ptr : NULL; r->has_tools = active_tool_schemas && active_tool_schemas[0]; - if (ds4_text_format_is_json(&r->text_format)) { + if (ds4_text_format_is_structured(&r->text_format)) { if (r->has_tools) { snprintf(err, errlen, "structured outputs with tools are not supported"); @@ -6421,7 +6514,8 @@ static bool request_uses_structured_stream(const request *r) { } static bool request_uses_structured_decoder(const request *r) { - return r && r->kind == REQ_CHAT && ds4_text_format_is_json(&r->text_format); + return r && r->kind == REQ_CHAT && + ds4_text_format_is_structured(&r->text_format); } /* Codex' Responses API uses 24-hex suffixes for response/item ids. Prefix @@ -12305,6 +12399,50 @@ static void test_parse_chat_response_format_json_object(void) { ds4_text_format_clear(&fmt); } +static void test_parse_chat_response_format_llguidance_extensions(void) { + const struct { + const char *json; + ds4_text_format_type type; + const char *constraint_type; + const char *needle; + } cases[] = { + { + "{\"type\":\"regex\",\"regex\":\"INV-[0-9]{4}\"}", + DS4_TEXT_FORMAT_REGEX, + "regex", + "INV-" + }, + { + "{\"type\":\"lark\",\"grammar\":\"%llguidance {}\\nstart: /OK/\"}", + DS4_TEXT_FORMAT_LARK, + "lark", + "start:" + }, + { + "{\"type\":\"llguidance\",\"grammar\":\"{\\\"grammars\\\":[]}\"}", + DS4_TEXT_FORMAT_LLGUIDANCE, + "llguidance", + "grammars" + }, + }; + + for (size_t i = 0; i < sizeof(cases) / sizeof(cases[0]); i++) { + const char *p = cases[i].json; + ds4_text_format fmt = {0}; + char err[160] = {0}; + + TEST_ASSERT(parse_chat_response_format(&p, &fmt, err, sizeof(err))); + TEST_ASSERT(fmt.type == cases[i].type); + TEST_ASSERT(!strcmp(ds4_text_format_constraint_type(&fmt), + cases[i].constraint_type)); + TEST_ASSERT(fmt.schema_json && strstr(fmt.schema_json, cases[i].needle)); + json_ws(&p); + TEST_ASSERT(*p == '\0'); + + ds4_text_format_clear(&fmt); + } +} + static void test_parse_chat_response_format_rejects_missing_schema(void) { const char *json = "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"bad\"}}"; const char *p = json; @@ -12354,6 +12492,50 @@ static void test_parse_responses_text_format_json_object(void) { ds4_text_format_clear(&fmt); } +static void test_parse_responses_text_format_llguidance_extensions(void) { + const struct { + const char *json; + ds4_text_format_type type; + const char *constraint_type; + const char *needle; + } cases[] = { + { + "{\"format\":{\"type\":\"regex\",\"regex\":\"INV-[0-9]{4}\"}}", + DS4_TEXT_FORMAT_REGEX, + "regex", + "INV-" + }, + { + "{\"format\":{\"type\":\"lark\",\"grammar\":\"%llguidance {}\\nstart: /OK/\"}}", + DS4_TEXT_FORMAT_LARK, + "lark", + "start:" + }, + { + "{\"format\":{\"type\":\"llguidance\",\"grammar\":\"{\\\"grammars\\\":[]}\"}}", + DS4_TEXT_FORMAT_LLGUIDANCE, + "llguidance", + "grammars" + }, + }; + + for (size_t i = 0; i < sizeof(cases) / sizeof(cases[0]); i++) { + const char *p = cases[i].json; + ds4_text_format fmt = {0}; + char err[160] = {0}; + + TEST_ASSERT(parse_responses_text_value(&p, &fmt, err, sizeof(err))); + TEST_ASSERT(fmt.type == cases[i].type); + TEST_ASSERT(!strcmp(ds4_text_format_constraint_type(&fmt), + cases[i].constraint_type)); + TEST_ASSERT(fmt.schema_json && strstr(fmt.schema_json, cases[i].needle)); + json_ws(&p); + TEST_ASSERT(*p == '\0'); + + ds4_text_format_clear(&fmt); + } +} + static void test_parse_responses_text_format_rejects_unknown_type(void) { const char *json = "{\"format\":{\"type\":\"xml\"}}"; const char *p = json; @@ -16168,9 +16350,11 @@ static void ds4_server_unit_tests_run(void) { test_render_chat_prompt_text_renders_tools_before_system(); test_parse_chat_response_format_json_schema(); test_parse_chat_response_format_json_object(); + test_parse_chat_response_format_llguidance_extensions(); test_parse_chat_response_format_rejects_missing_schema(); test_parse_responses_text_format_json_schema(); test_parse_responses_text_format_json_object(); + test_parse_responses_text_format_llguidance_extensions(); test_parse_responses_text_format_rejects_unknown_type(); test_parse_responses_text_format_text_is_noop(); test_tool_schema_order_from_anthropic_schema(); diff --git a/stress-test-cli.py b/stress-test-cli.py new file mode 100755 index 000000000..af853496d --- /dev/null +++ b/stress-test-cli.py @@ -0,0 +1,1086 @@ +#!/usr/bin/env python3 +"""Stress structured-output decoding across ds4 and llama.cpp servers. + +The OpenAI-compatible surfaces standardize JSON Schema structured outputs and +JSON mode. llguidance itself supports a wider set of grammar tags: JSON Schema, +JSON object, regex, Lark, and the internal guidance grammar-list wire format. + +This script keeps those layers explicit: + +* ds4 is exercised through /v1/chat/completions and /v1/responses with the + json_schema/json_object request shapes and ds4's llguidance extension types. +* llama.cpp is exercised with the same OpenAI-compatible JSON cases and, for + broader llguidance grammar-family cases, with llama.cpp's top-level grammar + request extension. +* Unsupported target/API/case combinations are reported as SKIP by default. Use + --strict-skips to make them fail the run, or --force-extensions to send + experimental non-OpenAI response_format types to targets that do not expose + them by default. + +Examples: + python3 stress-test-cli.py + + python3 stress-test-cli.py --start never \ + --ds4-base-url http://127.0.0.1:8000/v1 \ + --llama-base-url http://127.0.0.1:8080/v1 + + python3 stress-test-cli.py --targets llama --families regex,lark,llguidance \ + --llama-hf-model unsloth/Qwen3.5-9B-GGUF:Q8_0 +""" + +from __future__ import annotations + +import argparse +import datetime as _dt +import json +import math +import os +import re +import shlex +import subprocess +import sys +import tempfile +import time +import urllib.error +import urllib.parse +import urllib.request +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable + + +DEFAULT_DS4_BASE_URL = "http://127.0.0.1:8000/v1" +DEFAULT_LLAMA_BASE_URL = "http://127.0.0.1:8080/v1" +DEFAULT_LLAMA_HF_MODEL = "unsloth/Qwen3.5-9B-GGUF:Q8_0" + + +class ValidationError(Exception): + pass + + +class UnsupportedCase(Exception): + pass + + +Validator = Callable[[str], str] + + +@dataclass(frozen=True) +class Case: + name: str + family: str + prompt: str + validator: Validator + schema: dict[str, Any] | None = None + data: str = "" + llama_grammar: str | None = None + oracle_sample: str | None = None + max_tokens: int = 192 + + +@dataclass +class Target: + name: str + base_url: str + model: str + command: list[str] | None + cwd: Path + supports_response_format_extensions: bool + supports_grammar_extension: bool + process: subprocess.Popen[str] | None = None + log_path: Path | None = None + started_by_us: bool = False + + +@dataclass +class Counts: + passed: int = 0 + failed: int = 0 + skipped: int = 0 + + +def compact_json(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, separators=(",", ":")) + + +def type_matches(value: Any, typ: str) -> bool: + if typ == "object": + return isinstance(value, dict) + if typ == "array": + return isinstance(value, list) + if typ == "string": + return isinstance(value, str) + if typ == "integer": + return isinstance(value, int) and not isinstance(value, bool) + if typ == "number": + return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool) + if typ == "boolean": + return isinstance(value, bool) + if typ == "null": + return value is None + return True + + +def _validate_format(value: str, fmt: str, path: str) -> None: + if fmt == "date": + try: + _dt.date.fromisoformat(value) + except ValueError as exc: + raise ValidationError(f"{path}: expected RFC3339 date, got {value!r}") from exc + elif fmt == "time": + try: + _dt.time.fromisoformat(value.replace("Z", "+00:00")) + except ValueError as exc: + raise ValidationError(f"{path}: expected RFC3339 time, got {value!r}") from exc + elif fmt == "date-time": + try: + _dt.datetime.fromisoformat(value.replace("Z", "+00:00")) + except ValueError as exc: + raise ValidationError(f"{path}: expected RFC3339 date-time, got {value!r}") from exc + elif fmt == "email": + if re.fullmatch(r"[^@\s]+@[^@\s]+\.[^@\s]+", value) is None: + raise ValidationError(f"{path}: expected email, got {value!r}") + elif fmt == "uuid": + if re.fullmatch( + r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-" + r"[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}", + value, + ) is None: + raise ValidationError(f"{path}: expected uuid, got {value!r}") + + +def validate_schema(value: Any, schema: dict[str, Any], path: str = "$") -> None: + if "allOf" in schema: + for option in schema["allOf"]: + validate_schema(value, option, path) + + if "anyOf" in schema: + errors: list[str] = [] + for option in schema["anyOf"]: + try: + validate_schema(value, option, path) + return + except ValidationError as exc: + errors.append(str(exc)) + raise ValidationError(f"{path}: did not match anyOf: {'; '.join(errors)}") + + if "oneOf" in schema: + matches = 0 + errors: list[str] = [] + for option in schema["oneOf"]: + try: + validate_schema(value, option, path) + matches += 1 + except ValidationError as exc: + errors.append(str(exc)) + if matches != 1: + raise ValidationError(f"{path}: expected exactly one oneOf match, got {matches}: {'; '.join(errors)}") + return + + if "const" in schema and value != schema["const"]: + raise ValidationError(f"{path}: expected const {schema['const']!r}, got {value!r}") + if "enum" in schema and value not in schema["enum"]: + raise ValidationError(f"{path}: expected one of {schema['enum']!r}, got {value!r}") + + typ = schema.get("type") + if isinstance(typ, list): + if not any(type_matches(value, t) for t in typ): + raise ValidationError(f"{path}: wrong type {type(value).__name__}, expected {typ!r}") + elif isinstance(typ, str) and not type_matches(value, typ): + raise ValidationError(f"{path}: wrong type {type(value).__name__}, expected {typ!r}") + + if typ == "object" or "properties" in schema: + if not isinstance(value, dict): + raise ValidationError(f"{path}: expected object") + props = schema.get("properties", {}) + for key in schema.get("required", []): + if key not in value: + raise ValidationError(f"{path}: missing required property {key!r}") + min_props = schema.get("minProperties") + max_props = schema.get("maxProperties") + if min_props is not None and len(value) < min_props: + raise ValidationError(f"{path}: expected at least {min_props} properties") + if max_props is not None and len(value) > max_props: + raise ValidationError(f"{path}: expected at most {max_props} properties") + if schema.get("additionalProperties") is False: + extra = sorted(set(value) - set(props)) + if extra: + raise ValidationError(f"{path}: extra properties {extra!r}") + for key, sub in props.items(): + if key in value: + validate_schema(value[key], sub, f"{path}.{key}") + + if typ == "array" or "items" in schema or "prefixItems" in schema: + if not isinstance(value, list): + raise ValidationError(f"{path}: expected array") + min_items = schema.get("minItems") + max_items = schema.get("maxItems") + if min_items is not None and len(value) < min_items: + raise ValidationError(f"{path}: expected at least {min_items} items") + if max_items is not None and len(value) > max_items: + raise ValidationError(f"{path}: expected at most {max_items} items") + prefix_items = schema.get("prefixItems") + if isinstance(prefix_items, list): + for i, sub in enumerate(prefix_items): + if i < len(value): + validate_schema(value[i], sub, f"{path}[{i}]") + items = schema.get("items") + if isinstance(items, dict): + start = len(prefix_items) if isinstance(prefix_items, list) else 0 + for i, item in enumerate(value[start:], start): + validate_schema(item, items, f"{path}[{i}]") + + if isinstance(value, str): + if "minLength" in schema and len(value) < schema["minLength"]: + raise ValidationError(f"{path}: string shorter than minLength {schema['minLength']}") + if "maxLength" in schema and len(value) > schema["maxLength"]: + raise ValidationError(f"{path}: string longer than maxLength {schema['maxLength']}") + if "pattern" in schema and re.fullmatch(schema["pattern"], value) is None: + raise ValidationError(f"{path}: {value!r} does not match {schema['pattern']!r}") + if "format" in schema: + _validate_format(value, schema["format"], path) + + if isinstance(value, (int, float)) and not isinstance(value, bool): + if "minimum" in schema and value < schema["minimum"]: + raise ValidationError(f"{path}: {value!r} is below minimum {schema['minimum']!r}") + if "maximum" in schema and value > schema["maximum"]: + raise ValidationError(f"{path}: {value!r} is above maximum {schema['maximum']!r}") + if "exclusiveMinimum" in schema and value <= schema["exclusiveMinimum"]: + raise ValidationError(f"{path}: {value!r} is not above exclusiveMinimum {schema['exclusiveMinimum']!r}") + if "exclusiveMaximum" in schema and value >= schema["exclusiveMaximum"]: + raise ValidationError(f"{path}: {value!r} is not below exclusiveMaximum {schema['exclusiveMaximum']!r}") + if "multipleOf" in schema: + q = value / schema["multipleOf"] + if not math.isclose(q, round(q), rel_tol=0.0, abs_tol=1e-9): + raise ValidationError(f"{path}: {value!r} is not a multiple of {schema['multipleOf']!r}") + + +def parse_json_strict(text: str) -> Any: + stripped = text.strip() + try: + return json.loads(stripped) + except json.JSONDecodeError as exc: + raise ValidationError(f"output is not JSON: {text!r}") from exc + + +def json_schema_validator(schema: dict[str, Any]) -> Validator: + def _validate(text: str) -> str: + value = parse_json_strict(text) + validate_schema(value, schema) + return compact_json(value) + + return _validate + + +def json_object_validator(text: str) -> str: + value = parse_json_strict(text) + if not isinstance(value, dict): + raise ValidationError(f"output is not a JSON object: {value!r}") + return compact_json(value) + + +def regex_validator(pattern: str) -> Validator: + rx = re.compile(pattern) + + def _validate(text: str) -> str: + value = text.strip() + if rx.fullmatch(value) is None: + raise ValidationError(f"{value!r} does not match /{pattern}/") + return value + + return _validate + + +def choice_validator(choices: set[str]) -> Validator: + def _validate(text: str) -> str: + value = text.strip() + if value not in choices: + raise ValidationError(f"{value!r} is not one of {sorted(choices)!r}") + return value + + return _validate + + +def permutation_validator(chars: str) -> Validator: + expected = sorted(chars) + + def _validate(text: str) -> str: + value = text.strip() + if sorted(value) != expected or len(value) != len(chars): + raise ValidationError(f"{value!r} is not a permutation of {chars!r}") + return value + + return _validate + + +def substring_chunk_validator(prefix: str, words: list[str]) -> Validator: + allowed = {""} + for i in range(len(words)): + for j in range(i + 1, len(words) + 1): + allowed.add(" ".join(words[i:j])) + + def _validate(text: str) -> str: + value = text.strip() + if not value.startswith(prefix): + raise ValidationError(f"{value!r} does not start with {prefix!r}") + tail = value[len(prefix):] + if tail not in allowed: + raise ValidationError(f"{tail!r} is not an allowed contiguous word substring") + return value + + return _validate + + +def make_cases() -> list[Case]: + calendar_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "minLength": 1, "maxLength": 80}, + "date": {"type": "string", "format": "date"}, + "participants": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + "maxItems": 5, + }, + }, + "required": ["name", "date", "participants"], + "additionalProperties": False, + } + status_schema = { + "type": "object", + "properties": { + "status": {"const": "ok"}, + "priority": {"type": "string", "enum": ["low", "medium", "high"]}, + "retry_count": {"type": "integer", "minimum": 0, "maximum": 5}, + "active": {"type": "boolean"}, + }, + "required": ["status", "priority", "retry_count", "active"], + "additionalProperties": False, + } + ticket_schema = { + "type": "object", + "properties": { + "id": {"type": "string", "pattern": "TCK-[0-9]{3}"}, + "owner": {"anyOf": [{"type": "string", "minLength": 1}, {"type": "null"}]}, + "priority": {"oneOf": [{"const": "low"}, {"const": "medium"}, {"const": "high"}]}, + "contact": {"type": "string", "format": "email"}, + }, + "required": ["id", "owner", "priority", "contact"], + "additionalProperties": False, + } + reading_schema = { + "type": "object", + "properties": { + "reading": { + "type": "array", + "prefixItems": [ + {"const": "temperature_c"}, + {"type": "number", "minimum": -40, "exclusiveMaximum": 80, "multipleOf": 0.5}, + ], + "minItems": 2, + "maxItems": 2, + }, + "tags": { + "type": "array", + "items": {"type": "string", "pattern": "[a-z]{3,8}"}, + "minItems": 2, + "maxItems": 3, + }, + }, + "required": ["reading", "tags"], + "additionalProperties": False, + } + + inline_json_schema = { + "type": "object", + "properties": { + "kind": {"const": "metric"}, + "value": {"type": "integer", "minimum": 1, "maximum": 9}, + }, + "required": ["kind", "value"], + "additionalProperties": False, + } + inline_json_lark = f"""%llguidance {{}} +start: %json {compact_json(inline_json_schema)} +""" + choice_lark = """%llguidance {} +start: "red" | "green" | "blue" +""" + regex_lark = """%llguidance {} +start: /INV-[0-9]{4}/ +""" + regex_ext_lark = """%llguidance {} +start: "chunk:" %regex { "substring_words": "alpha beta gamma delta" } +""" + parametric_lark = """%llguidance {} +start: perm::0x0 +perm::_: "" %if is_ones([0:3]) + | "a" perm::set_bit(0) %if bit_clear(0) + | "b" perm::set_bit(1) %if bit_clear(1) + | "c" perm::set_bit(2) %if bit_clear(2) +""" + guidance_lark = """%llguidance {} +start: "YES" | "NO" +""" + guidance_wire = compact_json({"grammars": [{"lark_grammar": guidance_lark}]}) + + return [ + Case( + name="json_schema_calendar", + family="json_schema", + prompt="Return one lunch calendar event for Alice and Bob on 2026-06-01. Return only JSON.", + schema=calendar_schema, + validator=json_schema_validator(calendar_schema), + data=compact_json(calendar_schema), + oracle_sample='{"name":"Lunch","date":"2026-06-01","participants":["Alice","Bob"]}', + ), + Case( + name="json_schema_status", + family="json_schema", + prompt="Return a compact health-check object with status ok. Return only JSON.", + schema=status_schema, + validator=json_schema_validator(status_schema), + data=compact_json(status_schema), + oracle_sample='{"status":"ok","priority":"medium","retry_count":2,"active":true}', + ), + Case( + name="json_schema_anyof_oneof_format", + family="json_schema", + prompt=( + "Return one support ticket with an id like TCK-123, an owner or null, " + "one priority, and a contact email. Return only JSON." + ), + schema=ticket_schema, + validator=json_schema_validator(ticket_schema), + data=compact_json(ticket_schema), + oracle_sample='{"id":"TCK-123","owner":null,"priority":"high","contact":"ops@example.com"}', + ), + Case( + name="json_schema_tuple_numeric", + family="json_schema", + prompt="Return one sensor reading tuple and two short lowercase tags. Return only JSON.", + schema=reading_schema, + validator=json_schema_validator(reading_schema), + data=compact_json(reading_schema), + oracle_sample='{"reading":["temperature_c",21.5],"tags":["lab","green"]}', + ), + Case( + name="json_object_mode", + family="json_object", + prompt="Return a JSON object with a tiny task description and whether it is done.", + validator=json_object_validator, + data="", + oracle_sample='{"task":"check","done":false}', + ), + Case( + name="regex_invoice_id", + family="regex", + prompt="Return exactly one invoice id in the form INV-0427. No quotes, no prose.", + validator=regex_validator(r"INV-[0-9]{4}"), + data=r"INV-[0-9]{4}", + llama_grammar=regex_lark, + oracle_sample="INV-0427", + max_tokens=32, + ), + Case( + name="lark_choice", + family="lark", + prompt="Return exactly one lowercase color: red, green, or blue. No quotes, no prose.", + validator=choice_validator({"red", "green", "blue"}), + data=choice_lark, + llama_grammar=choice_lark, + oracle_sample="green", + max_tokens=16, + ), + Case( + name="lark_inline_json", + family="lark", + prompt='Return a compact JSON object with kind "metric" and a small integer value.', + validator=json_schema_validator(inline_json_schema), + data=inline_json_lark, + llama_grammar=inline_json_lark, + oracle_sample='{"kind":"metric","value":7}', + max_tokens=64, + ), + Case( + name="lark_regex_ext_substring", + family="lark", + prompt="Return exactly chunk:beta gamma. No quotes, no prose.", + validator=substring_chunk_validator("chunk:", ["alpha", "beta", "gamma", "delta"]), + data=regex_ext_lark, + llama_grammar=regex_ext_lark, + oracle_sample="chunk:beta gamma", + max_tokens=32, + ), + Case( + name="lark_parametric_permutation", + family="lark", + prompt="Return exactly one permutation of the letters a, b, and c. No separators, no prose.", + validator=permutation_validator("abc"), + data=parametric_lark, + llama_grammar=parametric_lark, + oracle_sample="cab", + max_tokens=16, + ), + Case( + name="llguidance_internal_wire", + family="llguidance", + prompt="Return exactly YES or NO in uppercase. No punctuation, no prose.", + validator=choice_validator({"YES", "NO"}), + data=guidance_wire, + llama_grammar=guidance_lark, + oracle_sample="YES", + max_tokens=16, + ), + ] + + +def post_json(url: str, payload: dict[str, Any], timeout: float, api_key: str | None = None) -> dict[str, Any]: + data = json.dumps(payload, separators=(",", ":")).encode("utf-8") + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + req = urllib.request.Request(url, data=data, headers=headers, method="POST") + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + raw = resp.read().decode("utf-8", errors="replace") + except urllib.error.HTTPError as exc: + raw = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"HTTP {exc.code}: {raw[:1600]}") from exc + except urllib.error.URLError as exc: + raise RuntimeError(str(exc)) from exc + try: + body = json.loads(raw) + except json.JSONDecodeError as exc: + raise RuntimeError(f"invalid JSON response: {raw[:1600]}") from exc + if isinstance(body, dict) and body.get("error"): + raise RuntimeError(f"API error: {body['error']!r}") + return body + + +def get_status(url: str, timeout: float) -> tuple[int | None, str]: + req = urllib.request.Request(url, method="GET") + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + return resp.status, resp.read(512).decode("utf-8", errors="replace") + except urllib.error.HTTPError as exc: + return exc.code, exc.read(512).decode("utf-8", errors="replace") + except urllib.error.URLError as exc: + return None, str(exc) + + +def extract_chat_text(body: dict[str, Any]) -> str: + choices = body.get("choices") + if not isinstance(choices, list) or not choices: + raise RuntimeError(f"missing choices in chat response: {body!r}") + message = choices[0].get("message", {}) + content = message.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if isinstance(part, dict): + if isinstance(part.get("text"), str): + parts.append(part["text"]) + elif isinstance(part.get("content"), str): + parts.append(part["content"]) + if parts: + return "".join(parts) + raise RuntimeError(f"missing text content in chat response: {body!r}") + + +def extract_responses_text(body: dict[str, Any]) -> str: + if isinstance(body.get("output_text"), str): + return body["output_text"] + parts: list[str] = [] + for item in body.get("output", []): + if not isinstance(item, dict): + continue + if item.get("type") == "message": + for part in item.get("content", []): + if isinstance(part, dict) and isinstance(part.get("text"), str): + parts.append(part["text"]) + if parts: + return "".join(parts) + raise RuntimeError(f"missing output text in responses response: {body!r}") + + +def response_format_for_case(case: Case, target: Target, api: str, args: argparse.Namespace) -> dict[str, Any]: + if case.family == "json_object": + fmt: dict[str, Any] = {"type": "json_object"} + if args.json_object_schema: + fmt["schema"] = {"type": "object"} + return fmt + + if case.family == "json_schema": + if not case.schema: + raise RuntimeError(f"{case.name}: missing schema") + if api == "chat": + if target.name == "llama" and args.llama_chat_schema_style == "flat": + return { + "type": "json_schema", + "name": case.name, + "strict": True, + "schema": case.schema, + } + return { + "type": "json_schema", + "json_schema": { + "name": case.name, + "strict": True, + "schema": case.schema, + }, + } + return { + "type": "json_schema", + "name": case.name, + "strict": True, + "schema": case.schema, + } + + if not args.force_extensions and not target.supports_response_format_extensions: + raise UnsupportedCase(f"{target.name} does not expose {case.family!r} through OpenAI response_format") + + if case.family == "regex": + return {"type": "regex", "regex": case.data} + if case.family == "lark": + return {"type": "lark", "grammar": case.data} + if case.family == "llguidance": + return {"type": "llguidance", "grammar": case.data} + raise UnsupportedCase(f"unknown structured-output family {case.family!r}") + + +def add_llama_common_payload_fields(payload: dict[str, Any], target: Target, args: argparse.Namespace) -> None: + if target.name != "llama": + return + if args.llama_disable_thinking: + payload.setdefault("chat_template_kwargs", {})["enable_thinking"] = False + if args.seed is not None: + payload["seed"] = args.seed + + +def chat_payload(target: Target, case: Case, args: argparse.Namespace) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": target.model, + "messages": [{"role": "user", "content": case.prompt}], + "max_tokens": case.max_tokens, + "temperature": 0, + } + if (case.family in {"json_schema", "json_object"} or + args.force_extensions or + target.supports_response_format_extensions): + payload["response_format"] = response_format_for_case(case, target, "chat", args) + elif target.supports_grammar_extension and case.llama_grammar: + payload["grammar"] = case.llama_grammar + else: + raise UnsupportedCase(f"{target.name}/chat cannot carry {case.family!r}") + if args.seed is not None: + payload["seed"] = args.seed + add_llama_common_payload_fields(payload, target, args) + return payload + + +def responses_payload(target: Target, case: Case, args: argparse.Namespace) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": target.model, + "input": case.prompt, + "max_output_tokens": case.max_tokens, + "temperature": 0, + } + if (case.family in {"json_schema", "json_object"} or + args.force_extensions or + target.supports_response_format_extensions): + payload["text"] = {"format": response_format_for_case(case, target, "responses", args)} + elif target.supports_grammar_extension and case.llama_grammar: + payload["grammar"] = case.llama_grammar + else: + raise UnsupportedCase(f"{target.name}/responses cannot carry {case.family!r}") + if args.seed is not None: + payload["seed"] = args.seed + add_llama_common_payload_fields(payload, target, args) + return payload + + +def check_case(target: Target, api: str, case: Case, args: argparse.Namespace) -> str: + if api == "chat": + payload = chat_payload(target, case, args) + body = post_json(f"{target.base_url}/chat/completions", payload, args.timeout, args.api_key) + text = extract_chat_text(body) + elif api == "responses": + payload = responses_payload(target, case, args) + body = post_json(f"{target.base_url}/responses", payload, args.timeout, args.api_key) + text = extract_responses_text(body) + else: + raise RuntimeError(f"unknown api {api!r}") + return case.validator(text) + + +def split_csv(value: str) -> list[str]: + return [x.strip() for x in value.split(",") if x.strip()] + + +def flatten_extra_args(values: list[str] | None) -> list[str]: + out: list[str] = [] + for value in values or []: + out.extend(shlex.split(value)) + return out + + +def base_root(base_url: str) -> str: + if base_url.endswith("/v1"): + return base_url[:-3] + return base_url.rstrip("/") + + +def port_from_url(base_url: str, default: int) -> int: + parsed = urllib.parse.urlparse(base_url) + if parsed.port: + return parsed.port + if parsed.scheme == "https": + return 443 + if parsed.scheme == "http": + return 80 + return default + + +def target_is_ready(target: Target, timeout: float = 2.0) -> bool: + for url in (f"{base_root(target.base_url)}/health", f"{target.base_url}/models"): + status, _body = get_status(url, timeout) + if status == 200: + return True + return False + + +def wait_ready(target: Target, startup_timeout: float) -> None: + deadline = time.time() + startup_timeout + last = "" + while time.time() < deadline: + if target.process and target.process.poll() is not None: + log_hint = f" log={target.log_path}" if target.log_path else "" + raise RuntimeError(f"{target.name} server exited with code {target.process.returncode}.{log_hint}") + for url in (f"{base_root(target.base_url)}/health", f"{target.base_url}/models"): + status, body = get_status(url, 2.0) + last = f"{url}: {status} {body[:200]}" + if status == 200: + return + time.sleep(1.0) + log_hint = f" log={target.log_path}" if target.log_path else "" + raise RuntimeError(f"{target.name} did not become ready within {startup_timeout:.0f}s ({last}).{log_hint}") + + +def start_target_if_needed(target: Target, args: argparse.Namespace) -> None: + already_ready = target_is_ready(target) + if args.start == "never": + if not already_ready: + raise RuntimeError(f"{target.name} is not reachable at {target.base_url} and --start=never was set") + return + if already_ready and args.start != "always": + return + if not target.command: + raise RuntimeError(f"{target.name} has no launch command") + + log_file = tempfile.NamedTemporaryFile( + prefix=f"stress-{target.name}-", + suffix=".log", + mode="w", + encoding="utf-8", + delete=False, + ) + target.log_path = Path(log_file.name) + target.process = subprocess.Popen( + target.command, + cwd=str(target.cwd), + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + ) + target.started_by_us = True + wait_ready(target, args.startup_timeout) + + +def stop_target(target: Target) -> None: + if not target.process or target.process.poll() is not None: + return + target.process.terminate() + try: + target.process.wait(timeout=15) + except subprocess.TimeoutExpired: + target.process.kill() + target.process.wait(timeout=15) + + +def build_ds4_command(args: argparse.Namespace) -> list[str]: + if args.ds4_cmd: + return shlex.split(args.ds4_cmd.format(port=port_from_url(args.ds4_base_url, 8000))) + return [ + args.ds4_binary, + "--model", + args.ds4_model_path, + "--ctx", + str(args.server_ctx), + "--tokens", + str(args.server_default_tokens), + "--port", + str(port_from_url(args.ds4_base_url, 8000)), + *flatten_extra_args(args.ds4_extra_arg), + ] + + +def build_llama_command(args: argparse.Namespace) -> list[str]: + if args.llama_cmd: + return shlex.split(args.llama_cmd.format(port=port_from_url(args.llama_base_url, 8080))) + cmd = [ + args.llama_binary, + "-hf", + args.llama_hf_model, + "-c", + str(args.server_ctx), + "--port", + str(port_from_url(args.llama_base_url, 8080)), + "--jinja", + ] + if args.llama_ngl is not None: + cmd.extend(["-ngl", str(args.llama_ngl)]) + cmd.extend(flatten_extra_args(args.llama_extra_arg)) + return cmd + + +def selected_cases(args: argparse.Namespace) -> list[Case]: + cases = make_cases() + families = set(split_csv(args.families)) if args.families != "all" else set() + names = set(args.case or []) + out = [ + c for c in cases + if (not families or c.family in families) and (not names or c.name in names) + ] + missing = names - {c.name for c in cases} + if missing: + raise SystemExit(f"unknown case(s): {', '.join(sorted(missing))}") + known_families = {c.family for c in cases} + unknown_families = families - known_families + if unknown_families: + raise SystemExit(f"unknown family/families: {', '.join(sorted(unknown_families))}") + return out + + +def run_llguidance_oracle(cases: list[Case], args: argparse.Namespace) -> Counts: + counts = Counts() + if args.oracle == "never": + return counts + try: + import llguidance # type: ignore + except ModuleNotFoundError: + msg = "SKIP oracle/llguidance: python package is not importable" + if args.oracle == "require": + print(msg, file=sys.stderr) + counts.failed += 1 + elif args.verbose: + print(msg) + counts.skipped += 1 + return counts + + try: + tok = llguidance.LLTokenizer("byte") + except Exception as exc: + msg = f"SKIP oracle/llguidance: failed to create byte tokenizer: {exc}" + if args.oracle == "require": + print(msg, file=sys.stderr) + counts.failed += 1 + elif args.verbose: + print(msg) + counts.skipped += 1 + return counts + + for case in cases: + if case.family == "json_schema": + grammar = llguidance.LLMatcher.grammar_from_json_schema(case.schema) + elif case.family == "json_object": + grammar = llguidance.LLMatcher.grammar_from_json_schema({"type": "object"}) + elif case.family == "regex": + grammar = llguidance.LLMatcher.grammar_from_regex(case.data) + elif case.family == "lark": + grammar = llguidance.LLMatcher.grammar_from_lark(case.data) + elif case.family == "llguidance": + grammar = llguidance.grammar_from("llguidance", case.data) + else: + counts.skipped += 1 + continue + + label = f"oracle/{case.family}/{case.name}" + try: + err = llguidance.LLMatcher.validate_grammar(grammar, tok) + if err: + raise RuntimeError(err) + if case.oracle_sample is not None: + matcher = llguidance.LLMatcher(tok, grammar) + for token in tok.tokenize_str(case.oracle_sample): + bias = matcher.compute_logit_bias() + if token >= len(bias) or bias[token] != 200: + raise RuntimeError(f"sample token {token} is not allowed") + if not matcher.consume_token(token): + raise RuntimeError(f"sample token {token} was rejected") + if not matcher.is_accepting(): + raise RuntimeError("sample did not leave matcher in accepting state") + print(f"PASS {label}") + counts.passed += 1 + except Exception as exc: + print(f"FAIL {label}: {exc}", file=sys.stderr) + counts.failed += 1 + if args.fail_fast: + raise + return counts + + +def run_target(target: Target, cases: list[Case], args: argparse.Namespace) -> Counts: + counts = Counts() + start_target_if_needed(target, args) + apis = split_csv(args.apis) + for repeat in range(args.repeat): + for api in apis: + for case in cases: + label = f"{target.name}/{api}/{case.family}/{case.name}" + if args.repeat > 1: + label = f"{label}#{repeat + 1}" + t0 = time.time() + try: + value = check_case(target, api, case, args) + elapsed = time.time() - t0 + if args.verbose: + print(f"PASS {label} {elapsed:.2f}s {value}") + else: + print(f"PASS {label} {elapsed:.2f}s") + counts.passed += 1 + except UnsupportedCase as exc: + elapsed = time.time() - t0 + msg = f"SKIP {label} {elapsed:.2f}s: {exc}" + if args.strict_skips: + print(msg, file=sys.stderr) + counts.failed += 1 + if args.fail_fast: + raise + else: + print(msg) + counts.skipped += 1 + except Exception as exc: + counts.failed += 1 + print(f"FAIL {label}: {exc}", file=sys.stderr) + if args.fail_fast: + raise + return counts + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--targets", default="ds4,llama", help="Comma-separated targets: ds4,llama") + p.add_argument("--apis", default="chat,responses", help="Comma-separated APIs: chat,responses") + p.add_argument("--families", default="all", help="Comma-separated families or all") + p.add_argument("--case", action="append", help="Run only this case name; may repeat") + p.add_argument("--list-cases", action="store_true", help="Print selected cases and exit") + p.add_argument("--repeat", type=int, default=1) + p.add_argument("--timeout", type=float, default=180.0) + p.add_argument("--startup-timeout", type=float, default=900.0) + p.add_argument("--start", choices=["missing", "never", "always"], default="missing") + p.add_argument("--stop-started", action=argparse.BooleanOptionalAction, default=True) + p.add_argument("--strict-skips", action="store_true", help="Treat unsupported matrix entries as failures") + p.add_argument("--force-extensions", action="store_true", help="Send regex/lark/llguidance as experimental response_format types") + p.add_argument("--fail-fast", action="store_true") + p.add_argument("--verbose", action="store_true") + p.add_argument("--api-key", help="Optional bearer token for non-local OpenAI-compatible servers") + p.add_argument("--seed", type=int, default=1) + p.add_argument("--json-object-schema", action="store_true", help="Attach {'type':'object'} to json_object mode") + p.add_argument("--oracle", choices=["auto", "never", "require"], default="auto", help="Local Python llguidance grammar validation") + + p.add_argument("--server-ctx", type=int, default=8192) + p.add_argument("--server-default-tokens", type=int, default=384) + + p.add_argument("--ds4-base-url", default=DEFAULT_DS4_BASE_URL) + p.add_argument("--ds4-model", default="ds4") + p.add_argument("--ds4-binary", default="./ds4-server") + p.add_argument("--ds4-model-path", default="ds4flash.gguf") + p.add_argument("--ds4-cmd", help="Override ds4 launch command; {port} is expanded") + p.add_argument("--ds4-extra-arg", action="append", help="Extra ds4-server args; may repeat") + + p.add_argument("--llama-base-url", default=DEFAULT_LLAMA_BASE_URL) + p.add_argument("--llama-model", default=DEFAULT_LLAMA_HF_MODEL) + p.add_argument("--llama-binary", default="llama-server") + p.add_argument("--llama-hf-model", default=DEFAULT_LLAMA_HF_MODEL) + p.add_argument("--llama-cmd", help="Override llama launch command; {port} is expanded") + p.add_argument("--llama-extra-arg", action="append", help="Extra llama-server args; may repeat") + p.add_argument("--llama-ngl", type=int, default=999, help="llama.cpp GPU layers; set -1 to omit") + p.add_argument( + "--llama-chat-schema-style", + choices=["openai", "flat"], + default="flat", + help="json_schema shape for llama.cpp chat response_format", + ) + p.add_argument("--llama-disable-thinking", action=argparse.BooleanOptionalAction, default=True) + return p.parse_args() + + +def main() -> int: + args = parse_args() + if args.llama_ngl is not None and args.llama_ngl < 0: + args.llama_ngl = None + + repo = Path(__file__).resolve().parent + targets_requested = split_csv(args.targets) + unknown_targets = set(targets_requested) - {"ds4", "llama"} + if unknown_targets: + raise SystemExit(f"unknown target(s): {', '.join(sorted(unknown_targets))}") + + cases = selected_cases(args) + if args.list_cases: + for case in cases: + print(f"{case.family}\t{case.name}") + return 0 + + total = Counts() + + oracle_counts = run_llguidance_oracle(cases, args) + total.passed += oracle_counts.passed + total.failed += oracle_counts.failed + total.skipped += oracle_counts.skipped + + target_map: dict[str, Target] = { + "ds4": Target( + name="ds4", + base_url=args.ds4_base_url.rstrip("/"), + model=args.ds4_model, + command=build_ds4_command(args), + cwd=repo, + supports_response_format_extensions=True, + supports_grammar_extension=False, + ), + "llama": Target( + name="llama", + base_url=args.llama_base_url.rstrip("/"), + model=args.llama_model, + command=build_llama_command(args), + cwd=repo, + supports_response_format_extensions=False, + supports_grammar_extension=True, + ), + } + + try: + for name in targets_requested: + target = target_map[name] + counts = run_target(target, cases, args) + total.passed += counts.passed + total.failed += counts.failed + total.skipped += counts.skipped + if args.stop_started and target.started_by_us: + stop_target(target) + finally: + for target in target_map.values(): + if args.stop_started and target.started_by_us: + stop_target(target) + + print(f"SUMMARY pass={total.passed} fail={total.failed} skip={total.skipped}") + return 1 if total.failed else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From ce6524980cba8d4f8245db007f83732628309be8 Mon Sep 17 00:00:00 2001 From: Pasquale Minervini Date: Sat, 30 May 2026 15:39:36 +0100 Subject: [PATCH 6/9] Remove stress test script from repo --- stress-test-cli.py | 1086 -------------------------------------------- 1 file changed, 1086 deletions(-) delete mode 100755 stress-test-cli.py diff --git a/stress-test-cli.py b/stress-test-cli.py deleted file mode 100755 index af853496d..000000000 --- a/stress-test-cli.py +++ /dev/null @@ -1,1086 +0,0 @@ -#!/usr/bin/env python3 -"""Stress structured-output decoding across ds4 and llama.cpp servers. - -The OpenAI-compatible surfaces standardize JSON Schema structured outputs and -JSON mode. llguidance itself supports a wider set of grammar tags: JSON Schema, -JSON object, regex, Lark, and the internal guidance grammar-list wire format. - -This script keeps those layers explicit: - -* ds4 is exercised through /v1/chat/completions and /v1/responses with the - json_schema/json_object request shapes and ds4's llguidance extension types. -* llama.cpp is exercised with the same OpenAI-compatible JSON cases and, for - broader llguidance grammar-family cases, with llama.cpp's top-level grammar - request extension. -* Unsupported target/API/case combinations are reported as SKIP by default. Use - --strict-skips to make them fail the run, or --force-extensions to send - experimental non-OpenAI response_format types to targets that do not expose - them by default. - -Examples: - python3 stress-test-cli.py - - python3 stress-test-cli.py --start never \ - --ds4-base-url http://127.0.0.1:8000/v1 \ - --llama-base-url http://127.0.0.1:8080/v1 - - python3 stress-test-cli.py --targets llama --families regex,lark,llguidance \ - --llama-hf-model unsloth/Qwen3.5-9B-GGUF:Q8_0 -""" - -from __future__ import annotations - -import argparse -import datetime as _dt -import json -import math -import os -import re -import shlex -import subprocess -import sys -import tempfile -import time -import urllib.error -import urllib.parse -import urllib.request -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable - - -DEFAULT_DS4_BASE_URL = "http://127.0.0.1:8000/v1" -DEFAULT_LLAMA_BASE_URL = "http://127.0.0.1:8080/v1" -DEFAULT_LLAMA_HF_MODEL = "unsloth/Qwen3.5-9B-GGUF:Q8_0" - - -class ValidationError(Exception): - pass - - -class UnsupportedCase(Exception): - pass - - -Validator = Callable[[str], str] - - -@dataclass(frozen=True) -class Case: - name: str - family: str - prompt: str - validator: Validator - schema: dict[str, Any] | None = None - data: str = "" - llama_grammar: str | None = None - oracle_sample: str | None = None - max_tokens: int = 192 - - -@dataclass -class Target: - name: str - base_url: str - model: str - command: list[str] | None - cwd: Path - supports_response_format_extensions: bool - supports_grammar_extension: bool - process: subprocess.Popen[str] | None = None - log_path: Path | None = None - started_by_us: bool = False - - -@dataclass -class Counts: - passed: int = 0 - failed: int = 0 - skipped: int = 0 - - -def compact_json(value: Any) -> str: - return json.dumps(value, ensure_ascii=False, separators=(",", ":")) - - -def type_matches(value: Any, typ: str) -> bool: - if typ == "object": - return isinstance(value, dict) - if typ == "array": - return isinstance(value, list) - if typ == "string": - return isinstance(value, str) - if typ == "integer": - return isinstance(value, int) and not isinstance(value, bool) - if typ == "number": - return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool) - if typ == "boolean": - return isinstance(value, bool) - if typ == "null": - return value is None - return True - - -def _validate_format(value: str, fmt: str, path: str) -> None: - if fmt == "date": - try: - _dt.date.fromisoformat(value) - except ValueError as exc: - raise ValidationError(f"{path}: expected RFC3339 date, got {value!r}") from exc - elif fmt == "time": - try: - _dt.time.fromisoformat(value.replace("Z", "+00:00")) - except ValueError as exc: - raise ValidationError(f"{path}: expected RFC3339 time, got {value!r}") from exc - elif fmt == "date-time": - try: - _dt.datetime.fromisoformat(value.replace("Z", "+00:00")) - except ValueError as exc: - raise ValidationError(f"{path}: expected RFC3339 date-time, got {value!r}") from exc - elif fmt == "email": - if re.fullmatch(r"[^@\s]+@[^@\s]+\.[^@\s]+", value) is None: - raise ValidationError(f"{path}: expected email, got {value!r}") - elif fmt == "uuid": - if re.fullmatch( - r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-" - r"[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}", - value, - ) is None: - raise ValidationError(f"{path}: expected uuid, got {value!r}") - - -def validate_schema(value: Any, schema: dict[str, Any], path: str = "$") -> None: - if "allOf" in schema: - for option in schema["allOf"]: - validate_schema(value, option, path) - - if "anyOf" in schema: - errors: list[str] = [] - for option in schema["anyOf"]: - try: - validate_schema(value, option, path) - return - except ValidationError as exc: - errors.append(str(exc)) - raise ValidationError(f"{path}: did not match anyOf: {'; '.join(errors)}") - - if "oneOf" in schema: - matches = 0 - errors: list[str] = [] - for option in schema["oneOf"]: - try: - validate_schema(value, option, path) - matches += 1 - except ValidationError as exc: - errors.append(str(exc)) - if matches != 1: - raise ValidationError(f"{path}: expected exactly one oneOf match, got {matches}: {'; '.join(errors)}") - return - - if "const" in schema and value != schema["const"]: - raise ValidationError(f"{path}: expected const {schema['const']!r}, got {value!r}") - if "enum" in schema and value not in schema["enum"]: - raise ValidationError(f"{path}: expected one of {schema['enum']!r}, got {value!r}") - - typ = schema.get("type") - if isinstance(typ, list): - if not any(type_matches(value, t) for t in typ): - raise ValidationError(f"{path}: wrong type {type(value).__name__}, expected {typ!r}") - elif isinstance(typ, str) and not type_matches(value, typ): - raise ValidationError(f"{path}: wrong type {type(value).__name__}, expected {typ!r}") - - if typ == "object" or "properties" in schema: - if not isinstance(value, dict): - raise ValidationError(f"{path}: expected object") - props = schema.get("properties", {}) - for key in schema.get("required", []): - if key not in value: - raise ValidationError(f"{path}: missing required property {key!r}") - min_props = schema.get("minProperties") - max_props = schema.get("maxProperties") - if min_props is not None and len(value) < min_props: - raise ValidationError(f"{path}: expected at least {min_props} properties") - if max_props is not None and len(value) > max_props: - raise ValidationError(f"{path}: expected at most {max_props} properties") - if schema.get("additionalProperties") is False: - extra = sorted(set(value) - set(props)) - if extra: - raise ValidationError(f"{path}: extra properties {extra!r}") - for key, sub in props.items(): - if key in value: - validate_schema(value[key], sub, f"{path}.{key}") - - if typ == "array" or "items" in schema or "prefixItems" in schema: - if not isinstance(value, list): - raise ValidationError(f"{path}: expected array") - min_items = schema.get("minItems") - max_items = schema.get("maxItems") - if min_items is not None and len(value) < min_items: - raise ValidationError(f"{path}: expected at least {min_items} items") - if max_items is not None and len(value) > max_items: - raise ValidationError(f"{path}: expected at most {max_items} items") - prefix_items = schema.get("prefixItems") - if isinstance(prefix_items, list): - for i, sub in enumerate(prefix_items): - if i < len(value): - validate_schema(value[i], sub, f"{path}[{i}]") - items = schema.get("items") - if isinstance(items, dict): - start = len(prefix_items) if isinstance(prefix_items, list) else 0 - for i, item in enumerate(value[start:], start): - validate_schema(item, items, f"{path}[{i}]") - - if isinstance(value, str): - if "minLength" in schema and len(value) < schema["minLength"]: - raise ValidationError(f"{path}: string shorter than minLength {schema['minLength']}") - if "maxLength" in schema and len(value) > schema["maxLength"]: - raise ValidationError(f"{path}: string longer than maxLength {schema['maxLength']}") - if "pattern" in schema and re.fullmatch(schema["pattern"], value) is None: - raise ValidationError(f"{path}: {value!r} does not match {schema['pattern']!r}") - if "format" in schema: - _validate_format(value, schema["format"], path) - - if isinstance(value, (int, float)) and not isinstance(value, bool): - if "minimum" in schema and value < schema["minimum"]: - raise ValidationError(f"{path}: {value!r} is below minimum {schema['minimum']!r}") - if "maximum" in schema and value > schema["maximum"]: - raise ValidationError(f"{path}: {value!r} is above maximum {schema['maximum']!r}") - if "exclusiveMinimum" in schema and value <= schema["exclusiveMinimum"]: - raise ValidationError(f"{path}: {value!r} is not above exclusiveMinimum {schema['exclusiveMinimum']!r}") - if "exclusiveMaximum" in schema and value >= schema["exclusiveMaximum"]: - raise ValidationError(f"{path}: {value!r} is not below exclusiveMaximum {schema['exclusiveMaximum']!r}") - if "multipleOf" in schema: - q = value / schema["multipleOf"] - if not math.isclose(q, round(q), rel_tol=0.0, abs_tol=1e-9): - raise ValidationError(f"{path}: {value!r} is not a multiple of {schema['multipleOf']!r}") - - -def parse_json_strict(text: str) -> Any: - stripped = text.strip() - try: - return json.loads(stripped) - except json.JSONDecodeError as exc: - raise ValidationError(f"output is not JSON: {text!r}") from exc - - -def json_schema_validator(schema: dict[str, Any]) -> Validator: - def _validate(text: str) -> str: - value = parse_json_strict(text) - validate_schema(value, schema) - return compact_json(value) - - return _validate - - -def json_object_validator(text: str) -> str: - value = parse_json_strict(text) - if not isinstance(value, dict): - raise ValidationError(f"output is not a JSON object: {value!r}") - return compact_json(value) - - -def regex_validator(pattern: str) -> Validator: - rx = re.compile(pattern) - - def _validate(text: str) -> str: - value = text.strip() - if rx.fullmatch(value) is None: - raise ValidationError(f"{value!r} does not match /{pattern}/") - return value - - return _validate - - -def choice_validator(choices: set[str]) -> Validator: - def _validate(text: str) -> str: - value = text.strip() - if value not in choices: - raise ValidationError(f"{value!r} is not one of {sorted(choices)!r}") - return value - - return _validate - - -def permutation_validator(chars: str) -> Validator: - expected = sorted(chars) - - def _validate(text: str) -> str: - value = text.strip() - if sorted(value) != expected or len(value) != len(chars): - raise ValidationError(f"{value!r} is not a permutation of {chars!r}") - return value - - return _validate - - -def substring_chunk_validator(prefix: str, words: list[str]) -> Validator: - allowed = {""} - for i in range(len(words)): - for j in range(i + 1, len(words) + 1): - allowed.add(" ".join(words[i:j])) - - def _validate(text: str) -> str: - value = text.strip() - if not value.startswith(prefix): - raise ValidationError(f"{value!r} does not start with {prefix!r}") - tail = value[len(prefix):] - if tail not in allowed: - raise ValidationError(f"{tail!r} is not an allowed contiguous word substring") - return value - - return _validate - - -def make_cases() -> list[Case]: - calendar_schema = { - "type": "object", - "properties": { - "name": {"type": "string", "minLength": 1, "maxLength": 80}, - "date": {"type": "string", "format": "date"}, - "participants": { - "type": "array", - "items": {"type": "string", "minLength": 1}, - "minItems": 1, - "maxItems": 5, - }, - }, - "required": ["name", "date", "participants"], - "additionalProperties": False, - } - status_schema = { - "type": "object", - "properties": { - "status": {"const": "ok"}, - "priority": {"type": "string", "enum": ["low", "medium", "high"]}, - "retry_count": {"type": "integer", "minimum": 0, "maximum": 5}, - "active": {"type": "boolean"}, - }, - "required": ["status", "priority", "retry_count", "active"], - "additionalProperties": False, - } - ticket_schema = { - "type": "object", - "properties": { - "id": {"type": "string", "pattern": "TCK-[0-9]{3}"}, - "owner": {"anyOf": [{"type": "string", "minLength": 1}, {"type": "null"}]}, - "priority": {"oneOf": [{"const": "low"}, {"const": "medium"}, {"const": "high"}]}, - "contact": {"type": "string", "format": "email"}, - }, - "required": ["id", "owner", "priority", "contact"], - "additionalProperties": False, - } - reading_schema = { - "type": "object", - "properties": { - "reading": { - "type": "array", - "prefixItems": [ - {"const": "temperature_c"}, - {"type": "number", "minimum": -40, "exclusiveMaximum": 80, "multipleOf": 0.5}, - ], - "minItems": 2, - "maxItems": 2, - }, - "tags": { - "type": "array", - "items": {"type": "string", "pattern": "[a-z]{3,8}"}, - "minItems": 2, - "maxItems": 3, - }, - }, - "required": ["reading", "tags"], - "additionalProperties": False, - } - - inline_json_schema = { - "type": "object", - "properties": { - "kind": {"const": "metric"}, - "value": {"type": "integer", "minimum": 1, "maximum": 9}, - }, - "required": ["kind", "value"], - "additionalProperties": False, - } - inline_json_lark = f"""%llguidance {{}} -start: %json {compact_json(inline_json_schema)} -""" - choice_lark = """%llguidance {} -start: "red" | "green" | "blue" -""" - regex_lark = """%llguidance {} -start: /INV-[0-9]{4}/ -""" - regex_ext_lark = """%llguidance {} -start: "chunk:" %regex { "substring_words": "alpha beta gamma delta" } -""" - parametric_lark = """%llguidance {} -start: perm::0x0 -perm::_: "" %if is_ones([0:3]) - | "a" perm::set_bit(0) %if bit_clear(0) - | "b" perm::set_bit(1) %if bit_clear(1) - | "c" perm::set_bit(2) %if bit_clear(2) -""" - guidance_lark = """%llguidance {} -start: "YES" | "NO" -""" - guidance_wire = compact_json({"grammars": [{"lark_grammar": guidance_lark}]}) - - return [ - Case( - name="json_schema_calendar", - family="json_schema", - prompt="Return one lunch calendar event for Alice and Bob on 2026-06-01. Return only JSON.", - schema=calendar_schema, - validator=json_schema_validator(calendar_schema), - data=compact_json(calendar_schema), - oracle_sample='{"name":"Lunch","date":"2026-06-01","participants":["Alice","Bob"]}', - ), - Case( - name="json_schema_status", - family="json_schema", - prompt="Return a compact health-check object with status ok. Return only JSON.", - schema=status_schema, - validator=json_schema_validator(status_schema), - data=compact_json(status_schema), - oracle_sample='{"status":"ok","priority":"medium","retry_count":2,"active":true}', - ), - Case( - name="json_schema_anyof_oneof_format", - family="json_schema", - prompt=( - "Return one support ticket with an id like TCK-123, an owner or null, " - "one priority, and a contact email. Return only JSON." - ), - schema=ticket_schema, - validator=json_schema_validator(ticket_schema), - data=compact_json(ticket_schema), - oracle_sample='{"id":"TCK-123","owner":null,"priority":"high","contact":"ops@example.com"}', - ), - Case( - name="json_schema_tuple_numeric", - family="json_schema", - prompt="Return one sensor reading tuple and two short lowercase tags. Return only JSON.", - schema=reading_schema, - validator=json_schema_validator(reading_schema), - data=compact_json(reading_schema), - oracle_sample='{"reading":["temperature_c",21.5],"tags":["lab","green"]}', - ), - Case( - name="json_object_mode", - family="json_object", - prompt="Return a JSON object with a tiny task description and whether it is done.", - validator=json_object_validator, - data="", - oracle_sample='{"task":"check","done":false}', - ), - Case( - name="regex_invoice_id", - family="regex", - prompt="Return exactly one invoice id in the form INV-0427. No quotes, no prose.", - validator=regex_validator(r"INV-[0-9]{4}"), - data=r"INV-[0-9]{4}", - llama_grammar=regex_lark, - oracle_sample="INV-0427", - max_tokens=32, - ), - Case( - name="lark_choice", - family="lark", - prompt="Return exactly one lowercase color: red, green, or blue. No quotes, no prose.", - validator=choice_validator({"red", "green", "blue"}), - data=choice_lark, - llama_grammar=choice_lark, - oracle_sample="green", - max_tokens=16, - ), - Case( - name="lark_inline_json", - family="lark", - prompt='Return a compact JSON object with kind "metric" and a small integer value.', - validator=json_schema_validator(inline_json_schema), - data=inline_json_lark, - llama_grammar=inline_json_lark, - oracle_sample='{"kind":"metric","value":7}', - max_tokens=64, - ), - Case( - name="lark_regex_ext_substring", - family="lark", - prompt="Return exactly chunk:beta gamma. No quotes, no prose.", - validator=substring_chunk_validator("chunk:", ["alpha", "beta", "gamma", "delta"]), - data=regex_ext_lark, - llama_grammar=regex_ext_lark, - oracle_sample="chunk:beta gamma", - max_tokens=32, - ), - Case( - name="lark_parametric_permutation", - family="lark", - prompt="Return exactly one permutation of the letters a, b, and c. No separators, no prose.", - validator=permutation_validator("abc"), - data=parametric_lark, - llama_grammar=parametric_lark, - oracle_sample="cab", - max_tokens=16, - ), - Case( - name="llguidance_internal_wire", - family="llguidance", - prompt="Return exactly YES or NO in uppercase. No punctuation, no prose.", - validator=choice_validator({"YES", "NO"}), - data=guidance_wire, - llama_grammar=guidance_lark, - oracle_sample="YES", - max_tokens=16, - ), - ] - - -def post_json(url: str, payload: dict[str, Any], timeout: float, api_key: str | None = None) -> dict[str, Any]: - data = json.dumps(payload, separators=(",", ":")).encode("utf-8") - headers = {"Content-Type": "application/json"} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - req = urllib.request.Request(url, data=data, headers=headers, method="POST") - try: - with urllib.request.urlopen(req, timeout=timeout) as resp: - raw = resp.read().decode("utf-8", errors="replace") - except urllib.error.HTTPError as exc: - raw = exc.read().decode("utf-8", errors="replace") - raise RuntimeError(f"HTTP {exc.code}: {raw[:1600]}") from exc - except urllib.error.URLError as exc: - raise RuntimeError(str(exc)) from exc - try: - body = json.loads(raw) - except json.JSONDecodeError as exc: - raise RuntimeError(f"invalid JSON response: {raw[:1600]}") from exc - if isinstance(body, dict) and body.get("error"): - raise RuntimeError(f"API error: {body['error']!r}") - return body - - -def get_status(url: str, timeout: float) -> tuple[int | None, str]: - req = urllib.request.Request(url, method="GET") - try: - with urllib.request.urlopen(req, timeout=timeout) as resp: - return resp.status, resp.read(512).decode("utf-8", errors="replace") - except urllib.error.HTTPError as exc: - return exc.code, exc.read(512).decode("utf-8", errors="replace") - except urllib.error.URLError as exc: - return None, str(exc) - - -def extract_chat_text(body: dict[str, Any]) -> str: - choices = body.get("choices") - if not isinstance(choices, list) or not choices: - raise RuntimeError(f"missing choices in chat response: {body!r}") - message = choices[0].get("message", {}) - content = message.get("content") - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for part in content: - if isinstance(part, dict): - if isinstance(part.get("text"), str): - parts.append(part["text"]) - elif isinstance(part.get("content"), str): - parts.append(part["content"]) - if parts: - return "".join(parts) - raise RuntimeError(f"missing text content in chat response: {body!r}") - - -def extract_responses_text(body: dict[str, Any]) -> str: - if isinstance(body.get("output_text"), str): - return body["output_text"] - parts: list[str] = [] - for item in body.get("output", []): - if not isinstance(item, dict): - continue - if item.get("type") == "message": - for part in item.get("content", []): - if isinstance(part, dict) and isinstance(part.get("text"), str): - parts.append(part["text"]) - if parts: - return "".join(parts) - raise RuntimeError(f"missing output text in responses response: {body!r}") - - -def response_format_for_case(case: Case, target: Target, api: str, args: argparse.Namespace) -> dict[str, Any]: - if case.family == "json_object": - fmt: dict[str, Any] = {"type": "json_object"} - if args.json_object_schema: - fmt["schema"] = {"type": "object"} - return fmt - - if case.family == "json_schema": - if not case.schema: - raise RuntimeError(f"{case.name}: missing schema") - if api == "chat": - if target.name == "llama" and args.llama_chat_schema_style == "flat": - return { - "type": "json_schema", - "name": case.name, - "strict": True, - "schema": case.schema, - } - return { - "type": "json_schema", - "json_schema": { - "name": case.name, - "strict": True, - "schema": case.schema, - }, - } - return { - "type": "json_schema", - "name": case.name, - "strict": True, - "schema": case.schema, - } - - if not args.force_extensions and not target.supports_response_format_extensions: - raise UnsupportedCase(f"{target.name} does not expose {case.family!r} through OpenAI response_format") - - if case.family == "regex": - return {"type": "regex", "regex": case.data} - if case.family == "lark": - return {"type": "lark", "grammar": case.data} - if case.family == "llguidance": - return {"type": "llguidance", "grammar": case.data} - raise UnsupportedCase(f"unknown structured-output family {case.family!r}") - - -def add_llama_common_payload_fields(payload: dict[str, Any], target: Target, args: argparse.Namespace) -> None: - if target.name != "llama": - return - if args.llama_disable_thinking: - payload.setdefault("chat_template_kwargs", {})["enable_thinking"] = False - if args.seed is not None: - payload["seed"] = args.seed - - -def chat_payload(target: Target, case: Case, args: argparse.Namespace) -> dict[str, Any]: - payload: dict[str, Any] = { - "model": target.model, - "messages": [{"role": "user", "content": case.prompt}], - "max_tokens": case.max_tokens, - "temperature": 0, - } - if (case.family in {"json_schema", "json_object"} or - args.force_extensions or - target.supports_response_format_extensions): - payload["response_format"] = response_format_for_case(case, target, "chat", args) - elif target.supports_grammar_extension and case.llama_grammar: - payload["grammar"] = case.llama_grammar - else: - raise UnsupportedCase(f"{target.name}/chat cannot carry {case.family!r}") - if args.seed is not None: - payload["seed"] = args.seed - add_llama_common_payload_fields(payload, target, args) - return payload - - -def responses_payload(target: Target, case: Case, args: argparse.Namespace) -> dict[str, Any]: - payload: dict[str, Any] = { - "model": target.model, - "input": case.prompt, - "max_output_tokens": case.max_tokens, - "temperature": 0, - } - if (case.family in {"json_schema", "json_object"} or - args.force_extensions or - target.supports_response_format_extensions): - payload["text"] = {"format": response_format_for_case(case, target, "responses", args)} - elif target.supports_grammar_extension and case.llama_grammar: - payload["grammar"] = case.llama_grammar - else: - raise UnsupportedCase(f"{target.name}/responses cannot carry {case.family!r}") - if args.seed is not None: - payload["seed"] = args.seed - add_llama_common_payload_fields(payload, target, args) - return payload - - -def check_case(target: Target, api: str, case: Case, args: argparse.Namespace) -> str: - if api == "chat": - payload = chat_payload(target, case, args) - body = post_json(f"{target.base_url}/chat/completions", payload, args.timeout, args.api_key) - text = extract_chat_text(body) - elif api == "responses": - payload = responses_payload(target, case, args) - body = post_json(f"{target.base_url}/responses", payload, args.timeout, args.api_key) - text = extract_responses_text(body) - else: - raise RuntimeError(f"unknown api {api!r}") - return case.validator(text) - - -def split_csv(value: str) -> list[str]: - return [x.strip() for x in value.split(",") if x.strip()] - - -def flatten_extra_args(values: list[str] | None) -> list[str]: - out: list[str] = [] - for value in values or []: - out.extend(shlex.split(value)) - return out - - -def base_root(base_url: str) -> str: - if base_url.endswith("/v1"): - return base_url[:-3] - return base_url.rstrip("/") - - -def port_from_url(base_url: str, default: int) -> int: - parsed = urllib.parse.urlparse(base_url) - if parsed.port: - return parsed.port - if parsed.scheme == "https": - return 443 - if parsed.scheme == "http": - return 80 - return default - - -def target_is_ready(target: Target, timeout: float = 2.0) -> bool: - for url in (f"{base_root(target.base_url)}/health", f"{target.base_url}/models"): - status, _body = get_status(url, timeout) - if status == 200: - return True - return False - - -def wait_ready(target: Target, startup_timeout: float) -> None: - deadline = time.time() + startup_timeout - last = "" - while time.time() < deadline: - if target.process and target.process.poll() is not None: - log_hint = f" log={target.log_path}" if target.log_path else "" - raise RuntimeError(f"{target.name} server exited with code {target.process.returncode}.{log_hint}") - for url in (f"{base_root(target.base_url)}/health", f"{target.base_url}/models"): - status, body = get_status(url, 2.0) - last = f"{url}: {status} {body[:200]}" - if status == 200: - return - time.sleep(1.0) - log_hint = f" log={target.log_path}" if target.log_path else "" - raise RuntimeError(f"{target.name} did not become ready within {startup_timeout:.0f}s ({last}).{log_hint}") - - -def start_target_if_needed(target: Target, args: argparse.Namespace) -> None: - already_ready = target_is_ready(target) - if args.start == "never": - if not already_ready: - raise RuntimeError(f"{target.name} is not reachable at {target.base_url} and --start=never was set") - return - if already_ready and args.start != "always": - return - if not target.command: - raise RuntimeError(f"{target.name} has no launch command") - - log_file = tempfile.NamedTemporaryFile( - prefix=f"stress-{target.name}-", - suffix=".log", - mode="w", - encoding="utf-8", - delete=False, - ) - target.log_path = Path(log_file.name) - target.process = subprocess.Popen( - target.command, - cwd=str(target.cwd), - stdout=log_file, - stderr=subprocess.STDOUT, - text=True, - ) - target.started_by_us = True - wait_ready(target, args.startup_timeout) - - -def stop_target(target: Target) -> None: - if not target.process or target.process.poll() is not None: - return - target.process.terminate() - try: - target.process.wait(timeout=15) - except subprocess.TimeoutExpired: - target.process.kill() - target.process.wait(timeout=15) - - -def build_ds4_command(args: argparse.Namespace) -> list[str]: - if args.ds4_cmd: - return shlex.split(args.ds4_cmd.format(port=port_from_url(args.ds4_base_url, 8000))) - return [ - args.ds4_binary, - "--model", - args.ds4_model_path, - "--ctx", - str(args.server_ctx), - "--tokens", - str(args.server_default_tokens), - "--port", - str(port_from_url(args.ds4_base_url, 8000)), - *flatten_extra_args(args.ds4_extra_arg), - ] - - -def build_llama_command(args: argparse.Namespace) -> list[str]: - if args.llama_cmd: - return shlex.split(args.llama_cmd.format(port=port_from_url(args.llama_base_url, 8080))) - cmd = [ - args.llama_binary, - "-hf", - args.llama_hf_model, - "-c", - str(args.server_ctx), - "--port", - str(port_from_url(args.llama_base_url, 8080)), - "--jinja", - ] - if args.llama_ngl is not None: - cmd.extend(["-ngl", str(args.llama_ngl)]) - cmd.extend(flatten_extra_args(args.llama_extra_arg)) - return cmd - - -def selected_cases(args: argparse.Namespace) -> list[Case]: - cases = make_cases() - families = set(split_csv(args.families)) if args.families != "all" else set() - names = set(args.case or []) - out = [ - c for c in cases - if (not families or c.family in families) and (not names or c.name in names) - ] - missing = names - {c.name for c in cases} - if missing: - raise SystemExit(f"unknown case(s): {', '.join(sorted(missing))}") - known_families = {c.family for c in cases} - unknown_families = families - known_families - if unknown_families: - raise SystemExit(f"unknown family/families: {', '.join(sorted(unknown_families))}") - return out - - -def run_llguidance_oracle(cases: list[Case], args: argparse.Namespace) -> Counts: - counts = Counts() - if args.oracle == "never": - return counts - try: - import llguidance # type: ignore - except ModuleNotFoundError: - msg = "SKIP oracle/llguidance: python package is not importable" - if args.oracle == "require": - print(msg, file=sys.stderr) - counts.failed += 1 - elif args.verbose: - print(msg) - counts.skipped += 1 - return counts - - try: - tok = llguidance.LLTokenizer("byte") - except Exception as exc: - msg = f"SKIP oracle/llguidance: failed to create byte tokenizer: {exc}" - if args.oracle == "require": - print(msg, file=sys.stderr) - counts.failed += 1 - elif args.verbose: - print(msg) - counts.skipped += 1 - return counts - - for case in cases: - if case.family == "json_schema": - grammar = llguidance.LLMatcher.grammar_from_json_schema(case.schema) - elif case.family == "json_object": - grammar = llguidance.LLMatcher.grammar_from_json_schema({"type": "object"}) - elif case.family == "regex": - grammar = llguidance.LLMatcher.grammar_from_regex(case.data) - elif case.family == "lark": - grammar = llguidance.LLMatcher.grammar_from_lark(case.data) - elif case.family == "llguidance": - grammar = llguidance.grammar_from("llguidance", case.data) - else: - counts.skipped += 1 - continue - - label = f"oracle/{case.family}/{case.name}" - try: - err = llguidance.LLMatcher.validate_grammar(grammar, tok) - if err: - raise RuntimeError(err) - if case.oracle_sample is not None: - matcher = llguidance.LLMatcher(tok, grammar) - for token in tok.tokenize_str(case.oracle_sample): - bias = matcher.compute_logit_bias() - if token >= len(bias) or bias[token] != 200: - raise RuntimeError(f"sample token {token} is not allowed") - if not matcher.consume_token(token): - raise RuntimeError(f"sample token {token} was rejected") - if not matcher.is_accepting(): - raise RuntimeError("sample did not leave matcher in accepting state") - print(f"PASS {label}") - counts.passed += 1 - except Exception as exc: - print(f"FAIL {label}: {exc}", file=sys.stderr) - counts.failed += 1 - if args.fail_fast: - raise - return counts - - -def run_target(target: Target, cases: list[Case], args: argparse.Namespace) -> Counts: - counts = Counts() - start_target_if_needed(target, args) - apis = split_csv(args.apis) - for repeat in range(args.repeat): - for api in apis: - for case in cases: - label = f"{target.name}/{api}/{case.family}/{case.name}" - if args.repeat > 1: - label = f"{label}#{repeat + 1}" - t0 = time.time() - try: - value = check_case(target, api, case, args) - elapsed = time.time() - t0 - if args.verbose: - print(f"PASS {label} {elapsed:.2f}s {value}") - else: - print(f"PASS {label} {elapsed:.2f}s") - counts.passed += 1 - except UnsupportedCase as exc: - elapsed = time.time() - t0 - msg = f"SKIP {label} {elapsed:.2f}s: {exc}" - if args.strict_skips: - print(msg, file=sys.stderr) - counts.failed += 1 - if args.fail_fast: - raise - else: - print(msg) - counts.skipped += 1 - except Exception as exc: - counts.failed += 1 - print(f"FAIL {label}: {exc}", file=sys.stderr) - if args.fail_fast: - raise - return counts - - -def parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - p.add_argument("--targets", default="ds4,llama", help="Comma-separated targets: ds4,llama") - p.add_argument("--apis", default="chat,responses", help="Comma-separated APIs: chat,responses") - p.add_argument("--families", default="all", help="Comma-separated families or all") - p.add_argument("--case", action="append", help="Run only this case name; may repeat") - p.add_argument("--list-cases", action="store_true", help="Print selected cases and exit") - p.add_argument("--repeat", type=int, default=1) - p.add_argument("--timeout", type=float, default=180.0) - p.add_argument("--startup-timeout", type=float, default=900.0) - p.add_argument("--start", choices=["missing", "never", "always"], default="missing") - p.add_argument("--stop-started", action=argparse.BooleanOptionalAction, default=True) - p.add_argument("--strict-skips", action="store_true", help="Treat unsupported matrix entries as failures") - p.add_argument("--force-extensions", action="store_true", help="Send regex/lark/llguidance as experimental response_format types") - p.add_argument("--fail-fast", action="store_true") - p.add_argument("--verbose", action="store_true") - p.add_argument("--api-key", help="Optional bearer token for non-local OpenAI-compatible servers") - p.add_argument("--seed", type=int, default=1) - p.add_argument("--json-object-schema", action="store_true", help="Attach {'type':'object'} to json_object mode") - p.add_argument("--oracle", choices=["auto", "never", "require"], default="auto", help="Local Python llguidance grammar validation") - - p.add_argument("--server-ctx", type=int, default=8192) - p.add_argument("--server-default-tokens", type=int, default=384) - - p.add_argument("--ds4-base-url", default=DEFAULT_DS4_BASE_URL) - p.add_argument("--ds4-model", default="ds4") - p.add_argument("--ds4-binary", default="./ds4-server") - p.add_argument("--ds4-model-path", default="ds4flash.gguf") - p.add_argument("--ds4-cmd", help="Override ds4 launch command; {port} is expanded") - p.add_argument("--ds4-extra-arg", action="append", help="Extra ds4-server args; may repeat") - - p.add_argument("--llama-base-url", default=DEFAULT_LLAMA_BASE_URL) - p.add_argument("--llama-model", default=DEFAULT_LLAMA_HF_MODEL) - p.add_argument("--llama-binary", default="llama-server") - p.add_argument("--llama-hf-model", default=DEFAULT_LLAMA_HF_MODEL) - p.add_argument("--llama-cmd", help="Override llama launch command; {port} is expanded") - p.add_argument("--llama-extra-arg", action="append", help="Extra llama-server args; may repeat") - p.add_argument("--llama-ngl", type=int, default=999, help="llama.cpp GPU layers; set -1 to omit") - p.add_argument( - "--llama-chat-schema-style", - choices=["openai", "flat"], - default="flat", - help="json_schema shape for llama.cpp chat response_format", - ) - p.add_argument("--llama-disable-thinking", action=argparse.BooleanOptionalAction, default=True) - return p.parse_args() - - -def main() -> int: - args = parse_args() - if args.llama_ngl is not None and args.llama_ngl < 0: - args.llama_ngl = None - - repo = Path(__file__).resolve().parent - targets_requested = split_csv(args.targets) - unknown_targets = set(targets_requested) - {"ds4", "llama"} - if unknown_targets: - raise SystemExit(f"unknown target(s): {', '.join(sorted(unknown_targets))}") - - cases = selected_cases(args) - if args.list_cases: - for case in cases: - print(f"{case.family}\t{case.name}") - return 0 - - total = Counts() - - oracle_counts = run_llguidance_oracle(cases, args) - total.passed += oracle_counts.passed - total.failed += oracle_counts.failed - total.skipped += oracle_counts.skipped - - target_map: dict[str, Target] = { - "ds4": Target( - name="ds4", - base_url=args.ds4_base_url.rstrip("/"), - model=args.ds4_model, - command=build_ds4_command(args), - cwd=repo, - supports_response_format_extensions=True, - supports_grammar_extension=False, - ), - "llama": Target( - name="llama", - base_url=args.llama_base_url.rstrip("/"), - model=args.llama_model, - command=build_llama_command(args), - cwd=repo, - supports_response_format_extensions=False, - supports_grammar_extension=True, - ), - } - - try: - for name in targets_requested: - target = target_map[name] - counts = run_target(target, cases, args) - total.passed += counts.passed - total.failed += counts.failed - total.skipped += counts.skipped - if args.stop_started and target.started_by_us: - stop_target(target) - finally: - for target in target_map.values(): - if args.stop_started and target.started_by_us: - stop_target(target) - - print(f"SUMMARY pass={total.passed} fail={total.failed} skip={total.skipped}") - return 1 if total.failed else 0 - - -if __name__ == "__main__": - raise SystemExit(main()) From 234fbf771f0d8252b0d216c573fc0fa1bd776b3c Mon Sep 17 00:00:00 2001 From: Pasquale Minervini Date: Sat, 30 May 2026 15:49:25 +0100 Subject: [PATCH 7/9] Remove structured output stress test --- tests/structured_outputs_stress.py | 424 ----------------------------- 1 file changed, 424 deletions(-) delete mode 100755 tests/structured_outputs_stress.py diff --git a/tests/structured_outputs_stress.py b/tests/structured_outputs_stress.py deleted file mode 100755 index 9bc7610fc..000000000 --- a/tests/structured_outputs_stress.py +++ /dev/null @@ -1,424 +0,0 @@ -#!/usr/bin/env python3 -"""Stress JSON structured outputs on OpenAI-compatible chat/responses APIs. - -Examples: - python3 tests/structured_outputs_stress.py \ - --base-url http://127.0.0.1:8000/v1 --model ds4 --apis chat,responses - - python3 tests/structured_outputs_stress.py \ - --base-url http://127.0.0.1:8080/v1 --model qwen --apis chat -""" - -from __future__ import annotations - -import argparse -import json -import re -import sys -import time -import urllib.error -import urllib.request -from dataclasses import dataclass -from typing import Any - - -@dataclass(frozen=True) -class Case: - name: str - prompt: str - schema: dict[str, Any] | None - json_object: bool = False - - -CASES: list[Case] = [ - Case( - name="calendar_event", - prompt=( - "Create one calendar event for Alice and Bob having lunch on " - "2026-06-01 at noon. Return only the requested JSON object." - ), - schema={ - "type": "object", - "properties": { - "name": {"type": "string"}, - "date": {"type": "string"}, - "participants": { - "type": "array", - "items": {"type": "string"}, - "minItems": 1, - "maxItems": 5, - }, - }, - "required": ["name", "date", "participants"], - "additionalProperties": False, - }, - ), - Case( - name="enum_const_integer_boolean", - prompt=( - "Return a compact health-check result. Use status ok, one priority, " - "a retry count, and whether the system is active." - ), - schema={ - "type": "object", - "properties": { - "status": {"const": "ok"}, - "priority": {"type": "string", "enum": ["low", "medium", "high"]}, - "retry_count": {"type": "integer", "minimum": 0, "maximum": 5}, - "active": {"type": "boolean"}, - }, - "required": ["status", "priority", "retry_count", "active"], - "additionalProperties": False, - }, - ), - Case( - name="nested_arrays", - prompt=( - "Return a 2 by 2 integer matrix and two short labels. Keep values " - "small and return only JSON." - ), - schema={ - "type": "object", - "properties": { - "matrix": { - "type": "array", - "minItems": 2, - "maxItems": 2, - "items": { - "type": "array", - "minItems": 2, - "maxItems": 2, - "items": {"type": "integer", "minimum": -9, "maximum": 9}, - }, - }, - "labels": { - "type": "array", - "minItems": 2, - "maxItems": 2, - "items": {"type": "string"}, - }, - }, - "required": ["matrix", "labels"], - "additionalProperties": False, - }, - ), - Case( - name="nullable_anyof_number_bounds", - prompt=( - "Return a score between zero and one, and use either an owner name " - "or null if unknown." - ), - schema={ - "type": "object", - "properties": { - "owner": {"anyOf": [{"type": "string"}, {"type": "null"}]}, - "score": {"type": "number", "minimum": 0, "maximum": 1}, - }, - "required": ["owner", "score"], - "additionalProperties": False, - }, - ), - Case( - name="pattern_string", - prompt="Return an inventory code in the form two uppercase letters, dash, three digits.", - schema={ - "type": "object", - "properties": { - "code": {"type": "string", "pattern": "^[A-Z]{2}-[0-9]{3}$"} - }, - "required": ["code"], - "additionalProperties": False, - }, - ), - Case( - name="json_object_mode", - prompt="Return a JSON object with two fields describing a tiny task list.", - schema=None, - json_object=True, - ), -] - - -class ValidationError(Exception): - pass - - -def type_matches(value: Any, typ: str) -> bool: - if typ == "object": - return isinstance(value, dict) - if typ == "array": - return isinstance(value, list) - if typ == "string": - return isinstance(value, str) - if typ == "integer": - return isinstance(value, int) and not isinstance(value, bool) - if typ == "number": - return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool) - if typ == "boolean": - return isinstance(value, bool) - if typ == "null": - return value is None - return True - - -def validate_schema(value: Any, schema: dict[str, Any], path: str = "$") -> None: - if "anyOf" in schema: - errors: list[str] = [] - for option in schema["anyOf"]: - try: - validate_schema(value, option, path) - return - except ValidationError as exc: - errors.append(str(exc)) - raise ValidationError(f"{path}: did not match anyOf: {'; '.join(errors)}") - - if "const" in schema and value != schema["const"]: - raise ValidationError(f"{path}: expected const {schema['const']!r}, got {value!r}") - if "enum" in schema and value not in schema["enum"]: - raise ValidationError(f"{path}: expected one of {schema['enum']!r}, got {value!r}") - - typ = schema.get("type") - if isinstance(typ, list): - if not any(type_matches(value, t) for t in typ): - raise ValidationError(f"{path}: wrong type {type(value).__name__}, expected {typ}") - elif isinstance(typ, str) and not type_matches(value, typ): - raise ValidationError(f"{path}: wrong type {type(value).__name__}, expected {typ}") - - if typ == "object" or "properties" in schema: - if not isinstance(value, dict): - raise ValidationError(f"{path}: expected object") - props = schema.get("properties", {}) - for key in schema.get("required", []): - if key not in value: - raise ValidationError(f"{path}: missing required property {key!r}") - if schema.get("additionalProperties") is False: - extra = sorted(set(value) - set(props)) - if extra: - raise ValidationError(f"{path}: extra properties {extra!r}") - for key, sub in props.items(): - if key in value: - validate_schema(value[key], sub, f"{path}.{key}") - - if typ == "array" or "items" in schema: - if not isinstance(value, list): - raise ValidationError(f"{path}: expected array") - min_items = schema.get("minItems") - max_items = schema.get("maxItems") - if min_items is not None and len(value) < min_items: - raise ValidationError(f"{path}: expected at least {min_items} items") - if max_items is not None and len(value) > max_items: - raise ValidationError(f"{path}: expected at most {max_items} items") - items = schema.get("items") - if isinstance(items, dict): - for i, item in enumerate(value): - validate_schema(item, items, f"{path}[{i}]") - - if isinstance(value, str) and "pattern" in schema: - if re.fullmatch(schema["pattern"], value) is None: - raise ValidationError(f"{path}: {value!r} does not match {schema['pattern']!r}") - - if isinstance(value, (int, float)) and not isinstance(value, bool): - if "minimum" in schema and value < schema["minimum"]: - raise ValidationError(f"{path}: {value!r} is below minimum {schema['minimum']!r}") - if "maximum" in schema and value > schema["maximum"]: - raise ValidationError(f"{path}: {value!r} is above maximum {schema['maximum']!r}") - - -def post_json(url: str, payload: dict[str, Any], timeout: float) -> dict[str, Any]: - data = json.dumps(payload, separators=(",", ":")).encode("utf-8") - req = urllib.request.Request( - url, - data=data, - headers={"Content-Type": "application/json"}, - method="POST", - ) - try: - with urllib.request.urlopen(req, timeout=timeout) as resp: - raw = resp.read().decode("utf-8", errors="replace") - except urllib.error.HTTPError as exc: - raw = exc.read().decode("utf-8", errors="replace") - raise RuntimeError(f"HTTP {exc.code}: {raw[:1000]}") from exc - except urllib.error.URLError as exc: - raise RuntimeError(str(exc)) from exc - try: - body = json.loads(raw) - except json.JSONDecodeError as exc: - raise RuntimeError(f"invalid JSON response: {raw[:1000]}") from exc - if isinstance(body, dict) and body.get("error"): - raise RuntimeError(f"API error: {body['error']!r}") - return body - - -def chat_payload(model: str, case: Case, json_object_schema: bool) -> dict[str, Any]: - response_format: dict[str, Any] - if case.json_object: - response_format = {"type": "json_object"} - if json_object_schema: - response_format["schema"] = {"type": "object"} - else: - response_format = { - "type": "json_schema", - "json_schema": { - "name": case.name, - "strict": True, - "schema": case.schema, - }, - } - return { - "model": model, - "messages": [{"role": "user", "content": case.prompt}], - "max_tokens": 256, - "temperature": 0, - "response_format": response_format, - } - - -def responses_payload(model: str, case: Case, json_object_schema: bool) -> dict[str, Any]: - fmt: dict[str, Any] - if case.json_object: - fmt = {"type": "json_object"} - if json_object_schema: - fmt["schema"] = {"type": "object"} - else: - fmt = { - "type": "json_schema", - "name": case.name, - "strict": True, - "schema": case.schema, - } - return { - "model": model, - "input": case.prompt, - "max_output_tokens": 256, - "temperature": 0, - "text": {"format": fmt}, - } - - -def extract_chat_text(body: dict[str, Any]) -> str: - choices = body.get("choices") - if not isinstance(choices, list) or not choices: - raise RuntimeError(f"missing choices in chat response: {body!r}") - message = choices[0].get("message", {}) - content = message.get("content") - if isinstance(content, str): - return content - if isinstance(content, list): - out: list[str] = [] - for part in content: - if isinstance(part, dict) and isinstance(part.get("text"), str): - out.append(part["text"]) - return "".join(out) - raise RuntimeError(f"missing text content in chat response: {body!r}") - - -def extract_responses_text(body: dict[str, Any]) -> str: - if isinstance(body.get("output_text"), str): - return body["output_text"] - out: list[str] = [] - for item in body.get("output", []): - if not isinstance(item, dict): - continue - if item.get("type") == "message": - for part in item.get("content", []): - if isinstance(part, dict) and isinstance(part.get("text"), str): - out.append(part["text"]) - if out: - return "".join(out) - raise RuntimeError(f"missing output text in responses response: {body!r}") - - -def check_case( - api: str, - base_url: str, - model: str, - case: Case, - timeout: float, - json_object_schema: bool, -) -> str: - if api == "chat": - body = post_json( - f"{base_url}/chat/completions", - chat_payload(model, case, json_object_schema), - timeout, - ) - text = extract_chat_text(body) - elif api == "responses": - body = post_json( - f"{base_url}/responses", - responses_payload(model, case, json_object_schema), - timeout, - ) - text = extract_responses_text(body) - else: - raise RuntimeError(f"unknown api {api!r}") - - try: - value = json.loads(text.strip()) - except json.JSONDecodeError as exc: - raise RuntimeError(f"{api}/{case.name}: output is not JSON: {text!r}") from exc - if not isinstance(value, dict): - raise RuntimeError(f"{api}/{case.name}: output is not a JSON object: {value!r}") - if case.schema is not None: - validate_schema(value, case.schema) - return json.dumps(value, ensure_ascii=False, sort_keys=True) - - -def parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser() - p.add_argument("--base-url", required=True, help="Base URL, usually http://host:port/v1") - p.add_argument("--model", required=True) - p.add_argument("--apis", default="chat,responses", help="Comma-separated: chat,responses") - p.add_argument("--case", action="append", help="Run only this case name; may repeat") - p.add_argument("--repeat", type=int, default=1) - p.add_argument("--timeout", type=float, default=120.0) - p.add_argument( - "--json-object-schema", - action="store_true", - help="Send {'type':'object'} with json_object mode for servers that require a concrete schema.", - ) - p.add_argument("--verbose", action="store_true") - return p.parse_args() - - -def main() -> int: - args = parse_args() - base_url = args.base_url.rstrip("/") - apis = [x.strip() for x in args.apis.split(",") if x.strip()] - selected = set(args.case or []) - cases = [c for c in CASES if not selected or c.name in selected] - missing = selected - {c.name for c in CASES} - if missing: - print(f"unknown case(s): {', '.join(sorted(missing))}", file=sys.stderr) - return 2 - - failures = 0 - for repeat in range(args.repeat): - for api in apis: - for case in cases: - label = f"{api}/{case.name}" - if args.repeat > 1: - label = f"{label}#{repeat + 1}" - t0 = time.time() - try: - value = check_case( - api, - base_url, - args.model, - case, - args.timeout, - args.json_object_schema, - ) - elapsed = time.time() - t0 - if args.verbose: - print(f"PASS {label} {elapsed:.2f}s {value}") - else: - print(f"PASS {label} {elapsed:.2f}s") - except Exception as exc: - failures += 1 - print(f"FAIL {label}: {exc}", file=sys.stderr) - return 1 if failures else 0 - - -if __name__ == "__main__": - raise SystemExit(main()) From 9d7b92b06cb218550da9074a7a53fa6432441691 Mon Sep 17 00:00:00 2001 From: Pasquale Minervini Date: Sat, 30 May 2026 15:54:41 +0100 Subject: [PATCH 8/9] Remove stale structured output stress test reference --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 6d55447ce..900fbbe2c 100644 --- a/README.md +++ b/README.md @@ -1150,7 +1150,6 @@ extractor self-test run first: make test # ./ds4-eval --self-test-extractors && ./ds4_test --all ./ds4_test --logprob-vectors ./ds4_test --server -python3 tests/structured_outputs_stress.py --base-url http://127.0.0.1:8000/v1 --model ds4 --apis chat,responses ``` ## Debugging Notes From 084d7c3db5f1e37f9f5b09a1028ac948b2a84335 Mon Sep 17 00:00:00 2001 From: Pasquale Minervini Date: Sat, 30 May 2026 16:37:57 +0100 Subject: [PATCH 9/9] Allow thinking with structured outputs --- README.md | 5 +++-- ds4_server.c | 28 ++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 900fbbe2c..a478ef6a4 100644 --- a/README.md +++ b/README.md @@ -653,8 +653,9 @@ static library there. To use an existing checkout instead, pass With that build, `/v1/chat/completions` supports `response_format.type=json_schema`, `json_object`, `regex`, `lark`, and `llguidance`; `/v1/responses` supports the same modes through `text.format`. -Structured outputs use constrained decoding, disable thinking for that turn, -and currently cannot be combined with tools. +Structured outputs use constrained decoding. If thinking is enabled, the +constraint applies after `` so the final assistant content is structured. +They currently cannot be combined with tools. `/v1/messages` is the Anthropic-compatible endpoint used by Claude Code style clients. It accepts `system`, `messages`, `tools`, `tool_choice`, `max_tokens`, diff --git a/ds4_server.c b/ds4_server.c index cce33485c..df750a9c8 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -3274,8 +3274,6 @@ static bool parse_chat_request(ds4_engine *e, server *s, const char *body, int d request_free(r); return false; } - thinking_enabled = false; - got_thinking = true; } if (!got_thinking && model_alias_disables_thinking(r->model)) thinking_enabled = false; if (!got_thinking && model_alias_enables_thinking(r->model)) thinking_enabled = true; @@ -4448,9 +4446,6 @@ static bool parse_responses_request(ds4_engine *e, server *s, const char *body, request_free(r); return false; } - thinking_enabled = false; - got_thinking = true; - r->reasoning_summary_emit = false; } if (!got_thinking && model_alias_disables_thinking(r->model)) thinking_enabled = false; if (!got_thinking && model_alias_enables_thinking(r->model)) thinking_enabled = true; @@ -10873,6 +10868,11 @@ static void generate_job(server *s, job *j) { double last_decode_log_t = decode_t0; int last_decode_log_completion = 0; thinking_state thinking = thinking_state_from_prompt(&j->req); + bool structured_waiting_for_think_close = structured && thinking.inside; + if (structured_waiting_for_think_close) { + trace_event(s, trace_id, + "structured output constraint delayed until "); + } const bool thinking_gates_tool_markers = ds4_think_mode_enabled(j->req.think_mode); bool tool_scan_waiting_for_think_close = thinking_gates_tool_markers && thinking.inside; @@ -10900,7 +10900,8 @@ static void generate_job(server *s, job *j) { if (in_tool_call && !dsml_decode_state_uses_payload_sampling(dsml_state)) { temperature = 0.0f; } - int token = structured ? + bool structured_active = structured && !structured_waiting_for_think_close; + int token = structured_active ? ds4_llguidance_sample(structured, s->session, temperature, top_k, top_p, min_p, &rng, err, sizeof(err)) : @@ -10916,7 +10917,8 @@ static void generate_job(server *s, job *j) { int toks[17]; int ntok = 0; - if (!structured && + if (!structured_active && + !structured_waiting_for_think_close && temperature <= 0.0f && ds4_engine_mtp_draft_tokens(s->engine) > 1 && getenv("DS4_MTP_SPEC_DISABLE") == NULL) @@ -10950,7 +10952,8 @@ static void generate_job(server *s, job *j) { stop_decode = true; break; } - if (structured && + structured_active = structured && !structured_waiting_for_think_close; + if (structured_active && !ds4_llguidance_accept(structured, s->engine, token, err, sizeof(err))) { finish = "error"; @@ -10964,7 +10967,16 @@ static void generate_job(server *s, job *j) { trace_piece(s, trace_id, piece, piece_len); buf_append(&text, piece, piece_len); + bool was_thinking_inside = thinking.inside; thinking_state_feed(&thinking, piece, piece_len); + if (structured_waiting_for_think_close && + was_thinking_inside && + !thinking.inside) + { + structured_waiting_for_think_close = false; + trace_event(s, trace_id, + "structured output constraint activated after "); + } if (j->req.kind == REQ_CHAT && j->req.has_tools) { dsml_decode_tracker_update(&dsml_tracker, text.ptr, text.len); }