diff --git a/.gitignore b/.gitignore index 5b4ebfbc9..e33564311 100644 --- a/.gitignore +++ b/.gitignore @@ -91,13 +91,14 @@ ipython_config.py # PyPI config .pypirc -# Project-local outputs +# Project local *.log *.out *.pkl batches/ wandb/ checkpoints/ +models/ experiments/ rollout_results/ outputs/ @@ -114,7 +115,7 @@ Megatron-LM/ glm/ # Generated documentation -docs/_build/ +docs/ /site # Dashboard frontend build artifacts diff --git a/README.md b/README.md index 5a136d934..a15a78e5e 100644 --- a/README.md +++ b/README.md @@ -30,29 +30,41 @@ #### 🟩 Install the **Rollout Server** (Polar): ```bash -uv venv +uv venv --python 3.13 uv pip install -e . +source .venv/bin/activate ``` -#### 🟩 Install the **Inference Server** (SGLang): +### 🟩 Install the **Inference Server** (SGLang or vLLM): + +Pick one (that your trainer supports). Avoid installing both under the same environment given dependency conflicts. + +**vLLM** +```bash +uv pip install vllm --torch-backend=auto +``` + +**SGLang** ```bash -uv pip install --prerelease=allow sglang==0.5.10 +uv pip install --prerelease=allow sglang==0.5.10 torch==2.9.1+cu128 bash scripts/patch/patch_sglang.sh ``` -The patch applies necessary TITO and prompt token id emission on pinned `sglang` version. We'll remove this once upstream supports go through. `vllm` integration is on the way. +The patch applies necessary TITO and prompt token id emission on the pinned `sglang` version. We'll remove this once upstream support goes through. + +### 🟩 Install your favorite **Training Framework**: -#### 🟩 Polar is trainer agnostic. So choice of **Trainer** and **Training Backend** are highly flexible given Polar's server boundaries. +Polar is trainer agnostic. So choice of **Trainer** and **Training Backend** are highly flexible given Polar's HTTP server boundaries. Currently, we provide a demo-purpose [Slime](https://github.com/THUDM/slime) integration in [Slime bridge installation guide](src/slime_bridge/README.md#slime-installation). -#### 🟩 (Optional) For SWE-bench official evaluation harness: +#### (Optional) For SWE-bench official evaluation harness: ```bash uv pip install -e ".[swebench]" ``` -#### 🟩 (Optional) To enable **polar dashboard** UI, build the frontend once. +#### (Optional) To enable **polar dashboard** UI, build the frontend once. ```bash cd web && npm install && npm run build @@ -62,7 +74,7 @@ cd web && npm install && npm run build ## Usage Guide -- ⭐ [Choose your Agent Harness](src/polar/agent/README.md): pick a built-in harness, or use the generic shell harness with wrapped agents. +- ⭐ [Choose your Agent Harness](src/polar/agent/README.md): Express your agent using the generic `shell` harness, or pick a preset shortcut. - 🚀 [Trajectory Construction and Eval](src/polar/trajectory/README.md): See [builder](src/polar/trajectory/builder/README.md) and [evaluator](src/polar/trajectory/evaluator/README.md) guides for registered strategies. - 🔧 [Deployment Topology](src/polar/config/README.md): configure the Polar service. @@ -110,7 +122,7 @@ Our development goal for **Polar** is low-intrusion and neutral, finding the low - [x] Slime bridge & RL example. - [x] CUA (VLM / VLA) Support. - [ ] More built-in evaluators (eg. self distillation with textual feedback). -- [ ] vLLM dual inference support. +- [x] vLLM dual inference support. - [ ] More trainer bridges (NemoRL, VERL, etc.). diff --git a/assets/dashboard_calculator.png b/assets/dashboard_calculator.png new file mode 100644 index 000000000..5e12b6173 Binary files /dev/null and b/assets/dashboard_calculator.png differ diff --git a/assets/dashboard_trajectory.png b/assets/dashboard_trajectory.png new file mode 100644 index 000000000..51aff3f06 Binary files /dev/null and b/assets/dashboard_trajectory.png differ diff --git a/examples/calculator/README.md b/examples/calculator/README.md index c50034e97..9a48b5e1b 100644 --- a/examples/calculator/README.md +++ b/examples/calculator/README.md @@ -1,110 +1,65 @@ # Calculator Example -This is a small end-to-end Polar rollout example. Each agent gets a tiny -`calculator.py` file with parser stubs, edits it, and the evaluator runs -`python3 test_calculator.py`. +The smallest end-to-end Polar run. Each harness gets a tiny `calculator.py` +with parser stubs, edits it, and the evaluator runs `python3 test_calculator.py`. +Use it as a quick smoke test that rollout, gateway, runtime, harness execution, +and evaluation all work together. -Use this example when you want a quick local check that rollout, gateway, -runtime setup, harness execution, and evaluation still work together. +## Prerequisites -The topology setup is used on 4 x B200 GPUs. Adjust based on your hardware. +Install **Polar** and **vLLM** as described in the [top-level README](../../README.md#installation). +This example uses 1 node 8×B200 — two vLLM servers (tensor-parallel 4 each). +Adjust the setup and topology for your hardware. -## What It Runs +## Quick Start -- rollout server on `:8080` -- two gateway nodes on `:8100` and `:8101` -- two local SGLang backends on `:8000` and `:8001` -- one shared runtime image: `polar-localhost-calculator:latest` -- six harnesses: `claude_code`, `codex`, `gemini_cli`, `opencode`, `pi`, - `qwen_code` - -The default scripts use Docker. Apptainer is also supported with -`--backend apptainer`. - -## Setup - -From the repo root: - -```bash -uv venv -uv pip install -e . -uv pip install --prerelease=allow sglang==0.5.10 -bash scripts/patch/patch_sglang.sh -``` - -Build the runtime image once: +### 1. Build the runtime image (once) ```bash uv run python examples/calculator/build_image.py ``` -## Start Services - -Start two SGLang servers, one per GPU group: +### 2. Start two vLLM servers ```bash -CUDA_VISIBLE_DEVICES=0 uv run python -m sglang.launch_server \ - --model-path Qwen/Qwen3.5-4B \ - --host 0.0.0.0 \ - --port 8000 \ - --tool-call-parser qwen3_coder \ - --reasoning-parser qwen3 \ - --mem-fraction-static 0.7 \ - --context-length 262144 \ - --trust-remote-code -``` +CUDA_VISIBLE_DEVICES=0,1,2,3 uv run vllm serve Qwen/Qwen3.6-27B --port 8000 \ + --tensor-parallel-size 4 --max-model-len 262144 \ + --reasoning-parser qwen3 --enable-auto-tool-choice --tool-call-parser qwen3_coder -```bash -CUDA_VISIBLE_DEVICES=1 uv run python -m sglang.launch_server \ - --model-path Qwen/Qwen3.5-4B \ - --host 0.0.0.0 \ - --port 8001 \ - --tool-call-parser qwen3_coder \ - --reasoning-parser qwen3 \ - --mem-fraction-static 0.7 \ - --context-length 262144 \ - --trust-remote-code +CUDA_VISIBLE_DEVICES=4,5,6,7 uv run vllm serve Qwen/Qwen3.6-27B --port 8001 \ + --tensor-parallel-size 4 --max-model-len 262144 \ + --reasoning-parser qwen3 --enable-auto-tool-choice --tool-call-parser qwen3_coder ``` -Start Polar: +### 3. Start Polar Servers ```bash uv run polar serve_rollout -c examples/calculator/topology.yaml -``` - -```bash uv run polar serve_gateway -c examples/calculator/topology.yaml --node-id localhost-node-01 -``` - -```bash uv run polar serve_gateway -c examples/calculator/topology.yaml --node-id localhost-node-02 ``` -## Run +### 4. Run -Run every harness: +Submits example harness at once and prints a reward comparison: ```bash -uv run python examples/calculator/submit_all.py +uv run python examples/calculator/run.py ``` -Run one harness: +Use Apptainer instead of Docker with `--backend apptainer`. -```bash -uv run python examples/calculator/submit_calculator_task.py claude_code -``` - -Use Apptainer instead of Docker: +### 5. (Optional) Watch in the dashboard ```bash -uv run python examples/calculator/submit_all.py --backend apptainer +uv run polar dashboard -c examples/calculator/topology.yaml ``` -Results are written under: +Open to inspect live tasks, sessions, trajectories, +and evaluations. -```text -examples/calculator/batches// -``` +

+ Calculator dashboard + Trajectory view +

-Each harness directory contains `request.json`, `response.json`, and -`summary.json`. diff --git a/examples/calculator/run.py b/examples/calculator/run.py new file mode 100644 index 000000000..185f1fd11 --- /dev/null +++ b/examples/calculator/run.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +"""Run the calculator demo across every harness and print a comparison table. + +Each harness gets a tiny `calculator.py` with parser stubs, edits it, and the +evaluator runs `python3 test_calculator.py`. All harnesses are submitted at +once; live progress and per-session detail are visible in the dashboard +(`polar dashboard -c examples/calculator/topology.yaml`). + + uv run python examples/calculator/run.py # docker (default) + uv run python examples/calculator/run.py --backend apptainer +""" + +from __future__ import annotations + +import argparse +import sys +import time +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import httpx + +EXAMPLE_DIR = Path(__file__).resolve().parent +ASSETS_DIR = EXAMPLE_DIR / "assets" +TEST_FILE = ASSETS_DIR / "test_calculator.py" +STARTER_FILE = ASSETS_DIR / "calculator.py" +TOPOLOGY = EXAMPLE_DIR / "topology.yaml" +RUNTIME_IMAGE = "polar-localhost-calculator:latest" +NUM_SAMPLES = 4 +# Generous budget: INIT install (npm / pip / venv) shares the per-task budget +# with the agent run and evaluation. +TIMEOUT_SECONDS = 1200.0 +POLL_INTERVAL_SECONDS = 10.0 + +HARNESSES = ( + "claude_code", + "codex", + "gemini_cli", + "opencode", + "pi", + "qwen_code", + "openhands_sdk", + "openclaw", + "hermes", +) + +INSTRUCTION = """\ +`calculator.py` has a `Calculator` class with a tokenizer and three stub methods. +Each stub is marked with a `# TODO` comment and returns `0`. + +Implement the three methods to build a recursive-descent expression parser: + +1. `_parse_expr` — handle `+` and `-` by calling `_parse_term` +2. `_parse_term` — handle `*` and `/` (integer division) by calling `_parse_factor` +3. `_parse_factor` — handle integer literals and parenthesized sub-expressions + +Also fix `__call__` to return the parsed value instead of `0`. + +Requirements: +- Work only in `/polar/session/workspace/calculator.py`. +- Keep the existing file structure, `_tokenize`, `_peek`, and `_consume` as-is. +- Do not add imports. +- Use `//` for division (integer division). +- You must make actual edits. An empty git diff fails the task. + +After editing, run `python3 test_calculator.py` to test. +""" + +# Per-harness INIT install command. npm CLIs install globally into +# ~/.local/bin; the two Python agents install via pip (hermes from PyPI, +# openhands-sdk into ~/.venv where its harness looks for the interpreter). +# Pinned versions keep the quickstart stable. Bump intentionally. +HARNESS_INSTALL: dict[str, str] = { + "claude_code": "npm install -g @anthropic-ai/claude-code@2.1.111", + "codex": "npm install -g @openai/codex@0.121.0", + "gemini_cli": "npm install -g @google/gemini-cli@0.38.1", + "opencode": "npm install -g opencode-ai@1.4.6", + "pi": "npm install -g @mariozechner/pi-coding-agent@0.67.68", + "qwen_code": "npm install -g @qwen-code/qwen-code@0.14.5", + "openclaw": "npm install -g openclaw@2026.5.27", + "hermes": "python3 -m pip install --user --quiet hermes-agent==0.15.1", + # Pin sdk + tools to the same version. Unpinned, pip resolves a mismatched + # pair (sdk 1.17 + tools 1.24) whose imports break; the latest 1.24 needs + # Python 3.13 (lmnr dep conflict on 3.12), so pin to 1.17.0 for this image. + "openhands_sdk": ( + "python3 -m venv $HOME/.venv && " + "$HOME/.venv/bin/pip install --quiet " + "openhands-sdk==1.17.0 openhands-tools==1.17.0" + ), +} + +# Model name the harness CLI sends; the gateway rewrites it to the served model. +HARNESS_MODEL: dict[str, str] = { + "claude_code": "claude-opus-4-5", + "codex": "gpt-5.4", + "gemini_cli": "gemini-2.5-flash-lite", + "opencode": "openai/gpt-5.4", + "pi": "openai/gpt-5.4", + "qwen_code": "qwen3-coder-plus", + "openhands_sdk": "openai/gpt-5.4", + "openclaw": "openai/gpt-5.4", + "hermes": "openai/gpt-5.4", +} + +# INIT stage: install the harness CLI, then set up a clean git workspace. +_WORKSPACE_PREPARE = ( + "rm -rf /polar/session/workspace && " + "mkdir -p /polar/session/workspace /polar/session/logs/agent && " + "cd /polar/session/workspace && " + "git init -q && " + "git config user.email 'polar@test' && " + "git config user.name 'Polar'" +) + +# Config/cache dirs that can leak into the workspace git diff. +_EVAL_EXCLUDES: dict[str, list[str]] = { + "claude_code": [".claude/**", "**/.claude/**"], + "codex": [".codex/**", "**/.codex/**"], + "gemini_cli": [".gemini/**", "**/.gemini/**"], + "opencode": [".opencode/**", "**/.opencode/**", ".config/opencode/**"], + "pi": [".pi/**", "**/.pi/**"], + "qwen_code": [".qwen/**", "**/.qwen/**"], + "openclaw": [".openclaw/**", "**/.openclaw/**"], + "hermes": [".hermes/**", "**/.hermes/**"], + "openhands_sdk": [".openhands/**", "**/.openhands/**"], +} +_COMMON_EXCLUDES = ["node_modules/**", "**/node_modules/**", ".cache/**", "**/.cache/**", ".venv/**", "**/.venv/**"] + + +def runtime_image_for_backend(backend: str) -> str: + if backend == "apptainer": + return f"docker-daemon:{RUNTIME_IMAGE}" + return RUNTIME_IMAGE + + +def build_task_payload(harness: str, batch_id: str, backend: str) -> dict[str, Any]: + return { + "task_id": f"calculator-{harness}-{batch_id}", + "instruction": INSTRUCTION, + "num_samples": NUM_SAMPLES, + "timeout_seconds": TIMEOUT_SECONDS, + "runtime": { + "backend": backend, + "image": runtime_image_for_backend(backend), + "prepare": [ + {"type": "exec", "command": f"{HARNESS_INSTALL[harness]} && {_WORKSPACE_PREPARE}"}, + {"type": "upload_file", "source": str(TEST_FILE), "target": "/polar/session/workspace/test_calculator.py"}, + {"type": "upload_file", "source": str(STARTER_FILE), "target": "/polar/session/workspace/calculator.py"}, + {"type": "exec", "command": "cd /polar/session/workspace && git add -A && git commit -qm 'initial'"}, + ], + "network": "host", + "workdir": "/polar/session/workspace", + }, + "agent": {"harness": harness, "model_name": HARNESS_MODEL[harness]}, + "builder": {"strategy": "prefix_merging"}, + "evaluator": { + "strategy": "test_on_output", + "config": { + "repo_dir": "/polar/session/workspace", + "patch_command": "cd /polar/session/workspace && git add -A && git diff --cached --binary", + "test_command": "cd /polar/session/workspace && python3 test_calculator.py && echo 'PASSED test_calculator'", + "test_timeout": 60.0, + "expected_output_json": {"test_calculator": "PASSED"}, + "exclude_patterns": [*_COMMON_EXCLUDES, *_EVAL_EXCLUDES[harness]], + }, + "refresh_runtime": True, + }, + } + + +def session_reward(session: dict[str, Any]) -> float | None: + traces = (session.get("trajectory") or {}).get("traces") or [] + reward = traces[-1].get("reward") if traces else None + return float(reward) if isinstance(reward, (int, float)) else None + + +def print_comparison(finished: dict[str, dict[str, Any]], elapsed: float) -> None: + header = f"{'Harness':<16} {'Reward':>8} {'Done':>6}" + print("\n" + "=" * len(header)) + print(header) + print("-" * len(header)) + for harness, result in finished.items(): + sessions = result.get("results") or [] + rewards = [r for r in (session_reward(s) for s in sessions) if r is not None] + mean = sum(rewards) / len(rewards) if rewards else 0.0 + done = sum(1 for s in sessions if s.get("status") == "COMPLETED") + print(f"{harness:<16} {mean:>8.3f} {done:>2}/{len(sessions):<2}") + print("=" * len(header)) + print(f"Wall time: {elapsed:.0f}s") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--backend", choices=["docker", "apptainer"], default="docker") + backend = parser.parse_args().backend + + from polar.config import TopologyConfig + + rollout_url = TopologyConfig.load(TOPOLOGY).rollout.public_url + batch_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + + print(f"Submitting {len(HARNESSES)} harnesses to {rollout_url} (backend={backend})") + timeout = httpx.Timeout(None, connect=30.0) + with httpx.Client(base_url=rollout_url, timeout=timeout) as client: + task_ids: dict[str, str] = {} + for harness in HARNESSES: + payload = build_task_payload(harness, batch_id, backend) + resp = client.post("/rollout/task/submit", json=payload) + resp.raise_for_status() + task_ids[harness] = resp.json()["task_id"] + print(f" {harness:<16} -> {task_ids[harness]}") + + print(f"\nPolling every {POLL_INTERVAL_SECONDS:.0f}s (watch live in the dashboard) ...") + t0 = time.monotonic() + finished: dict[str, dict[str, Any]] = {} + while len(finished) < len(HARNESSES): + time.sleep(POLL_INTERVAL_SECONDS) + for harness, tid in task_ids.items(): + if harness in finished: + continue + status = client.get(f"/rollout/task/{tid}").json() + if status["status"] != "running": + finished[harness] = status + print(f" [{time.monotonic() - t0:>5.0f}s] {harness} done") + elapsed = time.monotonic() - t0 + + print_comparison(finished, elapsed) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/calculator/submit_all.py b/examples/calculator/submit_all.py deleted file mode 100644 index b35985325..000000000 --- a/examples/calculator/submit_all.py +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env python3 -"""Submit one calculator task to every supported harness.""" - -from __future__ import annotations - -import argparse -import sys -import time -from datetime import UTC, datetime -from typing import Any - -import httpx - -from submit_calculator_task import ( - DEFAULT_BACKEND, - DEFAULT_NUM_SAMPLES, - DEFAULT_TOPOLOGY, - EXAMPLE_DIR, - SUPPORTED_HARNESSES, - build_task_payload, - summarize_result, - write_json, -) - -POLL_INTERVAL_SECONDS = 10.0 - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--backend", - choices=["docker", "apptainer"], - default=DEFAULT_BACKEND, - help="Runtime backend. Defaults to docker.", - ) - return parser.parse_args() - - -def resolve_rollout_url() -> str: - from polar.config import TopologyConfig - - topo = TopologyConfig.load(DEFAULT_TOPOLOGY) - return topo.rollout.public_url - - -def print_combined_summary( - results: dict[str, dict[str, Any]], - summaries: dict[str, dict[str, Any]], - elapsed: float, -) -> None: - header = f"{'Harness':<16} {'Rewards':<28} {'Mean':>6} {'Done':>6} {'Err':>4}" - print("\n" + "=" * len(header)) - print(header) - print("-" * len(header)) - for harness in results: - s = summaries[harness] - rtext = ", ".join( - "n/a" if r is None else f"{r:.1f}" for r in s["rewards"] - ) - print( - f"{harness:<16} [{rtext:<26}] " - f"{s['reward_mean']:>5.3f} " - f"{s['completed_sessions']:>2}/{s['total_sessions']:<2} " - f"{s['errors'] or '':>4}" - ) - print("-" * len(header)) - all_rewards = [r for s in summaries.values() for r in s["rewards"] if r is not None] - total_done = sum(s["completed_sessions"] for s in summaries.values()) - total_all = sum(s["total_sessions"] for s in summaries.values()) - mean = sum(all_rewards) / max(1, len(all_rewards)) - print(f"{'TOTAL':<16} {'':28} {mean:>5.3f} {total_done:>2}/{total_all:<2}") - print(f"Wall time: {elapsed:.0f}s") - print("=" * len(header)) - - -def main() -> int: - args = parse_args() - harnesses = list(SUPPORTED_HARNESSES) - batch_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") - rollout_url = resolve_rollout_url() - batch_dir = EXAMPLE_DIR / "batches" / batch_id - - n_total = len(harnesses) * DEFAULT_NUM_SAMPLES - print( - f"Submitting {len(harnesses)} harnesses x {DEFAULT_NUM_SAMPLES} " - f"samples = {n_total} sessions" - ) - print(f"Rollout URL: {rollout_url}") - print(f"Runtime backend: {args.backend}") - - # 1. Build and submit all tasks (async endpoint) - timeout = httpx.Timeout(None, connect=30.0) - task_ids: dict[str, str] = {} # harness -> task_id - - with httpx.Client(base_url=rollout_url, timeout=timeout) as client: - for harness in harnesses: - payload = build_task_payload(harness, batch_id, backend=args.backend) - out_dir = batch_dir / harness - write_json(out_dir / "request.json", payload) - - resp = client.post("/rollout/task/submit", json=payload) - resp.raise_for_status() - data = resp.json() - task_ids[harness] = data["task_id"] - print(f" {harness:<16} -> {data['task_id']}") - - # 2. Poll until all tasks finish - print(f"\nPolling every {POLL_INTERVAL_SECONDS:.0f}s ...") - t0 = time.monotonic() - finished: dict[str, dict[str, Any]] = {} - - while len(finished) < len(harnesses): - time.sleep(POLL_INTERVAL_SECONDS) - sessions_done = sum(s["completed_sessions"] for s in finished.values()) - newly_done: list[str] = [] - for harness, tid in task_ids.items(): - if harness in finished: - continue - resp = client.get(f"/rollout/task/{tid}") - resp.raise_for_status() - task_status = resp.json() - sessions_done += task_status["completed_sessions"] - if task_status["status"] != "running": - finished[harness] = task_status - newly_done.append(harness) - - elapsed = time.monotonic() - t0 - if newly_done: - # Clear progress line then print completion - sys.stdout.write("\r" + " " * 60 + "\r") - for h in newly_done: - d = finished[h]["completed_sessions"] - t = finished[h]["total_sessions"] - print(f" [{elapsed:>5.0f}s] {h:<16} done ({d}/{t})") - else: - sys.stdout.write( - f"\r [{elapsed:>5.0f}s] {sessions_done}/{n_total} sessions, " - f"{len(finished)}/{len(harnesses)} tasks done" - ) - sys.stdout.flush() - - elapsed = time.monotonic() - t0 - print() - - # 3. Save results and print summary - summaries: dict[str, dict[str, Any]] = {} - for harness in harnesses: - result = finished[harness] - out_dir = batch_dir / harness - write_json(out_dir / "response.json", result) - summary = summarize_result(result) - write_json(out_dir / "summary.json", summary) - summaries[harness] = summary - - print_combined_summary(finished, summaries, elapsed) - print(f"\nResults saved to {batch_dir}") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/examples/calculator/submit_calculator_task.py b/examples/calculator/submit_calculator_task.py deleted file mode 100644 index 95be87d23..000000000 --- a/examples/calculator/submit_calculator_task.py +++ /dev/null @@ -1,317 +0,0 @@ -#!/usr/bin/env python3 -"""Submit one calculator rollout through the local Polar services.""" - -from __future__ import annotations - -import argparse -import json -import subprocess -import sys -from datetime import UTC, datetime -from pathlib import Path -from typing import Any - -EXAMPLE_DIR = Path(__file__).resolve().parent -ASSETS_DIR = EXAMPLE_DIR / "assets" -TEST_FILE = ASSETS_DIR / "test_calculator.py" -STARTER_FILE = ASSETS_DIR / "calculator.py" -DEFAULT_TOPOLOGY = EXAMPLE_DIR / "topology.yaml" -DEFAULT_IMAGE = "polar-localhost-calculator:latest" -DEFAULT_BACKEND = "docker" -DEFAULT_NUM_SAMPLES = 1 -DEFAULT_TIMEOUT_SECONDS = 600.0 -SUPPORTED_HARNESSES = ( - "claude_code", - "codex", - "gemini_cli", - "opencode", - "pi", - "qwen_code", -) - -BASE_INSTRUCTION = """\ -`calculator.py` has a `Calculator` class with a tokenizer and three stub methods. -Each stub is marked with a `# TODO` comment and returns `0`. - -Implement the three methods to build a recursive-descent expression parser: - -1. `_parse_expr` — handle `+` and `-` by calling `_parse_term` -2. `_parse_term` — handle `*` and `/` (integer division) by calling `_parse_factor` -3. `_parse_factor` — handle integer literals and parenthesized sub-expressions - -Also fix `__call__` to return the parsed value instead of `0`. - -Requirements: -- Work only in `/polar/session/workspace/calculator.py`. -- Keep the existing file structure, `_tokenize`, `_peek`, and `_consume` as-is. -- Do not add imports. -- Use `//` for division (integer division). -- You must make actual edits. An empty git diff fails the task. - -After editing, run `python3 test_calculator.py` and stop. - -These checks must pass exactly: -- `cal("4*3-3") == 9` -- `cal("(2+3)*4") == 20` -- `cal("10/2+7") == 12` -- `cal("18-(3*4)") == 6` -- `cal(" 8 + 2 * 5 ") == 18` -""" - -# Pinned versions keep the quickstart stable. Bump intentionally. -NODE_HARNESS_PACKAGES: dict[str, str] = { - "claude_code": "@anthropic-ai/claude-code@2.1.111", - "codex": "@openai/codex@0.121.0", - "gemini_cli": "@google/gemini-cli@0.38.1", - "opencode": "opencode-ai@1.4.6", - "pi": "@mariozechner/pi-coding-agent@0.67.68", - "qwen_code": "@qwen-code/qwen-code@0.14.5", -} - -WORKSPACE_PREPARE = ( - "rm -rf /polar/session/workspace && " - "mkdir -p /polar/session/workspace /polar/session/logs/agent && " - "cd /polar/session/workspace && " - "git init -q && " - "git config user.email 'polar@test' && " - "git config user.name 'Polar'" -) - - -def prepare_command_for_harness(harness: str) -> str: - install_command = "" - if harness in NODE_HARNESS_PACKAGES: - install_command = f'npm install -g {NODE_HARNESS_PACKAGES[harness]} && ' - return install_command + WORKSPACE_PREPARE - - -# Common stray artifacts that can end up in cwd regardless of harness. -# The evaluator already skips __pycache__, *.pyc, *.pyo, .pytest_cache. -_COMMON_EVAL_EXCLUDES: list[str] = [ - "node_modules/**", - "**/node_modules/**", - ".cache/**", - "**/.cache/**", - ".venv/**", - "**/.venv/**", -] - -# Per-harness config / session dirs that can leak into the workspace git diff. -_HARNESS_EVAL_EXCLUDES: dict[str, list[str]] = { - "claude_code": [".claude/**", "**/.claude/**"], - "codex": [".codex/**", "**/.codex/**"], - "gemini_cli": [".gemini/**", "**/.gemini/**"], - "opencode": [".opencode/**", "**/.opencode/**", ".config/opencode/**"], - "pi": [".pi/**", "**/.pi/**"], - "qwen_code": [".qwen/**", "**/.qwen/**"], -} - - -def evaluator_exclude_patterns_for_harness(harness: str) -> list[str]: - return [*_COMMON_EVAL_EXCLUDES, *_HARNESS_EVAL_EXCLUDES.get(harness, [])] - - -def model_name_for_harness(harness: str) -> str | None: - defaults = { - "codex": "gpt-5.4", - "claude_code": "claude-opus-4-5", - "gemini_cli": "gemini-2.5-flash-lite", - "opencode": "openai/gpt-5.4", - "pi": "openai/gpt-5.4", - "qwen_code": "qwen3-coder-plus", - } - return defaults.get(harness) - - -def agent_spec_for_harness(harness: str) -> dict[str, Any]: - spec: dict[str, Any] = {"harness": harness} - model_name = model_name_for_harness(harness) - if model_name is not None: - spec["model_name"] = model_name - return spec - - -def builder_spec_for_harness(harness: str) -> dict[str, Any]: - if harness not in SUPPORTED_HARNESSES: - raise ValueError(f"Unsupported harness: {harness}") - return {"strategy": "prefix_merging"} - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "harness", - nargs="?", - choices=SUPPORTED_HARNESSES, - default="claude_code", - help="Harness to run. Defaults to claude_code.", - ) - parser.add_argument( - "--backend", - choices=["docker", "apptainer"], - default=DEFAULT_BACKEND, - help="Runtime backend. Defaults to docker.", - ) - return parser.parse_args() - - -def build_task_payload( - harness: str, - batch_id: str, - *, - backend: str = DEFAULT_BACKEND, -) -> dict[str, Any]: - test_file_abs = str(TEST_FILE.resolve()) - starter_file_abs = str(STARTER_FILE.resolve()) - runtime_image = runtime_image_for_backend(DEFAULT_IMAGE, backend) - return { - "task_id": f"calculator-{harness}-{batch_id}", - "instruction": BASE_INSTRUCTION, - "num_samples": DEFAULT_NUM_SAMPLES, - "timeout_seconds": DEFAULT_TIMEOUT_SECONDS, - "runtime": { - "backend": backend, - "image": runtime_image, - "prepare": [ - { - "type": "exec", - "command": prepare_command_for_harness(harness), - }, - { - "type": "upload_file", - "source": test_file_abs, - "target": "/polar/session/workspace/test_calculator.py", - }, - { - "type": "upload_file", - "source": starter_file_abs, - "target": "/polar/session/workspace/calculator.py", - }, - { - "type": "exec", - "command": ( - "cd /polar/session/workspace && " - "git add -A && git commit -qm 'initial'" - ), - }, - ], - "network": "host", - "workdir": "/polar/session/workspace", - }, - "agent": agent_spec_for_harness(harness), - "builder": builder_spec_for_harness(harness), - "evaluator": { - "strategy": "test_on_output", - "config": { - "repo_dir": "/polar/session/workspace", - "patch_command": ( - "cd /polar/session/workspace && " - "git add -A && git diff --cached --binary" - ), - "test_command": ( - "cd /polar/session/workspace && " - "python3 test_calculator.py && echo 'PASSED test_calculator'" - ), - "test_timeout": 60.0, - "expected_output_json": {"test_calculator": "PASSED"}, - "exclude_patterns": evaluator_exclude_patterns_for_harness(harness), - }, - "refresh_runtime": True, - }, - } - - -def runtime_image_for_backend(image: str, backend: str) -> str: - if backend != "apptainer": - return image - if image.startswith(("docker-daemon:", "docker://", "oras://")): - return image - return f"docker-daemon:{image}" - - -def write_json(path: Path, payload: Any) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, indent=2, ensure_ascii=True, sort_keys=True)) - - -def summarize_result(response: dict[str, Any]) -> dict[str, Any]: - sessions = response.get("results") or [] - rewards: list[float | None] = [] - completed = 0 - errors = 0 - for session in sessions: - if session.get("status") == "COMPLETED": - completed += 1 - if session.get("error"): - errors += 1 - trajectory = session.get("trajectory") or {} - if trajectory.get("status") == "ERROR" or trajectory.get("error"): - errors += 1 - traces = trajectory.get("traces") or [] - reward = traces[-1].get("reward") if traces else None - rewards.append(float(reward) if isinstance(reward, (int, float)) else None) - return { - "completed_sessions": completed, - "errors": errors, - "rewards": rewards, - "reward_mean": ( - sum(reward for reward in rewards if reward is not None) - / max(1, sum(1 for reward in rewards if reward is not None)) - ), - "total_sessions": len(sessions), - } - - -def print_reward_summary(harness: str, summary: dict[str, Any]) -> None: - reward_text = ", ".join( - "n/a" if reward is None else f"{reward:.1f}" - for reward in summary["rewards"] - ) - print("\nReward summary") - print(f"Harness: {harness}") - print(f"Rewards: [{reward_text}]") - print(f"Mean: {summary['reward_mean']:.3f}") - print(f"Completed: {summary['completed_sessions']}/{summary['total_sessions']}") - if summary["errors"]: - print(f"Errors: {summary['errors']}") - - -def main() -> int: - args = parse_args() - batch_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") - payload = build_task_payload(args.harness, batch_id, backend=args.backend) - output_dir = EXAMPLE_DIR / "batches" / batch_id / args.harness - request_path = output_dir / "request.json" - response_path = output_dir / "response.json" - write_json(request_path, payload) - print(f"Wrote request to {request_path}") - - command = [ - sys.executable, - "-m", - "polar.cli", - "submit", - str(request_path), - "-c", - str(DEFAULT_TOPOLOGY), - "--json", - ] - - completed = subprocess.run( - command, - check=True, - capture_output=True, - text=True, - ) - result = json.loads(completed.stdout) - write_json(response_path, result) - print(f"Task completed. Wrote response to {response_path}") - summary = summarize_result(result) - write_json(output_dir / "summary.json", summary) - print_reward_summary(args.harness, summary) - - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/examples/calculator/topology.yaml b/examples/calculator/topology.yaml index e723e5b75..d0ce92b96 100644 --- a/examples/calculator/topology.yaml +++ b/examples/calculator/topology.yaml @@ -14,8 +14,9 @@ gateway: max_init_workers: 8 max_run_workers: 8 max_postrun_workers: 8 - model_served: Qwen/Qwen3.5-4B - sglang: + model_served: Qwen/Qwen3.6-27B + inference: + engine: vllm base_url: http://127.0.0.1:8000 - id: localhost-node-02 host: 127.0.0.1 @@ -24,6 +25,7 @@ gateway: max_init_workers: 8 max_run_workers: 8 max_postrun_workers: 8 - model_served: Qwen/Qwen3.5-4B - sglang: + model_served: Qwen/Qwen3.6-27B + inference: + engine: vllm base_url: http://127.0.0.1:8001 diff --git a/examples/count_stars/README.md b/examples/count_stars/README.md index 25cb5be26..f260e48b8 100644 --- a/examples/count_stars/README.md +++ b/examples/count_stars/README.md @@ -1,101 +1,59 @@ # Count Stars Example -This is a small Polar rollout example with an image file in the task workspace. -Each supported harness gets the same file at -`/polar/session/workspace/polar_stars.png` and is asked to count the visible -stars. +A minimal image-input (VLM) run. Each harness gets the same +`polar_stars.png` in its workspace, inspects it, and writes the number of +visible stars to `answer.txt`. Use it to check that harnesses can work from an +image through the local vLLM OpenAI-compatible backend. -Use this example to check that coding-agent harnesses can work from an image -path in the runtime workspace through the local SGLang OpenAI-compatible -backend. +## Prerequisites -The topology and model setup match the calculator example. +Install **Polar** and **vLLM** as described in the [top-level README](../../README.md#installation). +This example uses 1 node 8×B200 — two vLLM servers (tensor-parallel 4 each). +Adjust the setup and topology for your hardware. -## What It Runs +## Quick Start -- rollout server on `:8080` -- two gateway nodes on `:8100` and `:8101` -- two local SGLang backends on `:8000` and `:8001` -- one shared runtime image: `polar-localhost-count-stars:latest` -- three harnesses: `claude_code`, `codex`, `gemini_cli` -- evaluator: `session_completed` - -The default scripts use Docker. Apptainer is also supported with -`--backend apptainer`. - -## Setup - -From the repo root: - -```bash -uv venv -uv pip install -e . -uv pip install --prerelease=allow sglang==0.5.10 -bash scripts/patch/patch_sglang.sh -``` - -Build the runtime image once: +### 1. Build the runtime image (once) ```bash uv run python examples/count_stars/build_image.py ``` -## Start Services - -Start two SGLang servers, one per GPU: - -Use the Qwen tool-call parser so `` responses are returned as -structured tool calls instead of plain assistant text. +### 2. Start two vLLM servers ```bash -CUDA_VISIBLE_DEVICES=0 uv run python -m sglang.launch_server \ - --model-path Qwen/Qwen3.6-27B \ - --host 127.0.0.1 \ - --port 8000 \ - --context-length 262144 \ - --tool-call-parser qwen3_coder \ - --reasoning-parser qwen3 \ - --trust-remote-code -``` +CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve Qwen/Qwen3.6-27B --port 8000 \ + --tensor-parallel-size 4 --max-model-len 262144 \ + --reasoning-parser qwen3 --enable-auto-tool-choice --tool-call-parser qwen3_coder -```bash -CUDA_VISIBLE_DEVICES=1 uv run python -m sglang.launch_server \ - --model-path Qwen/Qwen3.6-27B \ - --host 127.0.0.1 \ - --port 8001 \ - --context-length 262144 \ - --tool-call-parser qwen3_coder \ - --reasoning-parser qwen3 \ - --trust-remote-code +CUDA_VISIBLE_DEVICES=4,5,6,7 vllm serve Qwen/Qwen3.6-27B --port 8001 \ + --tensor-parallel-size 4 --max-model-len 262144 \ + --reasoning-parser qwen3 --enable-auto-tool-choice --tool-call-parser qwen3_coder ``` -If Hugging Face download access is rate-limited, set `HF_TOKEN` before starting -SGLang. - -Start Polar: +### 3. Start Polar Servers ```bash uv run polar serve_rollout -c examples/count_stars/topology.yaml -``` - -```bash uv run polar serve_gateway -c examples/count_stars/topology.yaml --node-id localhost-node-01 -``` - -```bash uv run polar serve_gateway -c examples/count_stars/topology.yaml --node-id localhost-node-02 ``` -## Run +### 4. Run -Run every harness: +Submits example harness at once and prints a completion comparison: ```bash -uv run python examples/count_stars/submit_all.py +uv run python examples/count_stars/run.py ``` -Run one harness: +Use Apptainer instead of Docker with `--backend apptainer`. + +### 5. (Optional) Watch in the dashboard ```bash -uv run python examples/count_stars/submit_count_stars_task.py codex +uv run polar dashboard -c examples/count_stars/topology.yaml ``` + +Open to inspect each harness's image reasoning and +the answer it wrote. diff --git a/examples/count_stars/run.py b/examples/count_stars/run.py new file mode 100644 index 000000000..96e3affce --- /dev/null +++ b/examples/count_stars/run.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +"""Run the count-stars demo across every harness and print a comparison table. + +Each harness gets the same image at `/polar/session/workspace/polar_stars.png`, +inspects it, and writes its star count to `answer.txt`. This exercises image +input through the local vLLM OpenAI-compatible backend. All harnesses are +submitted at once; per-session detail is visible in the dashboard +(`polar dashboard -c examples/count_stars/topology.yaml`). + + uv run python examples/count_stars/run.py # docker (default) + uv run python examples/count_stars/run.py --backend apptainer +""" + +from __future__ import annotations + +import argparse +import sys +import time +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import httpx + +EXAMPLE_DIR = Path(__file__).resolve().parent +IMAGE_FILE = EXAMPLE_DIR / "assets" / "polar_stars.png" +TOPOLOGY = EXAMPLE_DIR / "topology.yaml" +RUNTIME_IMAGE = "polar-localhost-count-stars:latest" +RUNTIME_IMAGE_PATH = "/polar/session/workspace/polar_stars.png" +NUM_SAMPLES = 4 +TIMEOUT_SECONDS = 300.0 +POLL_INTERVAL_SECONDS = 10.0 + +HARNESSES = ("claude_code", "codex", "gemini_cli") + +INSTRUCTION = """\ +Use your image viewing tool to inspect `/polar/session/workspace/polar_stars.png`. +Count the visible stars in that image. + +Write the answer as a single integer line to `/polar/session/workspace/answer.txt`. +Do not write any other text to that file. Stop after writing the file. +""" + +# Pinned versions keep the quickstart stable. Bump intentionally. +HARNESS_NPM_PACKAGE: dict[str, str] = { + "claude_code": "@anthropic-ai/claude-code@2.1.111", + "codex": "@openai/codex@0.121.0", + "gemini_cli": "@google/gemini-cli@0.38.1", +} + +# Model name the harness CLI sends; the gateway rewrites it to the served model. +HARNESS_MODEL: dict[str, str] = { + "claude_code": "claude-opus-4-5", + "codex": "gpt-5.4", + "gemini_cli": "gemini-2.5-flash-lite", +} + +# INIT stage: install the harness CLI, then set up a clean git workspace. +_WORKSPACE_PREPARE = ( + "rm -rf /polar/session/workspace && " + "mkdir -p /polar/session/workspace /polar/session/logs/agent && " + "cd /polar/session/workspace && " + "git init -q && " + "git config user.email 'polar@test' && " + "git config user.name 'Polar'" +) + + +def runtime_image_for_backend(backend: str) -> str: + if backend == "apptainer": + return f"docker-daemon:{RUNTIME_IMAGE}" + return RUNTIME_IMAGE + + +def build_task_payload(harness: str, batch_id: str, backend: str) -> dict[str, Any]: + return { + "task_id": f"count-stars-{harness}-{batch_id}", + "instruction": INSTRUCTION, + "num_samples": NUM_SAMPLES, + "timeout_seconds": TIMEOUT_SECONDS, + "runtime": { + "backend": backend, + "image": runtime_image_for_backend(backend), + "prepare": [ + {"type": "exec", "command": f"npm install -g {HARNESS_NPM_PACKAGE[harness]} && {_WORKSPACE_PREPARE}"}, + {"type": "upload_file", "source": str(IMAGE_FILE), "target": RUNTIME_IMAGE_PATH}, + ], + "network": "host", + "workdir": "/polar/session/workspace", + }, + "agent": {"harness": harness, "model_name": HARNESS_MODEL[harness]}, + "builder": {"strategy": "prefix_merging"}, + "evaluator": {"strategy": "session_completed"}, + } + + +def print_comparison(finished: dict[str, dict[str, Any]], elapsed: float) -> None: + header = f"{'Harness':<16} {'Completed':>10}" + print("\n" + "=" * len(header)) + print(header) + print("-" * len(header)) + for harness, result in finished.items(): + sessions = result.get("results") or [] + done = sum(1 for s in sessions if s.get("status") == "COMPLETED") + print(f"{harness:<16} {done:>5}/{len(sessions):<4}") + print("=" * len(header)) + print(f"Wall time: {elapsed:.0f}s") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--backend", choices=["docker", "apptainer"], default="docker") + backend = parser.parse_args().backend + + from polar.config import TopologyConfig + + rollout_url = TopologyConfig.load(TOPOLOGY).rollout.public_url + batch_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + + print(f"Submitting {len(HARNESSES)} harnesses to {rollout_url} (backend={backend})") + timeout = httpx.Timeout(None, connect=30.0) + with httpx.Client(base_url=rollout_url, timeout=timeout) as client: + task_ids: dict[str, str] = {} + for harness in HARNESSES: + payload = build_task_payload(harness, batch_id, backend) + resp = client.post("/rollout/task/submit", json=payload) + resp.raise_for_status() + task_ids[harness] = resp.json()["task_id"] + print(f" {harness:<16} -> {task_ids[harness]}") + + print(f"\nPolling every {POLL_INTERVAL_SECONDS:.0f}s (watch live in the dashboard) ...") + t0 = time.monotonic() + finished: dict[str, dict[str, Any]] = {} + while len(finished) < len(HARNESSES): + time.sleep(POLL_INTERVAL_SECONDS) + for harness, tid in task_ids.items(): + if harness in finished: + continue + status = client.get(f"/rollout/task/{tid}").json() + if status["status"] != "running": + finished[harness] = status + print(f" [{time.monotonic() - t0:>5.0f}s] {harness} done") + elapsed = time.monotonic() - t0 + + print_comparison(finished, elapsed) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/count_stars/submit_all.py b/examples/count_stars/submit_all.py deleted file mode 100644 index e22bd1a3c..000000000 --- a/examples/count_stars/submit_all.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 -"""Submit one count_stars task to every supported harness.""" - -from __future__ import annotations - -import argparse -import sys -import time -from datetime import UTC, datetime -from typing import Any - -import httpx - -from submit_count_stars_task import ( - DEFAULT_BACKEND, - DEFAULT_NUM_SAMPLES, - DEFAULT_TOPOLOGY, - EXAMPLE_DIR, - SUPPORTED_HARNESSES, - build_task_payload, - summarize_result, - write_json, -) - -POLL_INTERVAL_SECONDS = 10.0 - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--backend", - choices=["docker", "apptainer"], - default=DEFAULT_BACKEND, - help="Runtime backend. Defaults to docker.", - ) - return parser.parse_args() - - -def resolve_rollout_url() -> str: - from polar.config import TopologyConfig - - topo = TopologyConfig.load(DEFAULT_TOPOLOGY) - return topo.rollout.public_url - - -def print_combined_summary( - results: dict[str, dict[str, Any]], - summaries: dict[str, dict[str, Any]], - elapsed: float, -) -> None: - header = f"{'Harness':<16} {'Rewards':<28} {'Mean':>6} {'Done':>6} {'Err':>4}" - print("\n" + "=" * len(header)) - print(header) - print("-" * len(header)) - for harness in results: - s = summaries[harness] - rtext = ", ".join("n/a" if r is None else f"{r:.1f}" for r in s["rewards"]) - print( - f"{harness:<16} [{rtext:<26}] " - f"{s['reward_mean']:>5.3f} " - f"{s['completed_sessions']:>2}/{s['total_sessions']:<2} " - f"{s['errors'] or '':>4}" - ) - print("-" * len(header)) - all_rewards = [r for s in summaries.values() for r in s["rewards"] if r is not None] - total_done = sum(s["completed_sessions"] for s in summaries.values()) - total_all = sum(s["total_sessions"] for s in summaries.values()) - mean = sum(all_rewards) / max(1, len(all_rewards)) - print(f"{'TOTAL':<16} {'':28} {mean:>5.3f} {total_done:>2}/{total_all:<2}") - print(f"Wall time: {elapsed:.0f}s") - print("=" * len(header)) - - -def main() -> int: - args = parse_args() - harnesses = list(SUPPORTED_HARNESSES) - batch_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") - rollout_url = resolve_rollout_url() - batch_dir = EXAMPLE_DIR / "batches" / batch_id - - n_total = len(harnesses) * DEFAULT_NUM_SAMPLES - print( - f"Submitting {len(harnesses)} harnesses x {DEFAULT_NUM_SAMPLES} " - f"samples = {n_total} sessions" - ) - print(f"Rollout URL: {rollout_url}") - print(f"Runtime backend: {args.backend}") - - timeout = httpx.Timeout(None, connect=30.0) - task_ids: dict[str, str] = {} - - with httpx.Client(base_url=rollout_url, timeout=timeout) as client: - for harness in harnesses: - payload = build_task_payload(harness, batch_id, backend=args.backend) - out_dir = batch_dir / harness - write_json(out_dir / "request.json", payload) - - resp = client.post("/rollout/task/submit", json=payload) - resp.raise_for_status() - data = resp.json() - task_ids[harness] = data["task_id"] - print(f" {harness:<16} -> {data['task_id']}") - - print(f"\nPolling every {POLL_INTERVAL_SECONDS:.0f}s ...") - t0 = time.monotonic() - finished: dict[str, dict[str, Any]] = {} - - while len(finished) < len(harnesses): - time.sleep(POLL_INTERVAL_SECONDS) - sessions_done = sum(s["completed_sessions"] for s in finished.values()) - newly_done: list[str] = [] - for harness, tid in task_ids.items(): - if harness in finished: - continue - resp = client.get(f"/rollout/task/{tid}") - resp.raise_for_status() - task_status = resp.json() - sessions_done += task_status["completed_sessions"] - if task_status["status"] != "running": - finished[harness] = task_status - newly_done.append(harness) - - elapsed = time.monotonic() - t0 - if newly_done: - sys.stdout.write("\r" + " " * 60 + "\r") - for harness in newly_done: - done = finished[harness]["completed_sessions"] - total = finished[harness]["total_sessions"] - print(f" [{elapsed:>5.0f}s] {harness:<16} done ({done}/{total})") - else: - sys.stdout.write( - f"\r [{elapsed:>5.0f}s] {sessions_done}/{n_total} sessions, " - f"{len(finished)}/{len(harnesses)} tasks done" - ) - sys.stdout.flush() - - elapsed = time.monotonic() - t0 - print() - - summaries: dict[str, dict[str, Any]] = {} - for harness in harnesses: - result = finished[harness] - out_dir = batch_dir / harness - write_json(out_dir / "response.json", result) - summary = summarize_result(result) - write_json(out_dir / "summary.json", summary) - summaries[harness] = summary - - print_combined_summary(finished, summaries, elapsed) - print(f"\nResults saved to {batch_dir}") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) - diff --git a/examples/count_stars/submit_count_stars_task.py b/examples/count_stars/submit_count_stars_task.py deleted file mode 100644 index 3f528bf72..000000000 --- a/examples/count_stars/submit_count_stars_task.py +++ /dev/null @@ -1,235 +0,0 @@ -#!/usr/bin/env python3 -"""Submit one count_stars rollout through the local Polar services.""" - -from __future__ import annotations - -import argparse -import json -import subprocess -import sys -from datetime import UTC, datetime -from pathlib import Path -from typing import Any - -EXAMPLE_DIR = Path(__file__).resolve().parent -ASSETS_DIR = EXAMPLE_DIR / "assets" -IMAGE_FILE = ASSETS_DIR / "polar_stars.png" -DEFAULT_TOPOLOGY = EXAMPLE_DIR / "topology.yaml" -DEFAULT_IMAGE = "polar-localhost-count-stars:latest" -DEFAULT_BACKEND = "docker" -DEFAULT_NUM_SAMPLES = 1 -DEFAULT_TIMEOUT_SECONDS = 300.0 -RUNTIME_IMAGE_PATH = "/polar/session/workspace/polar_stars.png" -SUPPORTED_HARNESSES = ( - "claude_code", - "codex", - "gemini_cli", -) - -TASK_BODY = """\ -Use your image viewing tool to inspect `/polar/session/workspace/polar_stars.png`. -Count the visible stars in that image. - -Write the answer as a single integer line to `/polar/session/workspace/answer.txt`. -Do not write any other text to that file. Stop after writing the file. -""" - -NODE_HARNESS_PACKAGES: dict[str, str] = { - "claude_code": "@anthropic-ai/claude-code@2.1.111", - "codex": "@openai/codex@0.121.0", - "gemini_cli": "@google/gemini-cli@0.38.1", -} - -WORKSPACE_PREPARE = ( - "rm -rf /polar/session/workspace && " - "mkdir -p /polar/session/workspace /polar/session/logs/agent && " - "cd /polar/session/workspace && " - "git init -q && " - "git config user.email 'polar@test' && " - "git config user.name 'Polar'" -) - - -def prepare_command_for_harness(harness: str) -> str: - install_command = "" - if harness in NODE_HARNESS_PACKAGES: - install_command = f"npm install -g {NODE_HARNESS_PACKAGES[harness]} && " - return install_command + WORKSPACE_PREPARE - - -def instruction_for_harness(harness: str) -> str: - if harness not in SUPPORTED_HARNESSES: - raise ValueError(f"Unsupported harness: {harness}") - return TASK_BODY - - -def model_name_for_harness(harness: str) -> str | None: - defaults = { - "codex": "gpt-5.4", - "claude_code": "claude-opus-4-5", - "gemini_cli": "gemini-2.5-flash-lite", - } - return defaults.get(harness) - - -def agent_spec_for_harness(harness: str) -> dict[str, Any]: - spec: dict[str, Any] = {"harness": harness} - model_name = model_name_for_harness(harness) - if model_name is not None: - spec["model_name"] = model_name - return spec - - -def builder_spec_for_harness(harness: str) -> dict[str, Any]: - if harness not in SUPPORTED_HARNESSES: - raise ValueError(f"Unsupported harness: {harness}") - return {"strategy": "prefix_merging"} - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "harness", - nargs="?", - choices=SUPPORTED_HARNESSES, - default="codex", - help="Harness to run. Defaults to codex.", - ) - parser.add_argument( - "--backend", - choices=["docker", "apptainer"], - default=DEFAULT_BACKEND, - help="Runtime backend. Defaults to docker.", - ) - return parser.parse_args() - - -def build_task_payload( - harness: str, - batch_id: str, - *, - backend: str = DEFAULT_BACKEND, -) -> dict[str, Any]: - image_file_abs = str(IMAGE_FILE.resolve()) - runtime_image = runtime_image_for_backend(DEFAULT_IMAGE, backend) - return { - "task_id": f"count-stars-{harness}-{batch_id}", - "instruction": instruction_for_harness(harness), - "num_samples": DEFAULT_NUM_SAMPLES, - "timeout_seconds": DEFAULT_TIMEOUT_SECONDS, - "runtime": { - "backend": backend, - "image": runtime_image, - "prepare": [ - { - "type": "exec", - "command": prepare_command_for_harness(harness), - }, - { - "type": "upload_file", - "source": image_file_abs, - "target": RUNTIME_IMAGE_PATH, - }, - ], - "network": "host", - "workdir": "/polar/session/workspace", - }, - "agent": agent_spec_for_harness(harness), - "builder": builder_spec_for_harness(harness), - "evaluator": {"strategy": "session_completed"}, - } - - -def runtime_image_for_backend(image: str, backend: str) -> str: - if backend != "apptainer": - return image - if image.startswith(("docker-daemon:", "docker://", "oras://")): - return image - return f"docker-daemon:{image}" - - -def write_json(path: Path, payload: Any) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, indent=2, ensure_ascii=True, sort_keys=True)) - - -def summarize_result(response: dict[str, Any]) -> dict[str, Any]: - sessions = response.get("results") or [] - rewards: list[float | None] = [] - completed = 0 - errors = 0 - for session in sessions: - if session.get("status") == "COMPLETED": - completed += 1 - if session.get("error"): - errors += 1 - trajectory = session.get("trajectory") or {} - if trajectory.get("status") == "ERROR" or trajectory.get("error"): - errors += 1 - traces = trajectory.get("traces") or [] - reward = traces[-1].get("reward") if traces else None - rewards.append(float(reward) if isinstance(reward, (int, float)) else None) - return { - "completed_sessions": completed, - "errors": errors, - "rewards": rewards, - "reward_mean": ( - sum(reward for reward in rewards if reward is not None) - / max(1, sum(1 for reward in rewards if reward is not None)) - ), - "total_sessions": len(sessions), - } - - -def print_reward_summary(harness: str, summary: dict[str, Any]) -> None: - reward_text = ", ".join( - "n/a" if reward is None else f"{reward:.1f}" for reward in summary["rewards"] - ) - print("\nReward summary") - print(f"Harness: {harness}") - print(f"Rewards: [{reward_text}]") - print(f"Mean: {summary['reward_mean']:.3f}") - print(f"Completed: {summary['completed_sessions']}/{summary['total_sessions']}") - if summary["errors"]: - print(f"Errors: {summary['errors']}") - - -def main() -> int: - args = parse_args() - batch_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") - payload = build_task_payload(args.harness, batch_id, backend=args.backend) - output_dir = EXAMPLE_DIR / "batches" / batch_id / args.harness - request_path = output_dir / "request.json" - response_path = output_dir / "response.json" - write_json(request_path, payload) - print(f"Wrote request to {request_path}") - - command = [ - sys.executable, - "-m", - "polar.cli", - "submit", - str(request_path), - "-c", - str(DEFAULT_TOPOLOGY), - "--json", - ] - - completed = subprocess.run( - command, - check=True, - capture_output=True, - text=True, - ) - result = json.loads(completed.stdout) - write_json(response_path, result) - print(f"Task completed. Wrote response to {response_path}") - summary = summarize_result(result) - write_json(output_dir / "summary.json", summary) - print_reward_summary(args.harness, summary) - - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/examples/count_stars/topology.yaml b/examples/count_stars/topology.yaml index 50524767d..d0ce92b96 100644 --- a/examples/count_stars/topology.yaml +++ b/examples/count_stars/topology.yaml @@ -15,7 +15,8 @@ gateway: max_run_workers: 8 max_postrun_workers: 8 model_served: Qwen/Qwen3.6-27B - sglang: + inference: + engine: vllm base_url: http://127.0.0.1:8000 - id: localhost-node-02 host: 127.0.0.1 @@ -25,5 +26,6 @@ gateway: max_run_workers: 8 max_postrun_workers: 8 model_served: Qwen/Qwen3.6-27B - sglang: + inference: + engine: vllm base_url: http://127.0.0.1:8001 diff --git a/examples/swebench_verified/README.md b/examples/swebench_verified/README.md index 66d79fd4d..79537d57c 100644 --- a/examples/swebench_verified/README.md +++ b/examples/swebench_verified/README.md @@ -1,49 +1,45 @@ # SWE-bench Verified Example -Evaluate Polar agent harnesses on the full [SWE-bench Verified](https://huggingface.co/datasets/princeton-nlp/SWE-bench_Verified) benchmark (500 human-validated tasks). +Evaluate Polar agent harnesses on [SWE-bench Verified](https://huggingface.co/datasets/princeton-nlp/SWE-bench_Verified) +(500 human-validated tasks). Each task runs an agent inside a per-instance +container at the repo's `base_commit`, then grades the patch with the official +`swebench` harness. -Each task runs an agent inside a per-instance container with the repo at `base_commit`, then grades the resulting patch via `swebench.harness.grading`. +## Prerequisites -The topology setup is used on 4 x B200 GPUs. Adjust based on your hardware. - -## Installation +Install **Polar** + the SWE-bench extra and **vLLM** as described in the +[top-level README](../../README.md#installation): ```bash -uv venv uv pip install -e ".[swebench]" -uv pip install --prerelease=allow sglang==0.5.10 -bash scripts/patch/patch_sglang.sh ``` +This example assumes 1 node **8×B200** — two vLLM servers (tensor-parallel 4 each). + ## Quick Start -### 1. Start SGLang backends +### 1. Build runtime images + +Each runtime image layers Node.js on the per-instance SWE-bench image; harness +CLIs install at task time during the **INIT** stage. Build a subset first: + +```bash +uv run python examples/swebench_verified/build_images.py --max-tasks 10 # or no flag for all 500 +``` + +### 2. Start two vLLM servers ```bash -CUDA_VISIBLE_DEVICES=0,1 uv run python -m sglang.launch_server \ - --model-path Qwen/Qwen3.5-4B \ - --host 0.0.0.0 \ - --port 8000 \ - --tp-size 2 \ - --tool-call-parser qwen3_coder \ - --reasoning-parser qwen3 \ - --mem-fraction-static 0.7 \ - --context-length 262144 \ - --trust-remote-code - -CUDA_VISIBLE_DEVICES=2,3 uv run python -m sglang.launch_server \ - --model-path Qwen/Qwen3.5-4B \ - --host 0.0.0.0 \ - --port 8001 \ - --tp-size 2 \ - --tool-call-parser qwen3_coder \ - --reasoning-parser qwen3 \ - --mem-fraction-static 0.7 \ - --context-length 262144 \ - --trust-remote-code +CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve Qwen/Qwen3.6-27B --port 8000 \ + --tensor-parallel-size 4 --max-model-len 262144 \ + --reasoning-parser qwen3 --enable-auto-tool-choice --tool-call-parser qwen3_coder + +CUDA_VISIBLE_DEVICES=4,5,6,7 vllm serve Qwen/Qwen3.6-27B --port 8001 \ + --tensor-parallel-size 4 --max-model-len 262144 \ + --reasoning-parser qwen3 --enable-auto-tool-choice --tool-call-parser qwen3_coder ``` -### 2. Start Polar services +### 3. Start Polar ```bash uv run polar serve_rollout -c examples/swebench_verified/topology.yaml @@ -51,32 +47,29 @@ uv run polar serve_gateway -c examples/swebench_verified/topology.yaml --node-id uv run polar serve_gateway -c examples/swebench_verified/topology.yaml --node-id localhost-node-02 ``` -### 3. Build runtime images +### 4. Submit tasks + +Pick a harness and how many tasks to run; the resolved-rate summary prints to +the console when the batch finishes. Supported harnesses: `claude_code`, `codex`, `opencode`, `qwen_code`. + ```bash -# Build all 500 -uv run python examples/swebench_verified/build_images.py +# pass@1 over the first 10 tasks +uv run python examples/swebench_verified/submit_swebench_tasks.py --harness claude_code --max-tasks 10 + +# pass@8 over the first 10 tasks +uv run python examples/swebench_verified/submit_swebench_tasks.py --harness claude_code --max-tasks 10 --num-samples 8 -# Or build a subset -uv run python examples/swebench_verified/build_images.py --max-tasks 10 +# a single instance +uv run python examples/swebench_verified/submit_swebench_tasks.py --harness codex --instance-id django__django-15098 ``` -### 4. Submit tasks +Use Apptainer instead of Docker with `--runtime-backend apptainer`. + +### 5. (Optional) Watch in the dashboard ```bash -# Run all 500 tasks for pass@1 -uv run python examples/swebench_verified/submit_swebench_tasks.py \ - --harness claude_code \ - --topology examples/swebench_verified/topology.yaml \ - --runtime-backend docker \ - --num-samples 1 \ - --max-tasks 10 - -# pass@8 for first 10 tasks -uv run python examples/swebench_verified/submit_swebench_tasks.py \ - --harness claude_code \ - --topology examples/swebench_verified/topology.yaml \ - --runtime-backend docker \ - --num-samples 8 \ - --max-tasks 10 +uv run polar dashboard -c examples/swebench_verified/topology.yaml ``` + +Open for per-task patches, trajectories, and grading. diff --git a/examples/swebench_verified/submit_swebench_tasks.py b/examples/swebench_verified/submit_swebench_tasks.py index 2ce4edfe5..3fad97574 100644 --- a/examples/swebench_verified/submit_swebench_tasks.py +++ b/examples/swebench_verified/submit_swebench_tasks.py @@ -1,26 +1,29 @@ #!/usr/bin/env python3 -"""Submit SWE-bench Verified tasks through the Polar rollout server. +"""Submit SWE-bench Verified tasks to the Polar rollout server. -Usage: - python submit_swebench_tasks.py --harness opencode - python submit_swebench_tasks.py --harness codex --max-tasks 50 --num-samples 4 - python submit_swebench_tasks.py --harness claude_code --instance-id django__django-15098 +Each task runs an agent in a per-instance container and is graded by the +official `swebench` harness. Tasks are submitted at once; live progress and +per-session detail are visible in the dashboard +(`polar dashboard -c examples/swebench_verified/topology.yaml`). + + uv run python examples/swebench_verified/submit_swebench_tasks.py --harness claude_code --max-tasks 10 + uv run python examples/swebench_verified/submit_swebench_tasks.py --harness codex --max-tasks 50 --num-samples 4 + uv run python examples/swebench_verified/submit_swebench_tasks.py --harness claude_code --instance-id django__django-15098 """ from __future__ import annotations import argparse -import json import os import subprocess import sys -from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import UTC, datetime +import time from pathlib import Path from typing import Any +import httpx + from dataset import ( - DATASET_NAME, SUPPORTED_HARNESSES, load_swebench_verified, runtime_image_for_instance, @@ -29,14 +32,17 @@ EXAMPLE_DIR = Path(__file__).resolve().parent DEFAULT_TOPOLOGY = EXAMPLE_DIR / "topology.yaml" +POLL_INTERVAL_SECONDS = 15.0 +# Pinned versions keep the quickstart stable. Bump intentionally. HARNESS_NPM_PACKAGE: dict[str, str] = { - "codex": "@openai/codex@latest", - "opencode": "opencode-ai@latest", - "claude_code": "@anthropic-ai/claude-code@latest", + "codex": "@openai/codex@0.121.0", + "opencode": "opencode-ai@1.4.6", + "claude_code": "@anthropic-ai/claude-code@2.1.111", "qwen_code": "@qwen-code/qwen-code@0.14.5", } +# INIT stage: install the harness CLI, then stage the repo into the workspace. _PREPARE_BASE = ( "rm -rf /polar/session/workspace && " "mkdir -p /polar/session/logs/agent /polar/session/workspace \"$HOME/.venv/bin\" && " @@ -49,74 +55,43 @@ def prepare_command_for_harness(harness: str) -> str: - pkg = HARNESS_NPM_PACKAGE[harness] - return f"npm install -g {pkg} && {_PREPARE_BASE}" + return f"npm install -g {HARNESS_NPM_PACKAGE[harness]} && {_PREPARE_BASE}" def runtime_env_for_harness(harness: str) -> dict[str, str]: - env: dict[str, str] = {} - if harness == "opencode": - env["OPENCODE_FAKE_VCS"] = "git" - return env + return {"OPENCODE_FAKE_VCS": "git"} if harness == "opencode" else {} def evaluator_exclude_patterns_for_harness(harness: str) -> list[str]: patterns: list[str] = [] if harness == "claude_code": - patterns.extend([".claude/**", "**/.claude/**"]) + patterns += [".claude/**", "**/.claude/**"] if harness == "qwen_code": - patterns.extend([".qwen/**", "**/.qwen/**"]) + patterns += [".qwen/**", "**/.qwen/**"] return patterns - def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--harness", required=True, choices=SUPPORTED_HARNESSES) - parser.add_argument( - "--topology", - default=os.environ.get("POLAR_TOPOLOGY", str(DEFAULT_TOPOLOGY)), - ) - parser.add_argument( - "--num-samples", - type=int, - default=int(os.environ.get("NUM_SAMPLES", "1")), - ) - parser.add_argument( - "--max-tasks", - type=int, - default=int(os.environ.get("MAX_TASKS", "-1")), - help="Maximum tasks to submit. -1 = all 500.", - ) + parser.add_argument("--topology", default=os.environ.get("POLAR_TOPOLOGY", str(DEFAULT_TOPOLOGY))) + parser.add_argument("--num-samples", type=int, default=1, help="Samples per task (pass@k).") + parser.add_argument("--max-tasks", type=int, default=-1, help="Maximum tasks to submit. -1 = all 500.") parser.add_argument("--instance-id", action="append", default=[]) parser.add_argument("--timeout-seconds", type=float, default=3600.0) - parser.add_argument( - "--runtime-backend", - choices=["docker", "apptainer"], - default=os.environ.get("RUNTIME_BACKEND", "apptainer"), - ) - parser.add_argument("--output-dir", default=None) - parser.add_argument( - "--max-concurrent", - type=int, - default=int(os.environ.get("MAX_CONCURRENT", "32")), - help="Max parallel task submissions.", - ) + parser.add_argument("--runtime-backend", choices=["docker", "apptainer"], default="docker") parser.add_argument( "--model-name", - default=os.environ.get("MODEL_NAME", "gpt-5.4"), - help="Model name passed to the agent harness; Polar rewrites it to the served model.", + default="gpt-5.4", + help="Model name the harness sends; the gateway rewrites it to the served model.", ) return parser.parse_args() -# --------------------------------------------------------------------------- -# IO helpers -# --------------------------------------------------------------------------- - -def write_json(path: Path, payload: Any) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, indent=2, ensure_ascii=True, sort_keys=True)) +def runtime_image_for_backend(image: str, backend: str) -> str: + if backend == "apptainer" and not image.startswith(("docker-daemon:", "docker://", "oras://")): + return f"docker-daemon:{image}" + return image def docker_image_exists(image_ref: str) -> bool: @@ -128,18 +103,6 @@ def docker_image_exists(image_ref: str) -> bool: ).returncode == 0 -def runtime_image_for_backend(image: str, backend: str) -> str: - if backend != "apptainer": - return image - if image.startswith(("docker-daemon:", "docker://", "oras://")): - return image - return f"docker-daemon:{image}" - - -# --------------------------------------------------------------------------- -# Task construction -# --------------------------------------------------------------------------- - def select_instances(args: argparse.Namespace) -> list[dict[str, Any]]: instances = load_swebench_verified() if args.instance_id: @@ -154,12 +117,7 @@ def select_instances(args: argparse.Namespace) -> list[dict[str, Any]]: return instances -def build_task_request( - args: argparse.Namespace, - *, - instance: dict[str, Any], - batch_id: str, -) -> dict[str, Any]: +def build_task_request(args: argparse.Namespace, instance: dict[str, Any], batch_id: str) -> dict[str, Any]: instance_id = str(instance["instance_id"]) image = runtime_image_for_instance(instance_id) return { @@ -175,21 +133,13 @@ def build_task_request( "network": "host", "workdir": "/polar/session/workspace", }, - "agent": { - "harness": args.harness, - "model_name": args.model_name, - "settings": {}, - "env": {}, - }, + "agent": {"harness": args.harness, "model_name": args.model_name}, "builder": {"strategy": "prefix_merging"}, "evaluator": { "strategy": "swebench_harness", "config": { "repo_dir": "/testbed", - "patch_command": ( - "cd /polar/session/workspace && " - "git add -A && git diff --cached --binary" - ), + "patch_command": "cd /polar/session/workspace && git add -A && git diff --cached --binary", "instance": instance, "exclude_patterns": evaluator_exclude_patterns_for_harness(args.harness), }, @@ -198,182 +148,88 @@ def build_task_request( } -# --------------------------------------------------------------------------- -# Submission & result handling -# --------------------------------------------------------------------------- - -def submit_task_file( - request_path: Path, - *, - topology: str | None, -) -> dict[str, Any]: - command = [sys.executable, "-m", "polar.cli", "submit", str(request_path), "--json"] - if topology: - command.extend(["-c", topology]) - completed = subprocess.run(command, check=True, capture_output=True, text=True) - return json.loads(completed.stdout) - - -def summarize_result(response: dict[str, Any]) -> dict[str, Any]: - sessions = response.get("results") or [] - reward_one = completed = session_errors = trajectory_errors = 0 +def task_stats(result: dict[str, Any]) -> tuple[int, int]: + """Return (sessions with reward==1, total sessions) for one finished task.""" + sessions = result.get("results") or [] + reward_one = 0 for session in sessions: - if session.get("status") == "COMPLETED": - completed += 1 - if session.get("error"): - session_errors += 1 - traj = session.get("trajectory") or {} - if traj.get("status") == "ERROR" or traj.get("error"): - trajectory_errors += 1 - traces = traj.get("traces") or [] + traces = (session.get("trajectory") or {}).get("traces") or [] if traces and traces[-1].get("reward") == 1.0: reward_one += 1 - return { - "total_sessions": len(sessions), - "completed_sessions": completed, - "reward_one_sessions": reward_one, - "session_errors": session_errors, - "trajectory_errors": trajectory_errors, - } + return reward_one, len(sessions) -def print_reward_summary(summaries: list[dict[str, Any]]) -> None: - total_tasks = len(summaries) - resolved = sum(1 for s in summaries if s.get("reward_one_sessions", 0) > 0) - total_sessions = sum(s.get("total_sessions", 0) for s in summaries) - total_reward_one = sum(s.get("reward_one_sessions", 0) for s in summaries) - total_errors = sum(s.get("session_errors", 0) + s.get("trajectory_errors", 0) for s in summaries) +def print_summary(stats: dict[str, tuple[int, int]], elapsed: float) -> None: + total_tasks = len(stats) + resolved = sum(1 for r1, _ in stats.values() if r1 > 0) + total_sessions = sum(total for _, total in stats.values()) + reward_one = sum(r1 for r1, _ in stats.values()) print("\n" + "=" * 72) print(" SWE-bench Verified — Reward Summary") print("=" * 72) - print(f" Tasks submitted: {total_tasks}") - print( - f" Tasks resolved (≥1): {resolved}/{total_tasks}" - f" ({100 * resolved / max(total_tasks, 1):.1f}%)" - ) - print(f" Total sessions: {total_sessions}") - print( - f" Sessions reward=1: {total_reward_one}/{total_sessions}" - f" ({100 * total_reward_one / max(total_sessions, 1):.1f}%)" - ) - if total_errors: - print(f" Errors: {total_errors}") + print(f" Tasks resolved (>=1): {resolved}/{total_tasks} ({100 * resolved / max(total_tasks, 1):.1f}%)") + print(f" Sessions reward=1: {reward_one}/{total_sessions} ({100 * reward_one / max(total_sessions, 1):.1f}%)") + print(f" Wall time: {elapsed:.0f}s") print("=" * 72) - - print(f"\n {'Instance ID':<45} {'Resolved':>10} {'Sessions':>10}") - print(" " + "-" * 65) - for s in sorted(summaries, key=lambda x: x.get("instance_id", "")): - iid = s.get("instance_id", "???") - r1 = s.get("reward_one_sessions", 0) - total = s.get("total_sessions", 0) - print(f" {iid:<45} {r1:>5}/{total:<4} {total:>10}") - print() + print(f"\n {'Instance ID':<45} {'Resolved':>12}") + print(" " + "-" * 59) + for iid in sorted(stats): + r1, total = stats[iid] + print(f" {iid:<45} {f'{r1}/{total}':>12}") + print("\n Per-session detail: polar dashboard -c examples/swebench_verified/topology.yaml") -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - def main() -> int: args = parse_args() - batch_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + batch_id = time.strftime("%Y%m%dT%H%M%SZ", time.gmtime()) instances = select_instances(args) if not instances: raise SystemExit("No instances selected.") - print(f"Selected {len(instances)} SWE-bench Verified instance(s).") - ready: list[dict[str, Any]] = [] - missing_images: list[str] = [] + ready, missing = [], [] for instance in instances: image_ref = runtime_image_for_instance(str(instance["instance_id"])) - if docker_image_exists(image_ref): - ready.append(instance) - else: - missing_images.append(image_ref) - - if args.instance_id and missing_images: - raise SystemExit( - f"Missing runtime image(s):\n" - + "\n".join(f" - {img}" for img in missing_images) - + "\nRun: python build_images.py" - ) + (ready if docker_image_exists(image_ref) else missing).append(instance) if not ready: - raise SystemExit( - "No runtime images found.\n" - "Run: python build_images.py" - ) - if missing_images: - print(f"Skipping {len(missing_images)} instance(s) with missing images.") + raise SystemExit("No runtime images found. Run: python build_images.py") + if missing: + print(f"Skipping {len(missing)} instance(s) with missing images. Build them with: python build_images.py") instances = ready - output_dir = ( - Path(args.output_dir) if args.output_dir - else EXAMPLE_DIR / args.harness / "batches" / batch_id - ) - manifest = { - "dataset": DATASET_NAME, - "batch_id": batch_id, - "harness": args.harness, - "num_samples": args.num_samples, - "total_tasks": len(instances), - } - write_json(output_dir / "manifest.json", manifest) - - prepared: list[tuple[str, dict[str, Any], Path, Path]] = [] - for instance in instances: - iid = str(instance["instance_id"]) - task_dir = output_dir / sanitize_instance_id(iid) - req_path = task_dir / "request.json" - resp_path = task_dir / "response.json" - payload = build_task_request(args, instance=instance, batch_id=batch_id) - write_json(req_path, payload) - prepared.append((iid, payload, req_path, resp_path)) - print(f"Wrote {len(prepared)} request file(s) under {output_dir}") - - max_workers = min(len(prepared), args.max_concurrent) - print(f"Submitting {len(prepared)} task(s) (max_concurrent={max_workers}) ...") - - summaries: list[dict[str, Any]] = [] - - def _submit_one(item: tuple[str, dict[str, Any], Path, Path]) -> dict[str, Any]: - iid, payload, req_path, resp_path = item - result = submit_task_file(req_path, topology=args.topology) - write_json(resp_path, result) - summary = { - "instance_id": iid, - "task_id": payload["task_id"], - "response_path": str(resp_path), - **summarize_result(result), - } - print( - f" [{iid}] reward_1={summary['reward_one_sessions']}" - f"/{summary['total_sessions']}" - ) - return summary - - with ThreadPoolExecutor(max_workers=max_workers) as pool: - futures = {pool.submit(_submit_one, item): item for item in prepared} - for future in as_completed(futures): - try: - summaries.append(future.result()) - except Exception as exc: - iid = futures[future][0] - print(f" [{iid}] FAILED: {exc}") - summaries.append({ - "instance_id": iid, - "task_id": futures[future][1]["task_id"], - "error": str(exc), - "total_sessions": 0, - "completed_sessions": 0, - "reward_one_sessions": 0, - "session_errors": 1, - "trajectory_errors": 0, - }) - - write_json(output_dir / "summary.json", summaries) - print_reward_summary(summaries) - print(f"Batch summary: {output_dir / 'summary.json'}") + from polar.config import TopologyConfig + + rollout_url = TopologyConfig.load(args.topology).rollout.public_url + print(f"Submitting {len(instances)} task(s) to {rollout_url} " + f"(harness={args.harness}, samples={args.num_samples}, backend={args.runtime_backend})") + + timeout = httpx.Timeout(None, connect=30.0) + with httpx.Client(base_url=rollout_url, timeout=timeout) as client: + task_ids: dict[str, str] = {} # instance_id -> rollout task_id + for instance in instances: + iid = str(instance["instance_id"]) + payload = build_task_request(args, instance, batch_id) + resp = client.post("/rollout/task/submit", json=payload) + resp.raise_for_status() + task_ids[iid] = resp.json()["task_id"] + + print(f"Polling every {POLL_INTERVAL_SECONDS:.0f}s (watch live in the dashboard) ...") + t0 = time.monotonic() + stats: dict[str, tuple[int, int]] = {} + while len(stats) < len(task_ids): + time.sleep(POLL_INTERVAL_SECONDS) + for iid, tid in task_ids.items(): + if iid in stats: + continue + status = client.get(f"/rollout/task/{tid}").json() + if status["status"] != "running": + r1, total = task_stats(status) + stats[iid] = (r1, total) + print(f" [{time.monotonic() - t0:>5.0f}s] {iid:<45} resolved={r1}/{total} " + f"({len(stats)}/{len(task_ids)} done)") + elapsed = time.monotonic() - t0 + + print_summary(stats, elapsed) return 0 diff --git a/examples/swebench_verified/topology.yaml b/examples/swebench_verified/topology.yaml index e723e5b75..d0ce92b96 100644 --- a/examples/swebench_verified/topology.yaml +++ b/examples/swebench_verified/topology.yaml @@ -14,8 +14,9 @@ gateway: max_init_workers: 8 max_run_workers: 8 max_postrun_workers: 8 - model_served: Qwen/Qwen3.5-4B - sglang: + model_served: Qwen/Qwen3.6-27B + inference: + engine: vllm base_url: http://127.0.0.1:8000 - id: localhost-node-02 host: 127.0.0.1 @@ -24,6 +25,7 @@ gateway: max_init_workers: 8 max_run_workers: 8 max_postrun_workers: 8 - model_served: Qwen/Qwen3.5-4B - sglang: + model_served: Qwen/Qwen3.6-27B + inference: + engine: vllm base_url: http://127.0.0.1:8001 diff --git a/examples/swegym_slime_grpo/README.md b/examples/swegym_slime_grpo/README.md index 7d0b08b93..b3f3216a7 100644 --- a/examples/swegym_slime_grpo/README.md +++ b/examples/swegym_slime_grpo/README.md @@ -1,41 +1,71 @@ # SWE-Gym Slime GRPO -This example connects **Polar** rollout sessions with **Slime** GRPO training on -SWE-Gym tasks. +End-to-end **training** example: train **Qwen3.5-4B** with async **GRPO** on +**SWE-Gym** tasks, using **Polar** for agent rollouts and **Slime** for training. +Targets a single node with 8× B200 (2 GPUs train, 6 serve). -The exact model, topology, dependency pins, training arguments, worker counts, -ports, and harness settings are intentionally kept in the executable scripts and -YAML files. Treat those files as the source of truth, since the configuration is -expected to change as the example evolves. +> Unlike the rollout demos (calculator / count_stars / swebench_verified), this +> path serves the model with **SGLang**: Slime owns the inference engines and +> syncs the freshly trained weights into them every step (GPU-to-GPU NCCL). -## Installation +## Prerequisites -Follow [Slime bridge installation guide](../../src/slime_bridge/README.md#slime-installation) to install the proper versions of Slime and Megatron. +Install Polar and SGLang per [Polar installation](../../README.md#installation). + +Make sure to install Polar's optional swebench dependency for evaluation. +``` +uv pip install -e ".[swebench]" +``` ## Quick Start +Log into wandb and one command sets everything up and starts training: + ```bash +export WANDB_API_KEY= bash examples/swegym_slime_grpo/launch_e2e.sh ``` -`launch_e2e.sh` is the single-entry setup and run script for this example on a single node 8 x B200. It -creates the Slime and Megatron-LM checkouts, installs Polar, applies the -Slime and SGLang custom patches, builds the SWE-Gym data and Apptainer images, converts Qwen model weights -into Megatron format, and then hands off to `run.sh` to -launch the Polar rollout workers and Slime GRPO training job. +It clones Slime + Megatron-LM, installs the training-stack extras (Transformer +Engine; Flash Linear Attention; flash-attn on B200), applies the Slime/SGLang +patches, builds the +293-task SWE-Gym JSONL, pulls the Apptainer images + shared agent CLIs, converts +the Qwen weights to torch_dist, then hands off to `run.sh` (Polar services + Ray + the Slime training job). + +## (Optional) Watch rollouts in the dashboard + +While training runs, start the dashboard **from the repo root** (so its +`./rollout_results` path matches the rollout server's) to inspect live agent +sessions, trajectories, and the rewards feeding each training step: -- Adjust `topology.yaml` based on your hardware setups. -- Adjust `polar_config.yaml` for Polar side configs like harness to use (codex / claude_code / qwen_code / opencode / pi), async level, timeout, etc. -- Adjust `run.sh` for Slime side training arguments. +```bash +uv run polar dashboard -c tmp/swegym_slime_grpo/topology.yaml +``` + +Open . (`tmp/swegym_slime_grpo/topology.yaml` is the +rendered topology that `run.sh` writes at launch.) ## Files | File | Purpose | |---|---| -| `launch_e2e.sh` | Single-entry launcher for setup and execution | -| `run.sh` | Main training and rollout launch script | -| `polar_config.yaml` | Polar rollout task template and harness configuration | -| `topology.yaml` | Polar cluster and gateway topology | -| `prepare_data.py` | Builds the SWE-Gym JSONL data used by Slime | -| `prepare_apptainer_images.py` | Prepares local Apptainer SIF images and shared agent CLI assets | -| `convert_weights.sh` | Converts model weights into the format used by Slime | +| `launch_e2e.sh` | One-shot entry: setup + run | +| `run.sh` | Launches Polar services + Ray + Slime training job | +| `convert_weights.sh` | HF checkpoint → Megatron torch_dist | +| `model_args.sh` | Qwen3.5-4B Megatron args, shared by `run.sh` + `convert_weights.sh` | +| `topology.yaml` | Polar topology template (`${SGLANG_ROUTER_BASE_URL}` filled at runtime) | +| `polar_config.yaml` | Polar bridge config template (`${AGENT_CLI_DIR}`, `${APPTAINER_IMAGE_DIR}` filled at runtime) | +| `prepare_data.py` | Builds `swegym_train_293.jsonl` | +| `prepare_apptainer_images.py` | Pulls per-task SIF images, builds shared Node + agent CLI dir | +| `sample_tasks.py` | Dataset helpers (HF fetch, registry image lookup) | + +## Common knobs + +| What you want to tune | Where | +|---|---| +| Train/rollout GPU split, batch size, KL coef, LR | `run.sh` (env vars near top + Slime args at bottom) | +| Which agent harness (qwen_code / claude_code / codex / opencode / pi) | `polar_config.yaml` → `agent.harness` | +| Per-task timeout, async level, callback host | `polar_config.yaml` → `polar_*` keys | +| Gateway/rollout host & port, model served | `topology.yaml` | +| Which SWE-Gym dataset / split | `sample_tasks.py` → `DATASET_NAME`, `DATASET_SPLITS` | +| Model architecture args (don't change unless swapping models) | `model_args.sh` | diff --git a/examples/swegym_slime_grpo/convert_weights.sh b/examples/swegym_slime_grpo/convert_weights.sh index 9a58dfa66..d962e3849 100755 --- a/examples/swegym_slime_grpo/convert_weights.sh +++ b/examples/swegym_slime_grpo/convert_weights.sh @@ -10,6 +10,12 @@ PROJECT_ROOT="$(cd -- "${SCRIPT_DIR}/../.." && pwd)" SLIME_DIR="${SLIME_DIR:-${PROJECT_ROOT}/slime}" MEGATRON_DIR="${MEGATRON_DIR:-${PROJECT_ROOT}/Megatron-LM}" +PYTHON_BIN="${PYTHON_BIN:-${PROJECT_ROOT}/.venv/bin/python3}" +if [ ! -x "${PYTHON_BIN}" ]; then + PYTHON_BIN="$(command -v python3 || command -v python)" +fi +PYTHON_BIN_DIR="$(cd -- "$(dirname -- "${PYTHON_BIN}")" &>/dev/null && pwd)" +export PATH="${PYTHON_BIN_DIR}:${PATH}" if [ ! -f "${SLIME_DIR}/tools/convert_hf_to_torch_dist.py" ]; then echo "ERROR: Slime not found at ${SLIME_DIR}. Clone it first:" @@ -21,31 +27,8 @@ HF_CHECKPOINT="${HF_CHECKPOINT:-Qwen/Qwen3.5-4B}" OUTPUT_DIR="${TORCH_DIST_DIR:-${PROJECT_ROOT}/tmp/checkpoints/Qwen3.5-4B_torch_dist}" mkdir -p "$OUTPUT_DIR" -# Mirrors slime/slime/scripts/models/qwen3.5-4B.sh. -# --spec installs the hybrid (GatedDeltaNet + full) attention layer layout. -# tie_word_embeddings=true in the HF config → do NOT pass --untie-embeddings-and-output-weights. -MODEL_ARGS=( - --spec "slime_plugins.models.qwen3_5" "get_qwen3_5_spec" - --disable-bias-linear - --qk-layernorm - --group-query-attention - --num-attention-heads 16 - --num-query-groups 4 - --kv-channels 256 - --num-layers 32 - --hidden-size 2560 - --ffn-hidden-size 9216 - --use-gated-attention - --normalization RMSNorm - --apply-layernorm-1p - --position-embedding-type rope - --norm-epsilon 1e-6 - --rotary-percent 0.25 - --swiglu - --vocab-size 248320 - --rotary-base 10000000 - --attention-output-gate -) +# shellcheck source=./model_args.sh +source "${SCRIPT_DIR}/model_args.sh" echo "Converting ${HF_CHECKPOINT} -> ${OUTPUT_DIR}" diff --git a/examples/swegym_slime_grpo/launch_e2e.sh b/examples/swegym_slime_grpo/launch_e2e.sh index edec8d765..04f57c9e0 100755 --- a/examples/swegym_slime_grpo/launch_e2e.sh +++ b/examples/swegym_slime_grpo/launch_e2e.sh @@ -15,6 +15,8 @@ PYTHON_BIN="${PYTHON_BIN:-${PROJECT_ROOT}/.venv/bin/python3}" if [ ! -x "${PYTHON_BIN}" ]; then PYTHON_BIN="$(command -v python3 || command -v python)" fi +PYTHON_BIN_DIR="$(cd -- "$(dirname -- "${PYTHON_BIN}")" &>/dev/null && pwd)" +export PATH="${PYTHON_BIN_DIR}:${PATH}" SLIME_DIR="${SLIME_DIR:-${PROJECT_ROOT}/slime}" SLIME_REPO="${SLIME_REPO:-https://github.com/THUDM/slime.git}" SLIME_REF="${SLIME_REF:-v0.2.4}" @@ -22,28 +24,30 @@ SLIME_REF="${SLIME_REF:-v0.2.4}" MEGATRON_DIR="${MEGATRON_DIR:-${PROJECT_ROOT}/Megatron-LM}" MEGATRON_REPO="${MEGATRON_REPO:-https://github.com/NVIDIA/Megatron-LM.git}" MEGATRON_REF="${MEGATRON_REF:-main}" +# SWE-Gym's fork of the SWE-bench harness (grades the SWE-Gym instances). Not on +# PyPI, so installed from git; commit-pinned for reproducibility. +SWEGYM_PACKAGE_SPEC="${SWEGYM_PACKAGE_SPEC:-swegym @ git+https://github.com/SWE-Gym/SWE-Bench-Package.git@16dd480cce9b27bf111a362d280881c6def5d2a7}" HF_CHECKPOINT="${HF_CHECKPOINT:-Qwen/Qwen3.5-4B}" REF_LOAD="${REF_LOAD:-${TORCH_DIST_DIR:-${PROJECT_ROOT}/tmp/checkpoints/Qwen3.5-4B_torch_dist}}" TORCH_DIST_DIR="${TORCH_DIST_DIR:-${REF_LOAD}}" -if [ -n "${WANDB_RUN_ID:-}" ]; then - SAVE_DIR="${SAVE_DIR:-${PROJECT_ROOT}/tmp/ckpt/swegym_slime_grpo_qwen35_4b/${WANDB_RUN_ID}}" -fi -SAVE_DIR="${SAVE_DIR:-${PROJECT_ROOT}/tmp/ckpt/swegym_slime_grpo_qwen35_4b}" +RUN_ID="${RUN_ID:-${WANDB_RUN_ID:-swegym-slime-grpo-$(date -u +%Y%m%dT%H%M%SZ)}}" +SAVE_ROOT="${SAVE_ROOT:-${PROJECT_ROOT}/tmp/ckpt/swegym_slime_grpo_qwen35_4b}" +SAVE_DIR="${SAVE_DIR:-${SAVE_ROOT}/${RUN_ID}}" AGENT_CLI_DIR="${AGENT_CLI_DIR:-${PROJECT_ROOT}/tmp/swegym_agent_cli/opt_node}" APPTAINER_IMAGE_DIR="${APPTAINER_IMAGE_DIR:-${PROJECT_ROOT}/tmp/swegym_apptainer_images}" APPTAINER_CACHEDIR="${APPTAINER_CACHEDIR:-${PROJECT_ROOT}/tmp/apptainer_cache}" APPTAINER_TMPDIR="${APPTAINER_TMPDIR:-${PROJECT_ROOT}/tmp/apptainer_tmp}" -POLAR_APPTAINER_BIN="${POLAR_APPTAINER_BIN:-/usr/bin/apptainer}" +POLAR_APPTAINER_BIN="${POLAR_APPTAINER_BIN:-$(command -v apptainer || echo /usr/bin/apptainer)}" INSTALL_EDITABLE="${INSTALL_EDITABLE:-1}" +INSTALL_TRAINING_STACK="${INSTALL_TRAINING_STACK:-1}" # TE + FLA + flash-attn (SM100 only) +FLASH_LINEAR_ATTENTION_VERSION="${FLASH_LINEAR_ATTENTION_VERSION:-0.5.0}" +MBRIDGE_VERSION="${MBRIDGE_VERSION:-0.15.1}" # HF<->Megatron weight bridge (slime conversion) APPLY_SGLANG_PATCH="${APPLY_SGLANG_PATCH:-1}" PREPARE_IMAGES="${PREPARE_IMAGES:-1}" APPTAINER_PREPARE_JOBS="${APPTAINER_PREPARE_JOBS:-2}" CONVERT_WEIGHTS="${CONVERT_WEIGHTS:-auto}" -MONITOR_GPU="${MONITOR_GPU:-0}" -TRAIN_GPUS="${TRAIN_GPUS:-0,1}" -ROLLOUT_GPUS="${ROLLOUT_GPUS:-2,3,4,5,6,7}" export APPTAINER_CACHEDIR APPTAINER_TMPDIR POLAR_APPTAINER_BIN require_cmd() { @@ -92,19 +96,143 @@ if key and hasattr(wandb, "login"): PY } +flash_attn2_ready() { + "${PYTHON_BIN}" - <<'PY' +from importlib.metadata import PackageNotFoundError, version + +try: + installed = version("flash-attn") +except PackageNotFoundError: + raise SystemExit(1) + +if installed != "2.7.4.post1": + raise SystemExit(1) + +try: + import torch # noqa: F401 + import flash_attn_2_cuda # noqa: F401 + import flash_attn.flash_attn_interface # noqa: F401 +except Exception: + raise SystemExit(1) +PY +} + +swegym_harness_ready() { + "${PYTHON_BIN}" - <<'PY' +try: + from swegym.harness.constants import MAP_REPO_VERSION_TO_SPECS + from swegym.harness.grading import get_eval_report # noqa: F401 + from swegym.harness.test_spec import make_test_spec # noqa: F401 +except Exception: + raise SystemExit(1) + +needed = {"dask/dask", "python/mypy", "pandas-dev/pandas"} +if not needed.issubset(MAP_REPO_VERSION_TO_SPECS): + raise SystemExit(1) +PY +} + +ensure_swegym_harness() { + if swegym_harness_ready; then + echo "SWE-Gym harness package present; skipping." + else + echo "Installing SWE-Gym harness package..." + uv pip install --python "${PYTHON_BIN}" "${SWEGYM_PACKAGE_SPEC}" + fi +} + +# Install the GPU training-stack extras the editable installs do NOT pull. +# Idempotent — skips whatever is already importable. +# - Transformer Engine: required by Megatron on ANY GPU (its torch bindings build +# from source and need cuDNN headers from the pip nvidia-cudnn package). +# - Flash Linear Attention: required by Qwen3.5 GatedDeltaNet linear-attention layers. +# - flash-attn 2.x from source: ONLY on SM100/B200, where TE's cuDNN backend has no +# head_dim=256 kernel. Built for the detected arch; skipped on every other GPU so +# this script stays safe on H100 etc. (no wasted/failed builds). +ensure_training_stack() { + local cuda_home cudnn_path cc="" + cuda_home="${CUDA_HOME:-/usr/local/cuda}" + if [ ! -d "$cuda_home" ] && command -v nvcc >/dev/null 2>&1; then + cuda_home="$(dirname "$(dirname "$(command -v nvcc)")")" + fi + cudnn_path="$("${PYTHON_BIN}" -c 'import nvidia.cudnn; print(list(nvidia.cudnn.__path__)[0])' 2>/dev/null || true)" + + # --- Transformer Engine 2.5.0 (general) --- + if LD_LIBRARY_PATH="${cudnn_path}/lib:${LD_LIBRARY_PATH:-}" \ + "${PYTHON_BIN}" -c "import transformer_engine.pytorch" >/dev/null 2>&1; then + echo "Transformer Engine present; skipping." + else + if ! command -v nvcc >/dev/null 2>&1; then + echo "ERROR: nvcc not found — needed to build transformer-engine-torch." >&2 + echo " Install the CUDA toolkit (or set CUDA_HOME), then re-run." >&2 + exit 1 + fi + if [ -z "$cudnn_path" ]; then + echo "ERROR: pip 'nvidia-cudnn' not found in venv — is torch a CUDA build?" >&2 + echo " TE's source build needs its cuDNN headers; fix torch first (see README)." >&2 + exit 1 + fi + echo "Installing Transformer Engine 2.5.0 (building transformer-engine-torch)..." + uv pip install --python "${PYTHON_BIN}" ninja pybind11 setuptools wheel >/dev/null 2>&1 || true + CUDA_HOME="$cuda_home" \ + CPATH="${cudnn_path}/include:${cuda_home}/include:${CPATH:-}" \ + LIBRARY_PATH="${cudnn_path}/lib:${cuda_home}/lib64:${LIBRARY_PATH:-}" \ + uv pip install --python "${PYTHON_BIN}" --no-build-isolation "transformer-engine[pytorch]==2.5.0" + fi + + # --- Flash Linear Attention (Qwen3.5 linear-attention layers) --- + if "${PYTHON_BIN}" -c \ + "from fla.modules import FusedRMSNormGated, ShortConvolution; from fla.ops.gated_delta_rule import chunk_gated_delta_rule" \ + >/dev/null 2>&1; then + echo "Flash Linear Attention present; skipping." + else + echo "Installing Flash Linear Attention ${FLASH_LINEAR_ATTENTION_VERSION}..." + uv pip install --python "${PYTHON_BIN}" "flash-linear-attention==${FLASH_LINEAR_ATTENTION_VERSION}" + fi + + # --- flash-attn 2.x from source: SM100 (B200) only --- + cc="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -1 | tr -d '[:space:]')" || true + if [ "$cc" = "10.0" ]; then + if flash_attn2_ready; then + echo "flash-attn (2.x) present; skipping." + else + echo "SM100 detected: building flash-attn 2.7.4.post1 (head_dim=256 fallback)..." + TORCH_CUDA_ARCH_LIST=10.0 FLASH_ATTN_CUDA_ARCHS=100 FLASH_ATTENTION_FORCE_BUILD=TRUE \ + MAX_JOBS="${FA_MAX_JOBS:-32}" NVCC_THREADS=4 \ + uv pip install --python "${PYTHON_BIN}" --no-build-isolation flash-attn==2.7.4.post1 + fi + else + echo "compute_cap=${cc:-unknown} is not SM100; skipping flash-attn 2.x source build." + echo " (TE's cuDNN backend serves head_dim=256 off SM100. If training later aborts with" + echo " 'No dot product attention backend', build flash-attn 2.x for your arch — see README.)" + fi +} + require_cmd git require_cmd "${PYTHON_BIN}" require_cmd uv require_cmd "${POLAR_APPTAINER_BIN}" -require_cmd ray +require_cmd envsubst # run.sh uses it to render YAML templates clone_if_missing "Slime" "${SLIME_REPO}" "${SLIME_REF}" "${SLIME_DIR}" clone_if_missing "Megatron-LM" "${MEGATRON_REPO}" "${MEGATRON_REF}" "${MEGATRON_DIR}" if [ "${INSTALL_EDITABLE}" = "1" ]; then - uv pip install -e . - uv pip install -e "${SLIME_DIR}" - uv pip install -e "${MEGATRON_DIR}" + # [swebench] is load-bearing even though swegym (installed below) does the actual + # grading: swegym is a swebench fork that ships NO deps of its own, so it reuses + # swebench's dependency tree (datasets, docker, ghapi, unidiff, dotenv, requests...). + # So we need both — [swebench] for the deps, swegym for the SWE-Gym repo specs. + uv pip install --python "${PYTHON_BIN}" -e ".[swebench]" + uv pip install --python "${PYTHON_BIN}" -e "${SLIME_DIR}" + uv pip install --python "${PYTHON_BIN}" -e "${MEGATRON_DIR}" + # mbridge: HF<->Megatron weight map slime needs to convert Qwen3.5 (slime_plugins.mbridge). + # --no-deps keeps the pinned torch / TE / flash-attn stack untouched. + uv pip install --python "${PYTHON_BIN}" --no-deps "mbridge==${MBRIDGE_VERSION}" + ensure_swegym_harness +fi + +if [ "${INSTALL_TRAINING_STACK}" = "1" ]; then + ensure_training_stack fi bash "${PROJECT_ROOT}/scripts/patch/patch_slime.sh" "${SLIME_DIR}" @@ -133,26 +261,12 @@ fi maybe_login_wandb -MONITOR_PID="" -if [ "${MONITOR_GPU}" = "1" ]; then - uv run python "${PROJECT_ROOT}/scripts/monitor_wandb_gpu.py" \ - --no-wandb \ - --train-gpus "${TRAIN_GPUS}" \ - --rollout-gpus "${ROLLOUT_GPUS}" & - MONITOR_PID="$!" -fi - -cleanup() { - if [ -n "${MONITOR_PID}" ]; then - kill "${MONITOR_PID}" 2>/dev/null || true - fi -} -trap cleanup EXIT - HF_CHECKPOINT="${HF_CHECKPOINT}" \ REF_LOAD="${REF_LOAD}" \ TORCH_DIST_DIR="${TORCH_DIST_DIR}" \ SAVE_DIR="${SAVE_DIR}" \ +RUN_ID="${RUN_ID}" \ +SAVE_ROOT="${SAVE_ROOT}" \ PYTHON_BIN="${PYTHON_BIN}" \ SLIME_DIR="${SLIME_DIR}" \ MEGATRON_DIR="${MEGATRON_DIR}" \ diff --git a/examples/swegym_slime_grpo/model_args.sh b/examples/swegym_slime_grpo/model_args.sh new file mode 100644 index 000000000..c580c2fcc --- /dev/null +++ b/examples/swegym_slime_grpo/model_args.sh @@ -0,0 +1,28 @@ +# shellcheck shell=bash +# Qwen3.5-4B Megatron model args, shared by convert_weights.sh and run.sh. +# +# Mirrors slime/slime/scripts/models/qwen3.5-4B.sh. +# --spec wires in the hybrid GatedDeltaNet + full-attention layer layout. +# tie_word_embeddings=true in HF config → do NOT add --untie-embeddings-and-output-weights. +MODEL_ARGS=( + --spec "slime_plugins.models.qwen3_5" "get_qwen3_5_spec" + --disable-bias-linear + --qk-layernorm + --group-query-attention + --num-attention-heads 16 + --num-query-groups 4 + --kv-channels 256 + --num-layers 32 + --hidden-size 2560 + --ffn-hidden-size 9216 + --use-gated-attention + --normalization RMSNorm + --apply-layernorm-1p + --position-embedding-type rope + --norm-epsilon 1e-6 + --rotary-percent 0.25 + --swiglu + --vocab-size 248320 + --rotary-base 10000000 + --attention-output-gate +) diff --git a/examples/swegym_slime_grpo/polar_config.yaml b/examples/swegym_slime_grpo/polar_config.yaml index e31dea450..1203e4b18 100644 --- a/examples/swegym_slime_grpo/polar_config.yaml +++ b/examples/swegym_slime_grpo/polar_config.yaml @@ -10,9 +10,11 @@ polar_callback_host: "127.0.0.1" polar_min_complete_accept_fraction: 0.6 polar_allow_weight_update_overlap: false polar_weight_update_pause_timeout: 300 -# run.sh/launch_e2e.sh render this to an absolute host path before launch. -polar_agent_cli_dir: "tmp/swegym_agent_cli/opt_node" -polar_apptainer_image_dir: "tmp/swegym_apptainer_images" +# Filled by run.sh via envsubst — absolute host paths to assets prepared by +# prepare_apptainer_images.py. The container mounts and image lookups below +# reference these through {args.polar_agent_cli_dir} / {args.polar_apptainer_image_dir}. +polar_agent_cli_dir: ${AGENT_CLI_DIR} +polar_apptainer_image_dir: ${APPTAINER_IMAGE_DIR} # Task template rendered per sample group. Placeholders are filled # from the context: {instruction}, {num_samples}, {sample.*}, {args.*}. @@ -53,7 +55,7 @@ polar_task_template: ln -sf /opt/miniconda3/envs/testbed/bin/python "$HOME/.venv/bin/python3" && git config --global core.pager '' agent: - harness: "qwen_code" + harness: "codex" model_name: "gpt-5.4" builder: strategy: "prefix_merging" diff --git a/examples/swegym_slime_grpo/run.sh b/examples/swegym_slime_grpo/run.sh index be9027b45..f54f4a5a1 100755 --- a/examples/swegym_slime_grpo/run.sh +++ b/examples/swegym_slime_grpo/run.sh @@ -32,6 +32,7 @@ if [ ! -x "${PYTHON_BIN}" ]; then PYTHON_BIN="$(command -v python3 || command -v python)" fi PYTHON_BIN_DIR="$(cd -- "$(dirname -- "${PYTHON_BIN}")" &>/dev/null && pwd)" +export PATH="${PYTHON_BIN_DIR}:${PATH}" is_path_like() { case "$1" in @@ -78,7 +79,9 @@ fi # through slime_plugins.mbridge.qwen3_5 (text_config-aware) at convert-time. HF_CHECKPOINT="${HF_CHECKPOINT:-Qwen/Qwen3.5-4B}" REF_LOAD="${REF_LOAD:-${PROJECT_ROOT}/tmp/checkpoints/Qwen3.5-4B_torch_dist}" -SAVE_DIR="${SAVE_DIR:-${PROJECT_ROOT}/tmp/ckpt/swegym_slime_grpo_qwen35_4b}" +RUN_ID="${RUN_ID:-swegym-slime-grpo-$(date -u +%Y%m%dT%H%M%SZ)}" +SAVE_ROOT="${SAVE_ROOT:-${PROJECT_ROOT}/tmp/ckpt/swegym_slime_grpo_qwen35_4b}" +SAVE_DIR="${SAVE_DIR:-${SAVE_ROOT}/${RUN_ID}}" mkdir -p "$SAVE_DIR" if is_path_like "$HF_CHECKPOINT" && [ ! -e "$HF_CHECKPOINT" ]; then echo "ERROR: HF checkpoint not found at $HF_CHECKPOINT" @@ -91,31 +94,8 @@ if [ ! -d "$REF_LOAD" ] || [ ! -f "$REF_LOAD/latest_checkpointed_iteration.txt" exit 1 fi -# Mirrors slime/slime/scripts/models/qwen3.5-4B.sh. --spec wires in the hybrid -# GatedDeltaNet + full-attention layer layout. tie_word_embeddings=true in HF -# config → do NOT pass --untie-embeddings-and-output-weights. -MODEL_ARGS=( - --spec "slime_plugins.models.qwen3_5" "get_qwen3_5_spec" - --disable-bias-linear - --qk-layernorm - --group-query-attention - --num-attention-heads 16 - --num-query-groups 4 - --kv-channels 256 - --num-layers 32 - --hidden-size 2560 - --ffn-hidden-size 9216 - --use-gated-attention - --normalization RMSNorm - --apply-layernorm-1p - --position-embedding-type rope - --norm-epsilon 1e-6 - --rotary-percent 0.25 - --swiglu - --vocab-size 248320 - --rotary-base 10000000 - --attention-output-gate -) +# shellcheck source=./model_args.sh +source "${SCRIPT_DIR}/model_args.sh" # First run has an empty SAVE_DIR — slime's load_checkpoint asserts on empty. # Pick REF_LOAD (torch_dist) until the first save lands. @@ -126,62 +106,37 @@ else fi # ── Data ─────────────────────────────────────────────────────────── -PROMPT_DATA="${SCRIPT_DIR}/swegym_train_293.jsonl" +PROMPT_DATA="${PROMPT_DATA:-${SCRIPT_DIR}/swegym_train_293.jsonl}" if [ ! -f "$PROMPT_DATA" ]; then echo "Preparing train data..." "${PYTHON_BIN}" "${SCRIPT_DIR}/prepare_data.py" fi # ── Runtime configs ───────────────────────────────────────────────── -AGENT_CLI_DIR="${AGENT_CLI_DIR:-${PROJECT_ROOT}/tmp/swegym_agent_cli/opt_node}" -APPTAINER_IMAGE_DIR="${APPTAINER_IMAGE_DIR:-${PROJECT_ROOT}/tmp/swegym_apptainer_images}" -POLAR_APPTAINER_BIN="${POLAR_APPTAINER_BIN:-/usr/bin/apptainer}" -export POLAR_APPTAINER_BIN +export AGENT_CLI_DIR="${AGENT_CLI_DIR:-${PROJECT_ROOT}/tmp/swegym_agent_cli/opt_node}" +export APPTAINER_IMAGE_DIR="${APPTAINER_IMAGE_DIR:-${PROJECT_ROOT}/tmp/swegym_apptainer_images}" +# Prefer apptainer in PATH (HPC modules etc.); fall back to /usr/bin for Ubuntu defaults. +export POLAR_APPTAINER_BIN="${POLAR_APPTAINER_BIN:-$(command -v apptainer || echo /usr/bin/apptainer)}" SGLANG_ROUTER_PORT="${SGLANG_ROUTER_PORT:-9000}" SGLANG_ROUTER_HOST="${SGLANG_ROUTER_HOST:-$(detect_host_ip)}" -SGLANG_ROUTER_BASE_URL="${SGLANG_ROUTER_BASE_URL:-http://${SGLANG_ROUTER_HOST}:${SGLANG_ROUTER_PORT}}" +export SGLANG_ROUTER_BASE_URL="${SGLANG_ROUTER_BASE_URL:-http://${SGLANG_ROUTER_HOST}:${SGLANG_ROUTER_PORT}}" TOPOLOGY_TEMPLATE="${TOPOLOGY_TEMPLATE:-${SCRIPT_DIR}/topology.yaml}" POLAR_CONFIG_TEMPLATE="${POLAR_CONFIG_TEMPLATE:-${SCRIPT_DIR}/polar_config.yaml}" TOPOLOGY_PATH="${TOPOLOGY_PATH:-${RUN_DIR}/topology.yaml}" CUSTOM_CONFIG_PATH="${CUSTOM_CONFIG_PATH:-${RUN_DIR}/polar_config.yaml}" -"${PYTHON_BIN}" - "$TOPOLOGY_TEMPLATE" "$TOPOLOGY_PATH" "$SGLANG_ROUTER_BASE_URL" \ - "$POLAR_CONFIG_TEMPLATE" "$CUSTOM_CONFIG_PATH" "$AGENT_CLI_DIR" \ - "$APPTAINER_IMAGE_DIR" <<'PY' -from pathlib import Path -import sys -import yaml - -( - topology_template, - topology_out, - router_url, - polar_template, - polar_out, - agent_cli_dir, - apptainer_image_dir, -) = sys.argv[1:] - -with open(topology_template, encoding="utf-8") as fh: - topology = yaml.safe_load(fh) or {} -for node in topology.get("gateway", {}).get("nodes", []): - node.setdefault("sglang", {})["base_url"] = router_url -Path(topology_out).parent.mkdir(parents=True, exist_ok=True) -with open(topology_out, "w", encoding="utf-8") as fh: - yaml.safe_dump(topology, fh, sort_keys=False) - -with open(polar_template, encoding="utf-8") as fh: - polar_config = yaml.safe_load(fh) or {} -polar_config["polar_agent_cli_dir"] = agent_cli_dir -polar_config["polar_apptainer_image_dir"] = apptainer_image_dir -Path(polar_out).parent.mkdir(parents=True, exist_ok=True) -with open(polar_out, "w", encoding="utf-8") as fh: - yaml.safe_dump(polar_config, fh, sort_keys=False) -PY +# Render YAML templates: only the listed ${VARS} are expanded, so literal +# $HOME / $... inside polar_config.yaml are left untouched. +command -v envsubst >/dev/null || { echo "ERROR: envsubst not found (install gettext-base)"; exit 1; } +TEMPLATE_VARS='${SGLANG_ROUTER_BASE_URL} ${AGENT_CLI_DIR} ${APPTAINER_IMAGE_DIR}' +mkdir -p "$(dirname "$TOPOLOGY_PATH")" "$(dirname "$CUSTOM_CONFIG_PATH")" +envsubst "$TEMPLATE_VARS" < "$TOPOLOGY_TEMPLATE" > "$TOPOLOGY_PATH" +envsubst "$TEMPLATE_VARS" < "$POLAR_CONFIG_TEMPLATE" > "$CUSTOM_CONFIG_PATH" echo "Using topology: ${TOPOLOGY_PATH}" echo "Using Polar config: ${CUSTOM_CONFIG_PATH}" echo "Using Apptainer image dir: ${APPTAINER_IMAGE_DIR}" +echo "Using run id: ${RUN_ID}" echo "Using save dir: ${SAVE_DIR}" echo "Using SGLang router URL for Polar gateway: ${SGLANG_ROUTER_BASE_URL}" @@ -209,38 +164,46 @@ sleep 2 curl -sf http://127.0.0.1:8080/health || { echo "Polar rollout server not healthy"; exit 1; } # ── Step 2: Ray + Slime (manages SGLang engines + training) ─────── -echo "=== Starting Ray (all 8 GPUs) ===" +# GPU split — defaults are 2 training + 6 rollout (8x B200 single node). +ACTOR_NUM_GPUS_PER_NODE="${ACTOR_NUM_GPUS_PER_NODE:-2}" +ROLLOUT_NUM_GPUS="${ROLLOUT_NUM_GPUS:-6}" +ROLLOUT_NUM_GPUS_PER_ENGINE="${ROLLOUT_NUM_GPUS_PER_ENGINE:-1}" +ROLLOUT_BATCH_SIZE="${ROLLOUT_BATCH_SIZE:-4}" +N_SAMPLES_PER_PROMPT="${N_SAMPLES_PER_PROMPT:-16}" +MAX_TOKENS_PER_GPU="${MAX_TOKENS_PER_GPU:-60000}" +SGLANG_CONTEXT_LENGTH="${SGLANG_CONTEXT_LENGTH:-50000}" + +# Ray sizing — derive total GPUs from the actor/rollout split. +RAY_NUM_GPUS="${RAY_NUM_GPUS:-$((ACTOR_NUM_GPUS_PER_NODE + ROLLOUT_NUM_GPUS))}" +RAY_HEAD_IP="${RAY_HEAD_IP:-127.0.0.1}" + +echo "=== Starting Ray on ${RAY_HEAD_IP} (${RAY_NUM_GPUS} GPUs) ===" ray stop --force 2>/dev/null || true sleep 1 -ray start --head --node-ip-address 127.0.0.1 --num-gpus 8 --disable-usage-stats +ray start --head --node-ip-address "$RAY_HEAD_IP" --num-gpus "$RAY_NUM_GPUS" --disable-usage-stats -CUDNN_LIB="${CUDNN_LIB:-${PROJECT_ROOT}/.venv/lib/python3.13/site-packages/nvidia/cudnn/lib}" +# cuDNN lib path — probe the active venv instead of hardcoding python3.13. +if [ -z "${CUDNN_LIB:-}" ]; then + CUDNN_LIB="$("${PYTHON_BIN}" -c 'import nvidia.cudnn, os; print(os.path.join(list(nvidia.cudnn.__path__)[0], "lib"))' 2>/dev/null || true)" +fi RUNTIME_LD_LIBRARY_PATH="${LD_LIBRARY_PATH:-}" -if [ -d "$CUDNN_LIB" ]; then +if [ -n "${CUDNN_LIB}" ] && [ -d "$CUDNN_LIB" ]; then RUNTIME_LD_LIBRARY_PATH="${CUDNN_LIB}:${RUNTIME_LD_LIBRARY_PATH}" fi RUNTIME_ENV_JSON="{ \"env_vars\": { \"PYTHONPATH\": \"${MEGATRON_DIR}:${PROJECT_ROOT}/src\", \"PATH\": \"${PYTHON_BIN_DIR}:${PATH}\", - \"VIRTUAL_ENV\": \"${PROJECT_ROOT}/.venv\", + \"VIRTUAL_ENV\": \"${VIRTUAL_ENV:-${PROJECT_ROOT}/.venv}\", \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", \"WANDB_DIR\": \"${PROJECT_ROOT}/logs\", \"LD_LIBRARY_PATH\": \"${RUNTIME_LD_LIBRARY_PATH}\", - \"PYTORCH_CUDA_ALLOC_CONF\": \"max_split_size_mb:2048,expandable_segments:True\", + \"PYTORCH_ALLOC_CONF\": \"max_split_size_mb:2048,expandable_segments:True\", \"NVTE_DEBUG\": \"1\", \"NVTE_DEBUG_LEVEL\": \"2\" } }" -ACTOR_NUM_GPUS_PER_NODE="${ACTOR_NUM_GPUS_PER_NODE:-2}" -ROLLOUT_NUM_GPUS="${ROLLOUT_NUM_GPUS:-6}" -ROLLOUT_NUM_GPUS_PER_ENGINE="${ROLLOUT_NUM_GPUS_PER_ENGINE:-1}" -ROLLOUT_BATCH_SIZE="${ROLLOUT_BATCH_SIZE:-4}" -N_SAMPLES_PER_PROMPT="${N_SAMPLES_PER_PROMPT:-16}" -MAX_TOKENS_PER_GPU="${MAX_TOKENS_PER_GPU:-60000}" -SGLANG_CONTEXT_LENGTH="${SGLANG_CONTEXT_LENGTH:-50000}" - # Rollout sizing: 4 prompts × 16 trajectories = 64 trajectories/rollout. # This matches the earlier high-util baseline and keeps request groups smaller # so long tails do not collapse usable token throughput. @@ -249,7 +212,7 @@ SGLANG_CONTEXT_LENGTH="${SGLANG_CONTEXT_LENGTH:-50000}" # The custom data source rounds epoch length up to 37 rollout batches, so all # 293 train prompts are consumed once; the final fixed-size batch wraps 3 prompts. echo "=== Launching train_async.py ===" -ray job submit --address="http://127.0.0.1:8265" \ +ray job submit --address="http://${RAY_HEAD_IP}:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- "${PYTHON_BIN}" "${SLIME_DIR}/train_async.py" \ --actor-num-nodes 1 \ @@ -274,7 +237,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --metadata-key metadata \ --rollout-shuffle \ --reward-key score \ - --num-epoch 1 \ + --num-epoch "${NUM_EPOCH:-1}" \ --rollout-batch-size "$ROLLOUT_BATCH_SIZE" \ --n-samples-per-prompt "$N_SAMPLES_PER_PROMPT" \ --rollout-max-response-len 16000 \ diff --git a/examples/swegym_slime_grpo/topology.yaml b/examples/swegym_slime_grpo/topology.yaml index 6b2ea179b..d4dc6fcfc 100644 --- a/examples/swegym_slime_grpo/topology.yaml +++ b/examples/swegym_slime_grpo/topology.yaml @@ -15,6 +15,6 @@ gateway: max_run_workers: 16 max_postrun_workers: 16 model_served: Qwen/Qwen3.5-4B - sglang: - # run.sh/launch_e2e.sh render this to the host-reachable router URL. - base_url: http://127.0.0.1:9000 + inference: + engine: sglang + base_url: ${SGLANG_ROUTER_BASE_URL} diff --git a/src/polar/agent/README.md b/src/polar/agent/README.md index e1a812a30..df2bf4056 100644 --- a/src/polar/agent/README.md +++ b/src/polar/agent/README.md @@ -1,57 +1,128 @@ # Agent Harnesses -`polar.agent` defines how Polar launches an agent inside a prepared runtime. +In Polar, an **agent harness** is whatever launches your agent inside a prepared +runtime. The key idea is that you **do not integrate agents into Polar** — you +run them unmodified. A harness only has to: + +1. start the agent process, and +2. let the agent's LLM calls flow through the gateway proxy. + +Polar injects the proxy endpoints as environment variables +(`OPENAI_BASE_URL`, `ANTHROPIC_BASE_URL`, `GOOGLE_API_URL`, and matching +`*_API_KEY`s set to the session id). The gateway serves the model, rewrites the +request to the served model, and **captures the trajectory** from the wire-level +calls. So the harness never parses transcripts or implements agent logic — that +all lives in the agent. + The public task field is `agent`, validated by `models.AgentSpec`. -## Main Files +## Three ways to run an agent + +| Path | When to use | How | +|---|---|---| +| **Preset** | A popular agent we already ship a launcher for | `agent.harness: ""` | +| **`shell`** | Any agent you can express as a shell command | `agent.harness: "shell"` + `agent.custom_shell` | +| **`import_path`** | Your own harness class, kept in your repo | `agent.import_path: "your.module:YourHarness"` | -- `base.py`: base harness contract. -- `models.py`: `AgentSpec`, `MCPServerSpec`, and `AgentRunResult`. -- `factory.py`: built-in harness lookup and custom import loading. -- `harnesses/`: implementations for `claude_code`, `codex`, `gemini_cli`, - `opencode`, `openhands_sdk`, `pi`, `qwen_code`, and `shell`. +Presets are **conveniences, not integrations** — each is a thin `BaseHarness` +(a few dozen lines) that writes the agent's config and emits its run command. +If your agent isn't listed below, you don't add code to Polar: reach for `shell` +or `import_path`. -## Built-In Harnesses +## Presets API type names match `polar.gateway.detection.APIType`: `anthropic`, -`openai_chat`, `openai_responses`, and `google`. `require_streaming` describes the -request style the harness sends to/from the Polar gateway. Package versions are -verified external CLI/SDK releases for harnesses that need one; examples may -choose their own pins or `latest`. +`openai_chat`, `openai_responses`, and `google`. *Streaming* is the wire style +the agent sends to the proxy. *Version* is the external CLI/SDK release verified +end-to-end by the [calculator example](../../../examples/calculator/README.md); +examples may pin their own, but these are the known-good ones. -| Harness | API type | require_streaming | Package version | +| Preset | API type | Streaming | Verified version | |---|---|---|---| -| `claude_code` | `anthropic` | `true` | `@anthropic-ai/claude-code@2.1.116` | -| `codex` | `openai_responses` | `true` | `@openai/codex@0.122.0` | +| `claude_code` | `anthropic` | `true` | `@anthropic-ai/claude-code@2.1.111` | +| `codex` | `openai_responses` | `true` | `@openai/codex@0.121.0` | | `gemini_cli` | `google` | `true` | `@google/gemini-cli@0.38.1` | -| `opencode` | `openai_chat` | `true` | `opencode-ai@1.14.19` | -| `openhands_sdk` | `openai_chat` | `false` | `openhands-sdk==1.18.0` | +| `opencode` | `openai_chat` | `true` | `opencode-ai@1.4.6` | +| `openclaw` | `openai_chat` | `true` | `openclaw@2026.5.27` | +| `openhands_sdk` | `openai_chat` | `false` | `openhands-sdk==1.17.0` ¹ | +| `hermes` | `openai_chat` | `true` | `hermes-agent==0.15.1` | | `pi` | `openai_chat` | `false` | `@mariozechner/pi-coding-agent@0.67.68` | | `qwen_code` | `openai_chat` | `true` | `@qwen-code/qwen-code@0.14.5` | -| `shell` | chosen by `agent.custom_shell` | chosen by `agent.custom_shell` | chosen by `agent.custom_shell` | +| `shell` | set by `agent.custom_shell` | set by `agent.custom_shell` | — | + +¹ Install `openhands-tools==1.17.0` at the same version. `1.18+` needs Python +3.13 (a transitive `lmnr` pin is unsatisfiable on 3.12); pin `1.17.0` on a +Python 3.12 image. + +Each preset routes to the proxy a little differently because each agent reads a +different env var / config key — e.g. `gemini_cli` maps `GOOGLE_API_*` onto the +CLI's `GEMINI_API_KEY`/`GOOGLE_GEMINI_BASE_URL`; `openclaw` and `hermes` write the +gateway URL into their config files because they don't read `OPENAI_BASE_URL`; +`codex` declares a custom `responses`-wire provider. The per-file comments +explain each piece — that glue is the *only* reason a preset is more than five +lines. + +## The harness contract + +A harness receives the task instruction, a runtime execution helper, the model +name, environment, settings, and optional MCP servers. It returns an +`AgentRunResult` with status `completed`, `failed`, or `timeout`. + +- The harness starts the agent process. +- Polar owns runtime setup, the model proxy endpoints, completion capture, and + evaluation. + +```python +class BaseHarness: + async def setup(self, runtime) -> None: # write config, install nothing heavy + ... + def run_steps(self, instruction) -> list[ExecInput]: # the command(s) to run the agent + ... +``` + +### Anatomy of a preset + +A preset is just those two methods. For example, a CLI agent that already reads +`OPENAI_BASE_URL`/`OPENAI_API_KEY` needs almost nothing: + +```python +class MyAgentHarness(BaseHarness): + def run_steps(self, instruction: str) -> list[ExecInput]: + return [ExecInput(command=f"myagent --yolo -p {shlex.quote(instruction)}")] +``` -`shell` is built in as an escape hatch. Use this for your wrapped agents as execution commands. +`setup()` is where a preset writes a config file (MCP servers, a custom provider +base URL, skills). `run_steps()` returns the shell command(s); the injected proxy +env vars are merged in automatically. Note `setup()` runs *before* the proxy env +is available, so anything that needs `$OPENAI_BASE_URL` must be written inside a +`run_steps()` command (see `openclaw`/`hermes`/`pi`). -## Harness Contract +## Bring your own agent -A harness receives the task instruction, runtime execution helper, model name, -environment, settings, and optional MCP server definitions. It returns an -`AgentRunResult` with `completed`, `failed`, or `timeout`. +You don't need a preset. Two no-Polar-code paths: -Harnesses are responsible for starting the agent process. Polar is responsible -for runtime setup, model proxy endpoints, completion capture, and evaluation. +**`shell`** — wrap any command. Requires `agent.custom_shell`; cannot be combined +with MCP servers or a skills path. -## Adding A Harness +```yaml +agent: + harness: shell + custom_shell: + command: "my-agent run --task {{INSTRUCTION}} 2>&1 | tee $AGENT_LOG_DIR/agent.txt" +``` -Use one of two paths: +**`import_path`** — keep your harness class in your own repo and point at it: -- Add a built-in harness under `harnesses/` and register it in `factory.py`. -- Keep the code outside Polar and set `agent.import_path` in the task. +```yaml +agent: + import_path: "my_pkg.harness:MyAgentHarness" +``` -The import path should resolve to a harness class that follows the base -contract. +The import path must resolve to a `BaseHarness` subclass. -## Shell Harness +## Main files -The `shell` harness is for simple commands or custom wrappers. It requires -`agent.custom_shell` and cannot be combined with MCP servers or skills paths. +- `base.py` — the harness contract. +- `models.py` — `AgentSpec`, `MCPServerSpec`, `AgentRunResult`. +- `factory.py` — preset name lookup and `import_path` loading. +- `presets/` — the ready-made launchers in the table above. diff --git a/src/polar/agent/factory.py b/src/polar/agent/factory.py index 4fd15853e..014fbd7e4 100644 --- a/src/polar/agent/factory.py +++ b/src/polar/agent/factory.py @@ -9,19 +9,23 @@ def _builtin_harness_map() -> dict[str, type[BaseHarness]]: """Lazy import to avoid circular imports at module level.""" - from polar.agent.harnesses.claude_code import ClaudeCodeHarness - from polar.agent.harnesses.codex import CodexHarness - from polar.agent.harnesses.gemini_cli import GeminiCliHarness - from polar.agent.harnesses.openhands_sdk import OpenHandsSdkHarness - from polar.agent.harnesses.opencode import OpenCodeHarness - from polar.agent.harnesses.pi import PiHarness - from polar.agent.harnesses.qwen_code import QwenCodeHarness - from polar.agent.harnesses.shell import ShellHarness + from polar.agent.presets.claude_code import ClaudeCodeHarness + from polar.agent.presets.codex import CodexHarness + from polar.agent.presets.gemini_cli import GeminiCliHarness + from polar.agent.presets.hermes import HermesHarness + from polar.agent.presets.openclaw import OpenClawHarness + from polar.agent.presets.openhands_sdk import OpenHandsSdkHarness + from polar.agent.presets.opencode import OpenCodeHarness + from polar.agent.presets.pi import PiHarness + from polar.agent.presets.qwen_code import QwenCodeHarness + from polar.agent.presets.shell import ShellHarness return { "claude_code": ClaudeCodeHarness, "codex": CodexHarness, "gemini_cli": GeminiCliHarness, + "hermes": HermesHarness, + "openclaw": OpenClawHarness, "openhands_sdk": OpenHandsSdkHarness, "opencode": OpenCodeHarness, "pi": PiHarness, diff --git a/src/polar/agent/harnesses/__init__.py b/src/polar/agent/harnesses/__init__.py deleted file mode 100644 index 30caa8e76..000000000 --- a/src/polar/agent/harnesses/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Built-in agent harness implementations.""" diff --git a/src/polar/agent/presets/__init__.py b/src/polar/agent/presets/__init__.py new file mode 100644 index 000000000..43671c1f1 --- /dev/null +++ b/src/polar/agent/presets/__init__.py @@ -0,0 +1,38 @@ +"""Preset harnesses — ready-made launchers for popular agents. + +These are *conveniences*, not integrations. Polar runs any agent unmodified: +a harness only starts the agent process and lets its LLM calls flow through +the gateway proxy (which serves the model and captures the trajectory). Each +preset here is a thin ``BaseHarness`` that writes the agent's config and emits +its run command — typically a few dozen lines. + +To run an agent that isn't listed here, you do **not** add code to Polar: +- use the ``shell`` preset to wrap any command, or +- point ``agent.import_path`` at your own ``BaseHarness`` subclass. + +See ``polar/agent/README.md`` for the contract and a "bring your own" guide. +""" + +from polar.agent.presets.claude_code import ClaudeCodeHarness +from polar.agent.presets.codex import CodexHarness +from polar.agent.presets.gemini_cli import GeminiCliHarness +from polar.agent.presets.hermes import HermesHarness +from polar.agent.presets.openclaw import OpenClawHarness +from polar.agent.presets.opencode import OpenCodeHarness +from polar.agent.presets.openhands_sdk import OpenHandsSdkHarness +from polar.agent.presets.pi import PiHarness +from polar.agent.presets.qwen_code import QwenCodeHarness +from polar.agent.presets.shell import ShellHarness + +__all__ = [ + "ClaudeCodeHarness", + "CodexHarness", + "GeminiCliHarness", + "HermesHarness", + "OpenClawHarness", + "OpenCodeHarness", + "OpenHandsSdkHarness", + "PiHarness", + "QwenCodeHarness", + "ShellHarness", +] diff --git a/src/polar/agent/harnesses/claude_code.py b/src/polar/agent/presets/claude_code.py similarity index 100% rename from src/polar/agent/harnesses/claude_code.py rename to src/polar/agent/presets/claude_code.py diff --git a/src/polar/agent/harnesses/codex.py b/src/polar/agent/presets/codex.py similarity index 100% rename from src/polar/agent/harnesses/codex.py rename to src/polar/agent/presets/codex.py diff --git a/src/polar/agent/harnesses/gemini_cli.py b/src/polar/agent/presets/gemini_cli.py similarity index 92% rename from src/polar/agent/harnesses/gemini_cli.py rename to src/polar/agent/presets/gemini_cli.py index f38fcf7fc..535bfa9e1 100644 --- a/src/polar/agent/harnesses/gemini_cli.py +++ b/src/polar/agent/presets/gemini_cli.py @@ -67,6 +67,9 @@ def run_steps(self, instruction: str) -> list[ExecInput]: return [ ExecInput( command=( + # The gateway injects GOOGLE_API_KEY / GOOGLE_API_URL; the + # Gemini CLI reads GEMINI_API_KEY / GOOGLE_GEMINI_BASE_URL, + # so map one onto the other to route calls at the proxy. 'export GEMINI_API_KEY="$GOOGLE_API_KEY" ' 'GOOGLE_GEMINI_BASE_URL="$GOOGLE_API_URL" && ' f"gemini {flags_str} --prompt={escaped} " diff --git a/src/polar/agent/presets/hermes.py b/src/polar/agent/presets/hermes.py new file mode 100644 index 000000000..a2d13edec --- /dev/null +++ b/src/polar/agent/presets/hermes.py @@ -0,0 +1,107 @@ +"""Hermes harness — https://github.com/NousResearch/hermes-agent""" + +from __future__ import annotations + +import json +import shlex + +from polar.agent.base import BaseHarness +from polar.runtime.base import BaseRuntime, RUNTIME_AGENT_LOG_DIR +from polar.runtime.models import ExecInput + +# Isolated agent home so Hermes' sessions/skills/memory stay out of the +# workspace git diff. Read via the HERMES_HOME env var at run time. +_HERMES_HOME = "/tmp/hermes" +# A user-defined provider pins the OpenAI-compatible transport + gateway base +# URL explicitly. Hermes' built-in providers and `provider: auto` route by +# models.dev heuristics (and can mis-detect the model vendor), so we bypass them. +_PROVIDER = "polar" +# Substituted with $OPENAI_BASE_URL at exec time (the gateway env is only +# present during run steps, not setup), keeping the config static. +_BASE_URL_PLACEHOLDER = "__POLAR_GATEWAY_BASE_URL__" + + +class HermesHarness(BaseHarness): + """Run NousResearch's Hermes agent non-interactively (``hermes chat -q``). + + Routes through a user-defined ``providers.polar`` entry whose ``base_url`` is + the gateway proxy and whose ``key_env`` is ``OPENAI_API_KEY`` (injected by the + gateway as the session id), forcing the ``openai_chat`` wire format. + """ + + async def setup(self, runtime: BaseRuntime) -> None: + if self.skills_path: + await runtime.exec( + f"mkdir -p {_HERMES_HOME}/skills && " + f"cp -r {shlex.quote(self.skills_path)}/* " + f"{_HERMES_HOME}/skills/ 2>/dev/null || true" + ) + + def run_steps(self, instruction: str) -> list[ExecInput]: + # Bare model id (no provider prefix); the gateway rewrites it anyway. + model_id = (self.model_name or "gpt-5.4").rsplit("/", 1)[-1] + config_json = json.dumps(self._build_config()) + escaped = shlex.quote(instruction) + + # -q: single non-interactive query; -Q: quiet (final response only); + # --yolo: bypass tool-approval prompts. + flags = [ + "--yolo", + "chat", + f"-q {escaped}", + "-Q", + f"--model {shlex.quote(model_id)}", + f"--provider {_PROVIDER}", + ] + toolsets = self.settings.get("toolsets") + if toolsets is not None: + flags.append(f"--toolsets {shlex.quote(str(toolsets))}") + flags_str = " ".join(flags) + + return [ + ExecInput( + command=( + f"mkdir -p {_HERMES_HOME} && " + # base_url placeholder -> $OPENAI_BASE_URL at exec time. + f"printf '%s' {shlex.quote(config_json)} " + f'| sed "s|{_BASE_URL_PLACEHOLDER}|$OPENAI_BASE_URL|g" ' + f"> {_HERMES_HOME}/config.yaml && " + 'export PATH="$HOME/.local/bin:$PATH" && ' + f"hermes {flags_str} " + f"2>&1 | tee {RUNTIME_AGENT_LOG_DIR}/hermes.txt" + ), + env={**self.env, "HERMES_HOME": _HERMES_HOME, "TERMINAL_ENV": "local"}, + ) + ] + + def _build_config(self) -> dict: + config: dict = { + "providers": { + _PROVIDER: { + "base_url": _BASE_URL_PLACEHOLDER, + "key_env": "OPENAI_API_KEY", + "transport": "openai_chat", + } + }, + "toolsets": ["hermes-cli"], + "agent": {"max_turns": int(self.settings.get("max_turns", 90))}, + # Disable the self-improvement loop so runs stay stateless and don't + # write memory/profile files into the agent home. + "memory": {"memory_enabled": False, "user_profile_enabled": False}, + "terminal": {"backend": "local", "timeout": 180}, + "checkpoints": {"enabled": False}, + } + + if self.mcp_servers: + servers: dict[str, dict] = {} + for server in self.mcp_servers: + if server.transport == "stdio": + entry: dict = {"command": server.command} + if server.args: + entry["args"] = server.args + else: + entry = {"url": server.url} + servers[server.name] = entry + config["mcp_servers"] = servers + + return config diff --git a/src/polar/agent/presets/openclaw.py b/src/polar/agent/presets/openclaw.py new file mode 100644 index 000000000..468e5f823 --- /dev/null +++ b/src/polar/agent/presets/openclaw.py @@ -0,0 +1,104 @@ +"""OpenClaw harness — https://github.com/openclaw/openclaw (Node 22.19+)""" + +from __future__ import annotations + +import json +import shlex + +from polar.agent.base import BaseHarness +from polar.runtime.base import BaseRuntime, RUNTIME_AGENT_LOG_DIR +from polar.runtime.models import ExecInput + + +class OpenClawHarness(BaseHarness): + """Run OpenClaw's embedded agent in headless mode (``openclaw agent --local``). + + OpenClaw reads its config from ``~/.openclaw/openclaw.json`` and never picks + up ``OPENAI_BASE_URL`` from the environment, so the gateway endpoint must be + written into ``models.providers.openai.baseUrl``. The provider's API key is + read from ``OPENAI_API_KEY`` (injected by the gateway as the session id). + + ``model_name`` is ``openai/``; everything after the first ``/`` is the + model id registered under the ``openai`` provider. + """ + + _CONFIG_PATH = "$HOME/.openclaw/openclaw.json" + # Placeholder substituted with $OPENAI_BASE_URL at exec time (the gateway + # env is only present during run steps, not setup), keeping the JSON static. + _BASE_URL_PLACEHOLDER = "__POLAR_GATEWAY_BASE_URL__" + + async def setup(self, runtime: BaseRuntime) -> None: + # `setup` creates the baseline config, workspace, and per-agent session + # folders ("main" agent dir) that the embedded `--local` run expects. + await runtime.exec("openclaw setup --workspace . /dev/null || true" + ) + + def run_steps(self, instruction: str) -> list[ExecInput]: + model_id = (self.model_name or "openai/gpt-5.4").split("/", 1)[-1] + config = self._build_config(model_id) + config_json = json.dumps(config) + + # `openclaw agent --local` requires a session target; `setup` registers + # the "main" agent, so default to it. + agent_id = str(self.settings.get("openclaw_agent_id", "main")) + flags = [ + "--local", + "--json", + f"--agent {shlex.quote(agent_id)}", + f"--model openai/{shlex.quote(model_id)}", + ] + thinking = self.settings.get("thinking") + if thinking is not None: + flags.append(f"--thinking {shlex.quote(str(thinking))}") + flags_str = " ".join(flags) + escaped = shlex.quote(instruction) + + return [ + ExecInput( + command=( + "mkdir -p $HOME/.openclaw && " + # baseUrl placeholder -> $OPENAI_BASE_URL at exec time. + f"printf '%s' {shlex.quote(config_json)} " + f'| sed "s|{self._BASE_URL_PLACEHOLDER}|$OPENAI_BASE_URL|g" ' + f"> {self._CONFIG_PATH} && " + f"openclaw agent {flags_str} --message {escaped} " + f"2>&1 dict: + config: dict = { + "agents": {"defaults": {"workspace": "."}}, + "gateway": {"mode": "local"}, + "models": { + "providers": { + "openai": { + "baseUrl": self._BASE_URL_PLACEHOLDER, + "api": "openai-completions", + "models": [{"id": model_id, "name": model_id}], + } + } + }, + } + + if self.mcp_servers: + servers: dict[str, dict] = {} + for server in self.mcp_servers: + if server.transport == "stdio": + entry: dict = {"command": server.command} + if server.args: + entry["args"] = server.args + else: + entry = {"url": server.url, "transport": server.transport} + servers[server.name] = entry + config["mcp"] = {"servers": servers} + + return config diff --git a/src/polar/agent/harnesses/opencode.py b/src/polar/agent/presets/opencode.py similarity index 100% rename from src/polar/agent/harnesses/opencode.py rename to src/polar/agent/presets/opencode.py diff --git a/src/polar/agent/harnesses/openhands_sdk.py b/src/polar/agent/presets/openhands_sdk.py similarity index 95% rename from src/polar/agent/harnesses/openhands_sdk.py rename to src/polar/agent/presets/openhands_sdk.py index facd36f49..9dda1b4d8 100644 --- a/src/polar/agent/harnesses/openhands_sdk.py +++ b/src/polar/agent/presets/openhands_sdk.py @@ -1,4 +1,9 @@ -"""OpenHands SDK harness — lightweight SDK-based agent.""" +"""OpenHands SDK harness — https://github.com/OpenHands/software-agent-sdk + +Unlike the CLI presets, OpenHands ships as a Python SDK, so this preset writes a +small embedded runner script that builds the SDK agent and points its LLM at the +gateway (LLM_BASE_URL=$OPENAI_BASE_URL). Same idea, different launch shape. +""" from __future__ import annotations diff --git a/src/polar/agent/harnesses/pi.py b/src/polar/agent/presets/pi.py similarity index 100% rename from src/polar/agent/harnesses/pi.py rename to src/polar/agent/presets/pi.py diff --git a/src/polar/agent/harnesses/qwen_code.py b/src/polar/agent/presets/qwen_code.py similarity index 100% rename from src/polar/agent/harnesses/qwen_code.py rename to src/polar/agent/presets/qwen_code.py diff --git a/src/polar/agent/harnesses/shell.py b/src/polar/agent/presets/shell.py similarity index 100% rename from src/polar/agent/harnesses/shell.py rename to src/polar/agent/presets/shell.py diff --git a/src/polar/cli.py b/src/polar/cli.py index 4e873da44..fb9c33f40 100644 --- a/src/polar/cli.py +++ b/src/polar/cli.py @@ -260,7 +260,7 @@ def _build_topology_lines( reachable=bool(status), ) gateway_label = _gateway_label(node.id) - inference_label = _endpoint_label("inference", node.sglang_base_url) + inference_label = _endpoint_label(node.engine, node.inference_base_url) lines.append(f"{branch} {gateway_label} [{badge}] ── {inference_label}") return lines diff --git a/src/polar/config/README.md b/src/polar/config/README.md index f2e313684..f1f1d5707 100644 --- a/src/polar/config/README.md +++ b/src/polar/config/README.md @@ -1,29 +1,72 @@ # Configuration -`polar.config` loads one topology file that describes the rollout server and -all gateway nodes. +`polar.config` loads and validates the single `topology.yaml` that describes a +whole Polar deployment: one **rollout** server plus one or more **gateway** +nodes (each with its own inference backend). `TopologyConfig.load()` is the entry +point every `polar` command uses. -## Main Files +## Mental model -- `topology.py`: Pydantic models for rollout and gateway configuration. +One file, two halves: + +- `rollout:` — the central orchestrator that clients submit tasks to. +- `gateway:` — the worker fleet. `gateway.nodes[]` is a list; each entry is an + independent gateway process with its own ports, worker pools, and inference + endpoint. + +The schema is **strict and immutable**: unknown keys are rejected +(`extra="forbid"`, so a typo fails fast) and every model is frozen after load. +Convenience defaulting fills the gaps — a blank `public_url` is derived from +`host:port` (mapping `0.0.0.0`/`::` → `127.0.0.1`), and `gateway.rollout_server_url` +falls back to `rollout.public_url` when omitted. + +## Main files + +- `topology.py`: the Pydantic models (`TopologyConfig`, `RolloutServiceConfig`, + `GatewayConfig`, `GatewayNodeConfig`), `load()`, and the URL/selection helpers. - `__init__.py`: package exports. -## Topology Schema +## Schema + +**`rollout`** — `RolloutServiceConfig` -The top-level fields are: +| field | type | default | +|---|---|---| +| `host` | str | `0.0.0.0` | +| `port` | int | `8080` | +| `public_url` | str | derived from `host:port` | +| `save_dir` | str? | `None` (no result persistence) | +| `dispatch_poll_interval_seconds` | float | `1.0` | +| `callback_grace_seconds` | float | `120.0` | -- `rollout`: host, port, public URL, save directory, dispatch polling, and - callback timing. -- `gateway`: heartbeat interval, optional rollout URL override, and gateway - node list. -- `gateway.nodes[]`: node id, host, port, public URL, served model name, worker - limits, SGLang endpoint, and optional default runtime. +**`gateway`** — `GatewayConfig` -Unknown keys are rejected so removed or misspelled options fail early. +| field | type | default | +|---|---|---| +| `heartbeat_interval_seconds` | int | `30` | +| `rollout_server_url` | str? | `rollout.public_url` | +| `nodes` | list | **required**, ≥1, unique ids | +| `completion_persistence` | block | `enabled` | -## Example Topology +`completion_persistence` controls the async on-disk capture of model calls: +`enabled` (`true`), `max_field_bytes` (`1048576`), `queue_size` (`1024`). -A topology file declares one rollout server and one or more gateway nodes: +**`gateway.nodes[]`** — `GatewayNodeConfig` + +| field | type | default | +|---|---|---| +| `id` | str | hostname (must be unique) | +| `host` / `port` | str / int | `0.0.0.0` / `8081` | +| `public_url` | str | derived from `host:port` | +| `model_served` | str | `""` | +| `inference.engine` | `sglang` \| `vllm` | `sglang` | +| `inference.base_url` | str | `http://127.0.0.1:8000` | +| `max_init_workers` | int | `4` | +| `max_run_workers` | int | `2` | +| `max_postrun_workers` | int | `4` | +| `default_runtime` | `RuntimeSpec`? | `None` | + +## Example ```yaml rollout: @@ -43,23 +86,18 @@ gateway: max_init_workers: 8 max_run_workers: 4 max_postrun_workers: 4 - sglang: + inference: + engine: sglang # or vllm base_url: http://127.0.0.1:8000 ``` -## Public URL Rules - -`public_url` values must be reachable by the caller: - -- The rollout server calls each gateway node's `public_url`. -- Gateway nodes call the rollout server callback URL. -- Each gateway calls its configured `sglang.base_url`. - -When `public_url` is omitted or empty, Polar derives a local URL from host and -port. For multi-node deployments, set explicit reachable URLs. +## Reachable URLs and multi-node -## Multi-Node Selection +`public_url`s must be reachable by whoever calls them: the rollout server calls +each node's `public_url`; each node calls back to `rollout_server_url` and its +own `inference.base_url`. Locally the derived `127.0.0.1` URLs work; for +multi-host deployments set explicit reachable URLs. -`polar serve_gateway` needs `--node-id` when the topology contains more than one -gateway node. This prevents a gateway process from accidentally starting with -the wrong SGLang endpoint or worker limits. +`polar serve_gateway` requires `--node-id` when the topology has more than one +node, so a gateway process always starts with the right ports, worker limits, +and inference endpoint. diff --git a/src/polar/config/topology.py b/src/polar/config/topology.py index f16a84918..ffa3f0826 100644 --- a/src/polar/config/topology.py +++ b/src/polar/config/topology.py @@ -4,7 +4,7 @@ from pathlib import Path import socket -from typing import Any +from typing import Any, Literal from urllib.parse import urlparse import yaml @@ -19,13 +19,14 @@ class _StrictModel(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) -class _SGLangConfig(_StrictModel): +class _InferenceConfig(_StrictModel): + engine: Literal["sglang", "vllm"] = "sglang" base_url: str = "http://127.0.0.1:8000" @field_validator("base_url") @classmethod def _validate_url(cls, value: str) -> str: - return _normalize_http_url(value, "gateway.nodes[].sglang.base_url") + return _normalize_http_url(value, "gateway.nodes[].inference.base_url") class GatewayNodeConfig(_StrictModel): @@ -34,7 +35,7 @@ class GatewayNodeConfig(_StrictModel): port: int = Field(default=8081, ge=1, le=65535) public_url: str model_served: str = "" - sglang: _SGLangConfig = Field(default_factory=_SGLangConfig) + inference: _InferenceConfig = Field(default_factory=_InferenceConfig) max_init_workers: int = Field(default=4, gt=0) max_run_workers: int = Field(default=2, gt=0) max_postrun_workers: int = Field(default=4, gt=0) @@ -68,8 +69,12 @@ def _validate_public_url(cls, value: str) -> str: return _normalize_http_url(value, "gateway.nodes[].public_url") @property - def sglang_base_url(self) -> str: - return self.sglang.base_url + def inference_base_url(self) -> str: + return self.inference.base_url + + @property + def engine(self) -> str: + return self.inference.engine class _CompletionPersistenceConfig(_StrictModel): diff --git a/src/polar/gateway/README.md b/src/polar/gateway/README.md index 78d453eec..5c90f68e4 100644 --- a/src/polar/gateway/README.md +++ b/src/polar/gateway/README.md @@ -1,31 +1,70 @@ # Gateway Service -`polar.gateway` runs sessions on worker hosts. A gateway accepts dispatches from -the rollout server, prepares a runtime, runs an agent harness, proxies model -requests, builds a trajectory, and returns the terminal result. - -## Main Files - -- `server.py`: FastAPI app and gateway endpoints. -- `node.py`: gateway node lifecycle. -- `dispatcher.py`: INIT, READY, RUNNING, and POSTRUN stage orchestration. -- `session.py`: session state and status accounting. -- `proxy.py`: OpenAI-compatible proxy surface used by agent harnesses. -- `storage.py`: completion record storage. -- `detection.py`: request API-family detection. -- `transform/`: request and response transformers. - -## Responsibilities - -- Register with the rollout server and send heartbeats. -- Keep runtime preparation, active generation, and post-run work in separate - worker pools. -- Capture normalized completion records from proxied model calls. -- Build and evaluate trajectories after the agent exits. -- Call back to the rollout server with `SessionResult`. - -## Pause And Resume - -The gateway exposes controls used by training bridges to pause or resume model -generation. This lets a trainer stop new generation while weights are being -updated, then resume when the backend is ready. +`polar.gateway` is the per-worker FastAPI service that runs a session. It accepts +a dispatch from the rollout server, prepares a runtime, runs the agent harness, +**transparently proxies the agent's LLM calls** to a local inference server +(capturing every one), then builds and evaluates a trajectory and reports the +result back. + +## Mental model + +The agent never knows Polar is in the middle. Before running it, the gateway +injects proxy endpoints as environment variables — `OPENAI_BASE_URL`, +`ANTHROPIC_BASE_URL`, `GOOGLE_API_URL`, with the **API key set to the session +id**. The agent thinks it's calling OpenAI/Anthropic/Google, but every request +lands on the gateway's catch-all route, which: + +1. **detects** the API family from the path/headers/body (`detection.py`), +2. **transforms** the request to the served model and adds training fields + (`transform/`), +3. forwards it to the configured inference server (`engine.py` handles + SGLang/vLLM specifics), +4. **captures** the request + response as a completion record, and +5. transforms the response back into the shape the agent expects. + +Streaming is **synthetic**: even when the agent asks for a token stream, the +gateway makes one non-streaming backend call and replays the full answer as +well-formed SSE — simpler, and enough for capture. + +A session moves through staged worker pools: **INIT** (start runtime + run the +prepare recipe) → **READY** (wait for a run slot) → **RUNNING** (harness setup + +run) → **POST-RUN** (build trajectory, evaluate, tear down, call back). Terminal +statuses are `COMPLETED`, `ERROR`, or `TIMEOUT`. + +## Main files + +- `server.py`: the FastAPI app — the catch-all LLM proxy route, the + session/admin/health/events endpoints, and synthetic streaming. +- `node.py`: `GatewayNodeManager` — stage handlers, runtime prepare, trajectory + build + eval, rollout registration/heartbeat, result callback, and the agent + env injection. +- `dispatcher.py`: stage-isolated worker pools and the + INIT→READY→RUNNING→POST-RUN transitions. +- `session.py`: in-memory session registry, id validation, and resolving the + session id from an incoming proxied request. +- `detection.py`: API-family detection (`anthropic` / `openai_chat` / + `openai_responses` / `google`). +- `transform/`: per-API request/response transformers (see + [transform](transform/README.md)). +- `engine.py`: inference-backend strategy (SGLang / vLLM) — injects + token-id/logprob params and canonicalizes responses. +- `proxy.py`: `InferenceClient`, the HTTP client to the inference server, with + pause/resume generation gating. +- `storage.py`: in-memory completion-record store (the authoritative copy). +- `completion_writer.py`: background task that persists completions to disk off + the hot path. + +## What it captures + +Each proxied call is stored as a `CompletionRecord` that keeps both the agent's +original request and the served request, plus the response. Records live in +memory (used to build the trajectory) and, when `gateway.completion_persistence` +is enabled, are also written to +`/task_/sessions//completions/-.json`. + +## Pause and resume + +`POST /admin/inference/pause` and `/resume` gate the gateway's **outbound +generation** (the calls in `InferenceClient`). A training bridge pauses new +generation while it syncs weights, lets in-flight calls drain, then resumes — +this pauses inference, not the gateway process. diff --git a/src/polar/gateway/engine.py b/src/polar/gateway/engine.py new file mode 100644 index 000000000..a5e5aaabf --- /dev/null +++ b/src/polar/gateway/engine.py @@ -0,0 +1,148 @@ +"""Inference backend strategies for the Polar gateway. + +The gateway speaks the OpenAI Chat Completions API to a local inference server. +Two backends are supported, and they differ only in: + + 1. the request params that make them emit the token ids + per-token logprobs + Polar needs for training, and + 2. the exact shape of those fields in the response. + +The base implements the canonical contract -- request ``logprobs`` (the one +training param every backend needs) and a pass-through response. A backend +overrides only what it does differently, via two hooks: ``prepare_request`` +(extra request params) and ``normalize_response`` (response canonicalization). +Everything downstream -- storage, trace builder, transforms, slime adapter -- +then sees one shape. The canonical shape is SGLang's patched output: + + - prompt token ids: ``choice.input_token_ids`` (or ``response.prompt_token_ids``) + - response token ids: ``choice.token_ids`` (or ``logprobs.content[].token_id``) + - per-token logprobs: ``choice.logprobs.content[]`` with ``{token, token_id, logprob, ...}`` + +SGLang reaches this shape via ``scripts/patch/patch_sglang.sh``; vLLM reaches it +natively via the ``return_token_ids`` request flag plus a light response rename. +""" + +from __future__ import annotations + +from abc import ABC +from typing import Any + + +class InferenceEngine(ABC): + """Strategy for one OpenAI-compatible inference backend. + + The base encodes the canonical contract: request ``logprobs`` (the one + training param every backend needs) and pass the response through + unchanged. A backend overrides only what it does differently. + """ + + name: str + + def prepare_request(self, request: dict[str, Any]) -> dict[str, Any]: + """Inject the request params this backend needs to emit training signals. + + ``logprobs`` is universal; subclasses add backend-specific params (e.g. + token-id flags) on top via ``super().prepare_request(...)``. + """ + request["logprobs"] = True + return request + + def normalize_response(self, response: dict[str, Any]) -> dict[str, Any]: + """Canonicalize the backend's response (in place) and return it. + + The canonical shape is SGLang's patched output, so the default is a + pass-through; a backend that differs overrides this. + """ + return response + + +class SGLangEngine(InferenceEngine): + """Canonical backend: it emits Polar's training shape with no per-request or + response adaptation here -- so it inherits the base hooks unchanged. The one + thing SGLang lacks natively is token-id *output* (no request flag exists for + it); ``scripts/patch/patch_sglang.sh`` adds that on the response side. + """ + + name = "sglang" + + +class VLLMEngine(InferenceEngine): + """vLLM via its native OpenAI-compatible server. + + ``return_token_ids`` makes vLLM emit ``response.prompt_token_ids`` and + ``choice.token_ids`` -- the same ids SGLang's patch produces, with no source + patch needed. ``top_logprobs`` must be set (not None) for vLLM to populate + ``logprobs.content[]`` given ``logprobs=True``; 0 returns just the sampled + token's logprob, which is all training needs. + """ + + name = "vllm" + + def prepare_request(self, request: dict[str, Any]) -> dict[str, Any]: + request = super().prepare_request(request) # logprobs=True + request["return_token_ids"] = True + request.setdefault("top_logprobs", 0) + # vLLM reads input reasoning from `reasoning`, not Polar's canonical + # `reasoning_content`; rename it so prior turns' interleaved thinking + # survives templating (else they render an empty ``). + for message in request.get("messages") or []: + if isinstance(message, dict) and message.get("reasoning_content") is not None: + message["reasoning"] = message.pop("reasoning_content") + return request + + def normalize_response(self, response: dict[str, Any]) -> dict[str, Any]: + choices = response.get("choices") + if not isinstance(choices, list): + return response + for choice in choices: + if not isinstance(choice, dict): + continue + self._canonicalize_reasoning(choice.get("message")) + self._stamp_token_ids_onto_logprobs(choice) + return response + + @staticmethod + def _canonicalize_reasoning(message: Any) -> None: + """vLLM names the field ``reasoning``; Polar's canonical field is ``reasoning_content``.""" + if not isinstance(message, dict): + return + if message.get("reasoning_content") is None and message.get("reasoning") is not None: + message["reasoning_content"] = message.pop("reasoning") + + @staticmethod + def _stamp_token_ids_onto_logprobs(choice: dict[str, Any]) -> None: + """Parity with SGLang: copy token_id onto each logprob entry. + + vLLM builds ``logprobs.content`` and ``choice.token_ids`` from the same + ``output.token_ids``, so they align; guard on equal length regardless. + Not load-bearing for training (which reads ``choice.token_ids`` and the + per-entry ``logprob``) -- it keeps stored traces one shape across engines. + """ + token_ids = choice.get("token_ids") + logprobs = choice.get("logprobs") + if not isinstance(token_ids, list) or not isinstance(logprobs, dict): + return + content = logprobs.get("content") + if not isinstance(content, list) or len(content) != len(token_ids): + return + for entry, token_id in zip(content, token_ids): + if isinstance(entry, dict): + entry.setdefault("token_id", token_id) + + +_ENGINES: dict[str, type[InferenceEngine]] = { + SGLangEngine.name: SGLangEngine, + VLLMEngine.name: VLLMEngine, +} + + +def get_engine(name: str) -> InferenceEngine: + """Return the inference engine strategy for ``name`` (``sglang`` | ``vllm``).""" + try: + engine_cls = _ENGINES[name] + except KeyError: + supported = ", ".join(sorted(_ENGINES)) + raise ValueError( + f"Unknown inference engine {name!r}; supported: {supported}" + ) from None + return engine_cls() diff --git a/src/polar/gateway/proxy.py b/src/polar/gateway/proxy.py index badefc3ea..7b7ab3331 100644 --- a/src/polar/gateway/proxy.py +++ b/src/polar/gateway/proxy.py @@ -1,4 +1,9 @@ -"""HTTP client for forwarding requests to SGLang with SSE streaming support.""" +"""HTTP client for forwarding requests to an OpenAI-compatible inference server. + +Backend differences (request params, response shape) are isolated in the +``InferenceEngine`` strategy this client holds; the HTTP/streaming/pause logic +here is backend-agnostic. +""" from __future__ import annotations @@ -9,6 +14,8 @@ import httpx +from polar.gateway.engine import InferenceEngine + logger = logging.getLogger(__name__) @@ -48,19 +55,21 @@ class UpstreamTransportError(UpstreamError): """Raised for connection and transport failures.""" -class SGLangClient: - """Direct httpx client to SGLang's OpenAI-compatible API. +class InferenceClient: + """Direct httpx client to an inference server's OpenAI-compatible API. Per-call bound comes from the session's remaining-timeout budget (`_await_with_budget` at the gateway node). The internal httpx timeout - is a high liveness ceiling so that a stuck SGLang can't pin a request - past the session deadline. + is a high liveness ceiling so that a stuck engine can't pin a request + past the session deadline. The ``engine`` strategy injects backend-specific + request params and canonicalizes responses. """ _LIVENESS_TIMEOUT_SECONDS = 900.0 - def __init__(self, base_url: str): + def __init__(self, base_url: str, engine: InferenceEngine): self.base_url = base_url.rstrip("/") + self.engine = engine self._client: httpx.AsyncClient | None = None self._generation_paused = False self._inflight_generations = 0 @@ -106,10 +115,12 @@ async def completion(self, request: dict[str, Any]) -> dict[str, Any]: """Non-streaming chat completion. Returns the full JSON response.""" await self._acquire_generation_slot() client = await self._get_client() - request_copy = request.copy() + from copy import deepcopy + + request_copy = deepcopy(request) request_copy.pop("stream", None) request_copy["stream"] = False - + request_copy = self.engine.prepare_request(request_copy) try: resp = await client.post( "/v1/chat/completions", @@ -122,7 +133,7 @@ async def completion(self, request: dict[str, Any]) -> dict[str, Any]: await self._release_generation_slot() await self._raise_for_status(resp) - return resp.json() + return self.engine.normalize_response(resp.json()) async def _acquire_generation_slot(self) -> None: async with self._generation_condition: @@ -135,7 +146,7 @@ async def _release_generation_slot(self) -> None: self._generation_condition.notify_all() async def pause_generation(self, *, timeout_seconds: float = 300.0) -> dict[str, Any]: - """Block new generation requests and wait for current SGLang calls to drain.""" + """Block new generation requests and wait for current inference calls to drain.""" async with self._generation_condition: self._generation_paused = True self._generation_condition.notify_all() @@ -156,6 +167,7 @@ def generation_status(self) -> dict[str, Any]: "paused": self._generation_paused, "inflight": self._inflight_generations, "base_url": self.base_url, + "engine": self.engine.name, } async def list_models(self) -> dict[str, Any]: diff --git a/src/polar/gateway/server.py b/src/polar/gateway/server.py index 046504a87..1bc9466ea 100644 --- a/src/polar/gateway/server.py +++ b/src/polar/gateway/server.py @@ -18,9 +18,10 @@ from polar.config import GatewayNodeConfig, TopologyConfig from polar.gateway.completion_writer import CompletionWriter from polar.gateway.detection import APIType, detect, extract_model +from polar.gateway.engine import get_engine from polar.gateway.node import GatewayNodeManager from polar.gateway.proxy import ( - SGLangClient, + InferenceClient, UpstreamError, UpstreamHTTPError, UpstreamTimeoutError, @@ -57,7 +58,7 @@ class GatewayState: topology: TopologyConfig node: GatewayNodeConfig - sglang: SGLangClient + inference: InferenceClient storage: SessionStore transform_manager: TransformManager session_registry: SessionRegistry @@ -80,7 +81,7 @@ def configure_server(topology_path: str = "topology.yaml", *, node_id: str | Non def _build_state(topology: TopologyConfig, node_id: str | None) -> GatewayState: node = topology.select_gateway_node(node_id) - sglang = SGLangClient(node.sglang_base_url) + inference = InferenceClient(node.inference_base_url, get_engine(node.engine)) persistence_config = topology.gateway.completion_persistence save_dir = topology.rollout.save_dir completion_writer = CompletionWriter( @@ -114,7 +115,7 @@ def _build_state(topology: TopologyConfig, node_id: str | None) -> GatewayState: return GatewayState( topology=topology, node=node, - sglang=sglang, + inference=inference, storage=storage, transform_manager=transform_manager, session_registry=session_registry, @@ -187,7 +188,7 @@ async def _lifespan(_: FastAPI): yield finally: await state.node_manager.close() - await state.sglang.close() + await state.inference.close() state.storage.close() await state.completion_writer.close() @@ -399,7 +400,7 @@ def format_stream_output( async def list_models(): state = get_state() try: - return await state.sglang.list_models() + return await state.inference.list_models() except Exception as exc: logger.error("Failed to list models: %s", exc) return JSONResponse({"error": str(exc)}, status_code=502) @@ -410,14 +411,14 @@ async def health(): state = get_state() metrics = await state.node_manager.stage_metrics() try: - upstream = await state.sglang.health() + upstream = await state.inference.health() except Exception as exc: upstream = {"status": "error", "error": str(exc)} return { "status": "ok", "node_id": state.node.id, "gateway_url": state.node.public_url, - "sglang": upstream, + "inference": upstream, "metrics": metrics.model_dump(mode="json"), "active_status_counts": state.session_registry.active_status_counts(), "active_sessions": state.session_registry.active_sessions(), @@ -427,32 +428,32 @@ async def health(): } -@app.get("/admin/sglang/status") -async def sglang_generation_status(): - return get_state().sglang.generation_status() +@app.get("/admin/inference/status") +async def inference_generation_status(): + return get_state().inference.generation_status() -@app.post("/admin/sglang/pause") -async def pause_sglang_generation(timeout_seconds: float = 300.0): +@app.post("/admin/inference/pause") +async def pause_inference_generation(timeout_seconds: float = 300.0): state = get_state() try: - status = await state.sglang.pause_generation(timeout_seconds=timeout_seconds) + status = await state.inference.pause_generation(timeout_seconds=timeout_seconds) except TimeoutError as exc: raise HTTPException( status_code=504, - detail=f"Timed out waiting for SGLang requests to drain after {timeout_seconds}s", + detail=f"Timed out waiting for inference requests to drain after {timeout_seconds}s", ) from exc logger.info( - "Paused SGLang generation proxy for weight update; inflight=%s", + "Paused inference generation proxy for weight update; inflight=%s", status["inflight"], ) return status -@app.post("/admin/sglang/resume") -async def resume_sglang_generation(): - status = await get_state().sglang.resume_generation() - logger.info("Resumed SGLang generation proxy") +@app.post("/admin/inference/resume") +async def resume_inference_generation(): + status = await get_state().inference.resume_generation() + logger.info("Resumed inference generation proxy") return status @@ -679,7 +680,7 @@ async def _handle_non_streaming( ) -> JSONResponse: state = get_state() try: - response = await state.sglang.completion(openai_request) + response = await state.inference.completion(openai_request) except UpstreamError as exc: logger.warning("Non-streaming upstream error for session %s: %s", session_id, exc) return _upstream_error_response(api_type, exc) @@ -714,7 +715,7 @@ async def _handle_streaming( non_stream_request = {k: v for k, v in openai_request.items() if k != "stream_options"} non_stream_request["stream"] = False try: - response = await state.sglang.completion(non_stream_request) + response = await state.inference.completion(non_stream_request) except UpstreamError as exc: logger.warning("Upstream error for streaming session %s: %s", session_id, exc) return _upstream_error_response(api_type, exc) diff --git a/src/polar/gateway/transform/README.md b/src/polar/gateway/transform/README.md index 95944e18b..ce9b6f5be 100644 --- a/src/polar/gateway/transform/README.md +++ b/src/polar/gateway/transform/README.md @@ -1,33 +1,47 @@ # API Transforms -`polar.gateway.transform` keeps agent-facing APIs stable while adding the fields -needed for trainable SGLang completions. - -## Main Files - -- `base.py`: common transformer interface and training request enhancement. -- `openai_chat.py`: OpenAI Chat Completions passthrough and response repair. -- `openai_responses.py`: OpenAI Responses conversion and streaming events. -- `anthropic.py`: Anthropic-style request and response conversion. -- `google.py`: Google-style request and response conversion. -- `__init__.py`: transformer registry by detected API type. - -## Responsibilities - -Request transforms: - -- Preserve the user-requested model for agent compatibility. -- Forward to the served model expected by the gateway when needed. -- Add training fields such as logprobs. -- Normalize API-specific message shapes before proxying. - -Response transforms: - -- Return a shape the agent harness expects. -- Preserve tool calls, text chunks, finish reasons, and streaming events. -- Keep original requested model names where clients depend on them. +`polar.gateway.transform` is the adapter layer inside the proxy. It converts each +intercepted request from its native API (Anthropic / OpenAI Chat / OpenAI +Responses / Google) into **OpenAI Chat Completions** for the served model, then +converts the response back — so every agent sees the API shape it expects while +one backend serves them all. + +## Mental model + +- **One transformer per API type**, selected by `TransformManager` from the + detected `APIType`. OpenAI Chat is near-passthrough; Anthropic / Responses / + Google fully restructure messages, tools, and system prompts. +- The canonical internal format is **OpenAI Chat Completions**. +- Shared request normalization lives in `base.py` (`_normalize_request`): merge + `developer`→`system` roles, drop internal keys, and for Qwen3.5 models force + `enable_thinking=False`. Training-signal params (`logprobs`, token ids) and all + backend-specific request/response handling live in the gateway's `engine.py`, + not here. +- The model swap to the served model happens in the proxy (`server.py`); + transformers carry the requested model through and **restore it on the + response**, so clients still see the name they asked for. + +## Main files + +- `base.py`: the transformer interface + shared training enhancement (role + merge, logprobs, Qwen3.5 thinking-off). +- `openai_chat.py`: near-passthrough (e.g. `max_completion_tokens`→`max_tokens`). +- `openai_responses.py`: OpenAI Responses ↔ Chat, including reasoning items and + shell/function tools (used by Codex). +- `anthropic.py`: Anthropic Messages ↔ Chat — tool_use/tool_result and + Claude-Code header handling. +- `google.py`: Gemini `generateContent` ↔ Chat — functionDeclarations / Call / + Response. +- `images.py`: cross-API image and document content normalization. +- `reasoning.py`: round-trips reasoning content across the per-API thinking / + signature shapes. +- `__init__.py`: the `APIType` → transformer registry. ## Streaming -Streaming transforms operate chunk by chunk. They must preserve event ordering -and emit terminal events that client SDKs expect. +Each non-OpenAI transformer carries stream state so it can emit a correctly +ordered SSE sequence (open block → deltas → close). In practice the gateway +calls the backend once and drives the transformer with a single chunk + +finalize (see the synthetic-streaming note in the +[gateway README](../README.md)), so this machinery turns one complete response +into the event stream a client SDK expects. diff --git a/src/polar/gateway/transform/anthropic.py b/src/polar/gateway/transform/anthropic.py index 1a2a0ef8c..59c2d6791 100644 --- a/src/polar/gateway/transform/anthropic.py +++ b/src/polar/gateway/transform/anthropic.py @@ -17,6 +17,10 @@ anthropic_content_to_openai_chat, openai_chat_content_to_anthropic_blocks, ) +from polar.gateway.transform.reasoning import ( + extract_reasoning_from_anthropic_content, + make_signature, +) # Claude Code SDK leaks `x-anthropic-billing-header: ...cch=;` as the # first line of the system prompt. The cch= hash changes per request, so @@ -52,6 +56,9 @@ def __init__(self, model: str, finish_to_stop_reason: dict[str, str]): self.next_block_index = 0 self.text_block_index: int | None = None self.text_block_started = False + self.thinking_block_index: int | None = None + self.thinking_block_started = False + self.thinking_buffer = "" self.tool_calls: dict[int, _AnthropicToolCallState] = {} self.stop_reason = "end_turn" self.output_tokens = 0 @@ -92,8 +99,25 @@ def process_chunk(self, chunk: dict[str, Any], is_first: bool = False) -> list[d if finish_reason: self.stop_reason = self.finish_to_stop_reason.get(finish_reason, "end_turn") + # Thinking blocks must precede text and tool_use per Anthropic spec. + reasoning = delta.get("reasoning_content") + if reasoning: + if not self.thinking_block_started: + events.append(self._open_thinking_block()) + events.append( + { + "type": "content_block_delta", + "index": self.thinking_block_index, + "delta": {"type": "thinking_delta", "thinking": reasoning}, + } + ) + self.thinking_buffer += reasoning + content = delta.get("content") if content: + thinking_stop = self._close_thinking_block() + if thinking_stop: + events.extend(thinking_stop) if not self.text_block_started: events.append(self._open_text_block()) events.append( @@ -119,6 +143,10 @@ def finalize(self) -> list[dict[str, Any]]: events: list[dict[str, Any]] = [] + thinking_stop = self._close_thinking_block() + if thinking_stop: + events.extend(thinking_stop) + text_stop = self._close_text_block() if text_stop: events.append(text_stop) @@ -176,6 +204,36 @@ def _close_text_block(self) -> dict[str, Any] | None: self.text_block_index = None return event + def _open_thinking_block(self) -> dict[str, Any]: + self.thinking_block_started = True + self.thinking_block_index = self.next_block_index + self.next_block_index += 1 + self.any_block_started = True + return { + "type": "content_block_start", + "index": self.thinking_block_index, + "content_block": {"type": "thinking", "thinking": "", "signature": ""}, + } + + def _close_thinking_block(self) -> list[dict[str, Any]] | None: + if not self.thinking_block_started or self.thinking_block_index is None: + return None + idx = self.thinking_block_index + events = [ + { + "type": "content_block_delta", + "index": idx, + "delta": { + "type": "signature_delta", + "signature": make_signature(self.thinking_buffer), + }, + }, + {"type": "content_block_stop", "index": idx}, + ] + self.thinking_block_started = False + self.thinking_block_index = None + return events + def _process_tool_call(self, tool_call_delta: dict[str, Any]) -> list[dict[str, Any]]: events: list[dict[str, Any]] = [] @@ -208,6 +266,10 @@ def _process_tool_call(self, tool_call_delta: dict[str, Any]) -> list[dict[str, tool_state.buffered_arguments += args_str if tool_state.name and not tool_state.started: + thinking_stop = self._close_thinking_block() + if thinking_stop: + events.extend(thinking_stop) + text_stop = self._close_text_block() if text_stop: events.append(text_stop) @@ -264,7 +326,8 @@ class AnthropicTransformer(BaseTransformer): "stop": "end_turn", "length": "max_tokens", "tool_calls": "tool_use", - "content_filter": "end_turn", + "content_filter": "refusal", + "stop_sequence": "stop_sequence", } def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: @@ -293,16 +356,30 @@ def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: "messages": messages, "max_tokens": body.get("max_tokens", 4096), } + if "model" in body: + result["model"] = body["model"] if "temperature" in body: result["temperature"] = body["temperature"] if "top_p" in body: result["top_p"] = body["top_p"] + if "top_k" in body: + result["top_k"] = body["top_k"] if "stop_sequences" in body: result["stop"] = body["stop_sequences"] if body.get("stream", False): result["stream"] = True + # Anthropic `thinking` request param → enable_thinking on chat template. + thinking_cfg = body.get("thinking") + if isinstance(thinking_cfg, dict) and thinking_cfg.get("type") in { + "enabled", + "adaptive", + }: + chat_template_kwargs = dict(result.get("chat_template_kwargs") or {}) + chat_template_kwargs["enable_thinking"] = True + result["chat_template_kwargs"] = chat_template_kwargs + # Tools. Claude Code sometimes sends tools=[] on compaction/summary # turns; forwarding tool_choice without a non-empty tools list makes # SGLang reject with "tool_choice only allowed when tools specified". @@ -314,7 +391,7 @@ def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: body.get("tool_choice", {"type": "auto"}) ) - return self._enhance_for_training( + return self._normalize_request( result, body.get("_polar_model_served"), ) @@ -332,6 +409,16 @@ def transform_response( message = choice.get("message", {}) content = [] + reasoning = message.get("reasoning_content") + if isinstance(reasoning, str) and reasoning: + content.append( + { + "type": "thinking", + "thinking": reasoning, + "signature": make_signature(reasoning), + } + ) + text = message.get("content") if text or (isinstance(text, list) and text): content.extend(openai_chat_content_to_anthropic_blocks(text)) @@ -351,6 +438,7 @@ def transform_response( finish_reason = choice.get("finish_reason", "stop") stop_reason = self.FINISH_TO_STOP_REASON.get(finish_reason, "end_turn") usage = response.get("usage", {}) + anthropic_usage = self._usage_to_anthropic(usage) if not content: content.append({"type": "text", "text": ""}) @@ -363,10 +451,7 @@ def transform_response( "model": original_request.get("model", "claude-3"), "stop_reason": stop_reason, "stop_sequence": None, - "usage": { - "input_tokens": usage.get("prompt_tokens", 0), - "output_tokens": usage.get("completion_tokens", 0), - }, + "usage": anthropic_usage, } def create_stream_state(self, original_request: dict[str, Any]) -> AnthropicStreamState: @@ -413,6 +498,11 @@ def _transform_message(self, msg: dict[str, Any]) -> Optional[dict | list]: messages = [] + # Assistant `thinking` blocks → reasoning_content (kept for replay). + reasoning_text = "" + if role == "assistant": + reasoning_text = extract_reasoning_from_anthropic_content(content) + # Handle assistant messages with tool_use blocks if role == "assistant" and tool_uses: tool_calls = [] @@ -436,6 +526,8 @@ def _transform_message(self, msg: dict[str, Any]) -> Optional[dict | list]: "role": "assistant", "content": "\n".join(text_parts) if text_parts else None, } + if reasoning_text: + msg_dict["reasoning_content"] = reasoning_text if tool_calls: msg_dict["tool_calls"] = tool_calls return msg_dict @@ -446,11 +538,19 @@ def _transform_message(self, msg: dict[str, Any]) -> Optional[dict | list]: for tr in tool_results: tool_content = tr.get("content", "") converted_content = anthropic_content_to_openai_chat(tool_content) + text_content = self._flatten_content(converted_content) + # Anthropic marks failed tool results with is_error=true. + # Surface this to the model so it can see the call failed + # rather than treating the payload as normal output. + if tr.get("is_error"): + text_content = ( + f"[Tool Error] {text_content}" if text_content else "[Tool Error]" + ) messages.append( { "role": "tool", "tool_call_id": tr.get("tool_use_id", ""), - "content": self._flatten_content(converted_content), + "content": text_content, } ) # OpenAI tool messages stay text-only; images are sent as a @@ -466,7 +566,13 @@ def _transform_message(self, msg: dict[str, Any]) -> Optional[dict | list]: return messages if messages else None # Regular content blocks — keep images when present. - return {"role": role, "content": anthropic_content_to_openai_chat(content)} + result: dict[str, Any] = { + "role": role, + "content": anthropic_content_to_openai_chat(content), + } + if role == "assistant" and reasoning_text: + result["reasoning_content"] = reasoning_text + return result def _flatten_content(self, content: Any) -> str: if isinstance(content, str): @@ -494,11 +600,24 @@ def _image_parts(self, content: Any) -> list[dict[str, Any]]: def _transform_tools_to_openai(self, tools: list[dict]) -> list[dict]: result = [] for tool in tools: + # Anthropic server tools (web_search_*, code_execution_*) carry an + # explicit `type` and have no `input_schema`. SGLang can't dispatch + # them, so drop rather than forwarding a stub function tool. + tool_type = tool.get("type") + if ( + tool_type + and tool_type not in ("custom", "function") + and "input_schema" not in tool + ): + continue + name = tool.get("name") + if not isinstance(name, str) or not name: + continue result.append( { "type": "function", "function": { - "name": tool.get("name", ""), + "name": name, "description": tool.get("description", ""), "parameters": tool.get("input_schema", {}), }, @@ -513,6 +632,8 @@ def _transform_tool_choice_to_openai(self, tool_choice: Any) -> Any: return "auto" elif tc_type == "any": return "required" + elif tc_type == "none": + return "none" elif tc_type == "tool": return { "type": "function", @@ -526,6 +647,32 @@ def _parse_json_safe(self, s: str) -> dict: except (json.JSONDecodeError, TypeError): return {} + def _usage_to_anthropic(self, usage: dict[str, Any]) -> dict[str, Any]: + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) + cache_read = self._cached_prompt_tokens(usage) + input_tokens = max(prompt_tokens - cache_read, 0) if cache_read else prompt_tokens + + result: dict[str, Any] = { + "input_tokens": input_tokens, + "output_tokens": completion_tokens, + } + if cache_read: + result["cache_read_input_tokens"] = cache_read + cache_creation = usage.get("cache_creation_input_tokens") + if isinstance(cache_creation, int) and cache_creation: + result["cache_creation_input_tokens"] = cache_creation + return result + + def _cached_prompt_tokens(self, usage: dict[str, Any]) -> int: + details = usage.get("prompt_tokens_details") + if isinstance(details, dict): + cached = details.get("cached_tokens") + if isinstance(cached, int): + return cached + cached = usage.get("cached_tokens") + return cached if isinstance(cached, int) else 0 + def _error_response(self, message: str) -> dict[str, Any]: return { "type": "error", diff --git a/src/polar/gateway/transform/base.py b/src/polar/gateway/transform/base.py index cfba57e20..91d700fcc 100644 --- a/src/polar/gateway/transform/base.py +++ b/src/polar/gateway/transform/base.py @@ -1,4 +1,4 @@ -"""Base transformer interface with SGLang request enhancement.""" +"""Base transformer interface with inference-backend request enhancement.""" from __future__ import annotations @@ -9,13 +9,13 @@ class BaseTransformer(ABC): """Abstract base class for API transformers. - Transforms requests from source API format to OpenAI format (for SGLang), - and transforms responses back to source API format. + Transforms requests from source API format to OpenAI format (for the + inference backend), and transforms responses back to source API format. """ @abstractmethod def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: - """Transform request body to OpenAI/SGLang format.""" + """Transform request body to OpenAI format for the inference backend.""" pass @abstractmethod @@ -46,13 +46,29 @@ def create_stream_state(self, original_request: dict[str, Any]) -> Any | None: return None @staticmethod - def _is_qwen_model(model_name: str | None) -> bool: + def _is_qwen35_model(model_name: str | None) -> bool: if not model_name: return False - return "qwen" in model_name.lower() + return "qwen3.5" in model_name.lower() @staticmethod - def _merge_developer_role(request: dict[str, Any]) -> dict[str, Any]: + def _content_to_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, dict): + text = block.get("text") + if isinstance(text, str): + parts.append(text) + return "\n".join(parts) + return str(content) if content else "" + + @classmethod + def _merge_developer_role(cls, request: dict[str, Any]) -> dict[str, Any]: """Rename 'developer' role to 'system' and merge all system messages into one.""" messages = request.get("messages") if not isinstance(messages, list): @@ -69,34 +85,39 @@ def _merge_developer_role(request: dict[str, Any]) -> dict[str, Any]: non_system: list[Any] = [] for msg in normalized: if isinstance(msg, dict) and msg.get("role") == "system": - content = msg.get("content", "") - text = content if isinstance(content, str) else str(content) if content else "" + text = cls._content_to_text(msg.get("content", "")) if text: system_parts.append(text) else: non_system.append(msg) - if len(system_parts) > 1: - request["messages"] = [{"role": "system", "content": "\n\n".join(system_parts)}, *non_system] + if system_parts: + request["messages"] = [ + {"role": "system", "content": "\n\n".join(system_parts)}, + *non_system, + ] else: - request["messages"] = normalized + request["messages"] = non_system return request - def _enhance_for_training( + def _normalize_request( self, request: dict[str, Any], model_name: str | None = None, ) -> dict[str, Any]: - """Apply model compatibility fixes and request fields needed for training.""" + """Normalize the OpenAI request: drop internal keys, merge system roles, + and apply per-model template fixes. Training-signal params (logprobs, + token ids) are added later by the inference engine. + """ request.pop("_polar_model_served", None) - if self._is_qwen_model(model_name): - # Qwen chat templates do not support the developer role and need - # to be thinking disabled. - request = self._merge_developer_role(request) + request = self._merge_developer_role(request) + + if self._is_qwen35_model(model_name): + # Qwen3.5 outputs tool calls inside thinking; disable thinking. + # https://www.reddit.com/r/LocalLLaMA/comments/1sccqt2/i_think_i_got_solutions_for_qwen_35_tool_call_in/ chat_template_kwargs = dict(request.get("chat_template_kwargs") or {}) chat_template_kwargs.setdefault("enable_thinking", False) request["chat_template_kwargs"] = chat_template_kwargs - request["logprobs"] = True return request diff --git a/src/polar/gateway/transform/google.py b/src/polar/gateway/transform/google.py index 8da2bdd57..1a2d66ffa 100644 --- a/src/polar/gateway/transform/google.py +++ b/src/polar/gateway/transform/google.py @@ -12,6 +12,10 @@ google_part_to_openai_chat, openai_chat_content_to_google_parts, ) +from polar.gateway.transform.reasoning import ( + extract_reasoning_from_gemini_parts, + make_signature, +) @dataclass @@ -54,6 +58,15 @@ def process_chunk( self._usage = usage parts: list[dict[str, Any]] = [] + reasoning = delta.get("reasoning_content") + if isinstance(reasoning, str) and reasoning: + parts.append( + { + "thought": True, + "text": reasoning, + "thoughtSignature": make_signature(reasoning), + } + ) content = delta.get("content") if isinstance(content, str) and content: parts.append({"text": content}) @@ -161,12 +174,13 @@ def finalize(self) -> list[dict[str, Any]]: class GoogleTransformer(BaseTransformer): """Transform between Google Generative AI and OpenAI API formats.""" - ROLE_MAP = {"user": "user", "model": "assistant"} + ROLE_MAP = {"user": "user", "model": "assistant", "system": "system", "developer": "system"} FINISH_REASON_MAP_REVERSE = { "stop": "STOP", "length": "MAX_TOKENS", "content_filter": "SAFETY", "tool_calls": "STOP", + "stop_sequence": "STOP", } def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: @@ -174,18 +188,22 @@ def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: config = body.get("config") config_section = config if isinstance(config, dict) else {} - system_instruction = body.get("systemInstruction") or config_section.get( - "systemInstruction" + system_instruction = ( + body.get("systemInstruction") + or body.get("system_instruction") + or config_section.get("systemInstruction") + or config_section.get("system_instruction") ) - if isinstance(system_instruction, dict): - system_text = self._extract_text_from_parts(system_instruction.get("parts", [])) - if system_text: - messages.append({"role": "system", "content": system_text}) + system_text = self._extract_system_instruction_text(system_instruction) + if system_text: + messages.append({"role": "system", "content": system_text}) for content in body.get("contents", []): messages.extend(self._convert_content_to_messages(content)) result: dict[str, Any] = {"messages": messages} + if "model" in body: + result["model"] = body["model"] gen_config: dict[str, Any] = {} for source in ( @@ -203,12 +221,35 @@ def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: result["temperature"] = gen_config["temperature"] if "topP" in gen_config: result["top_p"] = gen_config["topP"] + if "topK" in gen_config: + result["top_k"] = gen_config["topK"] if "stopSequences" in gen_config: result["stop"] = gen_config["stopSequences"] + if "candidateCount" in gen_config: + result["n"] = gen_config["candidateCount"] + if "presencePenalty" in gen_config: + result["presence_penalty"] = gen_config["presencePenalty"] + if "frequencyPenalty" in gen_config: + result["frequency_penalty"] = gen_config["frequencyPenalty"] + if "seed" in gen_config: + result["seed"] = gen_config["seed"] + if "logprobs" in gen_config: + result["top_logprobs"] = gen_config["logprobs"] + + response_format = self._convert_response_format(gen_config) + if response_format is not None: + result["response_format"] = response_format if body.get("_streaming", False): result["stream"] = True + # Gemini `thinkingConfig.includeThoughts: true` → enable_thinking. + thinking_cfg = gen_config.get("thinkingConfig") or config_section.get("thinkingConfig") + if isinstance(thinking_cfg, dict) and thinking_cfg.get("includeThoughts"): + chat_template_kwargs = dict(result.get("chat_template_kwargs") or {}) + chat_template_kwargs["enable_thinking"] = True + result["chat_template_kwargs"] = chat_template_kwargs + # SGLang rejects tool_choice without a non-empty tools list; bind # the pair so one can't be forwarded without the other. tools = self._convert_tools( @@ -222,11 +263,72 @@ def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: if tool_choice is not None: result["tool_choice"] = tool_choice - return self._enhance_for_training( + return self._normalize_request( result, body.get("_polar_model_served"), ) + def _convert_response_format(self, gen_config: dict[str, Any]) -> dict[str, Any] | None: + response_format_cfg = gen_config.get("responseFormat") + if isinstance(response_format_cfg, dict): + text_format = response_format_cfg.get("text") + if isinstance(text_format, dict): + mime_type = text_format.get("mimeType") + if self._is_json_mime_type(mime_type): + schema = text_format.get("schema") + return self._chat_response_format_from_schema(schema) + + mime_type = gen_config.get("responseMimeType") + if not self._is_json_mime_type(mime_type): + return None + + schema = ( + gen_config.get("responseJsonSchema") + or gen_config.get("_responseJsonSchema") + or gen_config.get("responseSchema") + ) + return self._chat_response_format_from_schema(schema) + + @staticmethod + def _is_json_mime_type(value: Any) -> bool: + if not isinstance(value, str): + return False + return value.lower() == "application/json" or value.upper() == "APPLICATION_JSON" + + def _chat_response_format_from_schema(self, schema: Any) -> dict[str, Any]: + if isinstance(schema, dict) and schema: + return { + "type": "json_schema", + "json_schema": { + "name": "google_response", + "schema": self._normalize_google_schema(schema), + }, + } + return {"type": "json_object"} + + def _normalize_google_schema(self, schema: dict[str, Any]) -> dict[str, Any]: + normalized: dict[str, Any] = {} + for key, value in schema.items(): + if key == "type" and isinstance(value, str): + normalized[key] = value.lower() + elif key == "properties" and isinstance(value, dict): + normalized[key] = { + prop_name: self._normalize_google_schema(prop_schema) + if isinstance(prop_schema, dict) + else prop_schema + for prop_name, prop_schema in value.items() + } + elif key == "items" and isinstance(value, dict): + normalized[key] = self._normalize_google_schema(value) + elif key in {"anyOf", "oneOf", "allOf"} and isinstance(value, list): + normalized[key] = [ + self._normalize_google_schema(item) if isinstance(item, dict) else item + for item in value + ] + else: + normalized[key] = value + return normalized + def transform_response( self, response: dict[str, Any], @@ -238,6 +340,15 @@ def transform_response( for i, choice in enumerate(response.get("choices", [])): message = choice.get("message", {}) parts = [] + reasoning = message.get("reasoning_content") + if isinstance(reasoning, str) and reasoning: + parts.append( + { + "thought": True, + "text": reasoning, + "thoughtSignature": make_signature(reasoning), + } + ) content = message.get("content") if content or isinstance(content, list): parts.extend(openai_chat_content_to_google_parts(content)) @@ -256,13 +367,17 @@ def transform_response( ) usage = response.get("usage", {}) + usage_metadata = { + "promptTokenCount": usage.get("prompt_tokens", 0), + "candidatesTokenCount": usage.get("completion_tokens", 0), + "totalTokenCount": usage.get("total_tokens", 0), + } + cached_tokens = self._cached_prompt_tokens(usage) + if cached_tokens: + usage_metadata["cachedContentTokenCount"] = cached_tokens result = { "candidates": candidates, - "usageMetadata": { - "promptTokenCount": usage.get("prompt_tokens", 0), - "candidatesTokenCount": usage.get("completion_tokens", 0), - "totalTokenCount": usage.get("total_tokens", 0), - }, + "usageMetadata": usage_metadata, } function_calls = self._response_function_calls(candidates) if function_calls: @@ -281,6 +396,15 @@ def transform_stream_chunk( for choice in chunk.get("choices", []): delta = choice.get("delta", {}) or {} parts = [] + reasoning_chunk = delta.get("reasoning_content") + if isinstance(reasoning_chunk, str) and reasoning_chunk: + parts.append( + { + "thought": True, + "text": reasoning_chunk, + "thoughtSignature": make_signature(reasoning_chunk), + } + ) content = delta.get("content") if content: parts.append({"text": content}) @@ -379,14 +503,21 @@ def _convert_content_to_messages(self, content: Any) -> list[dict[str, Any]]: if openai_role == "assistant": message_content = google_content_parts_to_openai_chat(parts) - if message_content or tool_calls: + reasoning_text = extract_reasoning_from_gemini_parts(parts) + if message_content or tool_calls or reasoning_text: assistant_message: dict[str, Any] = { "role": "assistant", "content": message_content, } + if reasoning_text: + assistant_message["reasoning_content"] = reasoning_text if tool_calls: assistant_message["tool_calls"] = tool_calls messages.append(assistant_message) + elif openai_role == "system": + system_text = self._extract_text_from_parts(parts) + if system_text: + messages.append({"role": "system", "content": system_text}) else: if user_parts: messages.append({ @@ -438,9 +569,9 @@ def _convert_tool_choice(self, tool_config: Any) -> Any | None: return None mode = str(function_calling_config.get("mode", "")).upper() - allowed_names = function_calling_config.get("allowedFunctionNames") or function_calling_config.get( - "allowed_function_names" - ) + allowed_names = function_calling_config.get( + "allowedFunctionNames" + ) or function_calling_config.get("allowed_function_names") if mode == "NONE": return "none" if mode in {"ANY", "VALIDATED"}: @@ -549,3 +680,21 @@ def _extract_text_from_parts(self, parts: list) -> str: elif isinstance(part, str): texts.append(part) return "\n".join(texts) + + def _extract_system_instruction_text(self, system_instruction: Any) -> str: + if isinstance(system_instruction, str): + return system_instruction + if isinstance(system_instruction, dict): + return self._extract_text_from_parts(system_instruction.get("parts", [])) + if isinstance(system_instruction, list): + return self._extract_text_from_parts(system_instruction) + return "" + + def _cached_prompt_tokens(self, usage: dict[str, Any]) -> int: + details = usage.get("prompt_tokens_details") + if isinstance(details, dict): + cached = details.get("cached_tokens") + if isinstance(cached, int): + return cached + cached = usage.get("cached_tokens") + return cached if isinstance(cached, int) else 0 diff --git a/src/polar/gateway/transform/images.py b/src/polar/gateway/transform/images.py index 5c7848dec..87675194f 100644 --- a/src/polar/gateway/transform/images.py +++ b/src/polar/gateway/transform/images.py @@ -28,9 +28,15 @@ def parse_data_url(url: str) -> tuple[str, str] | None: return mime_type, match.group("data") +# OpenAI's image_url.detail only accepts these values; vLLM rejects anything +# else. Harnesses send their own (e.g. codex's "original"), so drop unknowns +# rather than forward a value that 400s the whole image request. +_VALID_IMAGE_DETAILS = frozenset({"auto", "low", "high"}) + + def openai_image_url_block(url: str, *, detail: Any = None) -> dict[str, Any]: image_url: dict[str, Any] = {"url": url} - if isinstance(detail, str) and detail: + if isinstance(detail, str) and detail in _VALID_IMAGE_DETAILS: image_url["detail"] = detail return {"type": "image_url", "image_url": image_url} @@ -140,10 +146,42 @@ def anthropic_content_to_openai_chat(content: Any) -> str | list[dict[str, Any]] image = anthropic_image_to_openai_chat(block) if image: parts.append(image) + continue + if block_type == "document": + text = anthropic_document_to_text(block) + if text: + parts.append(openai_text_block(text)) return openai_content_from_text_and_images(parts) +def anthropic_document_to_text(block: dict[str, Any]) -> str: + """Extract text from an Anthropic `document` block. + + Handles `source.type == "text"` and `source.type == "content"`. Base64 + PDFs are dropped — SGLang can't render binary docs through the chat + template. + """ + source = block.get("source") + if not isinstance(source, dict): + return "" + source_type = source.get("type") + if source_type == "text": + data = source.get("data") + return data if isinstance(data, str) else "" + if source_type == "content": + inner = source.get("content") + if isinstance(inner, list): + pieces: list[str] = [] + for inner_block in inner: + if isinstance(inner_block, dict) and inner_block.get("type") == "text": + text = inner_block.get("text") + if isinstance(text, str): + pieces.append(text) + return "\n".join(pieces) + return "" + + def anthropic_image_to_openai_chat(block: dict[str, Any]) -> dict[str, Any] | None: source = block.get("source") if not isinstance(source, dict): @@ -174,6 +212,10 @@ def google_content_parts_to_openai_chat(parts: Any) -> str | list[dict[str, Any] if not isinstance(part, dict): continue + # Thought parts are reasoning_content, not user-visible content. + if part.get("thought") is True: + continue + text = part.get("text") if isinstance(text, str): openai_parts.append(openai_text_block(text)) diff --git a/src/polar/gateway/transform/openai_chat.py b/src/polar/gateway/transform/openai_chat.py index a3eff12e6..b3ada9089 100644 --- a/src/polar/gateway/transform/openai_chat.py +++ b/src/polar/gateway/transform/openai_chat.py @@ -12,7 +12,9 @@ class OpenAIChatTransformer(BaseTransformer): def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: result = body.copy() - return self._enhance_for_training( + if "max_tokens" not in result and "max_completion_tokens" in result: + result["max_tokens"] = result["max_completion_tokens"] + return self._normalize_request( result, body.get("_polar_model_served"), ) diff --git a/src/polar/gateway/transform/openai_responses.py b/src/polar/gateway/transform/openai_responses.py index 84bdebdf4..434473be4 100644 --- a/src/polar/gateway/transform/openai_responses.py +++ b/src/polar/gateway/transform/openai_responses.py @@ -14,6 +14,10 @@ from polar.gateway.transform.base import BaseTransformer from polar.gateway.transform.images import openai_responses_input_content_to_chat +from polar.gateway.transform.reasoning import ( + encrypt_reasoning, + extract_reasoning_from_responses_item, +) @dataclass @@ -33,9 +37,14 @@ def __init__(self, model: str): self.model = model self.text_started = False self.text_content = "" + self.message_output_index = 0 self.output_index_offset = 0 self.tool_calls: dict[int, _ResponsesToolCallState] = {} self.usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + self.reasoning_started = False + self.reasoning_closed = False + self.reasoning_content = "" + self.reasoning_id = "" self.completed = False def process_chunk(self, chunk: dict[str, Any], is_first: bool = False) -> list[dict[str, Any]]: @@ -71,16 +80,59 @@ def process_chunk(self, chunk: dict[str, Any], is_first: bool = False) -> list[d choice = choices[0] delta = choice.get("delta", {}) or {} + # Reasoning item must come first so the harness sees the chain-of-thought + # before any output_text or function_call items. + reasoning_delta = delta.get("reasoning_content") + if isinstance(reasoning_delta, str) and reasoning_delta: + if not self.reasoning_started: + self.reasoning_started = True + self.reasoning_id = f"rs_{uuid.uuid4().hex[:24]}" + events.append( + { + "type": "response.output_item.added", + "output_index": 0, + "item": { + "type": "reasoning", + "id": self.reasoning_id, + "summary": [], + "content": [], + "status": "in_progress", + }, + } + ) + events.append( + { + "type": "response.reasoning_summary_part.added", + "item_id": self.reasoning_id, + "output_index": 0, + "summary_index": 0, + "part": {"type": "summary_text", "text": ""}, + } + ) + self.reasoning_content += reasoning_delta + events.append( + { + "type": "response.reasoning_summary_text.delta", + "item_id": self.reasoning_id, + "output_index": 0, + "summary_index": 0, + "delta": reasoning_delta, + } + ) + content = delta.get("content") if content: + # Close reasoning before opening message. + events.extend(self._close_reasoning()) if not self.text_started: self.text_started = True - self.output_index_offset = 1 + self.message_output_index = 1 if self.reasoning_started else 0 + self.output_index_offset = self.message_output_index + 1 message_id = f"msg_{uuid.uuid4().hex[:24]}" events.append( { "type": "response.output_item.added", - "output_index": 0, + "output_index": self.message_output_index, "item": { "type": "message", "id": message_id, @@ -93,7 +145,7 @@ def process_chunk(self, chunk: dict[str, Any], is_first: bool = False) -> list[d events.append( { "type": "response.content_part.added", - "output_index": 0, + "output_index": self.message_output_index, "content_index": 0, "part": {"type": "output_text", "text": ""}, } @@ -103,7 +155,7 @@ def process_chunk(self, chunk: dict[str, Any], is_first: bool = False) -> list[d events.append( { "type": "response.output_text.delta", - "output_index": 0, + "output_index": self.message_output_index, "content_index": 0, "delta": content, } @@ -113,6 +165,12 @@ def process_chunk(self, chunk: dict[str, Any], is_first: bool = False) -> list[d if not isinstance(tool_calls_delta, list): tool_calls_delta = [tool_calls_delta] + if tool_calls_delta and self.reasoning_started and not self.reasoning_closed: + events.extend(self._close_reasoning()) + if not self.text_started: + # No text — tools come immediately after reasoning. + self.output_index_offset = 1 + for tool_call in tool_calls_delta: if not isinstance(tool_call, dict): continue @@ -188,11 +246,14 @@ def finalize(self) -> list[dict[str, Any]]: events: list[dict[str, Any]] = [] + # Close reasoning if it never got closed by content/tools. + events.extend(self._close_reasoning()) + if self.text_started: events.append( { "type": "response.content_part.done", - "output_index": 0, + "output_index": self.message_output_index, "content_index": 0, "part": {"type": "output_text", "text": self.text_content}, } @@ -200,7 +261,7 @@ def finalize(self) -> list[dict[str, Any]]: events.append( { "type": "response.output_item.done", - "output_index": 0, + "output_index": self.message_output_index, "item": { "type": "message", "role": "assistant", @@ -239,6 +300,17 @@ def finalize(self) -> list[dict[str, Any]]: ) output: list[dict[str, Any]] = [] + if self.reasoning_started: + output.append( + { + "type": "reasoning", + "id": self.reasoning_id, + "summary": [{"type": "summary_text", "text": self.reasoning_content}], + "content": [{"type": "reasoning_text", "text": self.reasoning_content}], + "encrypted_content": encrypt_reasoning(self.reasoning_content), + "status": "completed", + } + ) if self.text_started: output.append( { @@ -279,6 +351,39 @@ def finalize(self) -> list[dict[str, Any]]: self.completed = True return events + def _close_reasoning(self) -> list[dict[str, Any]]: + if not self.reasoning_started or self.reasoning_closed: + return [] + self.reasoning_closed = True + return [ + { + "type": "response.reasoning_summary_text.done", + "item_id": self.reasoning_id, + "output_index": 0, + "summary_index": 0, + "text": self.reasoning_content, + }, + { + "type": "response.reasoning_summary_part.done", + "item_id": self.reasoning_id, + "output_index": 0, + "summary_index": 0, + "part": {"type": "summary_text", "text": self.reasoning_content}, + }, + { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "type": "reasoning", + "id": self.reasoning_id, + "summary": [{"type": "summary_text", "text": self.reasoning_content}], + "content": [{"type": "reasoning_text", "text": self.reasoning_content}], + "encrypted_content": encrypt_reasoning(self.reasoning_content), + "status": "completed", + }, + }, + ] + class OpenAIResponsesTransformer(BaseTransformer): """Transform OpenAI Responses API to/from SGLang chat completions.""" @@ -297,6 +402,8 @@ def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: messages.extend(self._convert_input_items_to_messages(input_data)) result: dict[str, Any] = {"messages": messages} + if "model" in body: + result["model"] = body["model"] if "max_tokens" in body: result["max_tokens"] = body["max_tokens"] @@ -306,22 +413,72 @@ def transform_request(self, body: dict[str, Any]) -> dict[str, Any]: result["temperature"] = body["temperature"] if "top_p" in body: result["top_p"] = body["top_p"] + if "top_logprobs" in body: + result["top_logprobs"] = body["top_logprobs"] + if "parallel_tool_calls" in body: + result["parallel_tool_calls"] = body["parallel_tool_calls"] if "stream" in body: result["stream"] = body["stream"] + text_cfg = body.get("text") + if isinstance(text_cfg, dict): + response_format = self._response_format_from_text_config(text_cfg) + if response_format is not None: + result["response_format"] = response_format + + # Responses `reasoning` request param → enable_thinking. + reasoning_cfg = body.get("reasoning") + if isinstance(reasoning_cfg, dict) and self._reasoning_config_enables_thinking( + reasoning_cfg + ): + chat_template_kwargs = dict(result.get("chat_template_kwargs") or {}) + chat_template_kwargs["enable_thinking"] = True + result["chat_template_kwargs"] = chat_template_kwargs + # SGLang rejects tool_choice without a non-empty tools list; bind # the pair so one can't be forwarded without the other. tools = self._convert_tools(body.get("tools", [])) if tools: result["tools"] = tools if "tool_choice" in body: - result["tool_choice"] = body["tool_choice"] + result["tool_choice"] = self._tool_choice_to_openai_chat( + body["tool_choice"] + ) - return self._enhance_for_training( + return self._normalize_request( result, body.get("_polar_model_served"), ) + def _response_format_from_text_config( + self, + text_cfg: dict[str, Any], + ) -> dict[str, Any] | None: + format_cfg = text_cfg.get("format") + if not isinstance(format_cfg, dict): + return None + + format_type = format_cfg.get("type") + if format_type == "text": + return None + if format_type == "json_object": + return {"type": "json_object"} + if format_type != "json_schema": + return None + + json_schema = format_cfg.get("json_schema") + if isinstance(json_schema, dict): + return {"type": "json_schema", "json_schema": json_schema} + + converted = { + key: format_cfg[key] + for key in ("name", "description", "schema", "strict") + if key in format_cfg + } + if not converted: + return None + return {"type": "json_schema", "json_schema": converted} + def transform_response( self, response: dict[str, Any], @@ -336,6 +493,19 @@ def transform_response( output_items: list[dict[str, Any]] = [] + reasoning = message.get("reasoning_content") + if isinstance(reasoning, str) and reasoning: + output_items.append( + { + "type": "reasoning", + "id": f"rs_{uuid.uuid4().hex[:24]}", + "summary": [{"type": "summary_text", "text": reasoning}], + "content": [{"type": "reasoning_text", "text": reasoning}], + "encrypted_content": encrypt_reasoning(reasoning), + "status": "completed", + } + ) + content = message.get("content") if content: output_items.append( @@ -351,14 +521,7 @@ def transform_response( func = tc.get("function", {}) name = func.get("name", "") if name in ("shell", "execute", "run_command"): - output_items.append( - { - "type": "local_shell_call", - "call_id": tc.get("id", ""), - "status": "completed", - "action": {"type": "execute", "command": func.get("arguments", "{}")}, - } - ) + output_items.append(self._local_shell_call_from_tool_call(tc)) else: output_items.append( { @@ -372,6 +535,14 @@ def transform_response( ) usage = response.get("usage", {}) + response_usage = { + "input_tokens": usage.get("prompt_tokens", 0), + "output_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + } + cached_tokens = self._cached_prompt_tokens(usage) + if cached_tokens: + response_usage["input_tokens_details"] = {"cached_tokens": cached_tokens} return { "id": response.get("id", f"resp_{uuid.uuid4().hex}"), "object": "response", @@ -379,11 +550,7 @@ def transform_response( "status": "completed", "model": original_request.get("model", response.get("model", "unknown")), "output": output_items, - "usage": { - "input_tokens": usage.get("prompt_tokens", 0), - "output_tokens": usage.get("completion_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), - }, + "usage": response_usage, } def create_stream_state(self, original_request: dict[str, Any]) -> ResponsesStreamState: @@ -413,17 +580,48 @@ def _convert_input_items_to_messages( pending_tool_calls: list[dict[str, Any]] = [] pending_tool_outputs: list[dict[str, Any]] = [] pending_input_content: list[dict[str, Any]] = [] + pending_reasoning: str = "" for item in items: item_type = item.get("type") + if item_type == "reasoning": + # A new reasoning item starts a new turn block. If the prior + # block already has its function_call_output, flush it now so + # this reasoning attaches to the NEXT function_call, not the + # previous one. (Otherwise codex's per-fc reasoning gets + # accumulated and dumped onto the wrong assistant message, + # breaking the prefix_merging chain.) + if pending_tool_outputs: + messages.extend( + self._flush_tool_block( + pending_tool_calls, + pending_tool_outputs, + pending_reasoning, + ) + ) + pending_tool_calls = [] + pending_tool_outputs = [] + pending_reasoning = "" + reasoning_text = extract_reasoning_from_responses_item(item) + if reasoning_text: + pending_reasoning = ( + f"{pending_reasoning}\n{reasoning_text}" + if pending_reasoning + else reasoning_text + ) + continue + if item_type in {"input_text", "input_image"}: if pending_tool_calls or pending_tool_outputs: messages.extend( - self._flush_tool_block(pending_tool_calls, pending_tool_outputs) + self._flush_tool_block( + pending_tool_calls, pending_tool_outputs, pending_reasoning + ) ) pending_tool_calls = [] pending_tool_outputs = [] + pending_reasoning = "" pending_input_content.append(item) continue @@ -433,14 +631,21 @@ def _convert_input_items_to_messages( pending_input_content = [] if pending_tool_calls or pending_tool_outputs: messages.extend( - self._flush_tool_block(pending_tool_calls, pending_tool_outputs) + self._flush_tool_block( + pending_tool_calls, pending_tool_outputs, pending_reasoning + ) ) pending_tool_calls = [] pending_tool_outputs = [] + pending_reasoning = "" role = item.get("role", "user") content = openai_responses_input_content_to_chat(item.get("content", "")) - messages.append({"role": role, "content": content}) + msg: dict[str, Any] = {"role": role, "content": content} + if role == "assistant" and pending_reasoning: + msg["reasoning_content"] = pending_reasoning + pending_reasoning = "" + messages.append(msg) elif item_type == "function_call": if pending_input_content: @@ -451,10 +656,12 @@ def _convert_input_items_to_messages( self._flush_tool_block( pending_tool_calls, pending_tool_outputs, + pending_reasoning, ) ) pending_tool_calls = [] pending_tool_outputs = [] + pending_reasoning = "" pending_tool_calls.append( { "id": item.get("call_id", f"call_{uuid.uuid4().hex[:24]}"), @@ -466,12 +673,35 @@ def _convert_input_items_to_messages( } ) + elif item_type in {"local_shell_call", "shell_call"}: + if pending_input_content: + messages.extend(self._flush_input_content(pending_input_content)) + pending_input_content = [] + if pending_tool_outputs: + messages.extend( + self._flush_tool_block( + pending_tool_calls, + pending_tool_outputs, + pending_reasoning, + ) + ) + pending_tool_calls = [] + pending_tool_outputs = [] + pending_reasoning = "" + pending_tool_calls.append(self._local_shell_call_to_tool_call(item)) + elif item_type == "function_call_output": if pending_input_content: messages.extend(self._flush_input_content(pending_input_content)) pending_input_content = [] pending_tool_outputs.extend(self._function_call_output_messages(item)) + elif item_type in {"local_shell_call_output", "shell_call_output"}: + if pending_input_content: + messages.extend(self._flush_input_content(pending_input_content)) + pending_input_content = [] + pending_tool_outputs.extend(self._local_shell_output_messages(item)) + else: if pending_input_content: messages.extend(self._flush_input_content(pending_input_content)) @@ -481,10 +711,12 @@ def _convert_input_items_to_messages( self._flush_tool_block( pending_tool_calls, pending_tool_outputs, + pending_reasoning, ) ) pending_tool_calls = [] pending_tool_outputs = [] + pending_reasoning = "" converted = self._convert_response_item_to_message(item) if isinstance(converted, list): messages.extend(converted) @@ -495,7 +727,22 @@ def _convert_input_items_to_messages( messages.extend(self._flush_input_content(pending_input_content)) if pending_tool_calls or pending_tool_outputs: - messages.extend(self._flush_tool_block(pending_tool_calls, pending_tool_outputs)) + messages.extend( + self._flush_tool_block( + pending_tool_calls, pending_tool_outputs, pending_reasoning + ) + ) + pending_reasoning = "" + + # Trailing reasoning with no following assistant message: synthesize one. + if pending_reasoning: + messages.append( + { + "role": "assistant", + "content": None, + "reasoning_content": pending_reasoning, + } + ) return messages @@ -515,6 +762,94 @@ def _function_call_output_messages(self, item: dict[str, Any]) -> list[dict[str, messages.append({"role": "user", "content": image_parts}) return messages + def _local_shell_call_to_tool_call(self, item: dict[str, Any]) -> dict[str, Any]: + return { + "id": item.get("call_id") or item.get("id") or f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": { + "name": "shell", + "arguments": self._local_shell_action_to_arguments(item.get("action")), + }, + } + + def _local_shell_output_messages(self, item: dict[str, Any]) -> list[dict[str, Any]]: + call_id = item.get("call_id") or item.get("id") or "" + return self._function_call_output_messages( + {"call_id": call_id, "output": item.get("output", "")} + ) + + def _local_shell_action_to_arguments(self, action: Any) -> str: + if isinstance(action, str): + return action + if not isinstance(action, dict): + return "{}" + + command = action.get("command") + if isinstance(command, str): + stripped = command.strip() + if stripped.startswith(("{", "[")): + try: + json.loads(stripped) + return stripped + except json.JSONDecodeError: + pass + return json.dumps({"cmd": command}) + + commands = action.get("commands") + if isinstance(commands, list): + command_values = [cmd for cmd in commands if isinstance(cmd, str)] + args: dict[str, Any] + if len(command_values) == 1: + args = {"cmd": command_values[0]} + else: + args = {"commands": command_values} + for key in ("timeout_ms", "max_output_length"): + if key in action: + args[key] = action[key] + return json.dumps(args) + + args = {key: value for key, value in action.items() if key != "type"} + return json.dumps(args) if args else "{}" + + def _local_shell_call_from_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]: + function = tool_call.get("function", {}) + arguments = function.get("arguments", "{}") if isinstance(function, dict) else "{}" + call_id = tool_call.get("id", "") + return { + "type": "local_shell_call", + "id": f"lsh_{uuid.uuid4().hex[:24]}", + "call_id": call_id, + "status": "completed", + "action": self._local_shell_action_from_arguments(arguments), + } + + def _local_shell_action_from_arguments(self, arguments: Any) -> dict[str, Any]: + parsed: Any = None + if isinstance(arguments, str): + try: + parsed = json.loads(arguments) + except json.JSONDecodeError: + parsed = None + elif isinstance(arguments, dict): + parsed = arguments + + if isinstance(parsed, dict): + commands = parsed.get("commands") + if isinstance(commands, list): + action = {"commands": [cmd for cmd in commands if isinstance(cmd, str)]} + else: + command = parsed.get("cmd") or parsed.get("command") + action = {"commands": [command]} if isinstance(command, str) else {} + for key in ("timeout_ms", "max_output_length"): + if key in parsed: + action[key] = parsed[key] + if action.get("commands"): + return action + + if isinstance(arguments, str) and arguments: + return {"commands": [arguments]} + return {"commands": []} + def _function_call_output_content(self, output: Any) -> Any: if isinstance(output, dict): if self._is_responses_content_block(output): @@ -566,16 +901,18 @@ def _flush_tool_block( self, tool_calls: list[dict[str, Any]], tool_outputs: list[dict[str, Any]], + reasoning: str = "", ) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] if tool_calls: - messages.append( - { - "role": "assistant", - "content": None, - "tool_calls": list(tool_calls), - } - ) + assistant_msg: dict[str, Any] = { + "role": "assistant", + "content": None, + "tool_calls": list(tool_calls), + } + if reasoning: + assistant_msg["reasoning_content"] = reasoning + messages.append(assistant_msg) messages.extend(tool_outputs) return messages @@ -614,6 +951,39 @@ def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: converted.append({"type": "function", "function": tool["function"]}) continue + tool_type = tool.get("type") + if tool_type in {"shell", "local_shell"}: + converted.append( + { + "type": "function", + "function": { + "name": "shell", + "description": tool.get( + "description", "Run shell commands in the local workspace." + ), + "parameters": { + "type": "object", + "properties": { + "cmd": {"type": "string"}, + "commands": { + "type": "array", + "items": {"type": "string"}, + }, + "timeout_ms": {"type": "number"}, + "max_output_length": {"type": "number"}, + }, + }, + }, + } + ) + continue + + # Drop server-side tool types Polar can't dispatch (web_search, + # file_search, computer_use, mcp, code_interpreter, image_generation, + # custom, etc.). Only client-side functions/shell are convertible. + if tool_type and tool_type != "function": + continue + name = tool.get("name") or tool.get("id", "") if not name: continue @@ -638,6 +1008,44 @@ def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: return converted + def _tool_choice_to_openai_chat(self, tool_choice: Any) -> Any: + if isinstance(tool_choice, str): + if tool_choice == "shell": + return {"type": "function", "function": {"name": "shell"}} + return tool_choice + + if not isinstance(tool_choice, dict): + return tool_choice + + choice_type = tool_choice.get("type") + if choice_type == "function": + function = tool_choice.get("function") + if isinstance(function, dict): + return tool_choice + name = tool_choice.get("name") + if isinstance(name, str) and name: + return {"type": "function", "function": {"name": name}} + if choice_type in {"shell", "local_shell"}: + return {"type": "function", "function": {"name": "shell"}} + return tool_choice + + def _reasoning_config_enables_thinking(self, reasoning_cfg: dict[str, Any]) -> bool: + if not reasoning_cfg: + return False + effort = reasoning_cfg.get("effort") + if isinstance(effort, str) and effort.lower() == "none": + return False + return True + + def _cached_prompt_tokens(self, usage: dict[str, Any]) -> int: + details = usage.get("prompt_tokens_details") + if isinstance(details, dict): + cached = details.get("cached_tokens") + if isinstance(cached, int): + return cached + cached = usage.get("cached_tokens") + return cached if isinstance(cached, int) else 0 + def _make_error_response(self, message: str) -> dict[str, Any]: return { "type": "response.failed", diff --git a/src/polar/gateway/transform/reasoning.py b/src/polar/gateway/transform/reasoning.py new file mode 100644 index 000000000..040d67af5 --- /dev/null +++ b/src/polar/gateway/transform/reasoning.py @@ -0,0 +1,110 @@ +"""Reasoning round-trip helpers shared across API transformers. + +SGLang's `--reasoning-parser` (split mode: `qwen3`, `minimax`, `deepseek-r1`, +etc.) splits the model's chain-of-thought into the assistant message's +`reasoning_content` field. This module helps each transform convert that +field to / from the API-specific reasoning shape: + +- Anthropic: `thinking` content block with `thinking` + `signature` +- Gemini: part with `thought: true`, `text`, `thoughtSignature` +- Responses: `reasoning` output item with `summary`, `content`, `encrypted_content` +- OAI Chat: `reasoning_content` field on the assistant message (passthrough) + +Signatures and `encrypted_content` only need to round-trip opaquely through +the harness (the gateway is the API server on both ends), so we use +deterministic synthetic tokens — no real cryptography is necessary. +""" + +from __future__ import annotations + +import base64 +import hashlib +from typing import Any + + +def make_signature(reasoning_text: str) -> str: + """Deterministic synthetic signature for an Anthropic/Gemini thought block.""" + if not reasoning_text: + return "" + digest = hashlib.sha256(reasoning_text.encode("utf-8")).digest() + return "sg_polar_" + base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=") + + +def encrypt_reasoning(reasoning_text: str) -> str: + """Pack reasoning into Responses-style `encrypted_content`. + + Base64-encoded so it survives transport. Decoded by `decrypt_reasoning` + when the harness replays it on the next turn. + """ + if not reasoning_text: + return "" + return "polar:" + base64.urlsafe_b64encode(reasoning_text.encode("utf-8")).decode("ascii") + + +def decrypt_reasoning(encrypted: str | None) -> str: + """Reverse of `encrypt_reasoning`. Returns empty string on any failure.""" + if not isinstance(encrypted, str) or not encrypted.startswith("polar:"): + return "" + try: + return base64.urlsafe_b64decode(encrypted[len("polar:") :].encode("ascii")).decode("utf-8") + except Exception: + return "" + + +def extract_reasoning_from_anthropic_content(content: Any) -> str: + """Extract reasoning_content from Anthropic assistant content blocks.""" + if not isinstance(content, list): + return "" + parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") == "thinking": + text = block.get("thinking", "") + if isinstance(text, str) and text: + parts.append(text) + return "\n".join(parts) + + +def extract_reasoning_from_gemini_parts(parts: Any) -> str: + """Extract reasoning_content from Gemini content parts (thought:true).""" + if not isinstance(parts, list): + return "" + pieces: list[str] = [] + for part in parts: + if not isinstance(part, dict): + continue + if part.get("thought") is True: + text = part.get("text", "") + if isinstance(text, str) and text: + pieces.append(text) + return "\n".join(pieces) + + +def extract_reasoning_from_responses_item(item: dict[str, Any]) -> str: + """Extract reasoning_content text from a Responses `reasoning` input item. + + Prefers `content[*].text` (full chain), falls back to `summary[*].text`, + finally tries `encrypted_content` (decoded by `decrypt_reasoning`). + """ + content = item.get("content") + if isinstance(content, list): + chunks = [ + b.get("text", "") + for b in content + if isinstance(b, dict) and isinstance(b.get("text"), str) + ] + joined = "\n".join(c for c in chunks if c) + if joined: + return joined + summary = item.get("summary") + if isinstance(summary, list): + chunks = [ + b.get("text", "") + for b in summary + if isinstance(b, dict) and isinstance(b.get("text"), str) + ] + joined = "\n".join(c for c in chunks if c) + if joined: + return joined + return decrypt_reasoning(item.get("encrypted_content")) diff --git a/src/polar/platform/README.md b/src/polar/platform/README.md index 51300e125..a8426309e 100644 --- a/src/polar/platform/README.md +++ b/src/polar/platform/README.md @@ -1,35 +1,62 @@ # Polar Dashboard -A read-only observability dashboard for the Polar rollout stack. It bundles a -React SPA and exposes a single FastAPI service that proxies (read-only) to the -rollout and gateway processes. +A read-only **observability dashboard** for a running Polar stack: a React SPA +plus a single FastAPI service that proxies (read-only) to the rollout server and +gateway nodes and reads finished tasks straight off disk. -The launch command (`polar dashboard -c topology.yaml [--port 8090]`) is -documented in the [top-level README](../../../README.md#cli-interface); -defaults: bind `127.0.0.1:8090`, rollout URL and `save_dir` pulled from -topology. +Launch it with `polar dashboard -c topology.yaml [--port 8090]` (see the +[top-level README](../../../README.md#cli-interface)). Defaults: binds +`127.0.0.1:8090`; the rollout URL and `save_dir` come from the topology (override +with `--rollout-url` / `--save-dir`). + +## Mental model + +The dashboard reads from **two sources and merges them**: + +- **Live HTTP** to the rollout server and each gateway (current tasks, sessions, + completions) via `UpstreamClient`. +- **On disk** under `save_dir` (finished tasks/sessions) via `FsIndex`, which + scans `task_*/` every couple of seconds. + +For live updates it runs an **SSE fan-out**: `SseFanout` opens one `/events` +stream per upstream (rollout + each gateway), tags each event with its source, +and republishes onto an in-process `EventBus` that the browser consumes at +`GET /api/events`. Everything is read-only; the only state-changing call is the +cancel proxy. + +## Main files + +- `cli.py`: registers the `dashboard` subcommand and starts the server. +- `config.py`: resolves host/port/rollout-url/save-dir from the topology + CLI + overrides. +- `server.py`: builds the FastAPI app, starts the upstream clients + fs poller + + SSE fan-outs, and serves the bundled SPA. +- `api/`: the HTTP route handlers (`topology`, `tasks`, `sessions`, `events`) — + where the `/api/*` endpoints below actually live. +- `fs_index.py`: the on-disk index of `task_*` dirs (`TaskSummary` / + `SessionSummary`). +- `sse_fanout.py` + `events.py`: upstream `/events` subscription and the + in-process event bus. +- `upstream.py`: the async httpx client used to read rollout + gateways. ## Frontend build -Production build (required for the wheel to ship a real UI; missing `web/dist/` -falls back to a small JSON placeholder): +The wheel ships a real UI only if `web/dist/` is built; without it the service +serves a small JSON placeholder. -``` +```bash cd web && npm install && npm run build # writes web/dist/ ``` -Dev loop with hot reload — runs at and proxies -`/api/*` to `http://127.0.0.1:8090/`: +Dev loop with hot reload (runs at , proxies `/api/*` to +`:8090`): -``` +```bash cd web && npm install && npm run dev ``` ## API surface (under `/api`) -The dashboard is read-only. The only state-changing endpoint is the cancel -proxy, so a running session can be aborted from the Session detail page. - | Path | Method | Purpose | | --- | --- | --- | | `/api/health` | GET | Service health + upstream reachability | @@ -38,27 +65,25 @@ proxy, so a running session can be aborted from the Session detail page. | `/api/tasks/{id}` | GET | Single task + session summaries | | `/api/sessions/{id}` | GET | Session detail | | `/api/sessions/{id}/trajectory` | GET | Built trajectory traces | -| `/api/sessions/{id}/completions` | GET | Completion records (gateway then disk) | +| `/api/sessions/{id}/completions` | GET | Completion records (gateway, then disk) | | `/api/sessions/{id}/evaluation` | GET | Evaluator outcome / strategy | | `/api/sessions/{id}/raw` | GET | Raw on-disk session payload | | `/api/sessions/{id}` | DELETE | Cancel a running session | | `/api/events` | GET (SSE) | Fan-out of rollout + gateway events | -## Polar-side additions (read-only) +## Read-only endpoints it depends on -The dashboard depends on a small set of read-only endpoints added to the other -Polar services: +The dashboard relies on a few read-only endpoints in the other services: - Rollout: `GET /tasks`, `GET /tasks/{id}/sessions`, `GET /events` (SSE). - Gateway: `GET /sessions`, `GET /sessions/{id}/completions`, `GET /events` (SSE). -- Gateway: completion records persist to - `/task_/sessions//completions/-.json` - via the `CompletionWriter` background task. Controlled by - `gateway.completion_persistence` in topology.yaml. +- Gateway completion records persist to + `/task_/sessions//completions/-.json` via the + `CompletionWriter` background task, controlled by + `gateway.completion_persistence` in `topology.yaml`. ## Task submission -Submission stays in the existing channels — `polar submit`, the example -scripts under `examples//`, or any client that posts to the rollout -server's `POST /rollout/task/submit`. The dashboard surfaces tasks as soon as -they appear in the rollout's memory or in `/`. +Submission stays in the usual channels — `polar submit`, the example scripts, or +any client posting to `POST /rollout/task/submit`. The dashboard surfaces a task +as soon as it appears in the rollout's memory or in `/`. diff --git a/src/polar/platform/api/topology.py b/src/polar/platform/api/topology.py index 616825383..14afe6e7c 100644 --- a/src/polar/platform/api/topology.py +++ b/src/polar/platform/api/topology.py @@ -1,4 +1,4 @@ -"""Live topology view: rollout + gateways + SGLang + worker pools.""" +"""Live topology view: rollout + gateways + inference engine + worker pools.""" from __future__ import annotations @@ -17,7 +17,8 @@ def _node_view(node) -> dict[str, Any]: "port": node.port, "gateway_url": node.public_url, "model_served": node.model_served, - "sglang_base_url": node.sglang_base_url, + "engine": node.engine, + "inference_base_url": node.inference_base_url, "max_init_workers": node.max_init_workers, "max_run_workers": node.max_run_workers, "max_postrun_workers": node.max_postrun_workers, diff --git a/src/polar/platform/fs_index.py b/src/polar/platform/fs_index.py index 188458279..92f3e57fb 100644 --- a/src/polar/platform/fs_index.py +++ b/src/polar/platform/fs_index.py @@ -29,6 +29,8 @@ class TaskSummary: completed_sessions: int = 0 errored_sessions: int = 0 mean_reward: float | None = None + mean_traces: float | None = None + mean_completions: float | None = None created_at: float | None = None updated_at: float | None = None save_dir_path: str = "" @@ -45,6 +47,8 @@ def to_dict(self) -> dict[str, Any]: "completed_sessions": self.completed_sessions, "errored_sessions": self.errored_sessions, "mean_reward": self.mean_reward, + "mean_traces": self.mean_traces, + "mean_completions": self.mean_completions, "created_at": self.created_at, "updated_at": self.updated_at, "save_dir_path": self.save_dir_path, @@ -184,6 +188,8 @@ def _scan_task_dir(task_dir: Path) -> TaskSummary | None: summary.num_samples = int(index.get("num_samples") or 0) summary.completed_sessions = int(index.get("completed_sessions") or 0) summary.mean_reward = index.get("mean_reward") + summary.mean_traces = index.get("mean_traces") + summary.mean_completions = index.get("mean_completions") if index.get("created_at") is not None: try: summary.created_at = float(index["created_at"]) @@ -198,6 +204,8 @@ def _scan_task_dir(task_dir: Path) -> TaskSummary | None: # No index file — fall back to scanning each session file. rewards: list[float] = [] + trace_counts: list[int] = [] + completion_counts: list[int] = [] completed = 0 errored = 0 earliest = summary.created_at @@ -216,6 +224,11 @@ def _scan_task_dir(task_dir: Path) -> TaskSummary | None: reward = _trace_reward(data) if reward is not None: rewards.append(reward) + traj = data.get("trajectory") or {} + traces = traj.get("traces") or [] + trace_counts.append(len(traces)) + record_count = (traj.get("metadata") or {}).get("record_count") + completion_counts.append(int(record_count) if isinstance(record_count, int) else len(traces)) try: mtime = session_path.stat().st_mtime ctime = session_path.stat().st_ctime @@ -236,6 +249,10 @@ def _scan_task_dir(task_dir: Path) -> TaskSummary | None: summary.completed_sessions = completed summary.errored_sessions = errored summary.mean_reward = sum(rewards) / len(rewards) if rewards else None + summary.mean_traces = sum(trace_counts) / len(trace_counts) if trace_counts else None + summary.mean_completions = ( + sum(completion_counts) / len(completion_counts) if completion_counts else None + ) summary.created_at = earliest summary.updated_at = latest if completed + errored == len(session_files) and session_files: diff --git a/src/polar/rollout/README.md b/src/polar/rollout/README.md index e264bf7b2..e544d92e5 100644 --- a/src/polar/rollout/README.md +++ b/src/polar/rollout/README.md @@ -1,71 +1,91 @@ # Rollout Service -`polar.rollout` owns task submission, gateway scheduling, callback collection, -status reporting, and result persistence. +`polar.rollout` is the **central orchestrator**. A client (a trainer, +`polar submit`, or an example script) posts a task here; the rollout server fans +it out into one session per sample, schedules each onto a healthy gateway node, +collects the terminal result, and optionally persists it. -## Main Files +## Mental model -- `server.py`: FastAPI app for task, callback, node, health, and status APIs. -- `manager.py`: task lifecycle and session expansion. -- `pipeline.py`: dispatch to gateways, wait for callbacks, poll fallback, and - result persistence. -- `balancer.py`: node health, pressure tracking, and scheduling. -- `models.py`: public task, session, node, heartbeat, and result models. -- `timer.py`: per-stage timing helpers. +A task becomes N sessions, each placed on a gateway: + +1. A `TaskRequest` is submitted — `POST /rollout/task/submit` returns + immediately with `{task_id, status: "running"}`; you poll or get a callback. +2. The manager creates one session per `num_samples`. +3. The scheduler picks a healthy, non-draining gateway with spare capacity. +4. The pipeline dispatches the session and waits for the gateway's callback, + **interleaving a poll** of the gateway as a live safety net (covers a dropped + callback, or a status flip before the payload is serialized). +5. The terminal `SessionResult` is recorded; if `TaskRequest.callback_url` is + set, the aggregate `TaskResult` is POSTed back to the trainer. +6. Results persist under `rollout.save_dir` when configured + (`save_dir/task_/ses_.json`). + +The `timeout_seconds` budget starts when a session enters **INIT**, not at +dispatch — time spent queued on a busy gateway (`REGISTERED`) isn't charged +against the agent's wall-clock. -## Lifecycle +## Main files -1. `TaskRequest` is accepted by the rollout server. -2. The manager creates one session per requested sample. -3. The scheduler picks a healthy non-draining gateway with available capacity. -4. The pipeline dispatches the session and waits for a callback. -5. Missing callbacks can fall back to gateway polling during the grace window. -6. Results are persisted when `rollout.save_dir` is configured. +- `server.py`: FastAPI app — task submit/get/status, node register/heartbeat/ + list/drain, the session-result callback, the `/tasks` + `/events` + observability routes, and `/health`. +- `manager.py`: task lifecycle, session expansion, and the trainer-facing + terminal callback. +- `pipeline.py`: dispatch to a gateway, wait (callback + interleaved poll), + persist. +- `balancer.py`: `NodeScheduler` — registration, heartbeat health, pressure-based + scheduling, draining. +- `models.py`: the public `TaskRequest`, `SessionResult`, `TaskResult`, + `TaskStatus`, and node/heartbeat models. +- `timer.py`: per-stage timing helpers. + +## HTTP endpoints -## Task Request Shape +| Method | Path | Purpose | +|---|---|---| +| POST | `/rollout/task/submit` | Submit a task (returns immediately) | +| GET | `/rollout/task/{task_id}` | Task status / progress | +| GET | `/rollout/status` | Service + fleet summary | +| POST | `/nodes/register`, `/nodes/{id}/heartbeat` | Gateway registration + heartbeat | +| GET / DELETE | `/nodes`, `/nodes/{id}` | List / inspect / drain nodes | +| POST | `/callbacks/session_result` | Gateway → rollout result callback | +| GET | `/tasks`, `/tasks/{id}/sessions`, `/events` | Observability (used by the dashboard) | +| GET | `/health` | Health check | -A minimal task looks like: +## Task request shape ```json { "task_id": "example-task-001", - "instruction": "Write a calculator and save it as calculator.py", + "instruction": "Implement calculator.py and make the tests pass.", "num_samples": 8, "timeout_seconds": 900, - "runtime": { - "backend": "docker", - "image": "polar-localhost-calculator:latest", - "workdir": "/polar/session/workspace", - "network": "host" - }, - "agent": { - "harness": "codex", - "model_name": "openai/gpt-5.4" - }, - "builder": { - "strategy": "prefix_merging" - }, - "evaluator": { - "strategy": "test_on_output", - "config": { - "test_command": "python3 test_calculator.py && echo 'PASSED test_calculator'", - "expected_output_json": {"test_calculator": "PASSED"} - } - } + "runtime": {"backend": "docker", "image": "...", "workdir": "/polar/session/workspace", "network": "host"}, + "agent": {"harness": "codex", "model_name": "openai/gpt-5.4"}, + "builder": {"strategy": "prefix_merging"}, + "evaluator": {"strategy": "test_on_output", "config": {"test_command": "...", "expected_output_json": {"test_calculator": "PASSED"}}}, + "callback_url": "http://trainer:9000/done", + "metadata": {"group_id": "g1", "rollout_step": 42} } ``` -Task fields: - -- `runtime` selects Docker or Apptainer and describes the sandbox image. -- `agent` selects a built-in harness or a custom import path. -- `builder` selects how completion records become trajectories. -- `evaluator` attaches rewards after the agent run. -- `metadata` can carry training fields such as group id, rollout step, or policy - version. +- `runtime`: the Docker/Apptainer sandbox (see [runtime](../runtime/README.md)). + Optional — a node's `default_runtime` is used if omitted. +- `agent`: a preset harness, the generic `shell` harness, or a custom + `import_path` (see [agent](../agent/README.md)). +- `builder` / `evaluator`: how completions become trajectories and how reward is + attached (see [trajectory](../trajectory/README.md)). +- `callback_url` (optional): where the final `TaskResult` is POSTed. +- `metadata` (optional): free-form key/values carried through onto the + trajectory — convenient for training fields like group id or policy version + (it's an unvalidated pass-through dict, not a fixed schema). ## Scheduling -Scheduling prefers healthy nodes with lower run pressure, then lower init -pressure. Nodes become ineligible when stale, draining, or blocked by post-run -backlog. +The scheduler prefers the least-loaded healthy node, comparing pressures in this +order — **run → post-run → init** — with a ready-buffer gap and total pressure as +tiebreakers. A node is eligible only while it is **healthy** (a heartbeat within +`heartbeat_interval × 2.5`), **not draining**, and **under its admission and +post-run-backlog limits**. A draining node stops receiving work and is removed +once its in-flight sessions finish. diff --git a/src/polar/rollout/manager.py b/src/polar/rollout/manager.py index 32ad8cd37..1d8b7bae3 100644 --- a/src/polar/rollout/manager.py +++ b/src/polar/rollout/manager.py @@ -71,6 +71,25 @@ def _mean_reward(results: list[SessionResult]) -> float | None: return sum(rewards) / len(rewards) +def _mean_traces(results: list[SessionResult]) -> float | None: + """Average number of traces per session for the task.""" + if not results: + return None + counts = [len(r.trajectory.traces) for r in results] + return sum(counts) / len(counts) + + +def _mean_completions(results: list[SessionResult]) -> float | None: + """Average number of raw completions (LLM requests) per session.""" + if not results: + return None + counts = [ + int(r.trajectory.metadata.get("record_count") or len(r.trajectory.traces)) + for r in results + ] + return sum(counts) / len(counts) + + class RolloutManager: """Manage the lifecycle of rollout sessions for a single submitted task.""" @@ -265,6 +284,8 @@ def list_tasks(self) -> list[dict[str, Any]]: "completed_sessions": record.completed_sessions, "errored_sessions": record.errored_sessions, "mean_reward": _mean_reward(record.results), + "mean_traces": _mean_traces(record.results), + "mean_completions": _mean_completions(record.results), "created_at": record.created_at, "updated_at": record.updated_at, "source": "live", diff --git a/src/polar/runtime/README.md b/src/polar/runtime/README.md index 6d89cdcfa..5cdc5e9f9 100644 --- a/src/polar/runtime/README.md +++ b/src/polar/runtime/README.md @@ -1,33 +1,57 @@ # Runtime Backends -`polar.runtime` provides isolated sandboxes for agent sessions. - -## Main Files - -- `models.py`: `RuntimeSpec`, `PrepareAction`, `ExecInput`, and `ExecResult`. -- `base.py`: runtime backend contract. -- `docker.py`: Docker runtime implementation. -- `apptainer.py`: Apptainer runtime implementation. -- `factory.py`: backend lookup. - -## Runtime Contract - -A runtime backend prepares files and directories, executes commands, exposes a -workspace, and cleans up after the session. It should hide container-specific -details from agent harnesses and evaluators. - -## Prepare Steps - -`RuntimeSpec.prepare` and `RuntimeSpec.eval_prepare` accept ordered actions: - -- `upload_file`: copy one host file into the runtime. -- `upload_dir`: copy one host directory into the runtime. -- `exec`: run a command inside the runtime. - -Prepare steps run before the agent. Eval-prepare steps run before evaluation -when an evaluator needs extra setup. - -## Docker And Apptainer - -Docker is the default backend for local examples. Apptainer is supported for -clusters where container execution must avoid Docker daemon access. +`polar.runtime` gives each rollout session its own **sandbox** — one container +(Docker or Apptainer) that lives for the whole session. The gateway uses it to +run the prepare recipe, execute the agent and evaluator commands, move files in +and out, then tear it down. + +## Mental model + +- **One `RuntimeSpec` → one container**, shared across the init → run → eval + stages of a session. +- The host session directory is **bind-mounted** to a fixed in-container path, + `/polar/session` (`RUNTIME_SESSION_DIR`). Uploads/downloads under that path + are plain host-side file copies (fast); paths outside it fall back to + `docker cp` / `tar` streaming. +- Commands run in a login shell (`bash -lc`) with working directory + `cwd or spec.workdir or /polar/session`. +- The factory verifies the chosen backend actually supports what the spec asks + for (GPUs, CPU/memory limits, internet-off) before building it. + +## Main files + +- `models.py`: `RuntimeSpec`, `PrepareAction`, `ExecInput`, `ExecResult`. +- `base.py`: the `BaseRuntime` contract, the `/polar/session` path constants, and + the bind-mount copy helpers. +- `docker.py`: `DockerRuntime` — the default backend. +- `apptainer.py`: `ApptainerRuntime` — daemonless, for clusters. +- `factory.py`: backend lookup + capability validation; also loads a custom + backend via `RuntimeSpec.import_path`. + +## The contract + +A backend implements `start`, `stop`, `exec`, `upload_file`, `upload_dir`, +`download_file`, `download_dir` (plus `cancel`), hiding container details from +harnesses and evaluators. Well-known in-container paths (from `base.py`) are +`/polar/session` and, under it, `artifacts/`, `logs/`, `logs/agent/`, +`logs/eval/`, and `eval_artifacts/`. + +## Prepare recipe + +`RuntimeSpec.prepare` and `RuntimeSpec.eval_prepare` are ordered lists of +`PrepareAction` steps: + +- `upload_file`: copy one host file in. +- `upload_dir`: copy one host directory in. +- `exec`: run a command inside the container. + +`prepare` runs before the agent. `eval_prepare` runs before evaluation — and if +it's omitted, the eval runtime simply replays `prepare`. + +## Docker vs Apptainer + +Docker is the default for local examples and supports `--cpus` / `--memory` +limits. Apptainer is daemonless (good for clusters that forbid the Docker +socket), uses a host-backed overlay, and exposes GPUs with `--nv`. Both +bind-mount the session directory and run commands via `bash -lc`, so harnesses +and evaluators behave the same on either. diff --git a/src/polar/trajectory/README.md b/src/polar/trajectory/README.md index 5d6411154..9332f8034 100644 --- a/src/polar/trajectory/README.md +++ b/src/polar/trajectory/README.md @@ -1,30 +1,57 @@ # Trajectories -`polar.trajectory` defines the data shape used after gateway completion records -are reconstructed into trainable traces. - -## Main Files - -- `models.py`: completion sessions, completion records, traces, trajectories, - builder specs, evaluator specs, and evaluator results. -- `registry.py`: builder and evaluator registration. -- `builder/`: trajectory construction strategies. -- `evaluator/`: reward and validation strategies. - -## Data Model - -- `CompletionRecord`: one normalized model call and response. -- `CompletionSession`: all completion records captured during one agent run. -- `Trace`: token ids, loss mask, messages, logprobs, reward, and metadata. -- `Trajectory`: terminal status plus one or more traces. - -## Reward Attachment - -Evaluators return an `EvalResult`. The gateway merges outcome rewards or -per-trace rewards into the built trajectory before sending the session result -back to the rollout server. - -## Extension Points - -Register new builders or evaluators through `registry.py`. Keep strategy names -stable because task files and Slime configs refer to them by string. +`polar.trajectory` defines the data shapes and the strategy registry used to turn +the model calls captured during a run into **trainable, reward-bearing traces**. +In the pipeline: the gateway captures a `CompletionSession` → a **builder** turns +it into a `Trajectory` → an **evaluator** scores it → the gateway merges the +reward onto the traces and sends the result back. + +## Mental model + +- Two plugin families — **builders** and **evaluators** — each chosen by a string + name in the task (`builder.strategy`, `evaluator.strategy`). +- A `StrategyRegistry` maps names to classes and constructs one per request from + `spec.config`. A name you didn't register also works if you pass a + `"module:ClassName"` import path. +- A `Trajectory` is a terminal status plus a list of `Trace`s. Each `Trace` + carries parallel token arrays (`response_ids` / `loss_mask` / + `response_logprobs`) and the messages. Reward lands on each `Trace.reward`, + attached by the gateway from the evaluator's result — builders don't set it. + +## Main files + +- `models.py`: the schemas — `CompletionRecord`, `CompletionSession`, `Trace`, + `Trajectory`, plus `StrategySpec`, `EvaluatorSpec`, `EvalResult`. +- `registry.py`: the generic `StrategyRegistry` + the default builder/evaluator + registries. +- `builder/`: trajectory construction strategies (see [builder](builder/README.md)). +- `evaluator/`: reward / validation strategies (see [evaluator](evaluator/README.md)). + +## Data model + +- `CompletionRecord`: one captured model call — `completion_id`, the original and + served `request`s, and the `response`. +- `CompletionSession`: every record from one run; it auto-sorts records by + timestamp so builders see them in order. +- `Trace`: `prompt_ids`, `response_ids`, `loss_mask`, `prompt_messages`, + `response_messages`, `tools`, `finish_reason`, `response_logprobs`, `reward`, + `metadata`. (`loss_mask`, when present, matches `response_ids` length and holds + only 0/1.) +- `Trajectory`: `status` (one of `COMPLETED` / `TIMEOUT` / `ERROR`), `traces` + (zero or more), `metadata`, `error`. + +## Reward attachment + +An evaluator returns an `EvalResult` with an `outcome_reward` and/or per-trace +`trace_rewards`. The gateway merges it: `trace_rewards` must have exactly one +entry per trace (otherwise the session is marked `ERROR`); a single +`outcome_reward` is broadcast to every trace. + +## Extension points + +Register new builders/evaluators in `registry.py`, or reference any +`BaseTrajectoryBuilder` / `BaseTrajectoryEvaluator` subclass directly by +`"module:ClassName"`. Keep strategy names stable — task files and Slime configs +refer to them by string. Registered out of the box: builders `per_request`, +`prefix_merging`; evaluators `session_completed`, `test_on_output`, +`swebench_harness`. diff --git a/src/polar/trajectory/builder/README.md b/src/polar/trajectory/builder/README.md index e7e7bc221..e27e7b6e9 100644 --- a/src/polar/trajectory/builder/README.md +++ b/src/polar/trajectory/builder/README.md @@ -1,29 +1,42 @@ # Trajectory Builders -Builders convert a `CompletionSession` into a `Trajectory`. +Builders convert a captured `CompletionSession` into a `Trajectory` of trainable +`Trace`s — the first reconstruction step, before evaluation. -## Main Files +## Main files -- `base.py`: builder contract. +- `base.py`: the builder contract (`async build(session) -> Trajectory`). - `per_request.py`: one trace per completion. -- `prefix_merging.py`: merge strict append-only prompt/response chains. -- `record_utils.py`: helpers for extracting messages, token ids, logprobs, and - response metadata from completion records. +- `prefix_merging.py`: stitch an append-only agent chain into one token-level trace. +- `record_utils.py`: helpers to pull messages, token ids, logprobs, and metadata + out of a completion record. ## `per_request` -`per_request` is the simplest strategy. Each completion becomes its own trace. -Use it when preserving every request independently is more important than -building longer multi-turn training examples. +The simplest strategy: each completion becomes its own trace. Use it when you +want every request preserved independently rather than merged into longer +multi-turn examples. ## `prefix_merging` -`prefix_merging` joins consecutive completions when each prompt is exactly the -previous prompt plus response. It starts a separate trace when that prefix -relationship breaks, such as after context compaction. - -## Loss Mask And Logprobs - -Builders should set `loss_mask = 1` for sampled assistant response tokens that -can train the policy. Interstitial or copied tokens should use `loss_mask = 0`. -Training bridges expect trainable tokens to have matching logprob data. +A multi-turn agent resends the growing conversation on every step, so +consecutive requests share a common prefix. `prefix_merging` detects this and +merges the chain into a single trace with one prompt and the concatenated turns. + +The join test is a **strict token prefix**: a new request joins the chain only +when its `prompt_ids` start with the previous completion's prompt (append-only). +A message-level key gates candidates first, and it deliberately ignores +tool-result and empty assistant messages so ordinary tool loops still merge. When +the prefix relationship breaks — e.g. after context compaction rewrites earlier +turns — a new trace is started (and a partially merged chain can be truncated at +the break). + +## Loss mask and logprobs + +Builders set `loss_mask = 1` for the sampled assistant tokens that should train +the policy and `loss_mask = 0` for interstitial/copied tokens. Sampled tokens +keep their real logprobs; `prefix_merging` fills interstitial positions with +`0.0` placeholders so the arrays stay aligned. The turn boundary is found via an +end-of-turn token (auto-detected, or set explicitly with the builder's +`end_of_turn_token_id` config). Training bridges expect trainable tokens to have +matching logprob data. diff --git a/src/polar/trajectory/builder/prefix_merging.py b/src/polar/trajectory/builder/prefix_merging.py index 7ea9e1797..d58499dc3 100644 --- a/src/polar/trajectory/builder/prefix_merging.py +++ b/src/polar/trajectory/builder/prefix_merging.py @@ -9,11 +9,16 @@ Design in two stages: -1. **Grouping** — detect which completions belong to the same append-only - agent chain. A cheap message-level key is used as an O(1) index, and a - strict token-prefix check (``C_{k+1}.prompt_ids`` must start with - ``C_k.prompt_ids``) is the final arbiter. Completions whose tokens - diverge start a fresh chain instead of silently polluting an existing one. +1. **Grouping** — route each completion to the chain it append-extends, tested + purely on tokens: a completion joins the chain whose last prompt is a prefix + of it (``C_k.prompt_ids`` is a prefix of ``C_{k+1}.prompt_ids``). This routes + correctly even when parallel agents / sub-agents interleave (each has a + distinct prompt prefix), and is robust to BPE re-tokenization because it + compares only server-tokenized prompts, whose shared prefix is stable across + the special-token generation-prompt boundary. We never compare the *sampled* + ``response_ids`` (those can re-tokenize in the next prompt, e.g. + ``[fish, ing]`` → ``[fishing]``); a completion that extends no open chain + starts a fresh one. 2. **Finalization** — walk each chain and build a merged token stream: @@ -28,14 +33,11 @@ - Interstitial slots get synthesized logprobs and a zero ``loss_mask``; sampled assistant slots keep their real logprobs and a one ``loss_mask``. -See ``docs/prefix_merging_algorithm.md`` for a full walkthrough with -examples, invariants, and edge cases. """ from __future__ import annotations import logging -from collections import defaultdict, deque from copy import deepcopy from typing import Any @@ -49,85 +51,6 @@ _NATURAL_STOP_REASONS = frozenset({"stop", "tool_calls", "stop_sequence"}) - -# --------------------------------------------------------------------------- -# Message-level grouping helpers — used to detect which completions belong -# to the same agentic chain (C_{i+1}'s prompt == C_i's prompt + response). -# --------------------------------------------------------------------------- - - -def _flatten_message_content(content: Any) -> str: - """Extract text from a message content field (string or content-part array).""" - if isinstance(content, str): - return content - if isinstance(content, list): - return "".join( - item.get("text", "") - for item in content - if isinstance(item, dict) and item.get("type") == "text" - ) - return str(content) if content is not None else "" - - -def _expand_messages_for_grouping(message: dict[str, Any]) -> list[dict[str, Any]]: - role = message.get("role") - if role != "assistant" or not message.get("tool_calls"): - return [message] - - expanded: list[dict[str, Any]] = [] - content = message.get("content") - if content not in (None, "", []): - expanded.append({"role": role, "content": content}) - expanded.append( - {"role": role, "content": None, "tool_calls": message.get("tool_calls")} - ) - return expanded - - -def _is_grouping_noise_message(message: dict[str, Any]) -> bool: - role = message.get("role") - if role in ["tool"]: - return True - if role == "assistant" and message.get("tool_calls"): - return False - content = _flatten_message_content(message.get("content")).strip() - if role == "assistant" and not content and not message.get("tool_calls"): - return True - return False - - -def _normalize_messages(messages: list[dict[str, Any]]) -> str: - """Flatten a message list into a deterministic key string. - - Format: ``role:contentrole:content...`` - """ - parts = [] - for msg in messages: - role = msg.get("role", "") - if role == "assistant" and msg.get("tool_calls"): - content = "" - else: - content = _flatten_message_content(msg.get("content")) - parts.append(f"{role}:{content}") - return "".join(parts) - - -def _grouping_key(messages: list[dict[str, Any]]) -> str: - """Normalize the structural conversation context used for chaining. - - Tool-result messages are omitted because they are harness artifacts that - appear between assistant turns in the next request prompt. - """ - return _normalize_messages( - [ - expanded_message - for message in messages - for expanded_message in _expand_messages_for_grouping(message) - if not _is_grouping_noise_message(expanded_message) - ], - ) - - class PrefixMergingBuilder(BaseTrajectoryBuilder): """Rebuild a chain's merged token stream using raw + canonical-interstitial. @@ -165,26 +88,17 @@ async def build(self, session: CompletionSession) -> Trajectory: ) chains: list[list[CompletionRecord]] = [] - waiting_chains: dict[str, deque[int]] = defaultdict(deque) + chain_tips: list[list[int]] = [] # last completion's prompt_ids, per chain for completion in session.completions: - trace = build_trace_from_completion(completion) - prompt_key = _grouping_key(trace.prompt_messages) - chain_idx = self._pop_compatible_chain( - prompt_key=prompt_key, - prompt_ids=trace.prompt_ids, - chains=chains, - waiting_chains=waiting_chains, - ) - - if chain_idx is not None: - chains[chain_idx].append(completion) - else: + prompt_ids = build_trace_from_completion(completion).prompt_ids + chain_idx = self._find_extendable_chain(prompt_ids, chain_tips) + if chain_idx is None: chain_idx = len(chains) - chains.append([completion]) - - next_key = _grouping_key(trace.prompt_messages + trace.response_messages) - waiting_chains[next_key].append(chain_idx) + chains.append([]) + chain_tips.append([]) + chains[chain_idx].append(completion) + chain_tips[chain_idx] = prompt_ids stats: dict[str, int] = { "chains_total": len(chains), @@ -233,7 +147,7 @@ def _finalize_chain( prompt_ids = list(first_trace.prompt_ids) stream_ids: list[int] = list(prompt_ids) - response_slots: list[dict[str, Any] | None] = [] + response_slots: list[float | None] = [] loss_mask: list[int] = [] response_messages: list[dict[str, Any]] = [] @@ -312,7 +226,7 @@ def _finalize_chain( stats["chains_reconstructed_truncated"] += 1 response_ids = stream_ids[len(prompt_ids):] - response_logprobs = self._finalize_logprobs(response_slots, response_ids) + response_logprobs = self._finalize_logprobs(response_slots) last_kept_trace = build_trace_from_completion(chain[kept - 1]) return Trace( @@ -385,7 +299,7 @@ def _slice_interstitial( def _append_response_tokens( trace: Trace, stream_ids: list[int], - response_slots: list[dict[str, Any] | None], + response_slots: list[float | None], loss_mask: list[int], ) -> None: """Append a completion's response_ids and parallel logprob slots.""" @@ -397,20 +311,18 @@ def _append_response_tokens( loss_mask.extend(trace_loss_mask) logprobs = trace.response_logprobs or [] for pos in range(len(response_ids)): - entry = logprobs[pos] if pos < len(logprobs) else None - response_slots.append(deepcopy(entry) if isinstance(entry, dict) else None) + value = logprobs[pos] if pos < len(logprobs) else None + response_slots.append(float(value) if isinstance(value, (int, float)) else None) @staticmethod def _finalize_logprobs( - slots: list[dict[str, Any] | None], - response_ids: list[int], - ) -> list[dict[str, Any]] | None: + slots: list[float | None], + ) -> list[float] | None: + # Interstitial slots (tool results, chat glue) get 0.0; loss_mask=0 + # makes the trainer ignore them. if not any(slot is not None for slot in slots): return None - return [ - slot if slot is not None else {"token_id": response_ids[i], "logprob": 0.0} - for i, slot in enumerate(slots) - ] + return [slot if slot is not None else 0.0 for slot in slots] @staticmethod def _chain_metadata(chain: list[CompletionRecord]) -> dict[str, Any]: @@ -420,49 +332,30 @@ def _chain_metadata(chain: list[CompletionRecord]) -> dict[str, Any]: return merged @staticmethod - def _pop_compatible_chain( - *, - prompt_key: str, + def _find_extendable_chain( prompt_ids: list[int], - chains: list[list[CompletionRecord]], - waiting_chains: dict[str, deque[int]], + chain_tips: list[list[int]], ) -> int | None: - """Pop a waiting chain that matches both at message-key and token levels. - - The message-level key (produced by ``_grouping_key``) is only a - *necessary* condition for joining a chain. Its normalization drops - tool messages and empty/```` assistants — both of which can - hide genuine token-level divergence (cache-control shifts, tools - schema rewrites, ```` injections). - - The *sufficient* condition is the strict append-only token-prefix - invariant: ``C_{k+1}.prompt_ids`` must start with ``C_k.prompt_ids``. - Enforcing this at chain-join time means a completion whose raw - tokenization diverges from the waiting chain's tail starts its own - new chain, instead of being silently appended (only to be dropped - later in finalization). - - Scans candidates in FIFO order; returns the first compatible index - and pops it. Returns None if no candidate passes the token check. + """Return the open chain this completion append-extends, else None. + + A completion continues a chain iff its prompt begins with that chain's + last prompt (``tip`` is a token-prefix of ``prompt_ids``). This routes + completions to the right chain even when parallel agents / sub-agents + interleave — each conversation has a distinct prompt prefix — and + tolerates the just-finished turn being re-serialized in history (tool-call + argument reformatting, whitespace), since that divergence falls *after* + the prompt. The compared prefix is two server-side tokenizations of the + same text, so BPE re-tokenization of the sampled response never enters + the decision. On overlap the longest matching tip wins (most advanced + chain). """ - queue = waiting_chains.get(prompt_key) - if not queue: - return None - for pos, chain_idx in enumerate(queue): - last_trace = build_trace_from_completion(chains[chain_idx][-1]) - last_pids = last_trace.prompt_ids - if ( - not prompt_ids - or not last_pids - or len(prompt_ids) < len(last_pids) - or prompt_ids[: len(last_pids)] != last_pids - ): - continue - del queue[pos] - if not queue: - waiting_chains.pop(prompt_key, None) - return chain_idx - return None + best_idx: int | None = None + best_len = -1 + for idx, tip in enumerate(chain_tips): + n = len(tip) + if n > best_len and 0 < n <= len(prompt_ids) and prompt_ids[:n] == tip: + best_idx, best_len = idx, n + return best_idx def _top_level_scheduler_metadata(metadata: dict[str, Any]) -> dict[str, Any]: diff --git a/src/polar/trajectory/builder/record_utils.py b/src/polar/trajectory/builder/record_utils.py index 4dd73995e..b036b4049 100644 --- a/src/polar/trajectory/builder/record_utils.py +++ b/src/polar/trajectory/builder/record_utils.py @@ -27,12 +27,20 @@ def _extract_response_ids(response: dict[str, Any], choice: dict[str, Any]) -> l return [] -def _extract_response_logprobs(choice: dict[str, Any]) -> list[dict[str, Any]] | None: +def _extract_response_logprobs(choice: dict[str, Any]) -> list[float] | None: + """Sampled-token logprob per position, aligned 1:1 with response_ids. + + The token id is intentionally dropped -- it is already in ``response_ids`` + at the same index; only the float is needed for training. + """ logprobs = choice.get("logprobs") if isinstance(logprobs, dict): content = logprobs.get("content") if isinstance(content, list): - return [deepcopy(item) for item in content if isinstance(item, dict)] + return [ + float(item.get("logprob", 0.0)) if isinstance(item, dict) else 0.0 + for item in content + ] return None diff --git a/src/polar/trajectory/evaluator/README.md b/src/polar/trajectory/evaluator/README.md index ad26385da..a77b0dd64 100644 --- a/src/polar/trajectory/evaluator/README.md +++ b/src/polar/trajectory/evaluator/README.md @@ -1,24 +1,46 @@ # Trajectory Evaluators -Evaluators attach rewards or validation metadata after an agent run finishes. +Evaluators score a built `Trajectory` into an `EvalResult` (an outcome reward +and/or per-trace rewards, plus metadata). The gateway then merges that reward +onto the trajectory's traces. -## Main Files +## Main files -- `base.py`: evaluator contract. -- `session_completed.py`: simple success reward when the session completes. -- `test_on_output.py`: run a command and reward successful test output. -- `swebench_harness.py`: SWE-bench grading integration. -- `_patch_utils.py`: patch helpers used by SWE-style evaluators. +- `base.py`: the evaluator contract (`async evaluate(trajectory, **runtime) -> EvalResult`). +- `session_completed.py`: reward by terminal status. +- `test_on_output.py`: apply the agent's changes and grade test output. +- `swebench_harness.py`: grade a patch with the SWE-bench harness. +- `_patch_utils.py`: `BasePatchEvaluator` — the shared extract → filter → apply → + test flow both grading evaluators build on. -## Built-In Strategies +## Built-in strategies -- `session_completed`: rewards sessions that reached a completed status. -- `test_on_output`: runs a configured command in the runtime and uses the exit - result as reward signal. -- `swebench_harness`: grades SWE-bench style patches with the SWE-bench harness. +**`session_completed`** — reward `1.0` if the session reached `COMPLETED`, else +`0.0`. Needs no runtime; handy as a smoke-test signal. -## Adding An Evaluator +**`test_on_output`** — for custom/toy tasks. It extracts the agent's git diff, +(optionally) applies it on a fresh runtime, runs a test command, and **grades by +matching parsed test output — not the exit code**: it reads +`PASSED`/`FAILED`/`ERROR`/`SKIPPED ` lines and rewards `1.0` only when the +parsed result **exactly equals** the expected map. -Implement the base contract, return an `EvalResult`, and register the strategy -name in the evaluator registry. Keep external services, GPUs, and large -datasets out of default unit tests. +| config key | required | meaning | +|---|---|---| +| `test_command` | yes | the command to run | +| `expected_output_json` | yes | `{node: "PASSED", ...}` the output must match | +| `repo_dir` | no | where the diff/test run (default `/testbed`) | +| `patch_command` | no | how to extract the diff (default a `git diff`) | +| `test_timeout` / `apply_timeout` | no | timeouts | +| `exclude_patterns` | no | paths to drop from the diff | + +**`swebench_harness`** — grades real SWE-bench-style patches with the SWE-bench +(or SWE-Gym) harness. Takes an `instance` dict plus the same patch config keys. + +Both grading evaluators need a live runtime (and a `fresh_eval_runtime` when the +task sets `refresh_runtime`); an empty diff scores `0.0`. + +## Adding an evaluator + +Implement the base contract, return an `EvalResult`, and register the name in +`registry.py` (or pass a `"module:ClassName"` import path). Keep external +services, GPUs, and large datasets out of default unit tests. diff --git a/src/polar/trajectory/models.py b/src/polar/trajectory/models.py index 0cabe5f4c..d247bd9d0 100644 --- a/src/polar/trajectory/models.py +++ b/src/polar/trajectory/models.py @@ -96,7 +96,7 @@ class Trace(BaseModel): response_messages: list[dict[str, Any]] = Field(default_factory=list) tools: list[dict[str, Any]] | None = None finish_reason: str | None = None - response_logprobs: list[dict[str, Any]] | None = None + response_logprobs: list[float] | None = None reward: float | None = None metadata: dict[str, Any] = Field(default_factory=dict) @@ -112,9 +112,14 @@ def _validate_loss_mask_values(cls, value: list[int]) -> list[int]: return normalized @model_validator(mode="after") - def _validate_loss_mask_length(self) -> "Trace": + def _validate_response_lengths(self) -> "Trace": if self.loss_mask and len(self.loss_mask) != len(self.response_ids): raise ValueError("loss_mask length must match response_ids length") + if ( + self.response_logprobs is not None + and len(self.response_logprobs) != len(self.response_ids) + ): + raise ValueError("response_logprobs length must match response_ids length") return self diff --git a/src/slime_bridge/README.md b/src/slime_bridge/README.md index 76e909e5b..c81fe4a25 100644 --- a/src/slime_bridge/README.md +++ b/src/slime_bridge/README.md @@ -1,34 +1,61 @@ # Slime Bridge -`slime_bridge` connects Slime rollout calls to a running Polar rollout service. -It is intentionally kept outside the core `polar` package namespace because -Slime, Ray, Megatron, and training dependencies are installed separately. +`slime_bridge` connects [Slime](https://github.com/THUDM/slime)'s RL training +loop to a running Polar rollout server over HTTP. It lives **outside** the +`polar` package because Slime, Ray, Megatron, and torch are installed separately +— Polar depends on none of them. -## Main Files +## How it fits -- `config.py`: bridge configuration helpers. -- `rollout.py`: async worker lifecycle, task submission, policy update - coordination, and Slime-facing rollout functions. -- `_messages.py`: prompt and message conversion. -- `adapter.py`: conversion from Polar session results to Slime-like samples. -- `data_source.py`: Slime data source integration. -- `reward.py`: reward helpers. -- `reward_post_process.py`: reward normalization after rollout. +Slime calls one entry point, `generate_rollout_polar_async`, wired in via +`--rollout-function-path`. From there the bridge: -## What The Bridge Owns +- submits async task batches to `polar_rollout_url` (or a node derived from + `polar_topology_path`) and collects each result through a local callback + listener with a polling safety net; +- tracks rollout ids and policy versions, stamps `{group_id, policy_version, + rollout_step}` onto every task, and **drops groups that drift too far + off-policy** (`max_off_policy_steps`) while keeping async admission bounded + (`max_async_level`); +- **pauses/resumes gateway generation** around weight updates (the gateway's + `/admin/inference/pause` + `/resume`) when overlap is enabled; +- converts each Polar `Trajectory` back into Slime `Sample`s (one per trace, + grouped so the reward post-processor treats them as one trajectory), dropping + empty or oversized traces; +- normalizes rewards GRPO-style per group and zeroes out failed/aborted + trajectories. -- Convert Slime samples and prompt messages into Polar task requests. -- Submit async batches to `polar_rollout_url`. -- Track rollout ids, policy versions, and scheduler metadata. -- Pause and resume gateway generation during configured weight update windows. -- Convert Polar trajectories back into sample objects expected by Slime. -- Normalize or filter rewards for failed and oversized trajectories. +## Main files -## Slime Installation +- `config.py`: `PolarSlimeConfig` + `resolve_polar_slime_config`; also renders the + task payload, the instruction, and the topology that points gateways at Slime's + SGLang router. +- `rollout.py`: the async worker (submit → callback/poll → convert), the + evaluation path, policy-update coordination, the acceptance filters, and the + Slime entry point. +- `_messages.py`: prompt/message flattening shared by rollout + adapter. +- `adapter.py`: convert a Polar `SessionResult` into Slime `Sample`s. +- `data_source.py`: `CeilEpochRolloutDataSourceWithBuffer` — rounds the epoch + length up so the dataset tail isn't skipped. +- `reward.py`: reward hook that reads the reward Polar already embedded. +- `reward_post_process.py`: trajectory-aware, group-normalized reward shaping. -Install Slime from the THUDM git checkout, not from the unrelated PyPI `slime` -package. The SWE-Gym Slime GRPO example uses `launch_e2e.sh` to automate this -setup; the manual equivalent from the repository root is: +## What the bridge owns + +- Turn Slime samples + prompts into Polar task requests and submit async batches. +- Track rollout ids / policy versions; drop off-policy-stale groups; bound async + admission. +- Filter unusable groups (zero trainable tokens, too few completed samples, + logprob errors) with per-category metrics. +- Pause/resume gateway generation during weight-update windows. +- Convert Polar trajectories back into Slime samples; normalize and zero rewards. +- Run the evaluation path over `eval_datasets` and emit W&B metrics. + +## Slime installation + +Install Slime from the THUDM git checkout (not the unrelated PyPI `slime` +package). The SWE-Gym Slime GRPO example automates this with `launch_e2e.sh`; the +manual equivalent from the repository root is: ```bash git clone --branch v0.2.4 --depth 1 https://github.com/THUDM/slime.git slime @@ -41,8 +68,6 @@ uv pip install -e Megatron-LM bash scripts/patch/patch_slime.sh slime ``` -Use `SLIME_DIR=/path/to/slime` and `MEGATRON_DIR=/path/to/Megatron-LM` when -working with existing checkouts outside the repository root. - -The Slime training environment is expected to provide heavy dependencies such -as `torch`. Polar does not add those dependencies for the first beta release. +Use `SLIME_DIR=/path/to/slime` and `MEGATRON_DIR=/path/to/Megatron-LM` for +checkouts outside the repository root. The Slime training environment provides +the heavy dependencies (e.g. `torch`); Polar does not add them. diff --git a/src/slime_bridge/adapter.py b/src/slime_bridge/adapter.py index e01e28fcb..0ef2324a7 100644 --- a/src/slime_bridge/adapter.py +++ b/src/slime_bridge/adapter.py @@ -94,7 +94,7 @@ def _build_sample( max_tokens: int | None = None, ) -> Any | None: prompt_ids = list(trace.prompt_ids) - response_ids = list(trace.response_ids) or _response_ids_from_logprobs(trace) + response_ids = list(trace.response_ids) if not prompt_ids or not response_ids: logger.warning( @@ -281,23 +281,10 @@ def _extract_rollout_log_probs( f"{len(logprobs)} != response length {response_len}" ) - values: list[float] = [] - for pos, (entry, mask_value) in enumerate(zip(logprobs, loss_mask, strict=True)): - if not isinstance(entry, dict): - if mask_value: - raise RolloutLogprobError( - f"Session {session_id} trace {trace_index}: logprob entry {pos} " - "is not a mapping" - ) - values.append(0.0) - continue - if mask_value and "logprob" not in entry: - raise RolloutLogprobError( - f"Session {session_id} trace {trace_index}: trainable token {pos} " - "is missing logprob" - ) - values.append(float(entry.get("logprob", 0.0))) - return values + # response_logprobs is one float per response token (interstitials are 0.0, + # masked out by loss_mask); the builder guarantees trainable tokens carry + # their real sampled logprob. + return [float(value) for value in logprobs] def _loss_mask_from_trace( @@ -324,16 +311,6 @@ def _loss_mask_from_trace( return [1 if int(value) else 0 for value in mask] -def _response_ids_from_logprobs(trace: "Trace") -> list[int]: - if not trace.response_logprobs: - return [] - return [ - int(item["token_id"]) - for item in trace.response_logprobs - if isinstance(item, dict) and item.get("token_id") is not None - ] - - def _load_sample_type() -> Any: try: from slime.utils.types import Sample diff --git a/src/slime_bridge/config.py b/src/slime_bridge/config.py index 17f8086f7..aec006457 100644 --- a/src/slime_bridge/config.py +++ b/src/slime_bridge/config.py @@ -208,7 +208,8 @@ def render_topology_template(topology_path: str | Path, args: Any) -> dict[str, "port": node.port, "public_url": node.public_url, "model_served": node.model_served, - "sglang": { + "inference": { + "engine": "sglang", "base_url": router_url, }, "max_init_workers": node.max_init_workers, diff --git a/src/slime_bridge/rollout.py b/src/slime_bridge/rollout.py index f08057ad4..7c25c364a 100644 --- a/src/slime_bridge/rollout.py +++ b/src/slime_bridge/rollout.py @@ -113,7 +113,7 @@ def update_policy_version(args: Any, policy_version: int) -> None: def prepare_policy_update(args: Any, policy_version: int) -> None: - """Optional hook called by Slime before overlapping SGLang weight sync.""" + """Optional hook called by Slime before overlapping inference weight sync.""" logger.info("Preparing Polar bridge for policy_version=%s weight update", policy_version) with _worker_lock: worker = _global_async_worker @@ -135,7 +135,7 @@ def prepare_policy_update(args: Any, policy_version: int) -> None: def finish_policy_update(args: Any, policy_version: int) -> None: - """Optional hook called by Slime after overlapping SGLang weight sync.""" + """Optional hook called by Slime after overlapping inference weight sync.""" try: _resume_gateway_generation(args) finally: @@ -171,11 +171,11 @@ def _pause_gateway_generation(args: Any) -> None: request_timeout = max(timeout_seconds + 5.0, 10.0) with httpx.Client(timeout=request_timeout) as client: response = client.post( - f"{gateway_url}/admin/sglang/pause", + f"{gateway_url}/admin/inference/pause", params={"timeout_seconds": timeout_seconds}, ) response.raise_for_status() - logger.info("Paused Polar gateway generation for SGLang weight update: %s", response.json()) + logger.info("Paused Polar gateway generation for inference weight update: %s", response.json()) def _resume_gateway_generation(args: Any) -> None: @@ -185,9 +185,9 @@ def _resume_gateway_generation(args: Any) -> None: request_timeout = float(getattr(args, "polar_gateway_control_timeout", 30.0)) with httpx.Client(timeout=max(request_timeout, 5.0)) as client: - response = client.post(f"{gateway_url}/admin/sglang/resume") + response = client.post(f"{gateway_url}/admin/inference/resume") response.raise_for_status() - logger.info("Resumed Polar gateway generation after SGLang weight update: %s", response.json()) + logger.info("Resumed Polar gateway generation after inference weight update: %s", response.json()) # --------------------------------------------------------------------------- diff --git a/tests/config/test_topology.py b/tests/config/test_topology.py index f059aabb4..6065ee990 100644 --- a/tests/config/test_topology.py +++ b/tests/config/test_topology.py @@ -80,6 +80,113 @@ def test_select_gateway_requires_node_id_for_multi_node_topology(tmp_path: Path) assert topology.select_gateway_node("node-b").port == 8081 +def test_inference_block_selects_engine_and_base_url(tmp_path: Path) -> None: + path = _write_yaml( + tmp_path / "topology.yaml", + { + "gateway": { + "nodes": [ + { + "id": "node-a", + "public_url": "http://127.0.0.1:8100", + "inference": {"engine": "vllm", "base_url": "http://127.0.0.1:8000"}, + } + ], + }, + }, + ) + node = TopologyConfig.load(path).gateway.nodes[0] + assert node.engine == "vllm" + assert node.inference_base_url == "http://127.0.0.1:8000" + + +def test_inference_engine_defaults_to_sglang(tmp_path: Path) -> None: + path = _write_yaml( + tmp_path / "topology.yaml", + { + "gateway": { + "nodes": [ + { + "id": "node-a", + "public_url": "http://127.0.0.1:8100", + "inference": {"base_url": "http://127.0.0.1:8000"}, + } + ], + }, + }, + ) + node = TopologyConfig.load(path).gateway.nodes[0] + assert node.engine == "sglang" + + +def test_inference_defaults_when_block_omitted(tmp_path: Path) -> None: + path = _write_yaml( + tmp_path / "topology.yaml", + {"gateway": {"nodes": [{"id": "node-a", "public_url": "http://127.0.0.1:8100"}]}}, + ) + node = TopologyConfig.load(path).gateway.nodes[0] + assert node.engine == "sglang" + assert node.inference_base_url == "http://127.0.0.1:8000" + + +def test_invalid_inference_engine_is_rejected(tmp_path: Path) -> None: + path = _write_yaml( + tmp_path / "topology.yaml", + { + "gateway": { + "nodes": [ + { + "id": "node-a", + "public_url": "http://127.0.0.1:8100", + "inference": {"engine": "tgi", "base_url": "http://127.0.0.1:8000"}, + } + ], + }, + }, + ) + with pytest.raises(ValueError): + TopologyConfig.load(path) + + +def test_invalid_inference_base_url_is_rejected(tmp_path: Path) -> None: + path = _write_yaml( + tmp_path / "topology.yaml", + { + "gateway": { + "nodes": [ + { + "id": "node-a", + "public_url": "http://127.0.0.1:8100", + "inference": {"engine": "vllm", "base_url": "not-a-url"}, + } + ], + }, + }, + ) + with pytest.raises(ValueError, match="inference.base_url"): + TopologyConfig.load(path) + + +def test_legacy_sglang_block_is_rejected(tmp_path: Path) -> None: + # No backward compatibility: the old `sglang:` block is now an unknown key. + path = _write_yaml( + tmp_path / "topology.yaml", + { + "gateway": { + "nodes": [ + { + "id": "node-a", + "public_url": "http://127.0.0.1:8100", + "sglang": {"base_url": "http://127.0.0.1:8000"}, + } + ], + }, + }, + ) + with pytest.raises(ValueError, match="sglang"): + TopologyConfig.load(path) + + def test_duplicate_gateway_node_ids_are_rejected(tmp_path: Path) -> None: path = _write_yaml( tmp_path / "topology.yaml", diff --git a/tests/gateway/test_engine.py b/tests/gateway/test_engine.py new file mode 100644 index 000000000..13c02581f --- /dev/null +++ b/tests/gateway/test_engine.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import pytest + +from polar.gateway.engine import SGLangEngine, VLLMEngine, get_engine + + +def test_get_engine_returns_the_right_strategy() -> None: + assert isinstance(get_engine("sglang"), SGLangEngine) + assert isinstance(get_engine("vllm"), VLLMEngine) + + +def test_get_engine_rejects_unknown_name() -> None: + with pytest.raises(ValueError, match="Unknown inference engine"): + get_engine("tgi") + + +def test_sglang_engine_requests_logprobs_and_passes_through() -> None: + engine = SGLangEngine() + request = {"messages": []} + out = engine.prepare_request(request) + assert out is request and out["logprobs"] is True + response = {"choices": [{"message": {"role": "assistant", "content": "hi"}}]} + assert engine.normalize_response(response) is response + + +def test_vllm_prepare_request_requests_token_ids_and_logprobs() -> None: + out = VLLMEngine().prepare_request({"messages": [], "logprobs": True}) + assert out["logprobs"] is True + assert out["return_token_ids"] is True + assert out["top_logprobs"] == 0 + + +def test_vllm_prepare_request_keeps_explicit_top_logprobs() -> None: + out = VLLMEngine().prepare_request({"logprobs": True, "top_logprobs": 5}) + assert out["top_logprobs"] == 5 + + +def test_vllm_prepare_request_forces_logprobs_when_absent() -> None: + out = VLLMEngine().prepare_request({"messages": []}) + assert out["logprobs"] is True + assert out["return_token_ids"] is True + assert out["top_logprobs"] == 0 + + +def test_vllm_normalize_renames_reasoning_to_reasoning_content() -> None: + response = { + "choices": [ + {"message": {"role": "assistant", "content": "a", "reasoning": "because"}} + ] + } + message = VLLMEngine().normalize_response(response)["choices"][0]["message"] + assert message["reasoning_content"] == "because" + assert "reasoning" not in message + + +def test_vllm_normalize_keeps_existing_reasoning_content() -> None: + response = {"choices": [{"message": {"reasoning": "new", "reasoning_content": "kept"}}]} + message = VLLMEngine().normalize_response(response)["choices"][0]["message"] + assert message["reasoning_content"] == "kept" + + +def test_vllm_normalize_without_reasoning_is_noop() -> None: + response = {"choices": [{"message": {"role": "assistant", "content": "hi"}}]} + out = VLLMEngine().normalize_response(response) + assert out["choices"][0]["message"] == {"role": "assistant", "content": "hi"} + + +def test_vllm_normalize_stamps_token_ids_onto_logprobs() -> None: + response = { + "choices": [ + { + "message": {"role": "assistant", "content": "hi"}, + "token_ids": [10, 11], + "logprobs": { + "content": [ + {"token": "h", "logprob": -0.1}, + {"token": "i", "logprob": -0.2}, + ] + }, + } + ] + } + content = VLLMEngine().normalize_response(response)["choices"][0]["logprobs"]["content"] + assert [entry["token_id"] for entry in content] == [10, 11] + + +def test_vllm_normalize_skips_token_id_stamp_on_length_mismatch() -> None: + response = { + "choices": [ + { + "token_ids": [10, 11, 12], + "logprobs": {"content": [{"token": "h", "logprob": -0.1}]}, + } + ] + } + content = VLLMEngine().normalize_response(response)["choices"][0]["logprobs"]["content"] + assert "token_id" not in content[0] diff --git a/tests/gateway/test_transform_anthropic.py b/tests/gateway/test_transform_anthropic.py index 26ebf33bf..2b28842d0 100644 --- a/tests/gateway/test_transform_anthropic.py +++ b/tests/gateway/test_transform_anthropic.py @@ -65,6 +65,7 @@ def test_anthropic_request_maps_all_fields_and_image_input_to_chat() -> None: "max_tokens": 128, "temperature": 0.2, "top_p": 0.9, + "top_k": 40, "stop_sequences": ["END"], "stream": True, "tools": [ @@ -108,6 +109,7 @@ def test_anthropic_request_maps_all_fields_and_image_input_to_chat() -> None: assert transformed["max_tokens"] == 128 assert transformed["temperature"] == 0.2 assert transformed["top_p"] == 0.9 + assert transformed["top_k"] == 40 assert transformed["stop"] == ["END"] assert transformed["stream"] is True assert transformed["tools"] == [ @@ -124,10 +126,143 @@ def test_anthropic_request_maps_all_fields_and_image_input_to_chat() -> None: "type": "function", "function": {"name": "write_answer"}, } - assert transformed["logprobs"] is True assert transformed["chat_template_kwargs"]["enable_thinking"] is False +def test_anthropic_request_maps_multi_turn_reasoning_and_parallel_tools() -> None: + transformer = AnthropicTransformer() + + transformed = transformer.transform_request( + { + "messages": [ + {"role": "user", "content": "Plan and call tools."}, + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": "Need two independent lookups.", + "signature": "sig-1", + }, + { + "type": "tool_use", + "id": "toolu-a", + "name": "lookup", + "input": {"q": "a"}, + }, + { + "type": "tool_use", + "id": "toolu-b", + "name": "lookup", + "input": {"q": "b"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu-a", + "content": "A", + }, + { + "type": "tool_result", + "tool_use_id": "toolu-b", + "content": [{"type": "text", "text": "B"}], + }, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": "Combine both results.", + "signature": "sig-2", + }, + {"type": "text", "text": "A and B"}, + ], + }, + ], + "max_tokens": 256, + } + ) + + assert transformed["messages"] == [ + {"role": "user", "content": "Plan and call tools."}, + { + "role": "assistant", + "content": None, + "reasoning_content": "Need two independent lookups.", + "tool_calls": [ + { + "id": "toolu-a", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "a"}'}, + }, + { + "id": "toolu-b", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "b"}'}, + }, + ], + }, + {"role": "tool", "tool_call_id": "toolu-a", "content": "A"}, + {"role": "tool", "tool_call_id": "toolu-b", "content": "B"}, + { + "role": "assistant", + "content": "A and B", + "reasoning_content": "Combine both results.", + }, + ] + + +def test_anthropic_tool_choice_variants_map_to_openai() -> None: + transformer = AnthropicTransformer() + base_body = { + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 10, + "tools": [{"name": "lookup", "input_schema": {"type": "object"}}], + } + + assert ( + transformer.transform_request({**base_body, "tool_choice": {"type": "auto"}})[ + "tool_choice" + ] + == "auto" + ) + assert ( + transformer.transform_request({**base_body, "tool_choice": {"type": "any"}})[ + "tool_choice" + ] + == "required" + ) + assert ( + transformer.transform_request({**base_body, "tool_choice": {"type": "none"}})[ + "tool_choice" + ] + == "none" + ) + assert transformer.transform_request( + {**base_body, "tool_choice": {"type": "tool", "name": "lookup"}} + )["tool_choice"] == {"type": "function", "function": {"name": "lookup"}} + + +def test_anthropic_adaptive_thinking_request_param_enables_thinking() -> None: + transformer = AnthropicTransformer() + + transformed = transformer.transform_request( + { + "thinking": {"type": "adaptive", "display": "summarized"}, + "messages": [{"role": "user", "content": "think"}], + "max_tokens": 2048, + } + ) + + assert transformed["chat_template_kwargs"]["enable_thinking"] is True + + def test_anthropic_response_maps_openai_content_and_usage_back() -> None: transformer = AnthropicTransformer() @@ -177,6 +312,28 @@ def test_anthropic_response_maps_openai_content_and_usage_back() -> None: ] +def test_anthropic_response_preserves_cached_usage_tokens() -> None: + transformer = AnthropicTransformer() + + response = transformer.transform_response( + { + "choices": [{"message": {"content": "Done"}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 3, + "prompt_tokens_details": {"cached_tokens": 4}, + }, + }, + {"model": "claude-test"}, + ) + + assert response["usage"] == { + "input_tokens": 6, + "output_tokens": 3, + "cache_read_input_tokens": 4, + } + + def test_anthropic_response_skips_empty_openai_content_with_tool_call() -> None: transformer = AnthropicTransformer() @@ -274,3 +431,368 @@ def test_anthropic_stream_state_emits_ordered_text_tool_and_usage_events() -> No "usage": {"output_tokens": 3}, } assert events[-1] == {"type": "message_stop"} + + +def test_anthropic_request_handles_url_image_source() -> None: + transformer = AnthropicTransformer() + transformed = transformer.transform_request( + { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + { + "type": "image", + "source": {"type": "url", "url": "https://example.test/cat.png"}, + }, + ], + } + ], + "max_tokens": 16, + } + ) + + assert transformed["messages"][0]["content"] == [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": "https://example.test/cat.png"}}, + ] + + +def test_anthropic_request_handles_document_text_source() -> None: + transformer = AnthropicTransformer() + transformed = transformer.transform_request( + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "document", + "source": { + "type": "text", + "media_type": "text/plain", + "data": "Doc body.", + }, + }, + {"type": "text", "text": "Summarize."}, + ], + } + ], + "max_tokens": 16, + } + ) + + # Document text flattens into the user message text content. + assert transformed["messages"][0]["content"] == "Doc body.\nSummarize." + + +def test_anthropic_request_handles_document_content_source() -> None: + transformer = AnthropicTransformer() + transformed = transformer.transform_request( + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "document", + "source": { + "type": "content", + "content": [ + {"type": "text", "text": "Page 1"}, + {"type": "text", "text": "Page 2"}, + ], + }, + }, + ], + } + ], + "max_tokens": 16, + } + ) + + assert transformed["messages"][0]["content"] == "Page 1\nPage 2" + + +def test_anthropic_request_drops_base64_pdf_documents() -> None: + transformer = AnthropicTransformer() + transformed = transformer.transform_request( + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "document", + "source": { + "type": "base64", + "media_type": "application/pdf", + "data": "JVBERi0xLjQK", + }, + }, + {"type": "text", "text": "What's in the PDF?"}, + ], + } + ], + "max_tokens": 16, + } + ) + + # Binary PDFs can't be rendered to the chat template; drop the document + # and forward the surrounding text so the model still sees the question. + assert transformed["messages"][0]["content"] == "What's in the PDF?" + + +def test_anthropic_tool_result_is_error_marks_content() -> None: + transformer = AnthropicTransformer() + transformed = transformer.transform_request( + { + "messages": [ + {"role": "user", "content": "Run a command."}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu-1", + "name": "shell", + "input": {"cmd": "fail"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu-1", + "content": "Permission denied", + "is_error": True, + }, + ], + }, + ], + "max_tokens": 16, + } + ) + + tool_msg = next(m for m in transformed["messages"] if m.get("role") == "tool") + assert tool_msg["tool_call_id"] == "toolu-1" + assert tool_msg["content"].startswith("[Tool Error]") + assert "Permission denied" in tool_msg["content"] + + +def test_anthropic_response_maps_extended_stop_reasons() -> None: + transformer = AnthropicTransformer() + + refusal = transformer.transform_response( + { + "choices": [ + {"message": {"content": "blocked"}, "finish_reason": "content_filter"} + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + }, + {"model": "claude-3"}, + ) + assert refusal["stop_reason"] == "refusal" + + stop_seq = transformer.transform_response( + { + "choices": [ + {"message": {"content": "x"}, "finish_reason": "stop_sequence"} + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + }, + {"model": "claude-3"}, + ) + assert stop_seq["stop_reason"] == "stop_sequence" + + +def test_anthropic_request_drops_server_side_tools() -> None: + transformer = AnthropicTransformer() + transformed = transformer.transform_request( + { + "messages": [{"role": "user", "content": "do it"}], + "max_tokens": 16, + "tools": [ + {"name": "lookup", "input_schema": {"type": "object"}}, + {"type": "web_search_20250305", "name": "web_search"}, + {"type": "code_execution_20250522", "name": "code_execution"}, + ], + } + ) + + # Only the custom function tool reaches SGLang; server-side tools are + # dropped because Polar can't execute them. + assert transformed["tools"] == [ + { + "type": "function", + "function": { + "name": "lookup", + "description": "", + "parameters": {"type": "object"}, + }, + }, + ] + + +def test_anthropic_request_system_list_with_cache_control_annotations() -> None: + transformer = AnthropicTransformer() + transformed = transformer.transform_request( + { + "system": [ + { + "type": "text", + "text": "You are a careful assistant.", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "Always cite sources."}, + ], + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 16, + } + ) + + # cache_control is dropped silently; both text blocks are joined into one + # system message (Anthropic supports per-block cache markers; SGLang does + # not, so we forward only the prompt text). + assert transformed["messages"][0] == { + "role": "system", + "content": "You are a careful assistant.\nAlways cite sources.", + } + + +def test_anthropic_stream_state_emits_parallel_tool_use_blocks() -> None: + transformer = AnthropicTransformer() + state = transformer.create_stream_state({"model": "claude-test"}) + + # Both tools opened in a single chunk; arguments split across two chunks. + events = state.process_chunk( + { + "choices": [ + { + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "toolu-a", + "function": {"name": "lookup_a", "arguments": '{"q":'}, + }, + { + "index": 1, + "id": "toolu-b", + "function": {"name": "lookup_b", "arguments": '{"q":'}, + }, + ] + } + } + ] + }, + is_first=True, + ) + events.extend( + state.process_chunk( + { + "choices": [ + { + "delta": { + "tool_calls": [ + {"index": 0, "function": {"arguments": '"a"}'}}, + {"index": 1, "function": {"arguments": '"b"}'}}, + ] + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"completion_tokens": 5}, + } + ) + ) + events.extend(state.finalize()) + + # Two distinct content blocks: index 0 carries toolu-a, index 1 carries toolu-b. + starts = [ + event for event in events if event["type"] == "content_block_start" + ] + assert len(starts) == 2 + assert (starts[0]["index"], starts[0]["content_block"]["id"], starts[0]["content_block"]["name"]) == ( + 0, + "toolu-a", + "lookup_a", + ) + assert (starts[1]["index"], starts[1]["content_block"]["id"], starts[1]["content_block"]["name"]) == ( + 1, + "toolu-b", + "lookup_b", + ) + + # Per-index argument deltas land on the right block. + deltas_by_index: dict[int, list[str]] = {} + for event in events: + if event["type"] == "content_block_delta" and event["delta"].get( + "type" + ) == "input_json_delta": + deltas_by_index.setdefault(event["index"], []).append( + event["delta"]["partial_json"] + ) + assert "".join(deltas_by_index[0]) == '{"q":"a"}' + assert "".join(deltas_by_index[1]) == '{"q":"b"}' + + # Both blocks are explicitly closed before the final message_delta. + stops = [ + event["index"] + for event in events + if event["type"] == "content_block_stop" + ] + assert stops == [0, 1] + assert events[-2]["delta"]["stop_reason"] == "tool_use" + assert events[-1] == {"type": "message_stop"} + + +def test_anthropic_stream_state_closes_thinking_before_tool_use() -> None: + transformer = AnthropicTransformer() + state = transformer.create_stream_state({"model": "claude-test"}) + + events = state.process_chunk( + { + "choices": [ + { + "delta": { + "reasoning_content": "I should call a tool.", + "tool_calls": [ + { + "index": 0, + "id": "toolu-1", + "function": {"name": "lookup", "arguments": '{"q":"x"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"completion_tokens": 3}, + }, + is_first=True, + ) + events.extend(state.finalize()) + + indexed_events = [ + (event["type"], event.get("index"), event.get("delta", {}).get("type")) + for event in events + if event["type"].startswith("content_block") + ] + thinking_stop_pos = indexed_events.index(("content_block_stop", 0, None)) + tool_start_pos = next( + i + for i, event in enumerate(indexed_events) + if event == ("content_block_start", 1, None) + ) + + assert indexed_events[:4] == [ + ("content_block_start", 0, None), + ("content_block_delta", 0, "thinking_delta"), + ("content_block_delta", 0, "signature_delta"), + ("content_block_stop", 0, None), + ] + assert thinking_stop_pos < tool_start_pos diff --git a/tests/gateway/test_transform_google.py b/tests/gateway/test_transform_google.py index 6df2fcd38..2a91e1c65 100644 --- a/tests/gateway/test_transform_google.py +++ b/tests/gateway/test_transform_google.py @@ -19,7 +19,19 @@ def test_google_request_maps_all_fields_and_image_input_to_chat() -> None: "maxOutputTokens": 128, "temperature": 0.2, "topP": 0.9, + "topK": 40, + "candidateCount": 2, + "presencePenalty": 0.1, + "frequencyPenalty": 0.2, + "seed": 123, + "logprobs": 4, "stopSequences": ["END"], + "responseMimeType": "application/json", + "responseSchema": { + "type": "OBJECT", + "properties": {"answer": {"type": "STRING"}}, + "required": ["answer"], + }, }, "tools": [ { @@ -116,7 +128,24 @@ def test_google_request_maps_all_fields_and_image_input_to_chat() -> None: assert transformed["max_tokens"] == 128 assert transformed["temperature"] == 0.2 assert transformed["top_p"] == 0.9 + assert transformed["top_k"] == 40 + assert transformed["n"] == 2 + assert transformed["presence_penalty"] == 0.1 + assert transformed["frequency_penalty"] == 0.2 + assert transformed["seed"] == 123 + assert transformed["top_logprobs"] == 4 assert transformed["stop"] == ["END"] + assert transformed["response_format"] == { + "type": "json_schema", + "json_schema": { + "name": "google_response", + "schema": { + "type": "object", + "properties": {"answer": {"type": "string"}}, + "required": ["answer"], + }, + }, + } assert transformed["stream"] is True assert transformed["tools"] == [ { @@ -132,10 +161,166 @@ def test_google_request_maps_all_fields_and_image_input_to_chat() -> None: "type": "function", "function": {"name": "write_answer"}, } - assert transformed["logprobs"] is True assert transformed["chat_template_kwargs"]["enable_thinking"] is False +def test_google_request_maps_tool_choice_modes() -> None: + transformer = GoogleTransformer() + base_body = { + "contents": [{"role": "user", "parts": [{"text": "hi"}]}], + "tools": [ + { + "functionDeclarations": [ + {"name": "lookup", "parametersJsonSchema": {"type": "object"}} + ] + } + ], + } + + assert ( + transformer.transform_request( + { + **base_body, + "toolConfig": {"functionCallingConfig": {"mode": "NONE"}}, + } + )["tool_choice"] + == "none" + ) + assert ( + transformer.transform_request( + { + **base_body, + "toolConfig": {"functionCallingConfig": {"mode": "ANY"}}, + } + )["tool_choice"] + == "required" + ) + assert transformer.transform_request( + { + **base_body, + "toolConfig": { + "functionCallingConfig": { + "mode": "ANY", + "allowedFunctionNames": ["lookup"], + } + }, + } + )["tool_choice"] == {"type": "function", "function": {"name": "lookup"}} + + +def test_google_request_maps_system_instruction_and_system_content_role() -> None: + transformer = GoogleTransformer() + + transformed = transformer.transform_request( + { + "systemInstruction": "Top-level system.", + "contents": [ + {"role": "system", "parts": [{"text": "Inline system."}]}, + {"role": "user", "parts": [{"text": "Hi"}]}, + ], + } + ) + + assert transformed["messages"] == [ + {"role": "system", "content": "Top-level system.\n\nInline system."}, + {"role": "user", "content": "Hi"}, + ] + + +def test_google_request_maps_multi_turn_reasoning_and_parallel_tools() -> None: + transformer = GoogleTransformer() + + transformed = transformer.transform_request( + { + "contents": [ + {"role": "user", "parts": [{"text": "Plan and call tools."}]}, + { + "role": "model", + "parts": [ + { + "thought": True, + "text": "Need two independent lookups.", + "thoughtSignature": "sig-1", + }, + { + "functionCall": { + "id": "call-a", + "name": "lookup", + "args": {"q": "a"}, + } + }, + { + "functionCall": { + "id": "call-b", + "name": "lookup", + "args": {"q": "b"}, + } + }, + ], + }, + { + "role": "user", + "parts": [ + { + "functionResponse": { + "id": "call-a", + "name": "lookup", + "response": {"text": "A"}, + } + }, + { + "functionResponse": { + "id": "call-b", + "name": "lookup", + "response": {"text": "B"}, + } + }, + ], + }, + { + "role": "model", + "parts": [ + { + "thought": True, + "text": "Combine both results.", + "thoughtSignature": "sig-2", + }, + {"text": "A and B"}, + ], + }, + ] + } + ) + + assert transformed["messages"] == [ + {"role": "user", "content": "Plan and call tools."}, + { + "role": "assistant", + "content": "", + "reasoning_content": "Need two independent lookups.", + "tool_calls": [ + { + "id": "call-a", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "a"}'}, + }, + { + "id": "call-b", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "b"}'}, + }, + ], + }, + {"role": "tool", "tool_call_id": "call-a", "content": '{"text": "A"}'}, + {"role": "tool", "tool_call_id": "call-b", "content": '{"text": "B"}'}, + { + "role": "assistant", + "content": "A and B", + "reasoning_content": "Combine both results.", + }, + ] + + def test_google_response_maps_openai_content_tools_finish_and_usage_back() -> None: transformer = GoogleTransformer() @@ -297,3 +482,105 @@ def test_google_stream_state_emits_finish_only_text_event() -> None: "candidatesTokenCount": 2, "totalTokenCount": 6, } + + +def test_google_response_maps_extended_finish_reasons() -> None: + transformer = GoogleTransformer() + + stop_seq = transformer.transform_response( + { + "choices": [ + {"message": {"content": "x"}, "finish_reason": "stop_sequence"} + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }, + {}, + ) + assert stop_seq["candidates"][0]["finishReason"] == "STOP" + + unknown = transformer.transform_response( + { + "choices": [ + {"message": {"content": "x"}, "finish_reason": "weird_reason"} + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }, + {}, + ) + # Unknown finish reasons fall through to STOP rather than crashing. + assert unknown["candidates"][0]["finishReason"] == "STOP" + + +def test_google_response_preserves_cached_usage_tokens() -> None: + transformer = GoogleTransformer() + + response = transformer.transform_response( + { + "choices": [{"message": {"content": "x"}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 2, + "total_tokens": 12, + "prompt_tokens_details": {"cached_tokens": 5}, + }, + }, + {}, + ) + + assert response["usageMetadata"]["cachedContentTokenCount"] == 5 + + +def test_google_request_drops_server_side_tools() -> None: + transformer = GoogleTransformer() + transformed = transformer.transform_request( + { + "contents": [{"role": "user", "parts": [{"text": "search"}]}], + "tools": [ + { + "functionDeclarations": [ + {"name": "lookup", "parameters": {"type": "object"}} + ] + }, + {"googleSearch": {}}, + {"codeExecution": {}}, + {"urlContext": {}}, + ], + } + ) + + # googleSearch, codeExecution, urlContext are server-side built-ins + # without functionDeclarations; only the custom function survives. + assert transformed["tools"] == [ + { + "type": "function", + "function": {"name": "lookup", "parameters": {"type": "object"}}, + } + ] + + +def test_google_direct_stream_chunk_handles_reasoning_content() -> None: + transformer = GoogleTransformer() + + transformed = transformer.transform_stream_chunk( + { + "choices": [ + { + "index": 0, + "delta": { + "reasoning_content": "Think first.", + "content": "Then answer.", + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 4, "completion_tokens": 2, "total_tokens": 6}, + }, + {}, + is_first=True, + ) + + candidate = transformed["candidates"][0] + assert candidate["content"]["parts"][0]["thought"] is True + assert candidate["content"]["parts"][0]["text"] == "Think first." + assert candidate["content"]["parts"][1] == {"text": "Then answer."} + assert candidate["finishReason"] == "STOP" diff --git a/tests/gateway/test_transform_openai_chat.py b/tests/gateway/test_transform_openai_chat.py index 6a55b8997..362989827 100644 --- a/tests/gateway/test_transform_openai_chat.py +++ b/tests/gateway/test_transform_openai_chat.py @@ -45,6 +45,7 @@ def test_openai_chat_request_preserves_fields_and_image_content() -> None: assert "_polar_model_served" not in transformed assert transformed["model"] == "client-visible-model" + # `developer` role is normalized to `system` and merged with adjacent systems. assert transformed["messages"] == [ {"role": "system", "content": "Use short answers.\n\nBe precise."}, body["messages"][2], @@ -56,7 +57,6 @@ def test_openai_chat_request_preserves_fields_and_image_content() -> None: assert transformed["stop"] == ["END"] assert transformed["tools"] == body["tools"] assert transformed["tool_choice"] == "auto" - assert transformed["logprobs"] is True assert transformed["chat_template_kwargs"] == {"foo": "bar", "enable_thinking": False} @@ -82,3 +82,81 @@ def test_openai_chat_response_and_stream_preserve_requested_model() -> None: assert response == {**upstream, "model": "requested-model"} assert stream_chunk["model"] == "requested-model" assert stream_chunk["choices"][0]["delta"]["content"] == "D" + + +def test_openai_chat_request_aliases_max_completion_tokens() -> None: + transformer = OpenAIChatTransformer() + + transformed = transformer.transform_request( + { + "messages": [{"role": "user", "content": "answer as JSON"}], + "max_completion_tokens": 32, + "response_format": {"type": "json_object"}, + } + ) + + assert transformed["max_completion_tokens"] == 32 + assert transformed["max_tokens"] == 32 + assert transformed["response_format"] == {"type": "json_object"} + + +def test_openai_chat_merges_developer_role_for_non_qwen_models() -> None: + transformer = OpenAIChatTransformer() + + transformed = transformer.transform_request( + { + "_polar_model_served": "MiniMax-M2.5", + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "developer", "content": "Use short answers."}, + {"role": "system", "content": [{"type": "text", "text": "Be precise."}]}, + ], + } + ) + + assert transformed["messages"] == [ + {"role": "system", "content": "Use short answers.\n\nBe precise."}, + {"role": "user", "content": "hi"}, + ] + assert "chat_template_kwargs" not in transformed + + +def test_openai_chat_preserves_tool_turns_and_reasoning_content() -> None: + transformer = OpenAIChatTransformer() + body = { + "messages": [ + {"role": "system", "content": "Use tools when needed."}, + {"role": "user", "content": "lookup x"}, + { + "role": "assistant", + "content": None, + "reasoning_content": "I need the lookup tool.", + "tool_calls": [ + { + "id": "call-1", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q":"x"}'}, + } + ], + }, + {"role": "tool", "tool_call_id": "call-1", "content": "result"}, + {"role": "user", "content": "summarize"}, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "lookup", + "parameters": {"type": "object"}, + "strict": True, + }, + } + ], + "tool_choice": {"type": "function", "function": {"name": "lookup"}}, + } + + transformed = transformer.transform_request(body) + + assert transformed["messages"] == body["messages"] + assert transformed["tools"] == body["tools"] + assert transformed["tool_choice"] == body["tool_choice"] diff --git a/tests/gateway/test_transform_openai_responses.py b/tests/gateway/test_transform_openai_responses.py index 8773f9522..4338526ec 100644 --- a/tests/gateway/test_transform_openai_responses.py +++ b/tests/gateway/test_transform_openai_responses.py @@ -1,5 +1,6 @@ from __future__ import annotations +from polar.gateway.transform.reasoning import encrypt_reasoning from polar.gateway.transform.openai_responses import OpenAIResponsesTransformer IMAGE_URL = "data:image/png;base64,abc123" @@ -38,7 +39,21 @@ def test_responses_request_maps_all_fields_and_image_input_to_chat() -> None: "max_output_tokens": 128, "temperature": 0.2, "top_p": 0.9, + "top_logprobs": 3, + "parallel_tool_calls": False, "stream": True, + "text": { + "format": { + "type": "json_schema", + "name": "answer", + "schema": { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "required": ["count"], + }, + "strict": True, + } + }, "tool_choice": "auto", "tools": [ { @@ -83,7 +98,21 @@ def test_responses_request_maps_all_fields_and_image_input_to_chat() -> None: assert transformed["max_tokens"] == 128 assert transformed["temperature"] == 0.2 assert transformed["top_p"] == 0.9 + assert transformed["top_logprobs"] == 3 + assert transformed["parallel_tool_calls"] is False assert transformed["stream"] is True + assert transformed["response_format"] == { + "type": "json_schema", + "json_schema": { + "name": "answer", + "schema": { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "required": ["count"], + }, + "strict": True, + }, + } assert transformed["tools"] == [ { "type": "function", @@ -96,7 +125,6 @@ def test_responses_request_maps_all_fields_and_image_input_to_chat() -> None: } ] assert transformed["tool_choice"] == "auto" - assert transformed["logprobs"] is True assert transformed["chat_template_kwargs"]["enable_thinking"] is False @@ -167,7 +195,366 @@ def test_responses_request_drops_tool_choice_when_tools_are_empty() -> None: assert transformed["messages"] == [{"role": "user", "content": "hello"}] assert "tool_choice" not in transformed assert "tools" not in transformed - assert transformed["logprobs"] is True + + +def test_responses_request_converts_nested_function_schema_and_preserves_strict() -> None: + transformer = OpenAIResponsesTransformer() + + transformed = transformer.transform_request( + { + "input": "use a tool", + "tools": [ + { + "type": "function", + "name": "lookup", + "description": "Lookup data", + "input_schema": { + "jsonSchema": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + } + }, + "strict": True, + }, + {"type": "web_search"}, + ], + "tool_choice": "required", + } + ) + + assert transformed["tools"] == [ + { + "type": "function", + "function": { + "name": "lookup", + "description": "Lookup data", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + "strict": True, + }, + } + ] + assert transformed["tool_choice"] == "required" + + +def test_responses_request_normalizes_flat_function_tool_choice() -> None: + transformer = OpenAIResponsesTransformer() + + transformed = transformer.transform_request( + { + "input": "use a tool", + "tools": [{"type": "function", "name": "lookup", "parameters": {}}], + "tool_choice": {"type": "function", "name": "lookup"}, + } + ) + + assert transformed["tool_choice"] == { + "type": "function", + "function": {"name": "lookup"}, + } + + +def test_responses_text_format_text_is_omitted_for_sglang() -> None: + transformer = OpenAIResponsesTransformer() + + transformed = transformer.transform_request( + { + "input": "plain text", + "text": {"format": {"type": "text"}}, + } + ) + + assert "response_format" not in transformed + + +def test_responses_reasoning_effort_none_does_not_enable_thinking() -> None: + transformer = OpenAIResponsesTransformer() + + transformed = transformer.transform_request( + { + "_polar_model_served": "MiniMax-M2.5", + "input": "answer directly", + "reasoning": {"effort": "none"}, + } + ) + + assert "chat_template_kwargs" not in transformed + + +def test_responses_request_round_trips_local_shell_items() -> None: + transformer = OpenAIResponsesTransformer() + + transformed = transformer.transform_request( + { + "input": [ + {"type": "message", "role": "user", "content": "run tests"}, + { + "type": "local_shell_call", + "call_id": "call-shell", + "action": {"commands": ["pytest tests/gateway -q"], "timeout_ms": 1000}, + "status": "completed", + }, + { + "type": "local_shell_call_output", + "id": "call-shell", + "output": "ok", + "status": "completed", + }, + ], + } + ) + + assert transformed["messages"] == [ + {"role": "user", "content": "run tests"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call-shell", + "type": "function", + "function": { + "name": "shell", + "arguments": '{"cmd": "pytest tests/gateway -q", "timeout_ms": 1000}', + }, + } + ], + }, + {"role": "tool", "tool_call_id": "call-shell", "content": "ok"}, + ] + + +def test_responses_response_emits_local_shell_call_for_shell_function() -> None: + transformer = OpenAIResponsesTransformer() + + response = transformer.transform_response( + { + "id": "chatcmpl-1", + "choices": [ + { + "message": { + "content": None, + "tool_calls": [ + { + "id": "call-shell", + "function": { + "name": "shell", + "arguments": '{"cmd": "pytest tests/gateway -q"}', + }, + } + ], + } + } + ], + }, + {"model": "requested-model"}, + ) + + shell_call = response["output"][0] + assert shell_call["type"] == "local_shell_call" + assert shell_call["call_id"] == "call-shell" + assert shell_call["action"] == {"commands": ["pytest tests/gateway -q"]} + + +def test_responses_request_recovers_reasoning_from_encrypted_content_only() -> None: + transformer = OpenAIResponsesTransformer() + encrypted = encrypt_reasoning("Recovered private reasoning.") + + transformed = transformer.transform_request( + { + "input": [ + {"type": "message", "role": "user", "content": "continue"}, + { + "type": "reasoning", + "id": "rs_1", + "summary": [], + "content": [], + "encrypted_content": encrypted, + }, + {"type": "message", "role": "assistant", "content": "ok"}, + ] + } + ) + + assert transformed["messages"][1] == { + "role": "assistant", + "content": "ok", + "reasoning_content": "Recovered private reasoning.", + } + + +def test_responses_request_keeps_reasoning_with_each_tool_turn() -> None: + transformer = OpenAIResponsesTransformer() + + transformed = transformer.transform_request( + { + "input": [ + {"type": "message", "role": "user", "content": "do two steps"}, + { + "type": "reasoning", + "summary": [{"type": "summary_text", "text": "Plan first call."}], + }, + { + "type": "function_call", + "call_id": "call-a", + "name": "lookup", + "arguments": '{"step": 1}', + }, + { + "type": "function_call_output", + "call_id": "call-a", + "output": "first result", + }, + { + "type": "reasoning", + "summary": [{"type": "summary_text", "text": "Plan second call."}], + }, + { + "type": "function_call", + "call_id": "call-b", + "name": "lookup", + "arguments": '{"step": 2}', + }, + { + "type": "function_call_output", + "call_id": "call-b", + "output": "second result", + }, + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "done"}], + }, + ] + } + ) + + assert transformed["messages"] == [ + {"role": "user", "content": "do two steps"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call-a", + "type": "function", + "function": {"name": "lookup", "arguments": '{"step": 1}'}, + } + ], + "reasoning_content": "Plan first call.", + }, + {"role": "tool", "tool_call_id": "call-a", "content": "first result"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call-b", + "type": "function", + "function": {"name": "lookup", "arguments": '{"step": 2}'}, + } + ], + "reasoning_content": "Plan second call.", + }, + {"role": "tool", "tool_call_id": "call-b", "content": "second result"}, + {"role": "assistant", "content": "done"}, + ] + + +def test_responses_request_groups_parallel_function_calls_and_outputs() -> None: + transformer = OpenAIResponsesTransformer() + + transformed = transformer.transform_request( + { + "input": [ + {"type": "message", "role": "user", "content": "parallel"}, + { + "type": "function_call", + "call_id": "call-a", + "name": "lookup", + "arguments": '{"q": "a"}', + }, + { + "type": "function_call", + "call_id": "call-b", + "name": "lookup", + "arguments": '{"q": "b"}', + }, + { + "type": "function_call_output", + "call_id": "call-a", + "output": {"body": "A"}, + }, + { + "type": "function_call_output", + "call_id": "call-b", + "output": {"content": [{"type": "output_text", "text": "B"}]}, + }, + ] + } + ) + + assert transformed["messages"] == [ + {"role": "user", "content": "parallel"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call-a", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "a"}'}, + }, + { + "id": "call-b", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "b"}'}, + }, + ], + }, + {"role": "tool", "tool_call_id": "call-a", "content": "A"}, + {"role": "tool", "tool_call_id": "call-b", "content": "B"}, + ] + + +def test_responses_request_drops_server_side_tools() -> None: + transformer = OpenAIResponsesTransformer() + + transformed = transformer.transform_request( + { + "input": "search the web", + "tools": [ + {"type": "function", "name": "lookup", "parameters": {"type": "object"}}, + {"type": "web_search"}, + {"type": "file_search"}, + {"type": "computer_use_preview"}, + {"type": "mcp", "server_url": "https://example.test"}, + {"type": "code_interpreter"}, + {"type": "image_generation"}, + # Typed tool with a name should still be dropped — name must + # not be enough to override the type filter. + {"type": "computer_use", "name": "computer_use"}, + ], + "tool_choice": "auto", + } + ) + + # Only the custom function tool reaches SGLang. + assert transformed["tools"] == [ + { + "type": "function", + "function": { + "name": "lookup", + "description": "", + "parameters": {"type": "object"}, + }, + } + ] + assert transformed["tool_choice"] == "auto" def test_responses_response_maps_chat_result_back_to_response_shape() -> None: @@ -287,3 +674,133 @@ def test_responses_stream_state_emits_response_events_for_text_and_tools() -> No assert events[-1]["response"]["model"] == "requested-model" assert events[-1]["response"]["output"][0]["content"][0]["text"] == "Hello" assert events[-1]["response"]["output"][1]["arguments"] == '{"q": "x"}' + + +def test_responses_stream_state_emits_text_only_response() -> None: + transformer = OpenAIResponsesTransformer() + state = transformer.create_stream_state({"model": "requested-model"}) + + events = state.process_chunk( + {"choices": [{"delta": {"content": "Hel"}}]}, + is_first=True, + ) + events.extend( + state.process_chunk( + { + "choices": [{"delta": {"content": "lo."}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5}, + } + ) + ) + events.extend(state.finalize()) + + types = [event["type"] for event in events] + assert types[0] == "response.created" + # Exactly one message item is opened and closed; no reasoning, no tool items. + assert types.count("response.output_item.added") == 1 + assert types.count("response.output_item.done") == 1 + added_item_types = [ + event["item"]["type"] + for event in events + if event["type"] == "response.output_item.added" + ] + assert added_item_types == ["message"] + assert "response.reasoning_summary_text.delta" not in types + assert "response.function_call_arguments.delta" not in types + + completed = events[-1] + assert completed["type"] == "response.completed" + assert [item["type"] for item in completed["response"]["output"]] == ["message"] + assert completed["response"]["output"][0]["content"][0]["text"] == "Hello." + assert completed["response"]["usage"]["output_tokens"] == 2 + + +def test_responses_stream_state_emits_reasoning_only_response() -> None: + transformer = OpenAIResponsesTransformer() + state = transformer.create_stream_state({"model": "requested-model"}) + + events = state.process_chunk( + {"choices": [{"delta": {"reasoning_content": "Think "}}]}, + is_first=True, + ) + events.extend( + state.process_chunk( + { + "choices": [ + {"delta": {"reasoning_content": "harder."}, "finish_reason": "stop"} + ], + "usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5}, + } + ) + ) + events.extend(state.finalize()) + + types = [event["type"] for event in events] + assert types[0] == "response.created" + # Reasoning summary lifecycle is complete: added → delta(s) → done. + assert "response.reasoning_summary_part.added" in types + assert "response.reasoning_summary_text.delta" in types + assert "response.reasoning_summary_text.done" in types + assert "response.reasoning_summary_part.done" in types + # No text or tool events because the model never emitted content/tool_calls. + assert "response.output_text.delta" not in types + assert "response.function_call_arguments.delta" not in types + + added_item_types = [ + event["item"]["type"] + for event in events + if event["type"] == "response.output_item.added" + ] + assert added_item_types == ["reasoning"] + + completed = events[-1] + assert completed["type"] == "response.completed" + output = completed["response"]["output"] + assert [item["type"] for item in output] == ["reasoning"] + assert output[0]["summary"][0]["text"] == "Think harder." + assert output[0]["content"][0]["text"] == "Think harder." + assert output[0]["encrypted_content"] + + +def test_responses_stream_state_orders_reasoning_before_tool_without_text() -> None: + transformer = OpenAIResponsesTransformer() + state = transformer.create_stream_state({"model": "requested-model"}) + + events = state.process_chunk( + { + "choices": [ + { + "delta": { + "reasoning_content": "Need a lookup.", + "tool_calls": [ + { + "index": 0, + "id": "call-1", + "function": {"name": "lookup", "arguments": '{"q":"x"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 4, "completion_tokens": 2, "total_tokens": 6}, + }, + is_first=True, + ) + events.extend(state.finalize()) + + added = [ + (event["output_index"], event["item"]["type"]) + for event in events + if event["type"] == "response.output_item.added" + ] + done = [ + (event["output_index"], event["item"]["type"]) + for event in events + if event["type"] == "response.output_item.done" + ] + + assert added == [(0, "reasoning"), (1, "function_call")] + assert done[:2] == [(0, "reasoning"), (1, "function_call")] + assert events[-1]["response"]["output"][0]["type"] == "reasoning" + assert events[-1]["response"]["output"][1]["type"] == "function_call" diff --git a/tests/gateway/test_transform_reasoning.py b/tests/gateway/test_transform_reasoning.py new file mode 100644 index 000000000..01959c0fb --- /dev/null +++ b/tests/gateway/test_transform_reasoning.py @@ -0,0 +1,327 @@ +"""Round-trip tests for reasoning_content across all four transformers. + +Each test: +1. Builds a SGLang chat completion with `reasoning_content`. +2. Calls transform_response → expects API-specific reasoning block emitted. +3. Feeds the harness-shaped output back as input via transform_request. +4. Verifies reasoning_content survives the round trip. +""" + +from __future__ import annotations + +from polar.gateway.transform.anthropic import ( + AnthropicStreamState, + AnthropicTransformer, +) +from polar.gateway.transform.google import GoogleTransformer +from polar.gateway.transform.openai_chat import OpenAIChatTransformer +from polar.gateway.transform.openai_responses import ( + OpenAIResponsesTransformer, + ResponsesStreamState, +) + +REASONING_TEXT = "Let me think step by step. 1+1 must equal 2." +ANSWER_TEXT = "The answer is 2." + + +def _sglang_response(*, with_tool_call: bool = False) -> dict: + message: dict = { + "role": "assistant", + "content": ANSWER_TEXT, + "reasoning_content": REASONING_TEXT, + } + if with_tool_call: + message["content"] = None + message["tool_calls"] = [ + { + "id": "call_x", + "type": "function", + "function": {"name": "answer", "arguments": '{"x": 2}'}, + } + ] + return { + "id": "chatcmpl-1", + "model": "MiniMax-M2.5", + "choices": [{"index": 0, "message": message, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, + } + + +# ---------- OpenAI Chat: pure passthrough ---------- + + +def test_openai_chat_preserves_reasoning_content_on_response() -> None: + t = OpenAIChatTransformer() + out = t.transform_response(_sglang_response(), {"model": "anything"}) + assert out["choices"][0]["message"]["reasoning_content"] == REASONING_TEXT + + +def test_openai_chat_preserves_reasoning_content_on_request() -> None: + t = OpenAIChatTransformer() + req = { + "messages": [ + {"role": "user", "content": "compute"}, + { + "role": "assistant", + "content": ANSWER_TEXT, + "reasoning_content": REASONING_TEXT, + }, + {"role": "user", "content": "explain"}, + ], + } + out = t.transform_request(req) + assert out["messages"][1]["reasoning_content"] == REASONING_TEXT + + +# ---------- Anthropic: emit thinking blocks, ingest them back ---------- + + +def test_anthropic_response_emits_thinking_block_before_text() -> None: + t = AnthropicTransformer() + out = t.transform_response(_sglang_response(), {"model": "claude-3"}) + blocks = out["content"] + assert blocks[0]["type"] == "thinking" + assert blocks[0]["thinking"] == REASONING_TEXT + assert blocks[0]["signature"] # signature populated + assert any(b.get("type") == "text" for b in blocks) + + +def test_anthropic_response_emits_thinking_with_tool_use() -> None: + t = AnthropicTransformer() + out = t.transform_response( + _sglang_response(with_tool_call=True), {"model": "claude-3"} + ) + types = [b["type"] for b in out["content"]] + assert types[0] == "thinking" + assert "tool_use" in types + + +def test_anthropic_request_recovers_reasoning_from_thinking_block() -> None: + t = AnthropicTransformer() + req = { + "_polar_model_served": "MiniMax-M2.5", + "messages": [ + {"role": "user", "content": "compute 1+1"}, + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": REASONING_TEXT, + "signature": "sg_polar_xxx", + }, + {"type": "text", "text": ANSWER_TEXT}, + ], + }, + {"role": "user", "content": "ok next"}, + ], + "max_tokens": 100, + } + out = t.transform_request(req) + # Find the assistant message + assistant_msg = next(m for m in out["messages"] if m.get("role") == "assistant") + assert assistant_msg["reasoning_content"] == REASONING_TEXT + + +def test_anthropic_thinking_request_param_enables_thinking() -> None: + t = AnthropicTransformer() + out = t.transform_request( + { + "_polar_model_served": "MiniMax-M2.5", + "thinking": {"type": "enabled", "budget_tokens": 1024}, + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 10, + } + ) + assert out["chat_template_kwargs"]["enable_thinking"] is True + + +def test_anthropic_streaming_emits_thinking_signature() -> None: + state = AnthropicStreamState( + model="claude-3", finish_to_stop_reason=AnthropicTransformer.FINISH_TO_STOP_REASON + ) + chunk = { + "choices": [ + { + "delta": { + "role": "assistant", + "content": ANSWER_TEXT, + "reasoning_content": REASONING_TEXT, + }, + "finish_reason": "stop", + } + ], + "usage": {"completion_tokens": 7}, + } + events = state.process_chunk(chunk, is_first=True) + state.finalize() + types = [e.get("type") for e in events] + deltas = [e.get("delta", {}).get("type") for e in events if e.get("type") == "content_block_delta"] + assert "content_block_start" in types + assert "thinking_delta" in deltas + assert "signature_delta" in deltas + assert "text_delta" in deltas + + +# ---------- Gemini: emit thought parts, ingest them back ---------- + + +def test_gemini_response_emits_thought_part_before_text() -> None: + t = GoogleTransformer() + out = t.transform_response(_sglang_response(), {}) + parts = out["candidates"][0]["content"]["parts"] + assert parts[0].get("thought") is True + assert parts[0]["text"] == REASONING_TEXT + assert parts[0]["thoughtSignature"] + # User-facing text must NOT include reasoning text. + visible_text = next( + p for p in parts[1:] if isinstance(p, dict) and "text" in p and not p.get("thought") + ) + assert visible_text["text"] == ANSWER_TEXT + + +def test_gemini_request_recovers_reasoning_from_thought_part() -> None: + t = GoogleTransformer() + out = t.transform_request( + { + "_polar_model_served": "MiniMax-M2.5", + "contents": [ + {"role": "user", "parts": [{"text": "compute"}]}, + { + "role": "model", + "parts": [ + { + "thought": True, + "text": REASONING_TEXT, + "thoughtSignature": "sg_polar_xxx", + }, + {"text": ANSWER_TEXT}, + ], + }, + {"role": "user", "parts": [{"text": "ok"}]}, + ], + } + ) + assistant_msg = next(m for m in out["messages"] if m.get("role") == "assistant") + assert assistant_msg["reasoning_content"] == REASONING_TEXT + # Visible content must not leak the reasoning text. + assert REASONING_TEXT not in (assistant_msg.get("content") or "") + + +def test_gemini_thinking_config_enables_thinking() -> None: + t = GoogleTransformer() + out = t.transform_request( + { + "_polar_model_served": "MiniMax-M2.5", + "contents": [{"role": "user", "parts": [{"text": "hi"}]}], + "generationConfig": {"thinkingConfig": {"includeThoughts": True}}, + } + ) + assert out["chat_template_kwargs"]["enable_thinking"] is True + + +# ---------- Responses: emit reasoning output items, ingest them back ---------- + + +def test_responses_response_emits_reasoning_item_first() -> None: + t = OpenAIResponsesTransformer() + out = t.transform_response(_sglang_response(), {"model": "gpt-5.4"}) + items = out["output"] + assert items[0]["type"] == "reasoning" + assert items[0]["summary"][0]["text"] == REASONING_TEXT + assert items[0]["content"][0]["text"] == REASONING_TEXT + assert items[0]["encrypted_content"] + assert items[1]["type"] == "message" + + +def test_responses_request_recovers_reasoning_attached_to_assistant() -> None: + t = OpenAIResponsesTransformer() + out = t.transform_request( + { + "_polar_model_served": "MiniMax-M2.5", + "input": [ + {"type": "message", "role": "user", "content": "compute"}, + { + "type": "reasoning", + "id": "rs_1", + "summary": [{"type": "summary_text", "text": REASONING_TEXT}], + "content": [{"type": "reasoning_text", "text": REASONING_TEXT}], + "encrypted_content": "polar:LWE=", + }, + {"type": "message", "role": "assistant", "content": ANSWER_TEXT}, + {"type": "message", "role": "user", "content": "explain"}, + ], + } + ) + assistant_msg = next(m for m in out["messages"] if m.get("role") == "assistant") + assert assistant_msg["reasoning_content"] == REASONING_TEXT + + +def test_responses_request_attaches_reasoning_to_function_call() -> None: + t = OpenAIResponsesTransformer() + out = t.transform_request( + { + "_polar_model_served": "MiniMax-M2.5", + "input": [ + {"type": "message", "role": "user", "content": "use tool"}, + { + "type": "reasoning", + "id": "rs_1", + "summary": [{"type": "summary_text", "text": REASONING_TEXT}], + }, + { + "type": "function_call", + "call_id": "call_1", + "name": "answer", + "arguments": '{"x": 2}', + }, + ], + } + ) + assistant_msg = next(m for m in out["messages"] if m.get("role") == "assistant") + assert assistant_msg.get("reasoning_content") == REASONING_TEXT + assert assistant_msg["tool_calls"][0]["function"]["name"] == "answer" + + +def test_responses_reasoning_request_param_enables_thinking() -> None: + t = OpenAIResponsesTransformer() + out = t.transform_request( + { + "_polar_model_served": "MiniMax-M2.5", + "input": "compute", + "reasoning": {"effort": "medium"}, + } + ) + assert out["chat_template_kwargs"]["enable_thinking"] is True + + +def test_responses_streaming_emits_reasoning_item_events() -> None: + state = ResponsesStreamState(model="gpt-5.4") + chunk = { + "choices": [ + { + "delta": { + "role": "assistant", + "content": ANSWER_TEXT, + "reasoning_content": REASONING_TEXT, + }, + "finish_reason": "stop", + } + ] + } + events = state.process_chunk(chunk, is_first=True) + state.finalize() + types = [e.get("type") for e in events] + assert "response.reasoning_summary_text.delta" in types + assert "response.reasoning_summary_text.done" in types + # output_item.added/done for reasoning + added_items = [ + e["item"]["type"] + for e in events + if e.get("type") == "response.output_item.added" + ] + assert "reasoning" in added_items + assert "message" in added_items + # Final response.completed includes both items. + completed = next(e for e in events if e.get("type") == "response.completed") + output_types = [it["type"] for it in completed["response"]["output"]] + assert output_types[0] == "reasoning" + assert "message" in output_types diff --git a/tests/platform/test_api.py b/tests/platform/test_api.py index 1b37538f0..70f0c13b8 100644 --- a/tests/platform/test_api.py +++ b/tests/platform/test_api.py @@ -66,7 +66,8 @@ def topology_with_results(tmp_path: Path) -> Path: max_run_workers: 4 max_postrun_workers: 4 model_served: Qwen/Qwen3.5-4B - sglang: + inference: + engine: sglang base_url: http://127.0.0.1:9000 """.strip() ) diff --git a/tests/slime_bridge/test_adapter.py b/tests/slime_bridge/test_adapter.py index 0bf60c860..b867b1ccf 100644 --- a/tests/slime_bridge/test_adapter.py +++ b/tests/slime_bridge/test_adapter.py @@ -46,10 +46,7 @@ def test_session_result_to_samples_converts_trace_to_slime_like_sample(monkeypat loss_mask=[1, 0], prompt_messages=[{"role": "user", "content": "Say hi"}], response_messages=[{"role": "assistant", "content": "Hi"}], - response_logprobs=[ - {"token_id": 3, "logprob": -0.1}, - {"token_id": 4, "logprob": -0.2}, - ], + response_logprobs=[-0.1, -0.2], reward=1.0, metadata={"group_id": "group-1"}, ) diff --git a/tests/slime_bridge/test_config.py b/tests/slime_bridge/test_config.py index 534d7e321..3a8b6c8bf 100644 --- a/tests/slime_bridge/test_config.py +++ b/tests/slime_bridge/test_config.py @@ -7,6 +7,7 @@ from slime_bridge.config import ( render_instruction, render_task_payload, + render_topology_template, resolve_polar_slime_config, resolve_sglang_router_base_url, ) @@ -119,3 +120,24 @@ def test_render_instruction_uses_optional_template() -> None: def test_resolve_sglang_router_base_url_requires_both_ip_and_port() -> None: assert resolve_sglang_router_base_url(_args()) == "http://127.0.0.1:30000" assert resolve_sglang_router_base_url(_args(sglang_router_port=None)) is None + + +def test_render_topology_template_emits_inference_block(tmp_path) -> None: + topology_path = tmp_path / "topology.yaml" + topology_path.write_text( + """ +rollout: {host: 127.0.0.1, port: 8080, public_url: http://127.0.0.1:8080} +gateway: + nodes: + - id: n1 + host: 127.0.0.1 + port: 8100 + public_url: http://127.0.0.1:8100 + model_served: Qwen/Qwen3.5-4B + inference: {engine: sglang, base_url: http://127.0.0.1:8000} +""".strip() + ) + rendered = render_topology_template(str(topology_path), _args()) + node = rendered["gateway"]["nodes"][0] + assert node["inference"] == {"engine": "sglang", "base_url": "http://127.0.0.1:30000"} + assert "sglang" not in node diff --git a/tests/trajectory/test_engine_trajectory_equivalence.py b/tests/trajectory/test_engine_trajectory_equivalence.py new file mode 100644 index 000000000..3bf013e35 --- /dev/null +++ b/tests/trajectory/test_engine_trajectory_equivalence.py @@ -0,0 +1,229 @@ +"""Both backends must produce the same trajectory from the same generation. + +SGLang (patched) and vLLM (`return_token_ids` + `VLLMEngine.normalize_response`) +expose the training fields in different response shapes. These tests pin that +the two shapes collapse to byte-identical ``Trace`` objects through the real +builders, so downstream training sees one trajectory regardless of engine. +""" + +from __future__ import annotations + +import asyncio + +from polar.gateway.engine import VLLMEngine +from polar.trajectory.builder.per_request import PerRequestBuilder +from polar.trajectory.builder.prefix_merging import PrefixMergingBuilder +from polar.trajectory.builder.record_utils import build_trace_from_completion +from polar.trajectory.models import CompletionRecord, CompletionSession + +_EOT = 99 # synthetic end-of-turn token id + + +def _sglang_record( + completion_id: str, + prompt_ids: list[int], + response_ids: list[int], + logprobs: list[float], + *, + content: str, + reasoning: str | None, + finish_reason: str, + prompt_messages: list[dict], + response_message: dict, +) -> CompletionRecord: + """Canonical SGLang shape: input_token_ids on the choice, token_id in logprobs.""" + message = {"role": "assistant", "content": content, **response_message} + if reasoning is not None: + message["reasoning_content"] = reasoning + response = { + "choices": [ + { + "input_token_ids": list(prompt_ids), + "message": message, + "finish_reason": finish_reason, + "logprobs": { + "content": [ + {"token": f"t{tid}", "token_id": tid, "logprob": lp, "bytes": []} + for tid, lp in zip(response_ids, logprobs) + ] + }, + } + ] + } + return CompletionRecord( + completion_id=completion_id, + request={"messages": prompt_messages}, + response=response, + ) + + +def _vllm_record( + completion_id: str, + prompt_ids: list[int], + response_ids: list[int], + logprobs: list[float], + *, + content: str, + reasoning: str | None, + finish_reason: str, + prompt_messages: list[dict], + response_message: dict, +) -> CompletionRecord: + """Native vLLM shape, passed through the gateway's normalize_response.""" + message = {"role": "assistant", "content": content, **response_message} + if reasoning is not None: + message["reasoning"] = reasoning # vLLM names it `reasoning` + response = { + "prompt_token_ids": list(prompt_ids), # top-level in vLLM + "choices": [ + { + "token_ids": list(response_ids), # on the choice in vLLM + "message": message, + "finish_reason": finish_reason, + "logprobs": { + "content": [ + {"token": f"t{tid}", "logprob": lp, "bytes": []} + for tid, lp in zip(response_ids, logprobs) + ] + }, + } + ], + } + response = VLLMEngine().normalize_response(response) + return CompletionRecord( + completion_id=completion_id, + request={"messages": prompt_messages}, + response=response, + ) + + +def _assert_traces_equal(a, b) -> None: + assert a.prompt_ids == b.prompt_ids + assert a.response_ids == b.response_ids + assert a.loss_mask == b.loss_mask + assert a.finish_reason == b.finish_reason + assert a.response_messages == b.response_messages + assert a.response_logprobs == b.response_logprobs + + +def test_single_turn_trace_is_identical_across_engines() -> None: + common = dict( + prompt_ids=[1, 2, 3], + response_ids=[10, 11, 12, 13], + logprobs=[-0.1, -0.2, -0.3, -0.4], + content="4", + reasoning="thinking", + finish_reason="stop", + prompt_messages=[{"role": "user", "content": "2+2?"}], + response_message={}, + ) + sg = build_trace_from_completion(_sglang_record("c1", **common)) + vllm = build_trace_from_completion(_vllm_record("c1", **common)) + + _assert_traces_equal(sg, vllm) + # And the actual values are the correct ones. + assert vllm.prompt_ids == [1, 2, 3] + assert vllm.response_ids == [10, 11, 12, 13] + assert vllm.loss_mask == [1, 1, 1, 1] + assert vllm.response_messages[0]["reasoning_content"] == "thinking" + assert vllm.response_logprobs == [-0.1, -0.2, -0.3, -0.4] + + +def test_per_request_builder_is_identical_across_engines() -> None: + common = dict( + prompt_ids=[1, 2, 3], + response_ids=[10, 11, 12, 13], + logprobs=[-0.1, -0.2, -0.3, -0.4], + content="4", + reasoning=None, + finish_reason="stop", + prompt_messages=[{"role": "user", "content": "2+2?"}], + response_message={}, + ) + sg_traj = asyncio.run( + PerRequestBuilder().build( + CompletionSession(session_id="s", completions=[_sglang_record("c1", **common)]) + ) + ) + vllm_traj = asyncio.run( + PerRequestBuilder().build( + CompletionSession(session_id="s", completions=[_vllm_record("c1", **common)]) + ) + ) + _assert_traces_equal(sg_traj.traces[0], vllm_traj.traces[0]) + + +def test_adapter_rollout_log_probs_are_identical_across_engines() -> None: + from slime_bridge.adapter import _extract_rollout_log_probs + + common = dict( + prompt_ids=[1, 2, 3], + response_ids=[10, 11, 12, 13], + logprobs=[-0.1, -0.2, -0.3, -0.4], + content="4", + reasoning="thinking", + finish_reason="stop", + prompt_messages=[{"role": "user", "content": "2+2?"}], + response_message={}, + ) + sg = build_trace_from_completion(_sglang_record("c1", **common)) + vllm = build_trace_from_completion(_vllm_record("c1", **common)) + + kwargs = dict( + response_len=4, + loss_mask=[1, 1, 1, 1], + require_trainable_logprobs=True, + session_id="s", + trace_index=0, + ) + sg_lp = _extract_rollout_log_probs(sg, **kwargs) + vllm_lp = _extract_rollout_log_probs(vllm, **kwargs) + assert sg_lp == vllm_lp == [-0.1, -0.2, -0.3, -0.4] + + +def _two_turn_chain(record_fn) -> list[CompletionRecord]: + """A valid 2-completion agent chain (C2.prompt == C1.prompt + C1.resp + tool).""" + q1 = {"role": "user", "content": "Q1"} + a1 = {"role": "assistant", "content": "A1"} + tool = {"role": "tool", "content": "result"} # interstitial, dropped by grouping + c1 = record_fn( + "c1", + prompt_ids=[1, 2, 3], + response_ids=[10, 11, _EOT], + logprobs=[-0.1, -0.2, -0.3], + content="A1", + reasoning=None, + finish_reason="stop", + prompt_messages=[q1], + response_message={}, + ) + # canonical_tail = [10, 11, _EOT, 50, 51] -> interstitial after _EOT = [50, 51] + c2 = record_fn( + "c2", + prompt_ids=[1, 2, 3, 10, 11, _EOT, 50, 51], + response_ids=[20, 21, _EOT], + logprobs=[-0.5, -0.6, -0.7], + content="A2", + reasoning=None, + finish_reason="stop", + prompt_messages=[q1, a1, tool], + response_message={}, + ) + return [c1, c2] + + +def test_prefix_merging_chain_is_identical_across_engines() -> None: + builder = PrefixMergingBuilder(end_of_turn_token_id=_EOT) + sg = asyncio.run( + builder.build(CompletionSession(session_id="s", completions=_two_turn_chain(_sglang_record))) + ) + vllm = asyncio.run( + builder.build(CompletionSession(session_id="s", completions=_two_turn_chain(_vllm_record))) + ) + + assert len(sg.traces) == len(vllm.traces) == 1 + _assert_traces_equal(sg.traces[0], vllm.traces[0]) + # The merged stream is the prompt + raw responses + canonical interstitial, + # with interstitial tokens masked out. + assert vllm.traces[0].response_ids == [10, 11, _EOT, 50, 51, 20, 21, _EOT] + assert vllm.traces[0].loss_mask == [1, 1, 1, 0, 0, 1, 1, 1] diff --git a/tests/trajectory/test_per_request_builder.py b/tests/trajectory/test_per_request_builder.py index f667dbe37..d59bbdf5c 100644 --- a/tests/trajectory/test_per_request_builder.py +++ b/tests/trajectory/test_per_request_builder.py @@ -71,7 +71,4 @@ def test_per_request_builder_emits_one_trace_per_completion() -> None: assert trace.prompt_messages == [{"role": "user", "content": "Say hi"}] assert trace.response_messages == [{"role": "assistant", "content": "Hi"}] assert trace.tools == [{"type": "function", "function": {"name": "lookup"}}] - assert trace.response_logprobs == [ - {"token_id": 3, "logprob": -0.1}, - {"token_id": 4, "logprob": -0.2}, - ] + assert trace.response_logprobs == [-0.1, -0.2] diff --git a/web/src/api/types.ts b/web/src/api/types.ts index bf9d02cb3..2b58ba19d 100644 --- a/web/src/api/types.ts +++ b/web/src/api/types.ts @@ -7,6 +7,8 @@ export interface TaskSummary { completed_sessions: number; errored_sessions?: number; mean_reward?: number | null; + mean_traces?: number | null; + mean_completions?: number | null; created_at?: number | null; updated_at?: number | null; save_dir_path?: string; @@ -93,7 +95,8 @@ export interface TopologyPayload { port: number; gateway_url: string; model_served: string; - sglang_base_url: string; + engine: string; + inference_base_url: string; max_init_workers: number; max_run_workers: number; max_postrun_workers: number; diff --git a/web/src/components/CompletionDiff.tsx b/web/src/components/CompletionDiff.tsx index 79ccc052a..b9780e1f6 100644 --- a/web/src/components/CompletionDiff.tsx +++ b/web/src/components/CompletionDiff.tsx @@ -46,8 +46,8 @@ export function CompletionDiff({ completion }: Props) { {open && (
- - + +
)} diff --git a/web/src/components/TopologyGraph.tsx b/web/src/components/TopologyGraph.tsx index fcb92166a..634919bad 100644 --- a/web/src/components/TopologyGraph.tsx +++ b/web/src/components/TopologyGraph.tsx @@ -81,8 +81,8 @@ export function TopologyGraph({ topology, isLoading }: Props) {
- SGLang - {gw.sglang_base_url} + {(gw.engine || "engine").toUpperCase()} + {gw.inference_base_url}
model diff --git a/web/src/routes/Dashboard.tsx b/web/src/routes/Dashboard.tsx index 0ad6093c6..89c0d1d9d 100644 --- a/web/src/routes/Dashboard.tsx +++ b/web/src/routes/Dashboard.tsx @@ -45,7 +45,8 @@ export function Dashboard() { task_id status harness - reward + reward (avg) + trace / completion progress updated @@ -65,6 +66,11 @@ export function Dashboard() { {formatReward(task.mean_reward)} + + {task.mean_traces != null ? task.mean_traces.toFixed(1) : "—"} + {" / "} + {task.mean_completions != null ? task.mean_completions.toFixed(1) : "—"} + {task.completed_sessions}/{task.num_samples}