diff --git a/.gitignore b/.gitignore index 98d674f..072b00a 100644 --- a/.gitignore +++ b/.gitignore @@ -77,6 +77,11 @@ finetune/eval/results_full/*/ft_chemqa.jsonl finetune/eval/results_full/*/*.log finetune/eval/results_full/*/*.jsonl +# Phase 4 grounding-audit outputs (sample, judged claims, summaries, +# logs). Reproducible from the orchestrator + a fresh OpenRouter run; +# never pushed. +phase4_grounding/outputs/ + # OS .DS_Store Thumbs.db diff --git a/DATASHEET.md b/DATASHEET.md index 08ebc79..5f0724f 100644 --- a/DATASHEET.md +++ b/DATASHEET.md @@ -73,7 +73,12 @@ No pre-defined split. Recommended: - Held-out human-annotated: scheduled for Round 2. **Are there any errors, sources of noise, redundancies?** -Yes, many — see `LIMITATIONS.md` in full. +Yes, many — see `LIMITATIONS.md` in full. The Phase 4 grounding audit +(`phase4_grounding/RESULTS.md`) measures **55.20% of claims as +UNSUPPORTED** by the cited evidence (95% CI 53.4–57.0%, n=3,076 claims +across 300 Q&As, keep-structural view) — i.e. the gold subset contains +substantial training-recall content. Engineering, ADME, and metabolism +topics carry the worst grounding. **Does the dataset rely on external resources?** Pipeline inputs are PubMed baseline XML, PMC open-access full text, and diff --git a/LIMITATIONS.md b/LIMITATIONS.md index c033122..23188ef 100644 --- a/LIMITATIONS.md +++ b/LIMITATIONS.md @@ -35,15 +35,42 @@ including the three models in this pipeline. Consequences: See `CONTAMINATION.md` for the proposed canary-based validation methodology. -## 3. The "soft rule" permits training-recall +## 3. The "soft rule" permits training-recall — measured Phase 1 and Phase 2 system prompts explicitly allow functional claims "supported by the evidence ... used silently as background knowledge". This phrasing admits recall from pretraining. There is no mechanism at generation time to distinguish a claim supported by a specific evidence -sentence from one the model would have made anyway. A Phase 4 grounding -check (LLM-based claim-to-evidence alignment scoring) is scheduled for -Round 2 but is not yet implemented. +sentence from one the model would have made anyway. + +A Phase 4 grounding audit decomposes Phase-2 answers claim-by-claim and +labels each claim as STATED, IMPLIED, STRUCTURAL (derivable from SMILES +alone), or UNSUPPORTED (training-recall candidate). Headline result on +a 300-Q&A stratified sample (3,076 claims): + +| View | UNSUPPORTED | 95% Wilson CI | +|---|---|---| +| keep-structural (clean training-recall proxy) | **55.20%** | 53.43–56.97% | +| drop-structural (PLAN-spec view) | **44.70%** | 43.11–46.30% | + +Both views are well above the 20% threshold that would have permitted a +narrow grounding claim. **The paper's grounding language is therefore +narrowed, and training-recall risk is flagged in `RESPONSIBLE_AI.md`.** + +Engineering / design / metabolism topics carry the worst grounding (75 / +67 / 69% UNSUPPORTED). Q&As *with* evidence attached are *more* +UNSUPPORTED than those without, consistent with the model elaborating +beyond what the evidence sentence states. + +Cross-check validation: an independent judge (`google/gemini-2.5-pro`) +re-judged 30 of the 300 Q&As; macro UNSUPPORTED rates differ by only ++3.68pp from the primary judge, with 26/30 per-row rates agreeing to +within 20pp. The headline is robust to judge choice. + +Full results, per-topic / per-split breakdowns, methodology, and a note +on dual-use refusals during judging are in +`phase4_grounding/RESULTS.md`. Code and orchestrator are in +`phase4_grounding/`. ## 4. Compound coverage is biased toward well-studied drugs diff --git a/RESPONSIBLE_AI.md b/RESPONSIBLE_AI.md index aa95dc2..0fcb071 100644 --- a/RESPONSIBLE_AI.md +++ b/RESPONSIBLE_AI.md @@ -99,16 +99,29 @@ model as clinically validated.** (`scripts/audit_redaction.py`). - Coverage-analysis script quantifies therapeutic-area and molecular-property skew (`scripts/analyze_coverage.py`). +- **Phase 4 grounding audit** measures the rate at which Phase-2 + answers contain claims not traceable to the cited evidence + (`phase4_grounding/RESULTS.md`). Headline: **55.20% UNSUPPORTED** in + the keep-structural view (95% CI 53.4–57.0%) on a 300-Q&A sample; + cross-validated by an independent judge to within +3.7pp. This + empirically substantiates Misuse Risk #2 (fabricated mechanisms) and + is the basis for narrowing the paper's grounding claim. Engineering / + design / metabolism Q&As carry the highest training-recall risk. +- **Dual-use refusal protocol** — the Phase 4 audit revealed that + `claude-sonnet-4.6` refuses to judge ~3% of dual-use chemistry Q&As + (toxin engineering, pesticide modifications, controlled-substance + analog reasoning). Falling back to `gemini-2.5-pro` recovers all of + them. Audits run on a single model will systematically miss this + topic; reproducers should use a heterogeneous-judge protocol. ### Deferred to Round 2 - Human-evaluated accuracy on a safety-critical-claim sub-sample (scheduled). -- Phase 4 grounding check that verifies each functional claim is - traceable to an evidence sentence (proposed; requires additional LLM - compute). - RAI review of the engineering-question category for synthesis-uplift - risk (proposed). + risk (proposed). The Phase 4 audit measured engineering Q&As at 74.6% + UNSUPPORTED, the worst of any topic — a strong prior for prioritizing + this review. - Dataset card fields per the Croissant RAI schema (skeleton provided in `croissant.json`; full population after full-run execution). diff --git a/croissant.json b/croissant.json index ed3186f..55d0c37 100644 --- a/croissant.json +++ b/croissant.json @@ -51,13 +51,14 @@ "rai:dataImputationProtocol": "Missing molecular_formula / molecular_weight from CID-Mass.gz are left null; compounds with zero matching evidence sentences after redaction are dropped (30.2% of premium-tier compounds).", "rai:dataPreprocessingProtocol": "Compound-name redaction to [COMPOUND] using longest-match-first synonym regex. Sentence-level dedup by redacted text. Random sampling (per-CID-seeded) to cap at 500 sentences per compound.", "rai:dataManipulationProtocol": "Four-phase LLM pipeline: Phase 1 generation, Phase 2 blind re-answer, Phase 3 agreement judge. All prompts versioned in the source tree at chem2textqa/qa_pipeline/phase_*/ .", - "rai:dataSocialImpact": "Drug-related training data with public-health implications. A model fine-tuned on this data could produce plausible-but-incorrect clinical claims. See RESPONSIBLE_AI.md for intended use, misuse risks, and mitigations.", + "rai:dataSocialImpact": "Drug-related training data with public-health implications. A Phase 4 grounding audit on a 300-Q&A stratified sample measured 55.20% of claims as UNSUPPORTED by the cited evidence (95% Wilson CI 53.4–57.0%, n=3076 claims; cross-validated by an independent judge to within +3.7pp). A model fine-tuned on this data could produce plausible-but-incorrect clinical claims, and the published gold subset contains substantial training-recall content. See RESPONSIBLE_AI.md and phase4_grounding/RESULTS.md for intended use, misuse risks, mitigations, and the full audit.", "rai:dataBiases": [ "Therapeutic-area bias: oncology / CV / CNS over-represented", "Approval-status bias: FDA-approved drugs vs research chemicals", "Publication bias: English-language biomedical literature", "Model-consensus bias: 'gold' labels reflect what two LLMs agree on, not ground truth", - "Training-data overlap: evidence sentences are likely in LLM pretraining corpora" + "Training-data overlap: evidence sentences are likely in LLM pretraining corpora", + "Training-recall content (measured): 55.20% of Phase-2 answer claims are not traceable to the cited evidence (Phase 4 audit, n=3076 claims). Engineering / ADME / metabolism Q&As are worst-grounded (>67% UNSUPPORTED); mechanism / therapeutic-use / toxicity are better-grounded (<41%). See phase4_grounding/RESULTS.md." ], "rai:dataUseCases": [ "Intended: instruction tuning for medicinal-chemistry LLM research.", diff --git a/phase4_grounding/PLAN.md b/phase4_grounding/PLAN.md new file mode 100644 index 0000000..8442521 --- /dev/null +++ b/phase4_grounding/PLAN.md @@ -0,0 +1,242 @@ +# Phase 4 — Grounding Audit Plan + +## Goal +For each functional Q&A, decompose the `phase2_answer` into **atomic claims**; label each claim as +`STATED` / `IMPLIED` / `UNSUPPORTED` (and optionally `STRUCTURAL`) against the parent compound's +evidence sentences. Aggregate to measure how often the soft rule allows training-recall claims to +slip into the dataset. + +This audit addresses C1 (claim grounding) and C3 (training-recall risk) directly. + +## Dataset facts (verified) +- `data/dataset_gold.jsonl`: **15,509 compounds, 188,541 Q&A pairs** (train 10,820 / val 2,340 / test 2,349). +- Each compound has `evidence_sentences` with fields `{id, pmid, source, text}`. Per-QA `evidence_ids` is mostly empty (the soft rule). +- Functional topic counts (top): engineering 21,904 · mechanism 17,669 · therapeutic_use 14,691 · metabolism 12,236 · toxicity 11,114 · adme 8,850 · design_levers 4,860 · drug_interactions 3,944 · prodrug_activation · resistance_mechanism · etc. +- `scripts/topic_bucket.py::bucket_topic` already classifies a topic as `structural` / `functional` / `other` — reuse it. +- The audit target is `phase2_answer` (the long, soft-rule answer). Phase 1 answer and judge reasoning are **not** shown to the judge, to keep the test honest. + +## Output layout +All code under `phase4_grounding/`. Outputs co-located so the deliverable is one folder. Code is organized as a small package with a clear separation between **library modules** (testable) and **entry-point scripts** (each script has a `main()` that parses args and delegates to the library). +``` +phase4_grounding/ +├── PLAN.md # this file +├── __init__.py +├── grounding/ # library package — pure modules, no I/O at import +│ ├── __init__.py +│ ├── sampling.py # Sampler class — stratified sampler +│ ├── evidence.py # EvidenceAttacher — selects/numbers evidence per QA +│ ├── prompt.py # PromptBuilder — renders claim_decomp prompt +│ ├── openrouter_client.py # OpenRouterClient — async, retries, cost tracker +│ ├── judge.py # ClaimJudge — orchestrates client + prompt + parsing +│ ├── parser.py # ClaimParser — strict JSON parse + schema validate +│ ├── aggregator.py # Aggregator — both keep/drop views, Wilson CI +│ ├── reporter.py # Reporter — markdown writers +│ └── models.py # @dataclass: Claim, JudgedQA, SampleRow, ViewConfig +├── prompts/ +│ └── claim_decomp.txt +├── scripts/ # entry-point scripts, each with main() +│ ├── sample_qa.py +│ ├── judge_claims.py +│ └── aggregate.py +├── tests/ # pytest, mirrors grounding/ layout +│ ├── conftest.py # fixtures: tiny dataset, fake OpenRouter client +│ ├── test_sampling.py +│ ├── test_evidence.py +│ ├── test_prompt.py +│ ├── test_parser.py +│ ├── test_judge.py # uses fake client (no network) +│ ├── test_aggregator.py +│ ├── test_reporter.py +│ └── data/ +│ └── tiny_dataset.jsonl # 5–10 hand-crafted compounds for deterministic tests +├── run_phase4_grounding.sh # orchestrator +└── outputs/ + ├── sample.jsonl # sampled QA + attached evidence bundle + ├── claims_per_qa.jsonl # primary judge output + ├── claims_per_qa.gemini.jsonl # cross-check subset (optional) + ├── claims_per_qa.errors.jsonl # rows where JSON parse failed twice + ├── grounding_summary_keep_structural.md # Option A view + └── grounding_summary_drop_structural.md # Option B view +``` + +### Module boundaries +- **`models.py`** — typed dataclasses passed between modules; the library's contract surface. +- **`sampling.py`** — `Sampler` class: `__init__(dataset_path, seed)`, `sample(n, weights) -> list[SampleRow]`. Pure; no network. +- **`evidence.py`** — `EvidenceAttacher.attach(qa, compound) -> list[EvidenceItem]`. Pure. +- **`prompt.py`** — `PromptBuilder.build(sample_row) -> str`. Loads `claim_decomp.txt` once. +- **`openrouter_client.py`** — `OpenRouterClient(api_key, concurrency, max_usd)`; `async chat(model, prompt, **kw) -> ChatResult`. Has a `FakeOpenRouterClient` subclass in tests/conftest for deterministic outputs. +- **`parser.py`** — `ClaimParser.parse(raw: str, attached_ids: set[int]) -> ParseResult`. Strict JSON + schema check + cross-ref of `evidence_id` against attached ids. +- **`judge.py`** — `ClaimJudge(client, prompt_builder, parser)`; `async judge(sample_row, model) -> JudgedQA`. One retry on parse failure. +- **`aggregator.py`** — `Aggregator(judged_qas)`; `.compute(view: Literal["keep","drop"]) -> ViewMetrics`. Wilson CI in a small helper. +- **`reporter.py`** — `Reporter(metrics_keep, metrics_drop, judged_qas)`; `.write(out_dir)` writes both summary files. + +### Entry-point scripts +Each script in `scripts/` follows the same shape: +```python +def main(argv: list[str] | None = None) -> int: + args = _build_parser().parse_args(argv) + # construct library objects, call them, write outputs + return 0 + +if __name__ == "__main__": + raise SystemExit(main()) +``` +This makes scripts importable in tests (`from phase4_grounding.scripts.sample_qa import main`) and exits with proper return codes for CI. + +## CLI parameters (all user-overridable) +Each script accepts: + +| Flag | Default | Notes | +|---|---|---| +| `--n` | required for `sample_qa.py` | total sample size; user supplies after manager approval | +| `--api-key-file` | `~/.openrouter_key` | path to text file containing the OpenRouter API key | +| `--primary-model` | `anthropic/claude-sonnet-4.6` | OpenRouter model id for primary judge | +| `--cross-check-model` | `google/gemini-2.5-pro` | only used if cross-check enabled | +| `--skip-cross-check` | `false` | when set, the second-model run is skipped entirely | +| `--cross-check-n` | `30` | size of the overlap subset for the second model | +| `--seed` | `0` | RNG seed for sampling | +| `--max-usd` | `15` | abort runner if cumulative OpenRouter spend exceeds this | +| `--concurrency` | `5` | async semaphore for OpenRouter calls | +| `--data-path` | `data/dataset_gold.jsonl` | input dataset | +| `--out-dir` | `phase4_grounding/outputs` | output directory | + +The API key is read **only** from the file at `--api-key-file`; never logged, never echoed. + +## Step 1 — Stratified sample (`sample_qa.py --n N`) +- Filter QA where `bucket_topic(topic) == 'functional'`. +- Strata = topic. Allocate the `N` slots **proportionally within the functional pool** but with a fixed weighting that emphasizes the four headline topics: + - `mechanism : metabolism : toxicity : engineering : therapeutic_use : (others)` = `0.20 : 0.17 : 0.17 : 0.20 : 0.13 : 0.13` + - "others" pool = `adme + design_levers + drug_interactions + prodrug_activation + resistance_mechanism + remaining functional topics`, allocated proportionally inside the 0.13 slice. + - Allocations are computed from `N` so the user just passes `--n` (e.g. `--n 300` or `--n 600`). +- Within each topic, split the slot 50/50 between `evidence_ids` non-empty vs empty (this is the headline metric breakdown). Fall back to whichever pool has rows when the other is exhausted; record the actual split in `sample.jsonl`. +- Sample across all splits (train/val/test) — this is a dataset audit, not a model eval. +- Persist `sample.jsonl` with the attached evidence bundle so judging is fully reproducible: + ```json + {"cid":..., "qa_index":..., "topic":..., "split":..., "evidence_ids_nonempty":bool, + "compound":{"name":..., "smiles":..., "molecular_formula":...}, + "question":"...", "phase2_answer":"...", + "evidence_attached":[{"id":1,"text":"..."}, ...]} + ``` +- Evidence attachment rule: + - `evidence_ids` non-empty → keep only those parent `evidence_sentences` (matched by `id`). + - else → attach all `evidence_sentences` for that compound. + +## Step 2 — Judge prompt (`prompts/claim_decomp.txt`) +The judge sees: compound name, SMILES, molecular formula, question, phase2 answer, numbered evidence (`[E1] ...`, `[E2] ...`). Required output is strict JSON: +```json +{"claims":[ + {"claim":"", + "label":"STATED|IMPLIED|UNSUPPORTED|STRUCTURAL", + "evidence_id": 2, + "rationale":"<<=25 words>"} +]} +``` +Label semantics enforced in the prompt: +- **STATED** — directly written in some evidence sentence (must cite the id). +- **IMPLIED** — inferable from evidence by one routine domain-reasoning step (cite id, give the step). +- **STRUCTURAL** — derivable from SMILES / formula / molecular weight alone (no evidence needed). See clarification below. +- **UNSUPPORTED** — none of the above; flagged as a **training-recall candidate**. + +Atomicity rules in the prompt: one assertion per claim; split conjunctions; numerical values, named entities, and mechanisms are atomic units; preserve hedging ("may", "is thought to") inside the claim — do not decompose hedges. + +### Clarification of the STRUCTURAL label +**Why this label exists.** C3 worries about *training recall* — the model writing functional claims from memorized literature instead of from the evidence. But many "functional" Q&A answers also contain claims that are pure structural inference, e.g. "the compound contains a primary amine" inside a *mechanism* answer. If those structural claims get bucketed as `UNSUPPORTED`, we systematically **overestimate** training recall. A separate `STRUCTURAL` label keeps `UNSUPPORTED` a clean proxy for "model used memorized literature, not the evidence and not the molecule itself." + +**Decision: produce both views as separate outputs — no choice required up-front.** +The judge always emits all 4 labels (`STATED` / `IMPLIED` / `UNSUPPORTED` / `STRUCTURAL`). The aggregator post-processes the same `claims_per_qa.jsonl` two ways and writes two summary files: +- **`grounding_summary_keep_structural.md`** — STRUCTURAL kept as its own bucket. Headline: `UNSUPPORTED / (STATED + IMPLIED + UNSUPPORTED)`; STRUCTURAL excluded from the denominator (it isn't an evidence-grounding question). +- **`grounding_summary_drop_structural.md`** — STRUCTURAL collapsed into IMPLIED (closer to the original spec; treats SMILES as "evidence" in a loose sense). Headline: `UNSUPPORTED / (STATED + IMPLIED + UNSUPPORTED)` over the collapsed labels. + +Both summaries cite each other so the reader can compare. The 10% / 20% decision rule (§Step 4) is applied to **each** view independently — if the conclusion flips between the two views, that's itself a finding worth flagging in the paper. + +## Step 3 — Judge runner (`judge_claims.py`) +- `httpx.AsyncClient`, semaphore = `--concurrency` (default 5), exponential backoff on 429 / 5xx. +- **Resumable**: read existing `claims_per_qa.jsonl`, skip `(cid, qa_index)` already judged. +- Strict JSON parsing. On parse failure, one retry appending "your previous output was not valid JSON, return only the JSON object". Second failure → write to `claims_per_qa.errors.jsonl` with the raw response, continue. +- Per-row record: + ```json + {"cid":..., "qa_index":..., "topic":..., "evidence_ids_nonempty":bool, + "num_evidence_attached": int, "model":"...", "claims":[...], + "usage":{"prompt_tokens":int,"completion_tokens":int}, + "latency_ms": int} + ``` +- **Cost ceiling**: abort if cumulative cost (estimated from token usage and OpenRouter's per-model price) exceeds `--max-usd`. Print running total every 25 calls. +- Cross-check pass (skipped if `--skip-cross-check`): + - Take the first `--cross-check-n` rows of `sample.jsonl`, re-judge with `--cross-check-model`, write to `claims_per_qa.gemini.jsonl`. Same retry / cost logic. + +## Step 4 — Aggregation (`aggregate.py` → two summary files) +Claim-level denominator = all atomic claims across the run. Aggregator runs twice over the same `claims_per_qa.jsonl`, once per view, writing the two summary files listed in the output layout. + +### Headline metrics (computed for both views) +- % STATED · % IMPLIED · % UNSUPPORTED · % STRUCTURAL (STRUCTURAL row collapses into IMPLIED in the "drop" view) +- **Grounded fraction** = STATED + IMPLIED +- **Training-recall candidate fraction** = UNSUPPORTED + +### Breakdowns +- By topic (mechanism / metabolism / toxicity / engineering / therapeutic_use / others). +- By `evidence_ids_nonempty` true vs false — this is the spec-required cut. +- By split (train / val / test). +- Per-QA: histogram of UNSUPPORTED claim count; **top-20 QA by UNSUPPORTED rate** for qualitative spot-checking. + +### Statistical honesty +- Wilson 95% CI on the UNSUPPORTED rate. With `N` chosen by the user, half-width ≈ `1.96 * sqrt(p(1-p)/n_claims)`; reported in the summary so a reader can see whether `N` is large enough to discriminate the 10% / 20% cutoffs. + +### Cross-check (only if not skipped) +- On the overlap subset, collapse to {grounded, ungrounded} and compute Cohen's κ between primary and cross-check models. Disagreement examples listed in an appendix. + +### Decision rule (printed at the top of the summary) +- UNSUPPORTED > 20% → narrow the paper's grounding claim; note training-recall risk in DATASHEET / RESPONSIBLE_AI. +- UNSUPPORTED < 10% → soft rule "well-behaved"; quote the number in the dataset card. +- 10–20% → add a caveat to DATASHEET / RESPONSIBLE_AI; do not claim full grounding. + +## Step 5 — Manual validation +- Spot-check **20 random claims** from `claims_per_qa.jsonl` by hand; record agreement with the judge in the appendix of `grounding_summary.md`. +- Sanity gate: every claim with a non-null `evidence_id` must reference an attached evidence sentence; otherwise auto-relabel that row as a parse error. + +## Orchestrator (`run_phase4_grounding.sh`) +Three-step pipeline (resumable; safe to re-run): +1. `python -m phase4_grounding.scripts.sample_qa --n "$N" --seed 0` +2. `python -m phase4_grounding.scripts.judge_claims --api-key-file "$KEY" --primary-model "$PRIMARY" [--skip-cross-check] [--cross-check-model "$XCHECK"] --max-usd "$BUDGET"` +3. `python -m phase4_grounding.scripts.aggregate` + +User runs e.g.: +``` +bash phase4_grounding/run_phase4_grounding.sh --n 300 --api-key-file ~/.openrouter_key --skip-cross-check +``` + +## Testing strategy +- **Framework**: pytest (already in `[project.optional-dependencies] dev`); tests live under `phase4_grounding/tests/` and are discovered when the existing `[tool.pytest.ini_options] testpaths` is extended in `pyproject.toml`. +- **Fixtures (`conftest.py`)**: + - `tiny_dataset` — 5–10 hand-crafted compounds, all functional topics represented, both `evidence_ids` non-empty and empty cases. + - `fake_openrouter_client` — returns scripted responses (well-formed JSON, malformed JSON, 429 → success). No network in any test. + - `tmp_out_dir` — pytest `tmp_path` for output writers. +- **Coverage targets per module**: + - `sampling`: stratification correctness, deterministic with seed, graceful when a stratum is empty, exact-N total. + - `evidence`: selects only listed `evidence_ids` when non-empty, attaches all when empty, stable numbering. + - `prompt`: contains all required fields, evidence is numbered `[E1]`, `[E2]`...; snapshot test against a small golden string. + - `parser`: accepts well-formed JSON; rejects malformed; rejects `evidence_id` not in attached set; preserves all 4 labels. + - `judge`: end-to-end with fake client — parse-success path, retry-then-success path, retry-then-error-log path. + - `openrouter_client`: cost tracker arithmetic; `--max-usd` raises `BudgetExceeded`; backoff sleeps respect `Retry-After` (mocked clock). + - `aggregator`: keep view excludes STRUCTURAL from denominator; drop view collapses STRUCTURAL → IMPLIED; Wilson CI matches a known reference for fixed inputs; per-topic and per-`evidence_ids_nonempty` breakdowns sum to the total. + - `reporter`: writes both files; decision-rule string matches the computed UNSUPPORTED rate. +- **Network policy**: every test that touches `OpenRouterClient` uses the fake; one optional `@pytest.mark.live` smoke test (skipped by default; opt-in with `--live`) hits OpenRouter with one cheap call. +- **Run locally**: `pytest phase4_grounding/tests -q --cov=phase4_grounding/grounding`. + +## Step-by-step implementation +The work is sliced into small, ordered steps. **Acceptance gate for each step: its unit tests pass (`pytest phase4_grounding/tests -q`).** No PR workflow, no CI pipeline, no coverage thresholds — just a green test run before moving to the next step. No step touches production data or makes network calls in tests (everything uses the fake client). + +| Step | Scope | Gate | +|---|---|---| +| 1 | Skeleton: package layout, `models.py` dataclasses, `tests/conftest.py` with `tiny_dataset` fixture | tests collect cleanly; one smoke test on a dataclass passes | +| 2 | `sampling.Sampler` + `test_sampling.py` | tests cover stratification, seed determinism, exhausted-stratum | +| 3 | `evidence.EvidenceAttacher` + `test_evidence.py` | covers both `evidence_ids` branches | +| 4 | `prompt.PromptBuilder` + `prompts/claim_decomp.txt` + `test_prompt.py` | snapshot test passes | +| 5 | `parser.ClaimParser` + `test_parser.py` | covers well/malformed JSON and bogus `evidence_id` | +| 6 | `openrouter_client.OpenRouterClient` + `FakeOpenRouterClient` + `test_openrouter_client.py` | retry, backoff, budget cap | +| 7 | `judge.ClaimJudge` + `test_judge.py` | end-to-end with fake client | +| 8 | `scripts/sample_qa.py` + `scripts/judge_claims.py` (entry points, `main()`, argparse) + integration test against `tiny_dataset` | integration test produces expected `claims_per_qa.jsonl` shape | +| 9 | `aggregator.Aggregator` + `reporter.Reporter` + `scripts/aggregate.py` + tests | both summary files generated; numbers match hand-computed reference | +| 10 | `run_phase4_grounding.sh` orchestrator | bash script runs all three steps end-to-end on `tiny_dataset` | + +Once step 10 passes, the real audit is a one-off manual run of the orchestrator with the user-approved `--n` and the production API key. + diff --git a/phase4_grounding/RESULTS.md b/phase4_grounding/RESULTS.md new file mode 100644 index 0000000..2e99ac1 --- /dev/null +++ b/phase4_grounding/RESULTS.md @@ -0,0 +1,130 @@ +# Phase 4 grounding audit — results + +Empirical measurement of how grounded Chem2TextQA's Phase-2 answers are +against the cited evidence. This document fixes the headline numbers +referenced in `DATASHEET.md`, `RESPONSIBLE_AI.md`, and `LIMITATIONS.md`. + +The orchestrator and code live alongside this file in +`phase4_grounding/`; see `USAGE.md` to reproduce. The detailed per-topic +/ per-split summary tables and per-row cross-check comparison are +regenerated into `phase4_grounding/outputs/` by the orchestrator +(gitignored — reproducible, not redistributed). + +## Headline + +A model judges every Phase-2 answer claim-by-claim against the attached +evidence and labels each claim as **STATED**, **IMPLIED**, **STRUCTURAL** +(derivable from SMILES alone), or **UNSUPPORTED** (training-recall +candidate). UNSUPPORTED is the proxy for training-recall risk; the PLAN +threshold for narrowing the paper's grounding claim is 20%. + +| View | n claims | UNSUPPORTED | 95% Wilson CI | Decision | +|---|---|---|---|---| +| **keep-structural** *(STRUCTURAL excluded from denominator — clean training-recall proxy)* | 3,076 | **55.20%** | 53.43–56.97% | **NARROW** | +| **drop-structural** *(STRUCTURAL collapsed into IMPLIED — original PLAN spec)* | 3,799 | **44.70%** | 43.11–46.30% | **NARROW** | + +Both views are well above the 20% PLAN threshold. The paper's grounding +claim must be narrowed and training-recall risk explicitly flagged. + +## Sample + +- **300 Q&As** sampled with seed 0 from the gold dataset + (`data/dataset_gold.jsonl`), stratified across topic / split / + evidence-attached buckets. +- All 300 judged successfully (292 by the primary model, 8 by the + cross-check model after primary refusals — see "Refusals" below). +- 30 Q&As were independently re-judged by a second model for agreement + validation. + +## Methodology + +| Stage | Model | Role | +|---|---|---| +| Sample | (deterministic) | `sample_qa.py --n 300 --seed 0` | +| Primary judge | `anthropic/claude-sonnet-4.6` | per-claim label | +| Cross-check | `google/gemini-2.5-pro` (n=30) | independent re-judgment | +| Aggregate | (deterministic) | rates, CIs, by-topic / by-split splits | + +Total cost: **$10.16** ($9.76 primary + cross-check + $0.39 rejudge +fallback for the 8 refused rows). + +## Where the recall risk concentrates (keep-structural) + +UNSUPPORTED rate by topic: + +| Topic | UNSUPPORTED | n | +|---|---|---| +| **engineering** | **74.6%** | ~640 | +| adme | 69.5% | ~120 | +| metabolism | 68.8% | ~460 | +| design_levers | 67.1% | ~165 | +| drug_interactions | 64.8% | ~55 | +| toxicity | 40.6% | ~540 | +| mechanism | 40.7% | ~615 | +| therapeutic_use | 38.2% | ~415 | + +Engineering / design / metabolism Q&As have the worst grounding — +unsurprisingly, since these topics ask for analog-design or pathway +reasoning that goes beyond what an evidence sentence directly states. + +Q&As **with** evidence attached have a *higher* UNSUPPORTED rate +(62.94%) than those without (46.49%) in the keep-structural view — +counterintuitive but consistent with training-recall: when evidence is +attached, the model is also more likely to add unsourced elaboration on +top of it. + +Full breakdown by topic / evidence / split, the per-Q&A UNSUPPORTED +histogram, and the top-20 worst-grounded Q&As regenerate into +`outputs/grounding_summary_keep_structural.md` on each run. + +## Cross-check validation + +30 of the 300 Q&As were independently re-judged by `gemini-2.5-pro`. +Macro UNSUPPORTED rates: + +| View | sonnet (primary) | gemini (cross-check) | macro Δ | +|---|---|---|---| +| keep-structural | 59.73% | 63.41% | **+3.68pp** | +| drop-structural | 47.09% | 48.40% | **+1.31pp** | + +The two models agree closely on the macro rate (within 4pp, both above +the 20% threshold). Per-row agreement is looser — mean absolute diff +~11pp (keep) — but **26/30 rows agree to within 20pp**, and the few +larger per-row disagreements average out at the macro level. + +This validates the headline UNSUPPORTED rate as model-agnostic to the +choice of judge, not an artifact of one model's labeling style. + +Full per-row table regenerates into `outputs/cross_check_agreement.md`. + +## Refusals — a methodological note + +The primary model (`claude-sonnet-4.6`) refused on 8 of 300 Q&As (2.7%) +by returning empty/null content. All 8 are **dual-use chemistry +queries**: engineering of toxins or controlled substances (ziconotide, +yessotoxin, anthopleurin B), pesticide modifications (carbofuran), +heavy-metal substitutions (lead chloride), and toxicity questions about +known toxins. The pattern is consistent with Anthropic's safety filter +on synthesis-uplift adjacent prompts. + +Falling back to `gemini-2.5-pro` on the same 8 prompts succeeded on all +8 with structurally reasonable claim decompositions. + +Implication for users: + +- Single-model audits of dual-use chemistry datasets will systematically + undercount the "engineering" topic. A heterogeneous-judge protocol is + required for full coverage. +- Anyone reproducing this audit with a single model should expect ~3% + loss on that topic and report it. + +## Reproducing + +```bash +phase4_grounding/run_phase4_grounding.sh --n 300 --max-usd 100 +python -m phase4_grounding.scripts.rejudge_errors # gemini fallback for refusals +python -m phase4_grounding.scripts.aggregate --out-dir phase4_grounding/outputs +python -m phase4_grounding.scripts.analyze_cross_check_agreement +``` + +Final cost: ~$10. See `USAGE.md` for argument reference. diff --git a/phase4_grounding/USAGE.md b/phase4_grounding/USAGE.md new file mode 100644 index 0000000..2aa5a21 --- /dev/null +++ b/phase4_grounding/USAGE.md @@ -0,0 +1,173 @@ +# Phase 4 — Grounding Audit Usage + +End-to-end usage for the claim-grounding audit. Decomposes each functional +`phase2_answer` into atomic claims, labels them against the parent compound's +evidence sentences, and aggregates the result into two summary markdown files. + +See `PLAN.md` for the full design rationale. + +## Prerequisites + +```bash +pip install httpx pytest pytest-asyncio +``` + +- `data/dataset_gold.jsonl` available at the repo root (or pass `--data-path`). +- An OpenRouter API key in a plain-text file (default: `~/.openrouter_key`). + The key is read **only** from this file — never echoed, never logged. + +## One-shot run (recommended) + +```bash +bash phase4_grounding/run_phase4_grounding.sh \ + --n 300 \ + --api-key-file ~/.openrouter_key \ + --skip-cross-check +``` + +The orchestrator runs three steps in order. Each step is resumable, so it is +safe to re-run after a crash, network blip, or budget abort. + +| Flag | Default | Description | +|---|---|---| +| `--n` | required | total sample size | +| `--api-key-file` | `~/.openrouter_key` | path to the OpenRouter API key | +| `--data-path` | `data/dataset_gold.jsonl` | input dataset | +| `--out-dir` | `phase4_grounding/outputs` | output directory | +| `--primary-model` | `anthropic/claude-sonnet-4.6` | primary judge | +| `--cross-check-model` | `google/gemini-2.5-pro` | second-opinion judge | +| `--skip-cross-check` | off | disable the cross-check pass | +| `--cross-check-n` | `30` | size of the cross-check overlap subset | +| `--max-usd` | `15` | hard budget cap; runner aborts past this | +| `--concurrency` | `5` | async semaphore for OpenRouter calls | +| `--seed` | `0` | RNG seed for sampling | + +## Per-step invocation + +If you want to run the steps individually (e.g. tweak the sample, then judge): + +### 1. Stratified sample + +```bash +python -m phase4_grounding.scripts.sample_qa \ + --n 300 \ + --seed 0 \ + --data-path data/dataset_gold.jsonl \ + --out-dir phase4_grounding/outputs +``` + +Writes `sample.jsonl`. Stratifies by topic with a fixed weighting that +emphasizes the four headline topics (mechanism / engineering / metabolism / +toxicity) and splits 50/50 between QAs with non-empty vs empty `evidence_ids`. + +### 2. Judge claims + +```bash +python -m phase4_grounding.scripts.judge_claims \ + --api-key-file ~/.openrouter_key \ + --primary-model anthropic/claude-sonnet-4.6 \ + --skip-cross-check \ + --max-usd 15 \ + --concurrency 5 \ + --out-dir phase4_grounding/outputs +``` + +Reads `sample.jsonl`. Skips `(cid, qa_index)` rows already in +`claims_per_qa.jsonl` so re-runs are idempotent. Writes: +- `claims_per_qa.jsonl` — successful rows (one JSON per line) +- `claims_per_qa.errors.jsonl` — rows where JSON parsing failed twice +- `claims_per_qa.gemini.jsonl` — cross-check pass (only if not skipped) + +If cumulative spend crosses `--max-usd`, the runner refuses the next call, +prints the abort reason, and preserves all rows written so far. + +### 3. Aggregate + +```bash +python -m phase4_grounding.scripts.aggregate \ + --out-dir phase4_grounding/outputs +``` + +Writes both summary files: +- `grounding_summary_keep_structural.md` — STRUCTURAL claims kept as their own + bucket, **excluded** from the STATED / IMPLIED / UNSUPPORTED denominator. +- `grounding_summary_drop_structural.md` — STRUCTURAL collapsed into IMPLIED. + +Each summary prints the decision-rule banner at the top: +- UNSUPPORTED > 20% → **NARROW** the paper's grounding claim. +- UNSUPPORTED 10–20% → **CAVEAT** in DATASHEET / RESPONSIBLE_AI. +- UNSUPPORTED < 10% → **WELL-BEHAVED**; quote in the dataset card. + +## Output layout + +``` +phase4_grounding/outputs/ +├── sample.jsonl +├── claims_per_qa.jsonl +├── claims_per_qa.errors.jsonl +├── claims_per_qa.gemini.jsonl # only if cross-check enabled +├── claims_per_qa.gemini.errors.jsonl +├── grounding_summary_keep_structural.md +└── grounding_summary_drop_structural.md +``` + +## Tests + +```bash +python -m pytest phase4_grounding/tests -q +``` + +All tests are offline — `OpenRouterClient` is replaced with a scripted fake. +No live API keys or network access are needed to run the suite. + +## Library API (for programmatic use) + +```python +from pathlib import Path + +from phase4_grounding.grounding.aggregator import Aggregator +from phase4_grounding.grounding.evidence import EvidenceAttacher +from phase4_grounding.grounding.judge import ClaimJudge +from phase4_grounding.grounding.openrouter_client import OpenRouterClient +from phase4_grounding.grounding.prompt import PromptBuilder +from phase4_grounding.grounding.reporter import Reporter +from phase4_grounding.grounding.sampling import Sampler + +# 1. Sample +sampler = Sampler(dataset_path="data/dataset_gold.jsonl", seed=0) +rows = sampler.sample(n=300) + +# 2. Judge (async) +async with OpenRouterClient(api_key="...", max_usd=15.0) as client: + judge = ClaimJudge(client, PromptBuilder()) + judged = [await judge.judge(r, model="anthropic/claude-sonnet-4.6") for r in rows] + +# 3. Aggregate + report +agg = Aggregator(judged) +Reporter(agg.compute("keep"), agg.compute("drop"), judged).write(Path("outputs")) +``` + +## Cost estimate (rule of thumb) + +For `--n 300` with `claude-sonnet-4.6` as the primary and a 30-row Gemini +cross-check, expect prompts of ~1–2k tokens and completions of ~500 tokens per +row. At list price that is ~$3–6 USD. Set `--max-usd` to your comfort level; +the runner aborts cleanly the first time spend crosses the cap. + +## Resumability + +Both `judge_claims.py` and the orchestrator are resumable: +- `judge_claims.py` reads existing `claims_per_qa.jsonl` and skips + `(cid, qa_index)` already judged. +- A budget abort, network failure, or `Ctrl-C` leaves the partial output + intact. Just re-run the same command — only the missing rows are judged. + +## Troubleshooting + +- **`pytest.mark.asyncio` warnings / failures** — install `pytest-asyncio`. +- **`api-key-file` not found** — pass an explicit `--api-key-file` pointing at + a plain-text file containing only the key. +- **All rows go to `claims_per_qa.errors.jsonl`** — usually the model is + emitting markdown fences. The parser is strict; the judge auto-retries once + with a "no markdown fences" hint, but a persistently-broken model will fail + twice. Check the raw text in the errors file and consider switching models. diff --git a/phase4_grounding/__init__.py b/phase4_grounding/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phase4_grounding/environment.yml b/phase4_grounding/environment.yml new file mode 100644 index 0000000..aff56f1 --- /dev/null +++ b/phase4_grounding/environment.yml @@ -0,0 +1,13 @@ +name: chem2text-phase4 +channels: + - conda-forge +dependencies: + - python=3.11 + - pip + - pip: + - httpx>=0.27 + - pytest>=7.4 + - pytest-asyncio>=0.23 + - pytest-cov + - tqdm>=4.66 + - tenacity>=8.2 diff --git a/phase4_grounding/grounding/__init__.py b/phase4_grounding/grounding/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phase4_grounding/grounding/aggregator.py b/phase4_grounding/grounding/aggregator.py new file mode 100644 index 0000000..f889c6a --- /dev/null +++ b/phase4_grounding/grounding/aggregator.py @@ -0,0 +1,171 @@ +"""Aggregate judged claims into headline metrics + breakdowns. + +Two views over the same `claims_per_qa.jsonl`: +- ``keep``: STRUCTURAL is its own bucket; the denominator excludes it. The + headline UNSUPPORTED rate is `UNSUPPORTED / (STATED + IMPLIED + UNSUPPORTED)`. +- ``drop``: STRUCTURAL is collapsed into IMPLIED; everything counted. + +`Aggregator(judged_qas).compute(view)` returns a `ViewMetrics`. The reporter +consumes both views to write the dual summary files. +""" +from __future__ import annotations + +import math +from collections.abc import Iterable +from typing import Literal + +from .models import Claim, JudgedQA, ViewMetrics + +_VIEW_LABELS_KEEP: tuple[str, ...] = ("STATED", "IMPLIED", "UNSUPPORTED") +_VIEW_LABELS_DROP: tuple[str, ...] = ("STATED", "IMPLIED", "UNSUPPORTED") +_TOP_N = 20 + + +def wilson_ci(successes: int, total: int, z: float = 1.96) -> tuple[float, float]: + """Wilson score 95% CI for a binomial proportion. + + Returns (0.0, 0.0) when total is 0 — the caller decides how to display it. + """ + if total == 0: + return (0.0, 0.0) + p = successes / total + denom = 1 + z * z / total + center = (p + z * z / (2 * total)) / denom + half = (z / denom) * math.sqrt(p * (1 - p) / total + z * z / (4 * total * total)) + return (max(0.0, center - half), min(1.0, center + half)) + + +class Aggregator: + """Compute label rates + breakdowns + per-QA stats over judged Q&A rows.""" + + def __init__(self, judged_qas: Iterable[JudgedQA]) -> None: + self.judged_qas: tuple[JudgedQA, ...] = tuple(judged_qas) + + def compute(self, view: Literal["keep", "drop"]) -> ViewMetrics: + if view not in ("keep", "drop"): + raise ValueError(f"view must be 'keep' or 'drop', got {view!r}") + + labels = _VIEW_LABELS_KEEP if view == "keep" else _VIEW_LABELS_DROP + + claims_for_metrics: list[tuple[JudgedQA, Claim, str]] = [] + structural_count = 0 + for qa in self.judged_qas: + for claim in qa.claims: + projected = self._project_label(claim.label, view) + if projected is None: + structural_count += 1 + continue + claims_for_metrics.append((qa, claim, projected)) + + total = len(claims_for_metrics) + counts = {label: 0 for label in labels} + for _, _, projected in claims_for_metrics: + counts[projected] += 1 + + rates = { + label: (counts[label] / total) if total else 0.0 for label in labels + } + grounded_rate = rates.get("STATED", 0.0) + rates.get("IMPLIED", 0.0) + unsupported_rate = rates.get("UNSUPPORTED", 0.0) + unsupported_ci = wilson_ci(counts.get("UNSUPPORTED", 0), total) + + by_topic = self._breakdown( + claims_for_metrics, key=lambda qa, _c: qa.topic, labels=labels + ) + by_evidence = self._breakdown( + claims_for_metrics, + key=lambda qa, _c: qa.evidence_ids_nonempty, + labels=labels, + ) + by_split = self._breakdown( + claims_for_metrics, key=lambda qa, _c: qa.split or "(unknown)", labels=labels + ) + + histogram, top_qa = self._per_qa_stats(view) + + return ViewMetrics( + view=view, + total_claims=total, + counts=counts, + rates=rates, + grounded_rate=grounded_rate, + unsupported_rate=unsupported_rate, + unsupported_ci=unsupported_ci, + structural_count=structural_count, + by_topic=by_topic, + by_evidence_ids_nonempty=by_evidence, + by_split=by_split, + per_qa_unsupported_histogram=histogram, + top_qa_by_unsupported=top_qa, + ) + + @staticmethod + def _project_label(label: str, view: Literal["keep", "drop"]) -> str | None: + """Map a raw judge label to the view-specific label. + + Returns None when the claim should be excluded from the denominator + (the keep view drops STRUCTURAL). + """ + if label == "STRUCTURAL": + return None if view == "keep" else "IMPLIED" + return label + + @staticmethod + def _breakdown( + claims_for_metrics: list[tuple[JudgedQA, Claim, str]], + *, + key, + labels: tuple[str, ...], + ) -> dict: + """Group claims by `key(qa, claim)` and compute per-group rates + total.""" + buckets: dict[object, dict[str, int]] = {} + for qa, claim, projected in claims_for_metrics: + k = key(qa, claim) + bucket = buckets.setdefault(k, {label: 0 for label in labels}) + bucket[projected] = bucket.get(projected, 0) + 1 + + out: dict = {} + for k, counts in buckets.items(): + total = sum(counts.values()) + entry = {label: counts.get(label, 0) for label in labels} + entry["total"] = total + for label in labels: + entry[f"{label}_rate"] = (counts.get(label, 0) / total) if total else 0.0 + out[k] = entry + return out + + def _per_qa_stats( + self, view: Literal["keep", "drop"] + ) -> tuple[dict[int, int], tuple[dict, ...]]: + """Histogram of UNSUPPORTED count per Q&A + top-N rows by UNSUPPORTED rate.""" + histogram: dict[int, int] = {} + rows: list[dict] = [] + for qa in self.judged_qas: + unsupported = 0 + denom = 0 + for claim in qa.claims: + projected = self._project_label(claim.label, view) + if projected is None: + continue + denom += 1 + if projected == "UNSUPPORTED": + unsupported += 1 + histogram[unsupported] = histogram.get(unsupported, 0) + 1 + rate = (unsupported / denom) if denom else 0.0 + rows.append( + { + "cid": qa.cid, + "qa_index": qa.qa_index, + "topic": qa.topic, + "split": qa.split, + "unsupported": unsupported, + "total_claims": denom, + "unsupported_rate": rate, + } + ) + + rows.sort( + key=lambda r: (r["unsupported_rate"], r["unsupported"]), + reverse=True, + ) + return histogram, tuple(rows[:_TOP_N]) diff --git a/phase4_grounding/grounding/evidence.py b/phase4_grounding/grounding/evidence.py new file mode 100644 index 0000000..ecfa117 --- /dev/null +++ b/phase4_grounding/grounding/evidence.py @@ -0,0 +1,59 @@ +"""Attach evidence sentences to a sampled Q&A. + +Rules (mirrors §Step 1 of PLAN.md): +- If the QA's `evidence_ids` list is non-empty, select only the parent + `evidence_sentences` whose source `id` is listed, preserving the order given + in `evidence_ids`. Unknown ids are skipped silently. +- Otherwise attach all `evidence_sentences` for the compound, in their original + order. + +The returned items are renumbered with **display ids** `1..N` so the judge +prompt can label them `[E1]`, `[E2]`, ... and the parser can validate that the +judge's `evidence_id` belongs to `{1..N}`. + +This module is pure: no I/O at import, no network. +""" +from __future__ import annotations + +from .models import EvidenceItem + + +class EvidenceAttacher: + """Pure helper that selects and renumbers evidence for a single QA.""" + + @staticmethod + def attach(qa: dict, compound: dict) -> tuple[EvidenceItem, ...]: + sentences = compound.get("evidence_sentences") or [] + id_to_sent: dict[int, dict] = {} + for s in sentences: + try: + id_to_sent[int(s["id"])] = s + except (KeyError, TypeError, ValueError): + continue + + evidence_ids = qa.get("evidence_ids") or [] + if evidence_ids: + selected: list[dict] = [] + for eid in evidence_ids: + try: + key = int(eid) + except (TypeError, ValueError): + continue + s = id_to_sent.get(key) + if s is not None: + selected.append(s) + else: + selected = list(sentences) + + attached: list[EvidenceItem] = [] + for display_id, s in enumerate(selected, start=1): + pmid = s.get("pmid") + attached.append( + EvidenceItem( + id=display_id, + text=s.get("text", ""), + pmid=str(pmid) if pmid is not None else None, + source=s.get("source"), + ) + ) + return tuple(attached) diff --git a/phase4_grounding/grounding/judge.py b/phase4_grounding/grounding/judge.py new file mode 100644 index 0000000..26719bd --- /dev/null +++ b/phase4_grounding/grounding/judge.py @@ -0,0 +1,111 @@ +"""End-to-end judging of a single Q&A: prompt → model call → parse. + +The orchestration logic is intentionally tiny: it composes already-tested +modules (prompt builder, OpenRouter client, parser) and adds the one-shot +retry on parse failure. + +`ClaimJudge.judge(row, model) -> JudgedQA` is the public surface. + +On the second consecutive parse failure the judge raises `JudgeError` with +the raw response on it; the runner script catches this and writes a row to +`claims_per_qa.errors.jsonl` (per PLAN §Step 3). +""" +from __future__ import annotations + +from typing import Protocol + +from .models import ChatResult, JudgedQA, ParseResult, SampleRow +from .parser import ClaimParser +from .prompt import PromptBuilder + +_RETRY_INSTRUCTION = ( + "\n\nYour previous output was not valid JSON. Return only the JSON object — " + "no markdown fences, no commentary." +) + + +class JudgeError(Exception): + """Raised when both judging attempts fail to produce parseable output.""" + + def __init__( + self, + message: str, + *, + cid: int, + qa_index: int, + model: str, + raw: str, + first_error: str | None, + second_error: str | None, + ) -> None: + super().__init__(message) + self.cid = cid + self.qa_index = qa_index + self.model = model + self.raw = raw + self.first_error = first_error + self.second_error = second_error + + +class _ChatClient(Protocol): + async def chat(self, *, model: str, prompt: str, **kwargs: object) -> ChatResult: ... + + +class ClaimJudge: + """Compose prompt + client + parser; one retry on parse failure.""" + + def __init__( + self, + client: _ChatClient, + prompt_builder: PromptBuilder, + parser: ClaimParser | None = None, + ) -> None: + self.client = client + self.prompt_builder = prompt_builder + self.parser = parser or ClaimParser() + + async def judge(self, row: SampleRow, model: str) -> JudgedQA: + prompt = self.prompt_builder.build(row) + attached_ids = {e.id for e in row.evidence_attached} + + first = await self.client.chat(model=model, prompt=prompt) + first_parsed = self.parser.parse(first.text, attached_ids) + if first_parsed.ok: + return self._judged_qa(row, model, first_parsed, [first]) + + retry_prompt = prompt + _RETRY_INSTRUCTION + second = await self.client.chat(model=model, prompt=retry_prompt) + second_parsed = self.parser.parse(second.text, attached_ids) + if second_parsed.ok: + return self._judged_qa(row, model, second_parsed, [first, second]) + + raise JudgeError( + "judge failed to produce valid JSON after one retry", + cid=row.cid, + qa_index=row.qa_index, + model=model, + raw=second.text, + first_error=first_parsed.error, + second_error=second_parsed.error, + ) + + @staticmethod + def _judged_qa( + row: SampleRow, + model: str, + parsed: ParseResult, + chats: list[ChatResult], + ) -> JudgedQA: + return JudgedQA( + cid=row.cid, + qa_index=row.qa_index, + topic=row.topic, + evidence_ids_nonempty=row.evidence_ids_nonempty, + num_evidence_attached=len(row.evidence_attached), + model=model, + claims=parsed.claims, + prompt_tokens=sum(c.prompt_tokens for c in chats), + completion_tokens=sum(c.completion_tokens for c in chats), + latency_ms=sum(c.latency_ms for c in chats), + split=row.split, + ) diff --git a/phase4_grounding/grounding/models.py b/phase4_grounding/grounding/models.py new file mode 100644 index 0000000..82ccf3a --- /dev/null +++ b/phase4_grounding/grounding/models.py @@ -0,0 +1,92 @@ +"""Typed dataclasses passed between phase4_grounding modules.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal + +ClaimLabel = Literal["STATED", "IMPLIED", "UNSUPPORTED", "STRUCTURAL"] +LABEL_VALUES: tuple[ClaimLabel, ...] = ("STATED", "IMPLIED", "UNSUPPORTED", "STRUCTURAL") + + +@dataclass(frozen=True) +class EvidenceItem: + id: int + text: str + pmid: str | None = None + source: str | None = None + + +@dataclass(frozen=True) +class Compound: + cid: int + name: str + smiles: str + molecular_formula: str + + +@dataclass(frozen=True) +class SampleRow: + cid: int + qa_index: int + topic: str + split: str + evidence_ids_nonempty: bool + compound: Compound + question: str + phase2_answer: str + evidence_attached: tuple[EvidenceItem, ...] + + +@dataclass(frozen=True) +class Claim: + claim: str + label: ClaimLabel + evidence_id: int | None + rationale: str | None + + +@dataclass(frozen=True) +class JudgedQA: + cid: int + qa_index: int + topic: str + evidence_ids_nonempty: bool + num_evidence_attached: int + model: str + claims: tuple[Claim, ...] + prompt_tokens: int + completion_tokens: int + latency_ms: int + split: str = "" + + +@dataclass(frozen=True) +class ParseResult: + ok: bool + claims: tuple[Claim, ...] = () + error: str | None = None + + +@dataclass(frozen=True) +class ChatResult: + text: str + prompt_tokens: int + completion_tokens: int + latency_ms: int + + +@dataclass(frozen=True) +class ViewMetrics: + view: Literal["keep", "drop"] + total_claims: int + counts: dict[str, int] + rates: dict[str, float] + grounded_rate: float + unsupported_rate: float + unsupported_ci: tuple[float, float] + structural_count: int = 0 + by_topic: dict[str, dict[str, float]] = field(default_factory=dict) + by_evidence_ids_nonempty: dict[bool, dict[str, float]] = field(default_factory=dict) + by_split: dict[str, dict[str, float]] = field(default_factory=dict) + per_qa_unsupported_histogram: dict[int, int] = field(default_factory=dict) + top_qa_by_unsupported: tuple[dict, ...] = field(default_factory=tuple) diff --git a/phase4_grounding/grounding/openrouter_client.py b/phase4_grounding/grounding/openrouter_client.py new file mode 100644 index 0000000..262f2b0 --- /dev/null +++ b/phase4_grounding/grounding/openrouter_client.py @@ -0,0 +1,184 @@ +"""Async OpenRouter chat client with retries, backoff, and a USD budget cap. + +Public surface: +- `OpenRouterClient(api_key, concurrency, max_usd)` — instantiate once per run. +- `await client.chat(model=..., prompt=...) -> ChatResult` — single completion. +- `client.spend_usd` — cumulative cost so far (for periodic logging). +- `BudgetExceeded` — raised on the call **after** cumulative spend crosses + `max_usd`. The result of the call that pushed spend over the line is still + returned, so the runner can persist it before aborting. + +Behavior: +- Single shared `httpx.AsyncClient`; retries on `429` and `5xx` with + exponential backoff (`backoff_base * 2**attempt`). When the response + carries a `Retry-After` header, that value is honored verbatim. +- Concurrency is bounded by an `asyncio.Semaphore`. +- `pricing` is per-million-token USD; unknown models cost zero. Override via + the `pricing` constructor argument when invoking real models so the budget + check is meaningful. + +The HTTP transport and sleep function are injectable so unit tests can run +without network or wall-clock waits. +""" +from __future__ import annotations + +import asyncio +import time +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +import httpx + +from .models import ChatResult + + +class BudgetExceeded(Exception): + """Raised when cumulative spend exceeds the configured `max_usd` cap.""" + + +@dataclass(frozen=True) +class ModelPricing: + """Per-million-token USD prices.""" + + prompt_per_mtok: float + completion_per_mtok: float + + +DEFAULT_PRICING: dict[str, ModelPricing] = { + "anthropic/claude-sonnet-4.6": ModelPricing(prompt_per_mtok=3.0, completion_per_mtok=15.0), + "google/gemini-2.5-pro": ModelPricing(prompt_per_mtok=1.25, completion_per_mtok=10.0), +} + + +class OpenRouterClient: + """Async OpenRouter completions client with retries and a budget cap.""" + + def __init__( + self, + api_key: str, + concurrency: int = 5, + max_usd: float = 15.0, + pricing: dict[str, ModelPricing] | None = None, + base_url: str = "https://openrouter.ai/api/v1", + max_retries: int = 5, + backoff_base: float = 1.0, + timeout_s: float = 60.0, + sleep: Callable[[float], Awaitable[None]] = asyncio.sleep, + transport: httpx.AsyncBaseTransport | None = None, + ) -> None: + self.api_key = api_key + self.max_usd = float(max_usd) + self.pricing = dict(pricing) if pricing is not None else dict(DEFAULT_PRICING) + self.max_retries = max_retries + self.backoff_base = backoff_base + self._sleep = sleep + self._sem = asyncio.Semaphore(concurrency) + self._client = httpx.AsyncClient( + base_url=base_url, + transport=transport, + timeout=httpx.Timeout(timeout_s), + ) + self.spend_usd: float = 0.0 + self.calls: int = 0 + + async def aclose(self) -> None: + await self._client.aclose() + + async def __aenter__(self) -> "OpenRouterClient": + return self + + async def __aexit__(self, *exc: object) -> None: + await self.aclose() + + def estimate_cost(self, model: str, prompt_tokens: int, completion_tokens: int) -> float: + p = self.pricing.get(model) + if p is None: + return 0.0 + return ( + prompt_tokens * p.prompt_per_mtok / 1_000_000.0 + + completion_tokens * p.completion_per_mtok / 1_000_000.0 + ) + + async def chat( + self, + *, + model: str, + prompt: str, + temperature: float = 0.0, + max_tokens: int | None = None, + **kw: Any, + ) -> ChatResult: + if self.spend_usd > self.max_usd: + raise BudgetExceeded( + f"cumulative spend ${self.spend_usd:.4f} exceeded budget ${self.max_usd:.2f}" + ) + async with self._sem: + return await self._chat_with_retries(model, prompt, temperature, max_tokens, kw) + + async def _chat_with_retries( + self, + model: str, + prompt: str, + temperature: float, + max_tokens: int | None, + kw: dict[str, Any], + ) -> ChatResult: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + body: dict[str, Any] = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "temperature": temperature, + **kw, + } + if max_tokens is not None: + body["max_tokens"] = max_tokens + + attempt = 0 + while True: + t0 = time.monotonic() + response = await self._client.post("/chat/completions", headers=headers, json=body) + + if response.status_code == 429 or 500 <= response.status_code < 600: + if attempt >= self.max_retries: + response.raise_for_status() + delay = self._compute_backoff(response, attempt) + await self._sleep(delay) + attempt += 1 + continue + + response.raise_for_status() + data = response.json() + return self._record_and_build(model, data, t0) + + def _compute_backoff(self, response: httpx.Response, attempt: int) -> float: + retry_after = response.headers.get("Retry-After") + if retry_after: + try: + return float(retry_after) + except ValueError: + pass + return self.backoff_base * (2 ** attempt) + + def _record_and_build(self, model: str, data: dict, t0: float) -> ChatResult: + try: + text = data["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError) as exc: + raise RuntimeError(f"unexpected OpenRouter response shape: {exc}") from exc + usage = data.get("usage") or {} + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + latency_ms = int((time.monotonic() - t0) * 1000) + + self.spend_usd += self.estimate_cost(model, prompt_tokens, completion_tokens) + self.calls += 1 + + return ChatResult( + text=text, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + latency_ms=latency_ms, + ) diff --git a/phase4_grounding/grounding/parser.py b/phase4_grounding/grounding/parser.py new file mode 100644 index 0000000..a027757 --- /dev/null +++ b/phase4_grounding/grounding/parser.py @@ -0,0 +1,110 @@ +"""Strict parser for the claim-decomposition judge response. + +`ClaimParser.parse(raw, attached_ids)` returns a `ParseResult`: +- `ok=True` with a tuple of `Claim` objects if and only if `raw` is valid JSON + matching the documented schema and every non-null `evidence_id` is one of + `attached_ids`. +- `ok=False` with a short `error` string otherwise. The caller (judge runner) + uses the error to decide whether to retry once and, on second failure, log + the raw response to `claims_per_qa.errors.jsonl`. +""" +from __future__ import annotations + +import json + +from .models import LABEL_VALUES, Claim, ParseResult + +_REQUIRED_FIELDS = ("claim", "label", "evidence_id", "rationale") + + +class ClaimParser: + """Strict JSON + schema validator for the judge's response.""" + + @staticmethod + def parse(raw: str | None, attached_ids: set[int]) -> ParseResult: + if not isinstance(raw, str): + return ParseResult( + ok=False, + error=f"raw response is not a string (got {type(raw).__name__})", + ) + try: + payload = json.loads(_strip_fences(raw)) + except json.JSONDecodeError as exc: + return ParseResult(ok=False, error=f"invalid JSON: {exc.msg}") + + if not isinstance(payload, dict): + return ParseResult(ok=False, error="root must be a JSON object") + if "claims" not in payload: + return ParseResult(ok=False, error="missing 'claims' key") + claims_raw = payload["claims"] + if not isinstance(claims_raw, list): + return ParseResult(ok=False, error="'claims' must be a list") + + claims: list[Claim] = [] + for i, item in enumerate(claims_raw): + err = _validate_claim_shape(item, i) + if err is not None: + return ParseResult(ok=False, error=err) + + label = item["label"] + if label not in LABEL_VALUES: + return ParseResult( + ok=False, + error=f"claim[{i}].label must be one of {LABEL_VALUES}, got {label!r}", + ) + + evidence_id = item["evidence_id"] + if evidence_id is not None: + if not isinstance(evidence_id, int) or isinstance(evidence_id, bool): + return ParseResult( + ok=False, + error=f"claim[{i}].evidence_id must be int or null, got {type(evidence_id).__name__}", + ) + if evidence_id not in attached_ids: + return ParseResult( + ok=False, + error=( + f"claim[{i}].evidence_id={evidence_id} not in attached " + f"ids {sorted(attached_ids)}" + ), + ) + + claims.append( + Claim( + claim=item["claim"], + label=label, + evidence_id=evidence_id, + rationale=item["rationale"], + ) + ) + + return ParseResult(ok=True, claims=tuple(claims)) + + +def _strip_fences(raw: str) -> str: + # Some models (notably gemini-2.5-pro) wrap JSON in ```json ... ``` despite + # being told not to. Strip a leading fence line and a trailing fence so the + # JSON body is what hits json.loads. + s = raw.strip() + if s.startswith("```"): + nl = s.find("\n") + s = s[nl + 1 :] if nl != -1 else s[3:] + s = s.rstrip() + if s.endswith("```"): + s = s[:-3].rstrip() + return s + + +def _validate_claim_shape(item: object, i: int) -> str | None: + if not isinstance(item, dict): + return f"claim[{i}] must be a JSON object" + for f in _REQUIRED_FIELDS: + if f not in item: + return f"claim[{i}] missing field {f!r}" + if not isinstance(item["claim"], str): + return f"claim[{i}].claim must be a string" + if item["rationale"] is not None and not isinstance(item["rationale"], str): + return f"claim[{i}].rationale must be a string or null" + if not isinstance(item["label"], str): + return f"claim[{i}].label must be a string" + return None diff --git a/phase4_grounding/grounding/prompt.py b/phase4_grounding/grounding/prompt.py new file mode 100644 index 0000000..bc61761 --- /dev/null +++ b/phase4_grounding/grounding/prompt.py @@ -0,0 +1,49 @@ +"""Render the claim-decomposition judge prompt for a single sample row. + +The template lives at `phase4_grounding/prompts/claim_decomp.txt` and uses +`{{NAME}}`-style placeholders so the embedded JSON example does not collide +with `str.format`'s `{` / `}` syntax. +""" +from __future__ import annotations + +from pathlib import Path + +from .models import EvidenceItem, SampleRow + +_DEFAULT_TEMPLATE_PATH = ( + Path(__file__).resolve().parents[1] / "prompts" / "claim_decomp.txt" +) + + +class PromptBuilder: + """Render the claim-decomposition prompt for a `SampleRow`.""" + + def __init__(self, template_path: str | Path = _DEFAULT_TEMPLATE_PATH) -> None: + self.template_path = Path(template_path) + self._template: str | None = None + + def _load(self) -> str: + if self._template is None: + self._template = self.template_path.read_text() + return self._template + + @staticmethod + def _render_evidence(items: tuple[EvidenceItem, ...]) -> str: + if not items: + return "(no evidence sentences attached)" + return "\n".join(f"[E{e.id}] {e.text}" for e in items) + + def build(self, row: SampleRow) -> str: + template = self._load() + replacements = { + "{{COMPOUND_NAME}}": row.compound.name, + "{{SMILES}}": row.compound.smiles, + "{{MOLECULAR_FORMULA}}": row.compound.molecular_formula, + "{{QUESTION}}": row.question, + "{{PHASE2_ANSWER}}": row.phase2_answer, + "{{EVIDENCE_BLOCK}}": self._render_evidence(row.evidence_attached), + } + out = template + for key, value in replacements.items(): + out = out.replace(key, value) + return out diff --git a/phase4_grounding/grounding/reporter.py b/phase4_grounding/grounding/reporter.py new file mode 100644 index 0000000..60c7dfa --- /dev/null +++ b/phase4_grounding/grounding/reporter.py @@ -0,0 +1,202 @@ +"""Markdown writers for the dual grounding-summary view. + +`Reporter(metrics_keep, metrics_drop, judged_qas).write(out_dir)` writes: +- `grounding_summary_keep_structural.md` +- `grounding_summary_drop_structural.md` + +Each summary cites the other so the reader can compare the two views. The +decision rule (>20% / 10–20% / <10% UNSUPPORTED) is printed at the top. +""" +from __future__ import annotations + +from pathlib import Path + +from .models import JudgedQA, ViewMetrics + +_KEEP_FILE = "grounding_summary_keep_structural.md" +_DROP_FILE = "grounding_summary_drop_structural.md" + + +def _decision(unsupported_rate: float) -> str: + pct = unsupported_rate * 100 + if pct > 20: + return ( + f"**Decision: NARROW.** UNSUPPORTED = {pct:.1f}% > 20%. Narrow the " + "paper's grounding claim; flag training-recall risk in DATASHEET / " + "RESPONSIBLE_AI." + ) + if pct < 10: + return ( + f"**Decision: WELL-BEHAVED.** UNSUPPORTED = {pct:.1f}% < 10%. The " + "soft rule looks safe; quote this number in the dataset card." + ) + return ( + f"**Decision: CAVEAT.** UNSUPPORTED = {pct:.1f}% (10–20%). Add a caveat " + "to DATASHEET / RESPONSIBLE_AI; do not claim full grounding." + ) + + +def _fmt_pct(x: float) -> str: + return f"{x * 100:.2f}%" + + +def _fmt_count_pct(count: int, total: int) -> str: + rate = (count / total) if total else 0.0 + return f"{count} ({rate * 100:.2f}%)" + + +class Reporter: + """Renders both summary markdown files from a pair of `ViewMetrics`.""" + + def __init__( + self, + metrics_keep: ViewMetrics, + metrics_drop: ViewMetrics, + judged_qas: list[JudgedQA] | tuple[JudgedQA, ...], + ) -> None: + if metrics_keep.view != "keep": + raise ValueError("metrics_keep must be a 'keep' view") + if metrics_drop.view != "drop": + raise ValueError("metrics_drop must be a 'drop' view") + self.metrics_keep = metrics_keep + self.metrics_drop = metrics_drop + self.judged_qas = tuple(judged_qas) + + def write(self, out_dir: Path) -> tuple[Path, Path]: + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + keep_path = out_dir / _KEEP_FILE + drop_path = out_dir / _DROP_FILE + + keep_path.write_text( + self._render( + metrics=self.metrics_keep, + title="Grounding Summary — KEEP STRUCTURAL view", + description=( + "STRUCTURAL claims (derivable from SMILES / formula alone) are " + "kept as their own bucket and **excluded from the denominator** " + "for STATED / IMPLIED / UNSUPPORTED rates. This view treats " + "UNSUPPORTED as a clean proxy for training-recall risk." + ), + cross_link=f"See also `{_DROP_FILE}` for the alternate view.", + ) + ) + + drop_path.write_text( + self._render( + metrics=self.metrics_drop, + title="Grounding Summary — DROP STRUCTURAL view", + description=( + "STRUCTURAL claims are collapsed into IMPLIED — i.e. SMILES / " + "formula are treated as 'evidence' in a loose sense. Closer to " + "the original PLAN spec." + ), + cross_link=f"See also `{_KEEP_FILE}` for the alternate view.", + ) + ) + + return keep_path, drop_path + + def _render( + self, + *, + metrics: ViewMetrics, + title: str, + description: str, + cross_link: str, + ) -> str: + lo, hi = metrics.unsupported_ci + lines: list[str] = [] + lines.append(f"# {title}") + lines.append("") + lines.append(_decision(metrics.unsupported_rate)) + lines.append("") + lines.append(description) + lines.append("") + lines.append(cross_link) + lines.append("") + lines.append("## Headline metrics") + lines.append("") + lines.append(f"- Total claims (denominator): **{metrics.total_claims}**") + lines.append(f"- Q&A judged: **{len(self.judged_qas)}**") + if metrics.view == "keep": + lines.append( + f"- STRUCTURAL claims (excluded from denominator): " + f"**{metrics.structural_count}**" + ) + for label, count in metrics.counts.items(): + lines.append( + f"- {label}: {_fmt_count_pct(count, metrics.total_claims)}" + ) + lines.append(f"- Grounded (STATED + IMPLIED): **{_fmt_pct(metrics.grounded_rate)}**") + lines.append( + f"- UNSUPPORTED: **{_fmt_pct(metrics.unsupported_rate)}** " + f"(95% Wilson CI: {_fmt_pct(lo)} – {_fmt_pct(hi)})" + ) + lines.append("") + + lines.extend(self._breakdown_table("By topic", metrics.by_topic, sort_keys=True)) + lines.extend( + self._breakdown_table( + "By evidence_ids non-empty", + {str(k): v for k, v in metrics.by_evidence_ids_nonempty.items()}, + sort_keys=False, + ) + ) + lines.extend(self._breakdown_table("By split", metrics.by_split, sort_keys=True)) + + lines.append("## Per-Q&A UNSUPPORTED histogram") + lines.append("") + lines.append("| UNSUPPORTED claims in QA | # of QAs |") + lines.append("|---|---|") + for k in sorted(metrics.per_qa_unsupported_histogram): + lines.append(f"| {k} | {metrics.per_qa_unsupported_histogram[k]} |") + lines.append("") + + lines.append(f"## Top {len(metrics.top_qa_by_unsupported)} Q&A by UNSUPPORTED rate") + lines.append("") + if metrics.top_qa_by_unsupported: + lines.append( + "| cid | qa_index | topic | split | UNSUPPORTED | total | rate |" + ) + lines.append("|---|---|---|---|---|---|---|") + for r in metrics.top_qa_by_unsupported: + lines.append( + f"| {r['cid']} | {r['qa_index']} | {r['topic']} | " + f"{r['split']} | {r['unsupported']} | " + f"{r['total_claims']} | {_fmt_pct(r['unsupported_rate'])} |" + ) + else: + lines.append("(no Q&A judged)") + lines.append("") + + return "\n".join(lines) + + @staticmethod + def _breakdown_table( + title: str, data: dict, *, sort_keys: bool + ) -> list[str]: + if not data: + return [f"## {title}", "", "(no data)", ""] + keys = sorted(data.keys()) if sort_keys else list(data.keys()) + sample = data[keys[0]] + label_keys = [k for k in sample if not k.endswith("_rate") and k != "total"] + out = [f"## {title}", ""] + header = ( + "| key | total | " + + " | ".join(label_keys) + + " | " + + " | ".join(f"{lbl}%" for lbl in label_keys) + + " |" + ) + out.append(header) + out.append("|" + "---|" * (2 + 2 * len(label_keys))) + for key in keys: + row = data[key] + cells = [f"{key}", f"{row['total']}"] + cells.extend(str(row.get(lbl, 0)) for lbl in label_keys) + cells.extend(_fmt_pct(row.get(f"{lbl}_rate", 0.0)) for lbl in label_keys) + out.append("| " + " | ".join(cells) + " |") + out.append("") + return out diff --git a/phase4_grounding/grounding/sampling.py b/phase4_grounding/grounding/sampling.py new file mode 100644 index 0000000..e0da101 --- /dev/null +++ b/phase4_grounding/grounding/sampling.py @@ -0,0 +1,187 @@ +"""Stratified sampling of functional Q&A pairs from the gold dataset. + +The sampler: +- Filters QA where `bucket_topic(topic) == 'functional'`. +- Allocates the requested N across topic strata using a fixed weighting that + emphasizes the four headline topics (mechanism, engineering, metabolism, toxicity) + plus therapeutic_use, with all remaining functional topics in an "other" bucket. +- Within each topic, splits 50/50 between QA with non-empty `evidence_ids` and + QA with empty `evidence_ids`. Falls back to the available pool when one side + is exhausted, preserving the requested per-topic count. +- Is deterministic for a given seed. + +Returns a list of SampleRow with a placeholder for `evidence_attached`; the actual +evidence selection is the EvidenceAttacher's job (Step 3). +""" +from __future__ import annotations + +import json +import random +from collections.abc import Iterable +from dataclasses import replace +from pathlib import Path +from typing import Mapping + +from .models import Compound, EvidenceItem, SampleRow +from .topic_bucket import bucket_topic + +DEFAULT_TOPIC_WEIGHTS: Mapping[str, float] = { + "mechanism": 0.20, + "engineering": 0.20, + "metabolism": 0.17, + "toxicity": 0.17, + "therapeutic_use": 0.13, + "_other": 0.13, +} + +HEADLINE_TOPICS = frozenset({"mechanism", "engineering", "metabolism", "toxicity", "therapeutic_use"}) + + +class Sampler: + """Stratified sampler over functional Q&A pairs.""" + + def __init__( + self, + dataset_path: str | Path, + seed: int = 0, + topic_weights: Mapping[str, float] = DEFAULT_TOPIC_WEIGHTS, + ) -> None: + self.dataset_path = Path(dataset_path) + self.seed = seed + self._weights = dict(topic_weights) + self._validate_weights() + + def _validate_weights(self) -> None: + total = sum(self._weights.values()) + if abs(total - 1.0) > 1e-6: + raise ValueError(f"topic weights must sum to 1.0, got {total}") + for k, v in self._weights.items(): + if v < 0: + raise ValueError(f"weight for {k} is negative: {v}") + + def _iter_records(self) -> Iterable[dict]: + with self.dataset_path.open() as f: + for line in f: + line = line.strip() + if line: + yield json.loads(line) + + @staticmethod + def _topic_key(topic: str) -> str: + """Return the stratum key: a headline topic name or '_other'.""" + normalized = topic.strip().lower().replace("-", "_") + if normalized in HEADLINE_TOPICS: + return normalized + return "_other" + + def _allocate(self, n: int) -> dict[str, int]: + """Allocate n slots to strata using the configured weights. + + Uses largest-remainder rounding so allocations sum exactly to n. + """ + if n <= 0: + raise ValueError(f"n must be positive, got {n}") + raw = {k: w * n for k, w in self._weights.items()} + floors = {k: int(v) for k, v in raw.items()} + remainder = n - sum(floors.values()) + # distribute remainder to strata with largest fractional parts + fracs = sorted( + ((k, raw[k] - floors[k]) for k in self._weights), + key=lambda kv: kv[1], + reverse=True, + ) + for k, _ in fracs[:remainder]: + floors[k] += 1 + return floors + + def _index_qa(self) -> dict[str, list[tuple[dict, dict]]]: + """Group functional QA by stratum key. + + Returns {stratum_key: [(record, qa_pair), ...]}. + """ + index: dict[str, list[tuple[dict, dict]]] = {k: [] for k in self._weights} + for rec in self._iter_records(): + for qa in rec.get("qa_pairs", []): + if bucket_topic(qa.get("topic")) != "functional": + continue + key = self._topic_key(qa["topic"]) + index.setdefault(key, []).append((rec, qa)) + return index + + @staticmethod + def _split_by_evidence( + items: list[tuple[dict, dict]], + ) -> tuple[list[tuple[dict, dict]], list[tuple[dict, dict]]]: + nonempty = [it for it in items if it[1].get("evidence_ids")] + empty = [it for it in items if not it[1].get("evidence_ids")] + return nonempty, empty + + @staticmethod + def _draw( + rng: random.Random, + pool_a: list[tuple[dict, dict]], + pool_b: list[tuple[dict, dict]], + target_a: int, + target_b: int, + ) -> list[tuple[dict, dict]]: + """Draw target_a from pool_a and target_b from pool_b. If one side is short, + backfill from the other. Pools are sampled without replacement. + """ + rng.shuffle(pool_a) + rng.shuffle(pool_b) + take_a = pool_a[: min(target_a, len(pool_a))] + take_b = pool_b[: min(target_b, len(pool_b))] + # backfill shortfall + short_a = target_a - len(take_a) + short_b = target_b - len(take_b) + if short_a > 0: + extra = pool_b[len(take_b) : len(take_b) + short_a] + take_b = take_b + extra + if short_b > 0: + extra = pool_a[len(take_a) : len(take_a) + short_b] + take_a = take_a + extra + return take_a + take_b + + def sample(self, n: int) -> list[SampleRow]: + rng = random.Random(self.seed) + allocations = self._allocate(n) + index = self._index_qa() + + chosen: list[tuple[dict, dict]] = [] + for stratum, k in allocations.items(): + if k == 0: + continue + pool = index.get(stratum, []) + nonempty, empty = self._split_by_evidence(pool) + half = k // 2 + target_nonempty = half + target_empty = k - half + chosen.extend(self._draw(rng, nonempty, empty, target_nonempty, target_empty)) + + return [self._row(rec, qa) for rec, qa in chosen] + + @staticmethod + def _row(rec: dict, qa: dict) -> SampleRow: + compound = Compound( + cid=int(rec["cid"]), + name=rec.get("name", ""), + smiles=rec.get("smiles", ""), + molecular_formula=rec.get("molecular_formula", ""), + ) + # SampleRow ships without evidence; EvidenceAttacher fills it in Step 3. + return SampleRow( + cid=int(rec["cid"]), + qa_index=int(qa["qa_index"]), + topic=qa["topic"], + split=rec.get("split", ""), + evidence_ids_nonempty=bool(qa.get("evidence_ids")), + compound=compound, + question=qa["question"], + phase2_answer=qa["phase2_answer"], + evidence_attached=(), + ) + + @staticmethod + def with_evidence(row: SampleRow, evidence: tuple[EvidenceItem, ...]) -> SampleRow: + """Convenience for downstream code: clone a row with attached evidence.""" + return replace(row, evidence_attached=evidence) diff --git a/phase4_grounding/grounding/topic_bucket.py b/phase4_grounding/grounding/topic_bucket.py new file mode 100644 index 0000000..226863e --- /dev/null +++ b/phase4_grounding/grounding/topic_bucket.py @@ -0,0 +1,22 @@ +"""Re-export `bucket_topic` from the repo's `scripts/topic_bucket.py` without depending +on `scripts/` being an importable package. + +A single source of truth avoids drift between phase4_grounding and the rest of the +pipeline. The loader uses importlib so this module works whether or not the project +is installed in editable mode. +""" +from __future__ import annotations + +import importlib.util +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[2] +_SCRIPT_PATH = _REPO_ROOT / "scripts" / "topic_bucket.py" + +_spec = importlib.util.spec_from_file_location("_topic_bucket_repo", _SCRIPT_PATH) +if _spec is None or _spec.loader is None: + raise ImportError(f"Could not load topic_bucket from {_SCRIPT_PATH}") +_module = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_module) + +bucket_topic = _module.bucket_topic diff --git a/phase4_grounding/prompts/claim_decomp.txt b/phase4_grounding/prompts/claim_decomp.txt new file mode 100644 index 0000000..57cb630 --- /dev/null +++ b/phase4_grounding/prompts/claim_decomp.txt @@ -0,0 +1,42 @@ +You are auditing a chemistry Q&A dataset. Your job is to decompose an answer into atomic claims and label each claim against the cited evidence. + +# Compound +- name: {{COMPOUND_NAME}} +- SMILES: {{SMILES}} +- molecular formula: {{MOLECULAR_FORMULA}} + +# Question +{{QUESTION}} + +# Answer to audit +{{PHASE2_ANSWER}} + +# Evidence +{{EVIDENCE_BLOCK}} + +# Task +Decompose the answer into atomic claims. For each claim, assign exactly one label: + +- STATED: the claim is directly written in some evidence sentence. Set `evidence_id` to the [E#] integer that supports it. +- IMPLIED: the claim is inferable from evidence by one routine domain-reasoning step. Set `evidence_id` to the [E#] integer used and describe the step in `rationale`. +- STRUCTURAL: the claim is derivable from the SMILES, molecular formula, or molecular weight alone, without any evidence sentence. Set `evidence_id` to null. +- UNSUPPORTED: none of the above. Set `evidence_id` to null. This flags a training-recall candidate. + +Atomicity rules: +- One assertion per claim. Split conjunctions ("X and Y") into two claims. +- Numerical values, named entities (gene / protein / enzyme names), and mechanisms are each atomic units. +- Preserve hedging in place ("may", "is thought to", "approximately") inside the claim — do not decompose hedges into separate claims. + +Constraints: +- `rationale` must be at most 25 words. +- `evidence_id` must be one of the [E#] integers shown above, or null. +- Return STRICT JSON only — no commentary, no markdown fences, no trailing text. + +Output schema: + +{"claims":[ + {"claim":"", + "label":"STATED|IMPLIED|UNSUPPORTED|STRUCTURAL", + "evidence_id":, + "rationale":"<<=25 words>"} +]} diff --git a/phase4_grounding/run_phase4_grounding.sh b/phase4_grounding/run_phase4_grounding.sh new file mode 100755 index 0000000..04765e2 --- /dev/null +++ b/phase4_grounding/run_phase4_grounding.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash +# Phase 4 grounding-audit orchestrator. +# +# Three-step pipeline (resumable; safe to re-run): +# 1. sample_qa.py +# 2. judge_claims.py +# 3. aggregate.py +# +# Usage: +# bash phase4_grounding/run_phase4_grounding.sh \ +# --n 300 \ +# --api-key-file ~/.openrouter_key \ +# [--skip-cross-check] \ +# [--data-path data/dataset_gold.jsonl] \ +# [--out-dir phase4_grounding/outputs] \ +# [--primary-model anthropic/claude-sonnet-4.6] \ +# [--cross-check-model google/gemini-2.5-pro] \ +# [--cross-check-n 30] \ +# [--max-usd 15] \ +# [--concurrency 5] \ +# [--seed 0] +set -euo pipefail + +N="" +API_KEY_FILE="$HOME/.openrouter_key" +DATA_PATH="data/dataset_gold.jsonl" +OUT_DIR="phase4_grounding/outputs" +PRIMARY_MODEL="anthropic/claude-sonnet-4.6" +CROSS_CHECK_MODEL="google/gemini-2.5-pro" +CROSS_CHECK_N=30 +MAX_USD=15 +CONCURRENCY=5 +SEED=0 +SKIP_CROSS_CHECK=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --n) N="$2"; shift 2 ;; + --api-key-file) API_KEY_FILE="$2"; shift 2 ;; + --data-path) DATA_PATH="$2"; shift 2 ;; + --out-dir) OUT_DIR="$2"; shift 2 ;; + --primary-model) PRIMARY_MODEL="$2"; shift 2 ;; + --cross-check-model) CROSS_CHECK_MODEL="$2"; shift 2 ;; + --cross-check-n) CROSS_CHECK_N="$2"; shift 2 ;; + --max-usd) MAX_USD="$2"; shift 2 ;; + --concurrency) CONCURRENCY="$2"; shift 2 ;; + --seed) SEED="$2"; shift 2 ;; + --skip-cross-check) SKIP_CROSS_CHECK=1; shift ;; + -h|--help) + sed -n '2,18p' "$0" + exit 0 ;; + *) echo "unknown argument: $1" >&2; exit 2 ;; + esac +done + +if [[ -z "$N" ]]; then + echo "error: --n is required" >&2 + exit 2 +fi + +mkdir -p "$OUT_DIR" + +echo "==> [1/3] sample_qa.py --n $N --seed $SEED" +python -m phase4_grounding.scripts.sample_qa \ + --n "$N" \ + --seed "$SEED" \ + --data-path "$DATA_PATH" \ + --out-dir "$OUT_DIR" + +echo "==> [2/3] judge_claims.py" +xcheck_flag=() +if [[ "$SKIP_CROSS_CHECK" -eq 1 ]]; then + xcheck_flag+=(--skip-cross-check) +fi +python -m phase4_grounding.scripts.judge_claims \ + --api-key-file "$API_KEY_FILE" \ + --primary-model "$PRIMARY_MODEL" \ + --cross-check-model "$CROSS_CHECK_MODEL" \ + --cross-check-n "$CROSS_CHECK_N" \ + --max-usd "$MAX_USD" \ + --concurrency "$CONCURRENCY" \ + --out-dir "$OUT_DIR" \ + "${xcheck_flag[@]}" + +echo "==> [3/3] aggregate.py" +python -m phase4_grounding.scripts.aggregate --out-dir "$OUT_DIR" + +echo "==> done. Outputs under: $OUT_DIR" diff --git a/phase4_grounding/scripts/__init__.py b/phase4_grounding/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phase4_grounding/scripts/aggregate.py b/phase4_grounding/scripts/aggregate.py new file mode 100644 index 0000000..97f54fa --- /dev/null +++ b/phase4_grounding/scripts/aggregate.py @@ -0,0 +1,89 @@ +"""Aggregate `claims_per_qa.jsonl` into the dual grounding-summary markdown files.""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from phase4_grounding.grounding.aggregator import Aggregator +from phase4_grounding.grounding.models import Claim, JudgedQA +from phase4_grounding.grounding.reporter import Reporter + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Aggregate judged claims into summaries") + p.add_argument("--out-dir", type=Path, default=Path("phase4_grounding/outputs")) + p.add_argument( + "--claims-file", + type=str, + default="claims_per_qa.jsonl", + help="filename inside --out-dir to read", + ) + return p + + +def load_judged(path: Path) -> list[JudgedQA]: + judged: list[JudgedQA] = [] + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + d = json.loads(line) + judged.append(_judged_from_dict(d)) + return judged + + +def _judged_from_dict(d: dict) -> JudgedQA: + claims = tuple( + Claim( + claim=c["claim"], + label=c["label"], + evidence_id=c.get("evidence_id"), + rationale=c.get("rationale", ""), + ) + for c in d.get("claims", []) + ) + usage = d.get("usage") or {} + return JudgedQA( + cid=int(d["cid"]), + qa_index=int(d["qa_index"]), + topic=d["topic"], + evidence_ids_nonempty=bool(d.get("evidence_ids_nonempty", False)), + num_evidence_attached=int(d.get("num_evidence_attached", 0)), + model=d.get("model", ""), + claims=claims, + prompt_tokens=int(usage.get("prompt_tokens", 0)), + completion_tokens=int(usage.get("completion_tokens", 0)), + latency_ms=int(d.get("latency_ms", 0)), + split=d.get("split", ""), + ) + + +def main(argv: list[str] | None = None) -> int: + args = _build_parser().parse_args(argv) + claims_path = args.out_dir / args.claims_file + judged = load_judged(claims_path) + + aggregator = Aggregator(judged) + metrics_keep = aggregator.compute("keep") + metrics_drop = aggregator.compute("drop") + + reporter = Reporter(metrics_keep, metrics_drop, judged) + keep_path, drop_path = reporter.write(args.out_dir) + + print(f"wrote {keep_path}") + print(f"wrote {drop_path}") + print( + f"keep-view UNSUPPORTED rate: {metrics_keep.unsupported_rate * 100:.2f}% " + f"(n_claims={metrics_keep.total_claims})" + ) + print( + f"drop-view UNSUPPORTED rate: {metrics_drop.unsupported_rate * 100:.2f}% " + f"(n_claims={metrics_drop.total_claims})" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/phase4_grounding/scripts/analyze_cross_check_agreement.py b/phase4_grounding/scripts/analyze_cross_check_agreement.py new file mode 100644 index 0000000..abb0a9d --- /dev/null +++ b/phase4_grounding/scripts/analyze_cross_check_agreement.py @@ -0,0 +1,162 @@ +"""Compute primary (sonnet) vs cross-check (gemini) agreement on the 30 +dual-judged rows from the Phase 4 grounding audit. + +The two models produce independent claim decompositions, so claim-level +1:1 alignment is undefined. Instead we measure rate-level agreement +per Q&A: each model produces a UNSUPPORTED rate and a Grounded rate +for the same row; we compare those. + +Writes a Markdown summary to outputs/cross_check_agreement.md. +""" +from __future__ import annotations + +import json +import statistics +from pathlib import Path + +OUT_DIR = Path(__file__).resolve().parents[1] / "outputs" +PRIMARY = OUT_DIR / "claims_per_qa.jsonl" +GEMINI = OUT_DIR / "claims_per_qa.gemini.jsonl" +REPORT = OUT_DIR / "cross_check_agreement.md" + + +def _load(path): + rows = {} + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + d = json.loads(line) + rows[(int(d["cid"]), int(d["qa_index"]))] = d + return rows + + +def _label_counts(claims): + counts = {"STATED": 0, "IMPLIED": 0, "UNSUPPORTED": 0, "STRUCTURAL": 0} + for c in claims: + counts[c["label"]] = counts.get(c["label"], 0) + 1 + return counts + + +def _rates(counts, view): + if view == "keep": + denom = counts["STATED"] + counts["IMPLIED"] + counts["UNSUPPORTED"] + else: + denom = sum(counts.values()) + if denom == 0: + return None, None, denom + if view == "keep": + unsupported = counts["UNSUPPORTED"] / denom + grounded = (counts["STATED"] + counts["IMPLIED"]) / denom + else: + unsupported = counts["UNSUPPORTED"] / denom + grounded = (counts["STATED"] + counts["IMPLIED"] + counts["STRUCTURAL"]) / denom + return unsupported, grounded, denom + + +def main(): + primary = _load(PRIMARY) + gemini = _load(GEMINI) + keys = sorted(set(primary) & set(gemini)) + + rows_keep = [] + rows_drop = [] + primary_models = set() + gemini_models = set() + for k in keys: + p = primary[k] + g = gemini[k] + primary_models.add(p["model"]) + gemini_models.add(g["model"]) + pc = _label_counts(p["claims"]) + gc = _label_counts(g["claims"]) + for view, sink in (("keep", rows_keep), ("drop", rows_drop)): + pu, pg, pd = _rates(pc, view) + gu, gg, gd = _rates(gc, view) + if pu is None or gu is None: + continue + sink.append( + { + "cid": k[0], + "qa": k[1], + "p_unsupported": pu, + "g_unsupported": gu, + "diff": gu - pu, + "p_n": pd, + "g_n": gd, + } + ) + + def _summary(rows, view_name): + diffs = [r["diff"] for r in rows] + abs_diffs = [abs(d) for d in diffs] + p_macro = ( + sum(r["p_unsupported"] * r["p_n"] for r in rows) / sum(r["p_n"] for r in rows) + ) + g_macro = ( + sum(r["g_unsupported"] * r["g_n"] for r in rows) / sum(r["g_n"] for r in rows) + ) + return { + "view": view_name, + "n_rows": len(rows), + "primary_macro_unsupported": p_macro, + "gemini_macro_unsupported": g_macro, + "macro_diff": g_macro - p_macro, + "mean_abs_per_row_diff": statistics.fmean(abs_diffs) if abs_diffs else 0.0, + "median_abs_per_row_diff": statistics.median(abs_diffs) if abs_diffs else 0.0, + "max_abs_per_row_diff": max(abs_diffs) if abs_diffs else 0.0, + "rows_within_10pp": sum(1 for d in abs_diffs if d <= 0.10), + "rows_within_20pp": sum(1 for d in abs_diffs if d <= 0.20), + } + + summ_keep = _summary(rows_keep, "keep-structural") + summ_drop = _summary(rows_drop, "drop-structural") + + lines = [] + lines.append("# Cross-check agreement — primary (sonnet-4.6) vs cross-check (gemini-2.5-pro)\n") + lines.append( + f"Comparison on {len(keys)} Q&A rows judged independently by both models.\n" + ) + lines.append(f"- Primary models present in subset: {sorted(primary_models)}") + lines.append(f"- Cross-check models: {sorted(gemini_models)}\n") + + for s in (summ_keep, summ_drop): + lines.append(f"## {s['view']} view") + lines.append("") + lines.append("| Metric | Value |") + lines.append("|---|---|") + lines.append(f"| n rows compared | {s['n_rows']} |") + lines.append(f"| Primary macro UNSUPPORTED | **{s['primary_macro_unsupported']*100:.2f}%** |") + lines.append(f"| Gemini macro UNSUPPORTED | **{s['gemini_macro_unsupported']*100:.2f}%** |") + lines.append(f"| Macro diff (gemini − primary) | {s['macro_diff']*100:+.2f}pp |") + lines.append(f"| Mean abs per-row diff | {s['mean_abs_per_row_diff']*100:.2f}pp |") + lines.append(f"| Median abs per-row diff | {s['median_abs_per_row_diff']*100:.2f}pp |") + lines.append(f"| Max abs per-row diff | {s['max_abs_per_row_diff']*100:.2f}pp |") + lines.append(f"| Rows within 10pp | {s['rows_within_10pp']} / {s['n_rows']} |") + lines.append(f"| Rows within 20pp | {s['rows_within_20pp']} / {s['n_rows']} |") + lines.append("") + + lines.append("## Per-row UNSUPPORTED rates (keep-structural)") + lines.append("") + lines.append("| cid | qa | primary | gemini | diff |") + lines.append("|---|---|---|---|---|") + for r in sorted(rows_keep, key=lambda r: -abs(r["diff"])): + lines.append( + f"| {r['cid']} | {r['qa']} | {r['p_unsupported']*100:.1f}% " + f"| {r['g_unsupported']*100:.1f}% | {r['diff']*100:+.1f}pp |" + ) + + REPORT.write_text("\n".join(lines) + "\n") + print(f"wrote {REPORT}") + for s in (summ_keep, summ_drop): + print( + f" {s['view']}: primary {s['primary_macro_unsupported']*100:.2f}% | " + f"gemini {s['gemini_macro_unsupported']*100:.2f}% | " + f"diff {s['macro_diff']*100:+.2f}pp | " + f"mean abs per-row {s['mean_abs_per_row_diff']*100:.2f}pp" + ) + + +if __name__ == "__main__": + main() diff --git a/phase4_grounding/scripts/judge_claims.py b/phase4_grounding/scripts/judge_claims.py new file mode 100644 index 0000000..190a32a --- /dev/null +++ b/phase4_grounding/scripts/judge_claims.py @@ -0,0 +1,253 @@ +"""Judge runner entry point: decomposes each sampled phase2 answer into claims. + +Reads `sample.jsonl`, judges each row via `ClaimJudge`, writes successful +rows to `claims_per_qa.jsonl` (resumable) and parse failures to +`claims_per_qa.errors.jsonl`. Aborts on budget-cap violation, preserving +partial output so a re-run picks up where it left off. + +`run_judge_pass` is the testable core — `main` just wires up the live +OpenRouterClient. +""" +from __future__ import annotations + +import argparse +import asyncio +import json +from pathlib import Path +from typing import Protocol + +from phase4_grounding.grounding.judge import ClaimJudge, JudgeError +from phase4_grounding.grounding.models import ( + ChatResult, + Compound, + EvidenceItem, + SampleRow, +) +from phase4_grounding.grounding.openrouter_client import BudgetExceeded, OpenRouterClient +from phase4_grounding.grounding.prompt import PromptBuilder + + +class _ChatClientLike(Protocol): + spend_usd: float + + async def chat(self, *, model: str, prompt: str, **kwargs: object) -> ChatResult: ... + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Judge sampled QA for grounding audit") + p.add_argument("--api-key-file", type=Path, default=Path("~/.openrouter_key")) + p.add_argument("--primary-model", default="anthropic/claude-sonnet-4.6") + p.add_argument("--cross-check-model", default="google/gemini-2.5-pro") + p.add_argument("--skip-cross-check", action="store_true") + p.add_argument("--cross-check-n", type=int, default=30) + p.add_argument("--max-usd", type=float, default=15.0) + p.add_argument("--concurrency", type=int, default=5) + p.add_argument("--out-dir", type=Path, default=Path("phase4_grounding/outputs")) + return p + + +def row_from_dict(d: dict) -> SampleRow: + return SampleRow( + cid=int(d["cid"]), + qa_index=int(d["qa_index"]), + topic=d["topic"], + split=d["split"], + evidence_ids_nonempty=bool(d["evidence_ids_nonempty"]), + compound=Compound( + cid=int(d["cid"]), + name=d["compound"]["name"], + smiles=d["compound"]["smiles"], + molecular_formula=d["compound"]["molecular_formula"], + ), + question=d["question"], + phase2_answer=d["phase2_answer"], + evidence_attached=tuple( + EvidenceItem( + id=int(e["id"]), + text=e["text"], + pmid=e.get("pmid"), + source=e.get("source"), + ) + for e in d["evidence_attached"] + ), + ) + + +def load_sample(path: Path) -> list[SampleRow]: + rows: list[SampleRow] = [] + with path.open() as f: + for line in f: + line = line.strip() + if line: + rows.append(row_from_dict(json.loads(line))) + return rows + + +def already_judged(path: Path) -> set[tuple[int, int]]: + if not path.exists(): + return set() + done: set[tuple[int, int]] = set() + with path.open() as f: + for line in f: + line = line.strip() + if line: + d = json.loads(line) + done.add((int(d["cid"]), int(d["qa_index"]))) + return done + + +def _judged_to_dict(judged) -> dict: + return { + "cid": judged.cid, + "qa_index": judged.qa_index, + "topic": judged.topic, + "split": judged.split, + "evidence_ids_nonempty": judged.evidence_ids_nonempty, + "num_evidence_attached": judged.num_evidence_attached, + "model": judged.model, + "claims": [ + { + "claim": c.claim, + "label": c.label, + "evidence_id": c.evidence_id, + "rationale": c.rationale, + } + for c in judged.claims + ], + "usage": { + "prompt_tokens": judged.prompt_tokens, + "completion_tokens": judged.completion_tokens, + }, + "latency_ms": judged.latency_ms, + } + + +def _error_to_dict(exc: JudgeError) -> dict: + return { + "cid": exc.cid, + "qa_index": exc.qa_index, + "model": exc.model, + "raw": exc.raw, + "first_error": exc.first_error, + "second_error": exc.second_error, + } + + +async def run_judge_pass( + rows: list[SampleRow], + judge: ClaimJudge, + client: _ChatClientLike, + *, + model: str, + out_path: Path, + err_path: Path, + concurrency: int, +) -> int: + """Judge each row and append its record to out_path (or err_path on parse failure). + + Returns the number of successfully judged rows. Raises `BudgetExceeded` if + the client refuses a call because the cumulative spend cap was crossed; + rows already written before the abort are preserved on disk. + """ + out_path.parent.mkdir(parents=True, exist_ok=True) + err_path.parent.mkdir(parents=True, exist_ok=True) + + sem = asyncio.Semaphore(concurrency) + lock = asyncio.Lock() + success = 0 + + with out_path.open("a") as out_f, err_path.open("a") as err_f: + + async def run_one(row: SampleRow) -> None: + nonlocal success + async with sem: + try: + judged = await judge.judge(row, model=model) + except JudgeError as exc: + async with lock: + err_f.write(json.dumps(_error_to_dict(exc)) + "\n") + err_f.flush() + return + async with lock: + out_f.write(json.dumps(_judged_to_dict(judged)) + "\n") + out_f.flush() + success += 1 + if success % 25 == 0: + print( + f"judged {success} rows | spend ${client.spend_usd:.4f}" + ) + + await asyncio.gather(*(run_one(r) for r in rows)) + + return success + + +async def _run(args: argparse.Namespace) -> int: + args.out_dir.mkdir(parents=True, exist_ok=True) + sample_path = args.out_dir / "sample.jsonl" + primary_out = args.out_dir / "claims_per_qa.jsonl" + primary_err = args.out_dir / "claims_per_qa.errors.jsonl" + + all_rows = load_sample(sample_path) + done = already_judged(primary_out) + pending = [r for r in all_rows if (r.cid, r.qa_index) not in done] + print( + f"primary pass: {len(pending)} pending / {len(all_rows)} total " + f"(model: {args.primary_model})" + ) + + api_key = Path(args.api_key_file).expanduser().read_text().strip() + async with OpenRouterClient( + api_key=api_key, + concurrency=args.concurrency, + max_usd=args.max_usd, + ) as client: + judge = ClaimJudge(client, PromptBuilder()) + try: + await run_judge_pass( + pending, + judge, + client, + model=args.primary_model, + out_path=primary_out, + err_path=primary_err, + concurrency=args.concurrency, + ) + except BudgetExceeded as exc: + print(f"aborting primary pass: {exc}") + + if not args.skip_cross_check: + xcheck_out = args.out_dir / "claims_per_qa.gemini.jsonl" + xcheck_err = args.out_dir / "claims_per_qa.gemini.errors.jsonl" + done_x = already_judged(xcheck_out) + subset_head = all_rows[: args.cross_check_n] + pending_x = [r for r in subset_head if (r.cid, r.qa_index) not in done_x] + print( + f"cross-check pass: {len(pending_x)} pending / {len(subset_head)} " + f"(model: {args.cross_check_model})" + ) + try: + await run_judge_pass( + pending_x, + judge, + client, + model=args.cross_check_model, + out_path=xcheck_out, + err_path=xcheck_err, + concurrency=args.concurrency, + ) + except BudgetExceeded as exc: + print(f"aborting cross-check pass: {exc}") + + print(f"final spend: ${client.spend_usd:.4f}") + + return 0 + + +def main(argv: list[str] | None = None) -> int: + args = _build_parser().parse_args(argv) + return asyncio.run(_run(args)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/phase4_grounding/scripts/rejudge_errors.py b/phase4_grounding/scripts/rejudge_errors.py new file mode 100644 index 0000000..475639b --- /dev/null +++ b/phase4_grounding/scripts/rejudge_errors.py @@ -0,0 +1,167 @@ +"""One-shot rejudge for rows that landed in claims_per_qa.errors.jsonl. + +Targets two failure modes seen on the 300-item production run: + * `raw=None` from OpenRouter (refusal / empty content) + * truncated JSON (provider-side default max_tokens cutoff) + +Strategy: re-run primary judging on those rows with an explicit +max_tokens=8000 cap. Successful rejudgments are appended to +claims_per_qa.jsonl; the errors file is rewritten with only items that +still failed after this attempt. +""" +from __future__ import annotations + +import asyncio +import json +import sys +from pathlib import Path + +PHASE4_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(PHASE4_ROOT.parent)) + +from phase4_grounding.grounding.judge import JudgeError +from phase4_grounding.grounding.openrouter_client import OpenRouterClient +from phase4_grounding.grounding.parser import ClaimParser +from phase4_grounding.grounding.prompt import PromptBuilder +from phase4_grounding.scripts.judge_claims import ( + _error_to_dict, + _judged_to_dict, + row_from_dict, +) +from phase4_grounding.grounding.models import JudgedQA + +OUT_DIR = PHASE4_ROOT / "outputs" +SAMPLE = OUT_DIR / "sample.jsonl" +PRIMARY_OUT = OUT_DIR / "claims_per_qa.jsonl" +ERRORS = OUT_DIR / "claims_per_qa.errors.jsonl" +KEY_PATH = Path("~/.openrouter_key").expanduser() +MODEL = "google/gemini-2.5-pro" +MAX_TOKENS = 8000 +CONCURRENCY = 4 + + +async def _judge_once(client, prompt_builder, parser, row, model, max_tokens): + prompt = prompt_builder.build(row) + attached_ids = {e.id for e in row.evidence_attached} + first = await client.chat(model=model, prompt=prompt, max_tokens=max_tokens) + first_parsed = parser.parse(first.text, attached_ids) + if first_parsed.ok: + return _make_judged(row, model, first_parsed, [first]), None + + retry_prompt = prompt + ( + "\n\nYour previous output was not valid JSON. Return only the JSON " + "object — no markdown fences, no commentary." + ) + second = await client.chat(model=model, prompt=retry_prompt, max_tokens=max_tokens) + second_parsed = parser.parse(second.text, attached_ids) + if second_parsed.ok: + return _make_judged(row, model, second_parsed, [first, second]), None + + err = JudgeError( + "rejudge: failed JSON parse on both attempts", + cid=row.cid, + qa_index=row.qa_index, + model=model, + raw=second.text, + first_error=first_parsed.error, + second_error=second_parsed.error, + ) + return None, err + + +def _make_judged(row, model, parsed, chats) -> JudgedQA: + return JudgedQA( + cid=row.cid, + qa_index=row.qa_index, + topic=row.topic, + evidence_ids_nonempty=row.evidence_ids_nonempty, + num_evidence_attached=len(row.evidence_attached), + model=model, + claims=parsed.claims, + prompt_tokens=sum(c.prompt_tokens for c in chats), + completion_tokens=sum(c.completion_tokens for c in chats), + latency_ms=sum(c.latency_ms for c in chats), + split=row.split, + ) + + +async def main() -> int: + error_keys = [] + with ERRORS.open() as f: + for line in f: + line = line.strip() + if not line: + continue + d = json.loads(line) + error_keys.append((d["cid"], d["qa_index"])) + + if not error_keys: + print("no errored rows — nothing to do") + return 0 + + error_set = set(error_keys) + print(f"rejudging {len(error_set)} rows: {sorted(error_set)}") + + rows_by_key = {} + with SAMPLE.open() as f: + for line in f: + line = line.strip() + if not line: + continue + d = json.loads(line) + key = (int(d["cid"]), int(d["qa_index"])) + if key in error_set: + rows_by_key[key] = row_from_dict(d) + + missing = error_set - rows_by_key.keys() + if missing: + print(f"WARNING: {len(missing)} errored rows not in sample.jsonl: {sorted(missing)}") + + rows = [rows_by_key[k] for k in error_set if k in rows_by_key] + + api_key = KEY_PATH.read_text().strip() + + builder = PromptBuilder() + parser = ClaimParser() + sem = asyncio.Semaphore(CONCURRENCY) + + successes: list = [] + failures: list = [] + + async with OpenRouterClient(api_key=api_key, concurrency=CONCURRENCY, max_usd=2.0) as client: + async def go(row): + async with sem: + return await _judge_once(client, builder, parser, row, MODEL, MAX_TOKENS) + + results = await asyncio.gather(*(go(r) for r in rows), return_exceptions=True) + + for row, result in zip(rows, results): + if isinstance(result, Exception): + print(f" {row.cid}/{row.qa_index}: EXCEPTION {type(result).__name__}: {result}") + continue + judged, err = result + if judged is not None: + successes.append(judged) + print(f" {row.cid}/{row.qa_index}: OK ({len(judged.claims)} claims)") + else: + failures.append(err) + print(f" {row.cid}/{row.qa_index}: STILL FAILED ({err.first_error}; {err.second_error})") + + print(f"\nrejudge spend: ${client.spend_usd:.4f}") + + if successes: + with PRIMARY_OUT.open("a") as f: + for j in successes: + f.write(json.dumps(_judged_to_dict(j)) + "\n") + + with ERRORS.open("w") as f: + for err in failures: + f.write(json.dumps(_error_to_dict(err)) + "\n") + + print(f"\nappended {len(successes)} rows to {PRIMARY_OUT.name}") + print(f"rewrote {ERRORS.name} with {len(failures)} remaining errors") + return 0 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/phase4_grounding/scripts/sample_qa.py b/phase4_grounding/scripts/sample_qa.py new file mode 100644 index 0000000..3d96cef --- /dev/null +++ b/phase4_grounding/scripts/sample_qa.py @@ -0,0 +1,104 @@ +"""Stratified sampling entry point: writes `sample.jsonl` under --out-dir. + +Composes the tested library modules `Sampler` + `EvidenceAttacher`. Output +layout matches PLAN §Step 1 so the judge runner and aggregator can consume +it without reopening the gold dataset. +""" +from __future__ import annotations + +import argparse +import json +from dataclasses import replace +from pathlib import Path + +from phase4_grounding.grounding.evidence import EvidenceAttacher +from phase4_grounding.grounding.models import SampleRow +from phase4_grounding.grounding.sampling import Sampler + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Sample functional QA for grounding audit") + p.add_argument("--n", type=int, required=True, help="total sample size") + p.add_argument("--seed", type=int, default=0) + p.add_argument( + "--data-path", + type=Path, + default=Path("data/dataset_gold.jsonl"), + help="input dataset (JSONL)", + ) + p.add_argument( + "--out-dir", + type=Path, + default=Path("phase4_grounding/outputs"), + help="directory to write sample.jsonl", + ) + return p + + +def _index_records(dataset_path: Path) -> dict[int, dict]: + idx: dict[int, dict] = {} + with dataset_path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + rec = json.loads(line) + idx[int(rec["cid"])] = rec + return idx + + +def _attach_evidence(row: SampleRow, rec: dict) -> SampleRow: + qa = next( + (q for q in rec.get("qa_pairs", []) if int(q.get("qa_index")) == row.qa_index), + None, + ) + if qa is None: + raise RuntimeError( + f"qa_index {row.qa_index} not found in cid={row.cid}" + ) + attached = EvidenceAttacher.attach(qa, rec) + return replace(row, evidence_attached=attached) + + +def row_to_dict(row: SampleRow) -> dict: + return { + "cid": row.cid, + "qa_index": row.qa_index, + "topic": row.topic, + "split": row.split, + "evidence_ids_nonempty": row.evidence_ids_nonempty, + "compound": { + "name": row.compound.name, + "smiles": row.compound.smiles, + "molecular_formula": row.compound.molecular_formula, + }, + "question": row.question, + "phase2_answer": row.phase2_answer, + "evidence_attached": [ + {"id": e.id, "text": e.text, "pmid": e.pmid, "source": e.source} + for e in row.evidence_attached + ], + } + + +def main(argv: list[str] | None = None) -> int: + args = _build_parser().parse_args(argv) + args.out_dir.mkdir(parents=True, exist_ok=True) + + sampler = Sampler(dataset_path=args.data_path, seed=args.seed) + rows = sampler.sample(args.n) + + records = _index_records(args.data_path) + attached_rows = [_attach_evidence(row, records[row.cid]) for row in rows] + + out_path = args.out_dir / "sample.jsonl" + with out_path.open("w") as f: + for row in attached_rows: + f.write(json.dumps(row_to_dict(row)) + "\n") + + print(f"wrote {len(attached_rows)} rows to {out_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/phase4_grounding/tests/__init__.py b/phase4_grounding/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phase4_grounding/tests/conftest.py b/phase4_grounding/tests/conftest.py new file mode 100644 index 0000000..b5a5263 --- /dev/null +++ b/phase4_grounding/tests/conftest.py @@ -0,0 +1,70 @@ +"""Shared fixtures for phase4_grounding tests. + +Notes: +- `tiny_dataset_path` returns the path to the hand-crafted JSONL. +- `tiny_dataset_records` returns the parsed dataset as a list of dicts. +- `fake_openrouter_client` returns a scriptable async client that never hits the network. +""" +from __future__ import annotations + +import json +import sys +from collections.abc import Iterable +from pathlib import Path + +import pytest + +# Make the project root importable so `from phase4_grounding.grounding... import` works +# regardless of how pytest is invoked. +_ROOT = Path(__file__).resolve().parents[2] +if str(_ROOT) not in sys.path: + sys.path.insert(0, str(_ROOT)) + +from phase4_grounding.grounding.models import ChatResult # noqa: E402 + + +@pytest.fixture +def tiny_dataset_path() -> Path: + return Path(__file__).parent / "data" / "tiny_dataset.jsonl" + + +@pytest.fixture +def tiny_dataset_records(tiny_dataset_path: Path) -> list[dict]: + with tiny_dataset_path.open() as f: + return [json.loads(line) for line in f if line.strip()] + + +@pytest.fixture +def tmp_out_dir(tmp_path: Path) -> Path: + out = tmp_path / "outputs" + out.mkdir() + return out + + +class FakeOpenRouterClient: + """Scriptable async stand-in for OpenRouterClient. + + Construct with a list of responses; each call to `chat` pops the next. + A response can be a ChatResult (success) or an Exception instance (raised). + """ + + def __init__(self, responses: Iterable[ChatResult | Exception]) -> None: + self._responses = list(responses) + self.calls: list[dict] = [] + + async def chat(self, *, model: str, prompt: str, **kwargs) -> ChatResult: + self.calls.append({"model": model, "prompt": prompt, "kwargs": kwargs}) + if not self._responses: + raise AssertionError("FakeOpenRouterClient ran out of scripted responses") + nxt = self._responses.pop(0) + if isinstance(nxt, Exception): + raise nxt + return nxt + + +@pytest.fixture +def fake_openrouter_client(): + def _factory(responses: Iterable[ChatResult | Exception]) -> FakeOpenRouterClient: + return FakeOpenRouterClient(responses) + + return _factory diff --git a/phase4_grounding/tests/data/tiny_dataset.jsonl b/phase4_grounding/tests/data/tiny_dataset.jsonl new file mode 100644 index 0000000..e15fb76 --- /dev/null +++ b/phase4_grounding/tests/data/tiny_dataset.jsonl @@ -0,0 +1,3 @@ +{"cid": 1001, "split": "train", "name": "Aspirin", "iupac_name": "2-acetoxybenzoic acid", "smiles": "CC(=O)Oc1ccccc1C(=O)O", "molecular_formula": "C9H8O4", "molecular_weight": 180.16, "inchi_key": "BSYNRYMUTXBXSQ-UHFFFAOYSA-N", "num_pmids": 2, "num_synonyms": 5, "num_evidence_sentences": 3, "evidence_sentences": [{"id": 1, "pmid": "11111111", "source": "abstract", "text": "[COMPOUND] irreversibly acetylates serine 530 of cyclooxygenase-1 (COX-1)."}, {"id": 2, "pmid": "11111111", "source": "abstract", "text": "Inhibition of COX-1 by [COMPOUND] reduces thromboxane A2 production in platelets."}, {"id": 3, "pmid": "22222222", "source": "abstract", "text": "[COMPOUND] is hydrolyzed by carboxylesterases to salicylic acid in the liver."}], "qa_pairs": [{"qa_index": 1, "topic": "mechanism", "question": "How does aspirin inhibit COX-1?", "phase1_answer": "Aspirin acetylates a serine residue in COX-1.", "phase2_answer": "Aspirin acts by irreversibly acetylating serine 530 of cyclooxygenase-1 (COX-1), which blocks the enzyme's active site. This inhibition prevents conversion of arachidonic acid to prostaglandins. In platelets, the resulting reduction in thromboxane A2 explains aspirin's antiplatelet activity. The acetyl group is transferred from the acetyloxy substituent on the benzoic-acid scaffold.", "verdict": "agree", "judge_reasoning": "Both answers identify the acetylation mechanism.", "evidence_ids": [1, 2]}, {"qa_index": 2, "topic": "metabolism", "question": "How is aspirin metabolized?", "phase1_answer": "It is hydrolyzed to salicylic acid.", "phase2_answer": "Aspirin is rapidly hydrolyzed by hepatic carboxylesterases to salicylic acid. Salicylic acid is further conjugated with glycine to form salicyluric acid, the major urinary metabolite. A minor fraction undergoes glucuronidation. Renal clearance is pH-dependent due to the ionizable carboxylic acid.", "verdict": "agree", "judge_reasoning": "Both identify hydrolysis to salicylate.", "evidence_ids": [3]}, {"qa_index": 3, "topic": "engineering", "question": "How could you redesign aspirin to reduce gastric irritation?", "phase1_answer": "Mask the carboxylic acid as a prodrug.", "phase2_answer": "Reducing direct exposure of the gastric mucosa to the free carboxylic acid is the standard lever. An ester or amide prodrug of the carboxyl group, hydrolyzed systemically, lowers local irritation. Enteric coating is a formulation alternative. The acetyloxy group should be preserved since acetylation of COX-1 is the active mechanism.", "verdict": "agree", "judge_reasoning": "Both suggest masking the acid.", "evidence_ids": []}]} +{"cid": 1002, "split": "val", "name": "Caffeine", "iupac_name": "1,3,7-trimethylpurine-2,6-dione", "smiles": "Cn1cnc2c1c(=O)n(C)c(=O)n2C", "molecular_formula": "C8H10N4O2", "molecular_weight": 194.19, "inchi_key": "RYYVLZVUVIJVGH-UHFFFAOYSA-N", "num_pmids": 1, "num_synonyms": 12, "num_evidence_sentences": 2, "evidence_sentences": [{"id": 1, "pmid": "33333333", "source": "abstract", "text": "[COMPOUND] is a competitive antagonist at adenosine A1 and A2A receptors."}, {"id": 2, "pmid": "33333333", "source": "abstract", "text": "CYP1A2 is the primary enzyme responsible for [COMPOUND] N-3 demethylation in humans."}], "qa_pairs": [{"qa_index": 1, "topic": "mechanism", "question": "What is caffeine's primary mechanism of action?", "phase1_answer": "It blocks adenosine receptors.", "phase2_answer": "Caffeine is a competitive antagonist at adenosine A1 and A2A receptors. Blocking A2A in the striatum increases dopaminergic transmission, which underlies its psychostimulant effect. The trimethylxanthine scaffold is essential for adenosine-receptor affinity.", "verdict": "agree", "judge_reasoning": "Both identify adenosine antagonism.", "evidence_ids": [1]}, {"qa_index": 2, "topic": "toxicity", "question": "What is the main toxicity concern at high doses?", "phase1_answer": "Cardiovascular effects.", "phase2_answer": "At supratherapeutic doses, caffeine causes tachyarrhythmias and seizures, mediated by combined adenosine-receptor antagonism and phosphodiesterase inhibition. Lethal dose in adults is approximately 150 mg/kg.", "verdict": "agree", "judge_reasoning": "Both flag cardiovascular toxicity.", "evidence_ids": []}, {"qa_index": 3, "topic": "metabolism", "question": "Which CYP enzyme demethylates caffeine?", "phase1_answer": "CYP1A2.", "phase2_answer": "CYP1A2 catalyzes the N-3 demethylation of caffeine to paraxanthine, which accounts for roughly 80% of caffeine clearance in humans. CYP2E1 contributes a minor fraction.", "verdict": "agree", "judge_reasoning": "Both name CYP1A2.", "evidence_ids": [2]}]} +{"cid": 1003, "split": "test", "name": "Methanol", "iupac_name": "methanol", "smiles": "CO", "molecular_formula": "CH4O", "molecular_weight": 32.04, "inchi_key": "OKKJLVBELUTLKV-UHFFFAOYSA-N", "num_pmids": 1, "num_synonyms": 3, "num_evidence_sentences": 1, "evidence_sentences": [{"id": 1, "pmid": "44444444", "source": "abstract", "text": "Alcohol dehydrogenase oxidizes [COMPOUND] to formaldehyde, which aldehyde dehydrogenase further converts to formic acid."}], "qa_pairs": [{"qa_index": 1, "topic": "toxicity", "question": "Why is methanol toxic?", "phase1_answer": "It is metabolized to formate.", "phase2_answer": "Methanol toxicity is driven by its hepatic oxidation: alcohol dehydrogenase converts it to formaldehyde, then aldehyde dehydrogenase to formic acid. Formate accumulation causes metabolic acidosis and inhibits cytochrome c oxidase, producing optic-nerve injury. Ethanol is a competitive antidote at ADH.", "verdict": "agree", "judge_reasoning": "Both identify formate as the toxic metabolite.", "evidence_ids": [1]}, {"qa_index": 2, "topic": "engineering", "question": "Could a structural analog avoid this toxicity?", "phase1_answer": "Yes, blocking ADH oxidation would help.", "phase2_answer": "Replacing the alcohol with a sterically hindered tertiary alcohol would prevent ADH oxidation, avoiding formate generation. The single-carbon scaffold of methanol leaves no room for such substitution, so an analog necessarily becomes a different compound class.", "verdict": "agree", "judge_reasoning": "Both note ADH avoidance.", "evidence_ids": []}]} diff --git a/phase4_grounding/tests/test_aggregator.py b/phase4_grounding/tests/test_aggregator.py new file mode 100644 index 0000000..5b9c358 --- /dev/null +++ b/phase4_grounding/tests/test_aggregator.py @@ -0,0 +1,157 @@ +"""Tests for grounding.aggregator.Aggregator and the wilson_ci helper.""" +from __future__ import annotations + +import math + +import pytest + +from phase4_grounding.grounding.aggregator import Aggregator, wilson_ci +from phase4_grounding.grounding.models import Claim, JudgedQA + + +def _qa( + cid: int, + *, + topic: str = "mechanism", + split: str = "train", + evidence_ids_nonempty: bool = True, + labels: list[str], +) -> JudgedQA: + claims = tuple( + Claim(claim=f"c{i}", label=lbl, evidence_id=None, rationale="r") + for i, lbl in enumerate(labels) + ) + return JudgedQA( + cid=cid, + qa_index=1, + topic=topic, + evidence_ids_nonempty=evidence_ids_nonempty, + num_evidence_attached=2, + model="m", + claims=claims, + prompt_tokens=10, + completion_tokens=10, + latency_ms=10, + split=split, + ) + + +def test_wilson_ci_known_reference(): + # 50/100 → centre 0.5, 95% CI ≈ (0.404, 0.596) + lo, hi = wilson_ci(50, 100) + assert math.isclose(lo, 0.4038, abs_tol=1e-3) + assert math.isclose(hi, 0.5962, abs_tol=1e-3) + + +def test_wilson_ci_zero_total_returns_zero_zero(): + assert wilson_ci(0, 0) == (0.0, 0.0) + + +def test_wilson_ci_zero_successes(): + lo, hi = wilson_ci(0, 100) + assert lo == 0.0 + assert hi > 0.0 + + +def test_keep_view_excludes_structural_from_denominator(): + qas = [ + _qa(1, labels=["STATED", "STRUCTURAL", "UNSUPPORTED"]), + _qa(2, labels=["IMPLIED", "STRUCTURAL"]), + ] + metrics = Aggregator(qas).compute("keep") + # 2 STRUCTURAL excluded from denominator (3 remaining) + assert metrics.total_claims == 3 + assert metrics.structural_count == 2 + assert metrics.counts == {"STATED": 1, "IMPLIED": 1, "UNSUPPORTED": 1} + assert metrics.unsupported_rate == pytest.approx(1 / 3) + assert metrics.grounded_rate == pytest.approx(2 / 3) + + +def test_drop_view_collapses_structural_into_implied(): + qas = [ + _qa(1, labels=["STATED", "STRUCTURAL", "UNSUPPORTED"]), + _qa(2, labels=["IMPLIED", "STRUCTURAL"]), + ] + metrics = Aggregator(qas).compute("drop") + # all 5 claims counted; STRUCTURAL → IMPLIED + assert metrics.total_claims == 5 + assert metrics.counts["IMPLIED"] == 3 # 1 IMPLIED + 2 STRUCTURAL + assert metrics.counts["STATED"] == 1 + assert metrics.counts["UNSUPPORTED"] == 1 + assert metrics.unsupported_rate == pytest.approx(1 / 5) + + +def test_unsupported_ci_matches_wilson(): + qas = [_qa(1, labels=["UNSUPPORTED"] * 25 + ["STATED"] * 75)] + metrics = Aggregator(qas).compute("keep") + expected = wilson_ci(25, 100) + assert metrics.unsupported_ci == pytest.approx(expected) + + +def test_by_topic_breakdown_sums_to_total(): + qas = [ + _qa(1, topic="mechanism", labels=["STATED", "STATED"]), + _qa(2, topic="toxicity", labels=["UNSUPPORTED"]), + _qa(3, topic="mechanism", labels=["IMPLIED"]), + ] + metrics = Aggregator(qas).compute("keep") + assert metrics.by_topic["mechanism"]["total"] == 3 + assert metrics.by_topic["toxicity"]["total"] == 1 + sum_total = sum(v["total"] for v in metrics.by_topic.values()) + assert sum_total == metrics.total_claims + + +def test_by_evidence_ids_nonempty_breakdown(): + qas = [ + _qa(1, evidence_ids_nonempty=True, labels=["STATED"]), + _qa(2, evidence_ids_nonempty=False, labels=["UNSUPPORTED", "UNSUPPORTED"]), + ] + metrics = Aggregator(qas).compute("keep") + assert metrics.by_evidence_ids_nonempty[True]["total"] == 1 + assert metrics.by_evidence_ids_nonempty[False]["total"] == 2 + assert metrics.by_evidence_ids_nonempty[False]["UNSUPPORTED"] == 2 + + +def test_by_split_breakdown(): + qas = [ + _qa(1, split="train", labels=["STATED"]), + _qa(2, split="val", labels=["UNSUPPORTED"]), + _qa(3, split="test", labels=["IMPLIED"]), + ] + metrics = Aggregator(qas).compute("keep") + assert set(metrics.by_split) == {"train", "val", "test"} + + +def test_per_qa_unsupported_histogram(): + qas = [ + _qa(1, labels=["STATED"]), # 0 UNSUPPORTED + _qa(2, labels=["UNSUPPORTED"]), # 1 + _qa(3, labels=["UNSUPPORTED", "UNSUPPORTED"]), # 2 + _qa(4, labels=["UNSUPPORTED"]), # 1 + ] + metrics = Aggregator(qas).compute("keep") + assert metrics.per_qa_unsupported_histogram == {0: 1, 1: 2, 2: 1} + + +def test_top_qa_sorted_by_unsupported_rate(): + qas = [ + _qa(1, labels=["STATED", "STATED"]), # rate 0 + _qa(2, labels=["UNSUPPORTED"]), # rate 1.0 + _qa(3, labels=["UNSUPPORTED", "STATED"]), # rate 0.5 + ] + metrics = Aggregator(qas).compute("keep") + top = metrics.top_qa_by_unsupported + assert top[0]["cid"] == 2 + assert top[1]["cid"] == 3 + + +def test_invalid_view_raises(): + with pytest.raises(ValueError): + Aggregator([]).compute("bogus") # type: ignore[arg-type] + + +def test_empty_input_yields_zero_metrics(): + metrics = Aggregator([]).compute("keep") + assert metrics.total_claims == 0 + assert metrics.unsupported_rate == 0.0 + assert metrics.unsupported_ci == (0.0, 0.0) diff --git a/phase4_grounding/tests/test_evidence.py b/phase4_grounding/tests/test_evidence.py new file mode 100644 index 0000000..471107d --- /dev/null +++ b/phase4_grounding/tests/test_evidence.py @@ -0,0 +1,130 @@ +"""Tests for grounding.evidence.EvidenceAttacher.""" +from __future__ import annotations + +from phase4_grounding.grounding.evidence import EvidenceAttacher +from phase4_grounding.grounding.models import EvidenceItem + + +def _compound(*sentences: dict) -> dict: + return {"evidence_sentences": list(sentences)} + + +def _qa(evidence_ids: list[int] | None = None) -> dict: + return {"evidence_ids": list(evidence_ids) if evidence_ids is not None else []} + + +def test_attaches_only_listed_ids_when_evidence_ids_nonempty(): + compound = _compound( + {"id": 1, "text": "alpha", "pmid": "111", "source": "abstract"}, + {"id": 2, "text": "beta", "pmid": "222", "source": "abstract"}, + {"id": 3, "text": "gamma", "pmid": "333", "source": "abstract"}, + ) + qa = _qa([1, 3]) + attached = EvidenceAttacher.attach(qa, compound) + assert [e.text for e in attached] == ["alpha", "gamma"] + + +def test_attaches_all_when_evidence_ids_empty(): + compound = _compound( + {"id": 1, "text": "alpha", "pmid": "111", "source": "abstract"}, + {"id": 2, "text": "beta", "pmid": "222", "source": "abstract"}, + ) + qa = _qa([]) + attached = EvidenceAttacher.attach(qa, compound) + assert [e.text for e in attached] == ["alpha", "beta"] + + +def test_attaches_all_when_evidence_ids_missing(): + compound = _compound({"id": 1, "text": "alpha", "pmid": "111", "source": "abstract"}) + qa: dict = {} # no evidence_ids key at all + attached = EvidenceAttacher.attach(qa, compound) + assert len(attached) == 1 + assert attached[0].text == "alpha" + + +def test_display_ids_are_sequential_starting_at_one(): + compound = _compound( + {"id": 7, "text": "alpha", "pmid": "p", "source": "abstract"}, + {"id": 9, "text": "beta", "pmid": "p", "source": "abstract"}, + {"id": 11, "text": "gamma", "pmid": "p", "source": "abstract"}, + ) + # Non-empty branch: display ids must be 1..N regardless of source ids. + attached = EvidenceAttacher.attach(_qa([9, 11]), compound) + assert [e.id for e in attached] == [1, 2] + # Empty branch: same renumbering rule. + attached_all = EvidenceAttacher.attach(_qa([]), compound) + assert [e.id for e in attached_all] == [1, 2, 3] + + +def test_preserves_order_of_evidence_ids_list(): + compound = _compound( + {"id": 1, "text": "alpha", "pmid": "p", "source": "abstract"}, + {"id": 2, "text": "beta", "pmid": "p", "source": "abstract"}, + {"id": 3, "text": "gamma", "pmid": "p", "source": "abstract"}, + ) + # listed in 3, 1, 2 order — output must follow that order, not source order. + attached = EvidenceAttacher.attach(_qa([3, 1, 2]), compound) + assert [e.text for e in attached] == ["gamma", "alpha", "beta"] + assert [e.id for e in attached] == [1, 2, 3] + + +def test_unknown_evidence_ids_are_skipped(): + compound = _compound( + {"id": 1, "text": "alpha", "pmid": "p", "source": "abstract"}, + {"id": 2, "text": "beta", "pmid": "p", "source": "abstract"}, + ) + attached = EvidenceAttacher.attach(_qa([1, 99, 2]), compound) + assert [e.text for e in attached] == ["alpha", "beta"] + assert [e.id for e in attached] == [1, 2] + + +def test_returns_empty_when_no_sentences_at_all(): + attached = EvidenceAttacher.attach(_qa([]), _compound()) + assert attached == () + + +def test_returns_empty_when_listed_ids_all_unknown(): + compound = _compound({"id": 1, "text": "alpha", "pmid": "p", "source": "abstract"}) + attached = EvidenceAttacher.attach(_qa([42, 43]), compound) + assert attached == () + + +def test_pmid_and_source_are_carried_through(): + compound = _compound( + {"id": 1, "text": "alpha", "pmid": "12345", "source": "fulltext"}, + ) + attached = EvidenceAttacher.attach(_qa([1]), compound) + assert attached[0].pmid == "12345" + assert attached[0].source == "fulltext" + + +def test_pmid_none_when_missing(): + compound = _compound({"id": 1, "text": "alpha", "source": "abstract"}) + attached = EvidenceAttacher.attach(_qa([1]), compound) + assert attached[0].pmid is None + assert attached[0].source == "abstract" + + +def test_returned_items_are_evidence_item_instances(): + compound = _compound({"id": 1, "text": "alpha", "pmid": "p", "source": "abstract"}) + attached = EvidenceAttacher.attach(_qa([1]), compound) + assert all(isinstance(e, EvidenceItem) for e in attached) + + +def test_works_with_tiny_dataset_records(tiny_dataset_records): + """Smoke test against the hand-crafted tiny_dataset.jsonl fixture.""" + aspirin = tiny_dataset_records[0] + assert aspirin["name"] == "Aspirin" + + # qa_index 1 (mechanism) cites evidence_ids [1, 2] + qa1 = aspirin["qa_pairs"][0] + attached = EvidenceAttacher.attach(qa1, aspirin) + assert len(attached) == 2 + assert [e.id for e in attached] == [1, 2] + assert "irreversibly acetylates" in attached[0].text + + # qa_index 3 (engineering) has empty evidence_ids → all 3 sentences + qa3 = aspirin["qa_pairs"][2] + attached_all = EvidenceAttacher.attach(qa3, aspirin) + assert len(attached_all) == 3 + assert [e.id for e in attached_all] == [1, 2, 3] diff --git a/phase4_grounding/tests/test_integration.py b/phase4_grounding/tests/test_integration.py new file mode 100644 index 0000000..9967dee --- /dev/null +++ b/phase4_grounding/tests/test_integration.py @@ -0,0 +1,264 @@ +"""Integration tests for the Step-8 entry-point scripts. + +Exercises `scripts.sample_qa.main` and the core of `scripts.judge_claims` +(`run_judge_pass`) against `tiny_dataset.jsonl`. The judge pass uses a +scripted fake client so no network is touched. +""" +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from phase4_grounding.grounding.judge import ClaimJudge +from phase4_grounding.grounding.models import ChatResult +from phase4_grounding.grounding.prompt import PromptBuilder +from phase4_grounding.scripts import aggregate, judge_claims, sample_qa + + +# Tiny dataset has only 4 functional topics (mechanism/toxicity/engineering/ +# metabolism), so weight all of them equally and zero out the rest to avoid +# "silent shortfall" on under-represented strata. +_EVEN_WEIGHTS = { + "mechanism": 0.25, + "engineering": 0.25, + "metabolism": 0.25, + "toxicity": 0.25, + "therapeutic_use": 0.0, + "_other": 0.0, +} + + +def _good_response(text: str = "STATED X", evidence_id: int | None = 1) -> ChatResult: + payload = { + "claims": [ + { + "claim": text, + "label": "STATED" if evidence_id is not None else "UNSUPPORTED", + "evidence_id": evidence_id, + "rationale": "because the evidence says so", + } + ] + } + return ChatResult( + text=json.dumps(payload), + prompt_tokens=100, + completion_tokens=50, + latency_ms=5, + ) + + +def test_sample_qa_writes_expected_shape( + tiny_dataset_path: Path, tmp_out_dir: Path, monkeypatch: pytest.MonkeyPatch +): + # Use an even-weights sampler so the tiny dataset's 4 topics each get slots. + monkeypatch.setattr(sample_qa, "Sampler", _EvenWeightedSamplerFactory()) + + rc = sample_qa.main( + [ + "--n", + "4", + "--seed", + "0", + "--data-path", + str(tiny_dataset_path), + "--out-dir", + str(tmp_out_dir), + ] + ) + assert rc == 0 + + out_path = tmp_out_dir / "sample.jsonl" + assert out_path.exists() + records = [json.loads(line) for line in out_path.read_text().splitlines() if line] + assert len(records) == 4 + + for r in records: + # Required top-level fields per PLAN §Step 1 + assert set(r) >= { + "cid", + "qa_index", + "topic", + "split", + "evidence_ids_nonempty", + "compound", + "question", + "phase2_answer", + "evidence_attached", + } + assert set(r["compound"]) == {"name", "smiles", "molecular_formula"} + # evidence is renumbered 1..N + for i, ev in enumerate(r["evidence_attached"], start=1): + assert ev["id"] == i + assert "text" in ev + + +class _EvenWeightedSamplerFactory: + """Swap the default weights with the even ones for the integration test.""" + + def __call__(self, *, dataset_path, seed=0, topic_weights=None): + from phase4_grounding.grounding.sampling import Sampler + + return Sampler( + dataset_path=dataset_path, + seed=seed, + topic_weights=_EVEN_WEIGHTS, + ) + + +@pytest.mark.asyncio +async def test_run_judge_pass_writes_claims_and_errors( + tiny_dataset_path: Path, tmp_out_dir: Path, fake_openrouter_client, monkeypatch +): + # 1. Build a sample file via sample_qa + monkeypatch.setattr(sample_qa, "Sampler", _EvenWeightedSamplerFactory()) + sample_qa.main( + [ + "--n", + "4", + "--seed", + "0", + "--data-path", + str(tiny_dataset_path), + "--out-dir", + str(tmp_out_dir), + ] + ) + + # 2. Load sample rows + rows = judge_claims.load_sample(tmp_out_dir / "sample.jsonl") + assert len(rows) == 4 + + # 3. Script the fake client: 3 good, 1 bad (twice) → 3 success + 1 error row + responses = [ + _good_response(), + _good_response(), + _good_response(), + ChatResult(text="not json", prompt_tokens=10, completion_tokens=5, latency_ms=1), + ChatResult(text="still not json", prompt_tokens=10, completion_tokens=5, latency_ms=1), + ] + client = fake_openrouter_client(responses) + client.spend_usd = 0.0 # attribute required by run_judge_pass + judge = ClaimJudge(client, PromptBuilder()) + + out_path = tmp_out_dir / "claims_per_qa.jsonl" + err_path = tmp_out_dir / "claims_per_qa.errors.jsonl" + + # Concurrency 1 keeps the scripted-response order deterministic. + success = await judge_claims.run_judge_pass( + rows, + judge, + client, + model="anthropic/claude-sonnet-4.6", + out_path=out_path, + err_path=err_path, + concurrency=1, + ) + + assert success == 3 + + good_lines = [ + json.loads(line) for line in out_path.read_text().splitlines() if line + ] + assert len(good_lines) == 3 + for r in good_lines: + assert set(r) >= { + "cid", + "qa_index", + "topic", + "evidence_ids_nonempty", + "num_evidence_attached", + "model", + "claims", + "usage", + "latency_ms", + } + assert r["model"] == "anthropic/claude-sonnet-4.6" + assert r["usage"]["prompt_tokens"] == 100 + assert r["claims"][0]["label"] in ("STATED", "IMPLIED", "UNSUPPORTED", "STRUCTURAL") + + err_lines = [ + json.loads(line) for line in err_path.read_text().splitlines() if line + ] + assert len(err_lines) == 1 + assert err_lines[0]["raw"] == "still not json" + assert err_lines[0]["first_error"] is not None + assert err_lines[0]["second_error"] is not None + + # 4. Aggregate: ensure both summary files appear and are non-empty + rc = aggregate.main(["--out-dir", str(tmp_out_dir)]) + assert rc == 0 + keep = tmp_out_dir / "grounding_summary_keep_structural.md" + drop = tmp_out_dir / "grounding_summary_drop_structural.md" + assert keep.exists() and keep.read_text().strip() + assert drop.exists() and drop.read_text().strip() + + +@pytest.mark.asyncio +async def test_run_judge_pass_is_resumable( + tiny_dataset_path: Path, tmp_out_dir: Path, fake_openrouter_client, monkeypatch +): + """Rows already in claims_per_qa.jsonl are not re-judged.""" + monkeypatch.setattr(sample_qa, "Sampler", _EvenWeightedSamplerFactory()) + sample_qa.main( + [ + "--n", + "4", + "--seed", + "0", + "--data-path", + str(tiny_dataset_path), + "--out-dir", + str(tmp_out_dir), + ] + ) + rows = judge_claims.load_sample(tmp_out_dir / "sample.jsonl") + + out_path = tmp_out_dir / "claims_per_qa.jsonl" + err_path = tmp_out_dir / "claims_per_qa.errors.jsonl" + + # Pretend the first two rows were already judged on a prior run. + with out_path.open("w") as f: + for r in rows[:2]: + f.write( + json.dumps( + { + "cid": r.cid, + "qa_index": r.qa_index, + "topic": r.topic, + "evidence_ids_nonempty": r.evidence_ids_nonempty, + "num_evidence_attached": len(r.evidence_attached), + "model": "anthropic/claude-sonnet-4.6", + "claims": [], + "usage": {"prompt_tokens": 0, "completion_tokens": 0}, + "latency_ms": 0, + } + ) + + "\n" + ) + + done = judge_claims.already_judged(out_path) + pending = [r for r in rows if (r.cid, r.qa_index) not in done] + assert len(pending) == 2 + + client = fake_openrouter_client([_good_response(), _good_response()]) + client.spend_usd = 0.0 + judge = ClaimJudge(client, PromptBuilder()) + + success = await judge_claims.run_judge_pass( + pending, + judge, + client, + model="anthropic/claude-sonnet-4.6", + out_path=out_path, + err_path=err_path, + concurrency=1, + ) + + assert success == 2 + # out_path now has 4 total: 2 preexisting + 2 new + all_lines = [line for line in out_path.read_text().splitlines() if line] + assert len(all_lines) == 4 + # fake client saw exactly 2 calls (not 4) + assert len(client.calls) == 2 diff --git a/phase4_grounding/tests/test_judge.py b/phase4_grounding/tests/test_judge.py new file mode 100644 index 0000000..5b0cea1 --- /dev/null +++ b/phase4_grounding/tests/test_judge.py @@ -0,0 +1,178 @@ +"""Tests for grounding.judge.ClaimJudge. + +Uses the in-conftest FakeOpenRouterClient — no network. The fake's `calls` +attribute lets us assert what was sent on each attempt. +""" +from __future__ import annotations + +import json + +import pytest + +from phase4_grounding.grounding.judge import ClaimJudge, JudgeError +from phase4_grounding.grounding.models import ( + ChatResult, + Compound, + EvidenceItem, + SampleRow, +) +from phase4_grounding.grounding.prompt import PromptBuilder + + +def _row(*, evidence: tuple[EvidenceItem, ...] = ()) -> SampleRow: + return SampleRow( + cid=42, + qa_index=3, + topic="mechanism", + split="train", + evidence_ids_nonempty=bool(evidence), + compound=Compound(cid=42, name="Aspirin", smiles="C", molecular_formula="CH4"), + question="How does it work?", + phase2_answer="It does X and Y.", + evidence_attached=evidence, + ) + + +def _chat(text: str, *, prompt_tokens: int = 100, completion_tokens: int = 50) -> ChatResult: + return ChatResult( + text=text, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + latency_ms=12, + ) + + +def _good_json() -> str: + return json.dumps( + { + "claims": [ + {"claim": "X", "label": "STATED", "evidence_id": 1, "rationale": "r"}, + {"claim": "Y", "label": "UNSUPPORTED", "evidence_id": None, "rationale": "r"}, + ] + } + ) + + +@pytest.mark.asyncio +async def test_judge_success_on_first_attempt(fake_openrouter_client): + client = fake_openrouter_client([_chat(_good_json())]) + judge = ClaimJudge(client, PromptBuilder()) + row = _row(evidence=(EvidenceItem(id=1, text="e", pmid="p", source="abstract"),)) + + judged = await judge.judge(row, model="anthropic/claude-sonnet-4.6") + + assert len(judged.claims) == 2 + assert judged.cid == 42 + assert judged.qa_index == 3 + assert judged.topic == "mechanism" + assert judged.evidence_ids_nonempty is True + assert judged.num_evidence_attached == 1 + assert judged.model == "anthropic/claude-sonnet-4.6" + assert judged.prompt_tokens == 100 + assert judged.completion_tokens == 50 + # only one call was made — no retry + assert len(client.calls) == 1 + + +@pytest.mark.asyncio +async def test_judge_retries_once_on_parse_failure_then_succeeds(fake_openrouter_client): + client = fake_openrouter_client( + [ + _chat("```json\n{...}\n```"), # malformed (markdown fence + ellipsis) + _chat(_good_json()), + ] + ) + judge = ClaimJudge(client, PromptBuilder()) + row = _row(evidence=(EvidenceItem(id=1, text="e", pmid="p", source="abstract"),)) + + judged = await judge.judge(row, model="anthropic/claude-sonnet-4.6") + assert len(judged.claims) == 2 + # both calls happened, and the second prompt carried the retry instruction + assert len(client.calls) == 2 + assert "Your previous output was not valid JSON" in client.calls[1]["prompt"] + assert "Your previous output was not valid JSON" not in client.calls[0]["prompt"] + # token tally aggregates both calls + assert judged.prompt_tokens == 200 + assert judged.completion_tokens == 100 + + +@pytest.mark.asyncio +async def test_judge_raises_judge_error_after_two_failures(fake_openrouter_client): + client = fake_openrouter_client( + [ + _chat("not json at all"), + _chat("still not json"), + ] + ) + judge = ClaimJudge(client, PromptBuilder()) + row = _row(evidence=(EvidenceItem(id=1, text="e", pmid="p", source="abstract"),)) + + with pytest.raises(JudgeError) as excinfo: + await judge.judge(row, model="anthropic/claude-sonnet-4.6") + + err = excinfo.value + assert err.cid == 42 + assert err.qa_index == 3 + assert err.model == "anthropic/claude-sonnet-4.6" + assert err.raw == "still not json" + assert err.first_error is not None + assert err.second_error is not None + assert len(client.calls) == 2 + + +@pytest.mark.asyncio +async def test_judge_raises_judge_error_when_evidence_id_unknown_twice(fake_openrouter_client): + """Bogus evidence_id is a parse failure (per parser); the same error twice + should bubble up as JudgeError.""" + bad = json.dumps( + {"claims": [{"claim": "x", "label": "STATED", "evidence_id": 99, "rationale": "r"}]} + ) + client = fake_openrouter_client([_chat(bad), _chat(bad)]) + judge = ClaimJudge(client, PromptBuilder()) + row = _row(evidence=(EvidenceItem(id=1, text="e", pmid="p", source="abstract"),)) + + with pytest.raises(JudgeError) as excinfo: + await judge.judge(row, model="anthropic/claude-sonnet-4.6") + assert "99" in (excinfo.value.second_error or "") + + +@pytest.mark.asyncio +async def test_judge_passes_correct_attached_ids_to_parser(fake_openrouter_client): + """If the model cites E2 and the row has 2 evidence items, parsing must succeed.""" + payload = json.dumps( + { + "claims": [ + {"claim": "x", "label": "STATED", "evidence_id": 2, "rationale": "r"}, + ] + } + ) + client = fake_openrouter_client([_chat(payload)]) + judge = ClaimJudge(client, PromptBuilder()) + row = _row( + evidence=( + EvidenceItem(id=1, text="a", pmid="p", source="abstract"), + EvidenceItem(id=2, text="b", pmid="p", source="abstract"), + ) + ) + judged = await judge.judge(row, model="anthropic/claude-sonnet-4.6") + assert judged.claims[0].evidence_id == 2 + + +@pytest.mark.asyncio +async def test_judge_handles_no_attached_evidence(fake_openrouter_client): + """When no evidence is attached, the judge can still emit STRUCTURAL/UNSUPPORTED claims.""" + payload = json.dumps( + { + "claims": [ + {"claim": "has methyl group", "label": "STRUCTURAL", "evidence_id": None, "rationale": "r"}, + {"claim": "cures cancer", "label": "UNSUPPORTED", "evidence_id": None, "rationale": "r"}, + ] + } + ) + client = fake_openrouter_client([_chat(payload)]) + judge = ClaimJudge(client, PromptBuilder()) + row = _row(evidence=()) # nothing attached + + judged = await judge.judge(row, model="anthropic/claude-sonnet-4.6") + assert judged.num_evidence_attached == 0 + assert {c.label for c in judged.claims} == {"STRUCTURAL", "UNSUPPORTED"} diff --git a/phase4_grounding/tests/test_models.py b/phase4_grounding/tests/test_models.py new file mode 100644 index 0000000..314e84d --- /dev/null +++ b/phase4_grounding/tests/test_models.py @@ -0,0 +1,75 @@ +"""Smoke tests: dataclass construction and tiny_dataset shape.""" +from __future__ import annotations + +import pytest + +from phase4_grounding.grounding.models import ( + LABEL_VALUES, + ChatResult, + Claim, + Compound, + EvidenceItem, + JudgedQA, + SampleRow, +) + + +def test_claim_label_values_exact(): + assert LABEL_VALUES == ("STATED", "IMPLIED", "UNSUPPORTED", "STRUCTURAL") + + +def test_dataclasses_are_frozen(): + c = Claim(claim="x", label="STATED", evidence_id=1, rationale="r") + with pytest.raises(Exception): + c.label = "IMPLIED" # type: ignore[misc] + + +def test_sample_row_construct(): + compound = Compound(cid=1, name="X", smiles="C", molecular_formula="CH4") + row = SampleRow( + cid=1, + qa_index=1, + topic="mechanism", + split="train", + evidence_ids_nonempty=True, + compound=compound, + question="Q?", + phase2_answer="A.", + evidence_attached=(EvidenceItem(id=1, text="t"),), + ) + assert row.compound.smiles == "C" + assert row.evidence_attached[0].id == 1 + + +def test_judged_qa_construct(): + j = JudgedQA( + cid=1, + qa_index=1, + topic="mechanism", + evidence_ids_nonempty=False, + num_evidence_attached=0, + model="m", + claims=(Claim(claim="x", label="UNSUPPORTED", evidence_id=None, rationale=""),), + prompt_tokens=10, + completion_tokens=5, + latency_ms=42, + ) + assert j.claims[0].label == "UNSUPPORTED" + + +def test_chat_result_construct(): + r = ChatResult(text="hi", prompt_tokens=1, completion_tokens=1, latency_ms=1) + assert r.text == "hi" + + +def test_tiny_dataset_loads(tiny_dataset_records): + assert len(tiny_dataset_records) >= 3 + topics = {qa["topic"] for rec in tiny_dataset_records for qa in rec["qa_pairs"]} + assert {"mechanism", "metabolism", "toxicity", "engineering"} <= topics + has_nonempty = any( + qa.get("evidence_ids") for rec in tiny_dataset_records for qa in rec["qa_pairs"] + ) + has_empty = any( + not qa.get("evidence_ids") for rec in tiny_dataset_records for qa in rec["qa_pairs"] + ) + assert has_nonempty and has_empty diff --git a/phase4_grounding/tests/test_openrouter_client.py b/phase4_grounding/tests/test_openrouter_client.py new file mode 100644 index 0000000..8d762fa --- /dev/null +++ b/phase4_grounding/tests/test_openrouter_client.py @@ -0,0 +1,231 @@ +"""Tests for grounding.openrouter_client.OpenRouterClient. + +The HTTP layer is faked with `httpx.MockTransport` and `asyncio.sleep` is +replaced with a recorder so backoff tests are wall-clock free. +""" +from __future__ import annotations + +import json +from collections.abc import Callable + +import httpx +import pytest + +from phase4_grounding.grounding.models import ChatResult +from phase4_grounding.grounding.openrouter_client import ( + DEFAULT_PRICING, + BudgetExceeded, + ModelPricing, + OpenRouterClient, +) + + +def _ok_response(text: str = "hello", prompt_tokens: int = 100, completion_tokens: int = 50) -> dict: + return { + "choices": [{"message": {"content": text}}], + "usage": {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens}, + } + + +def _make_transport(responses: list[httpx.Response]) -> tuple[httpx.MockTransport, list[httpx.Request]]: + """Build a MockTransport that returns the given responses in order. + Returns the transport and a list that records each received request.""" + received: list[httpx.Request] = [] + queue = list(responses) + + def handler(request: httpx.Request) -> httpx.Response: + received.append(request) + if not queue: + raise AssertionError("transport ran out of scripted responses") + return queue.pop(0) + + return httpx.MockTransport(handler), received + + +def _make_recording_sleep() -> tuple[Callable, list[float]]: + durations: list[float] = [] + + async def sleep(d: float) -> None: + durations.append(d) + + return sleep, durations + + +@pytest.mark.asyncio +async def test_cost_tracker_arithmetic_uses_per_million_pricing(): + pricing = {"test/model": ModelPricing(prompt_per_mtok=2.0, completion_per_mtok=4.0)} + transport, _ = _make_transport( + [httpx.Response(200, json=_ok_response(prompt_tokens=1_000_000, completion_tokens=500_000))] + ) + sleep, _ = _make_recording_sleep() + async with OpenRouterClient( + api_key="x", pricing=pricing, transport=transport, sleep=sleep + ) as client: + await client.chat(model="test/model", prompt="hi") + # 1M prompt @ $2 + 0.5M completion @ $4 = $2 + $2 = $4 + assert client.spend_usd == pytest.approx(4.0) + assert client.calls == 1 + + +@pytest.mark.asyncio +async def test_unknown_model_costs_zero(): + transport, _ = _make_transport( + [httpx.Response(200, json=_ok_response(prompt_tokens=1_000, completion_tokens=1_000))] + ) + sleep, _ = _make_recording_sleep() + async with OpenRouterClient(api_key="x", transport=transport, sleep=sleep) as client: + await client.chat(model="some/unknown-model", prompt="hi") + assert client.spend_usd == 0.0 + + +@pytest.mark.asyncio +async def test_budget_exceeded_raised_on_call_after_overspend(): + """Spend may exceed the cap by one call; the next call raises so the + runner can persist the result of the call that pushed over.""" + pricing = {"m": ModelPricing(prompt_per_mtok=10.0, completion_per_mtok=10.0)} + transport, _ = _make_transport( + [ + httpx.Response(200, json=_ok_response(prompt_tokens=1_000_000, completion_tokens=0)), + httpx.Response(200, json=_ok_response()), + ] + ) + sleep, _ = _make_recording_sleep() + async with OpenRouterClient( + api_key="x", pricing=pricing, max_usd=5.0, transport=transport, sleep=sleep + ) as client: + # First call: cost = $10, exceeds cap, but result still returned. + result = await client.chat(model="m", prompt="hi") + assert isinstance(result, ChatResult) + assert client.spend_usd == pytest.approx(10.0) + # Second call: refuses before issuing the request. + with pytest.raises(BudgetExceeded): + await client.chat(model="m", prompt="hi again") + + +@pytest.mark.asyncio +async def test_retry_after_header_is_honored_verbatim(): + transport, _ = _make_transport( + [ + httpx.Response(429, headers={"Retry-After": "7"}, json={"error": "rate limited"}), + httpx.Response(200, json=_ok_response()), + ] + ) + sleep, sleeps = _make_recording_sleep() + async with OpenRouterClient( + api_key="x", transport=transport, sleep=sleep, backoff_base=99.0 + ) as client: + await client.chat(model="anthropic/claude-sonnet-4.6", prompt="hi") + assert sleeps == [7.0] + + +@pytest.mark.asyncio +async def test_exponential_backoff_when_no_retry_after(): + transport, _ = _make_transport( + [ + httpx.Response(500, json={"error": "boom"}), + httpx.Response(503, json={"error": "boom"}), + httpx.Response(200, json=_ok_response()), + ] + ) + sleep, sleeps = _make_recording_sleep() + async with OpenRouterClient( + api_key="x", transport=transport, sleep=sleep, backoff_base=2.0 + ) as client: + await client.chat(model="anthropic/claude-sonnet-4.6", prompt="hi") + # base * 2^0 = 2, base * 2^1 = 4 + assert sleeps == [2.0, 4.0] + + +@pytest.mark.asyncio +async def test_retries_exhausted_raises_http_error(): + responses = [httpx.Response(500, json={"error": "boom"}) for _ in range(6)] + transport, _ = _make_transport(responses) + sleep, _ = _make_recording_sleep() + async with OpenRouterClient( + api_key="x", transport=transport, sleep=sleep, max_retries=3, backoff_base=0.1 + ) as client: + with pytest.raises(httpx.HTTPStatusError): + await client.chat(model="anthropic/claude-sonnet-4.6", prompt="hi") + + +@pytest.mark.asyncio +async def test_4xx_other_than_429_does_not_retry(): + transport, received = _make_transport( + [httpx.Response(400, json={"error": "bad request"})] + ) + sleep, sleeps = _make_recording_sleep() + async with OpenRouterClient(api_key="x", transport=transport, sleep=sleep) as client: + with pytest.raises(httpx.HTTPStatusError): + await client.chat(model="anthropic/claude-sonnet-4.6", prompt="hi") + # exactly one HTTP attempt, no sleeps + assert len(received) == 1 + assert sleeps == [] + + +@pytest.mark.asyncio +async def test_authorization_header_sent(): + transport, received = _make_transport([httpx.Response(200, json=_ok_response())]) + sleep, _ = _make_recording_sleep() + async with OpenRouterClient(api_key="my-secret-key", transport=transport, sleep=sleep) as client: + await client.chat(model="anthropic/claude-sonnet-4.6", prompt="hi") + assert received[0].headers["Authorization"] == "Bearer my-secret-key" + + +@pytest.mark.asyncio +async def test_request_body_carries_model_and_prompt(): + transport, received = _make_transport([httpx.Response(200, json=_ok_response())]) + sleep, _ = _make_recording_sleep() + async with OpenRouterClient(api_key="x", transport=transport, sleep=sleep) as client: + await client.chat( + model="anthropic/claude-sonnet-4.6", + prompt="decompose this answer", + temperature=0.0, + max_tokens=512, + ) + body = json.loads(received[0].content) + assert body["model"] == "anthropic/claude-sonnet-4.6" + assert body["messages"] == [{"role": "user", "content": "decompose this answer"}] + assert body["temperature"] == 0.0 + assert body["max_tokens"] == 512 + + +@pytest.mark.asyncio +async def test_chat_returns_token_usage_in_chat_result(): + transport, _ = _make_transport( + [httpx.Response(200, json=_ok_response(prompt_tokens=42, completion_tokens=17))] + ) + sleep, _ = _make_recording_sleep() + async with OpenRouterClient(api_key="x", transport=transport, sleep=sleep) as client: + result = await client.chat(model="anthropic/claude-sonnet-4.6", prompt="hi") + assert result.prompt_tokens == 42 + assert result.completion_tokens == 17 + assert result.text == "hello" + assert result.latency_ms >= 0 + + +@pytest.mark.asyncio +async def test_default_pricing_includes_audit_models(): + """Sanity: DEFAULT_PRICING has both models named in PLAN.md so the budget + cap is meaningful out of the box.""" + assert "anthropic/claude-sonnet-4.6" in DEFAULT_PRICING + assert "google/gemini-2.5-pro" in DEFAULT_PRICING + for p in DEFAULT_PRICING.values(): + assert p.prompt_per_mtok > 0 + assert p.completion_per_mtok > 0 + + +@pytest.mark.asyncio +async def test_invalid_retry_after_falls_back_to_exponential(): + transport, _ = _make_transport( + [ + httpx.Response(429, headers={"Retry-After": "garbage"}, json={"error": "x"}), + httpx.Response(200, json=_ok_response()), + ] + ) + sleep, sleeps = _make_recording_sleep() + async with OpenRouterClient( + api_key="x", transport=transport, sleep=sleep, backoff_base=3.0 + ) as client: + await client.chat(model="anthropic/claude-sonnet-4.6", prompt="hi") + # falls back to backoff_base * 2^0 + assert sleeps == [3.0] diff --git a/phase4_grounding/tests/test_parser.py b/phase4_grounding/tests/test_parser.py new file mode 100644 index 0000000..2ca5391 --- /dev/null +++ b/phase4_grounding/tests/test_parser.py @@ -0,0 +1,222 @@ +"""Tests for grounding.parser.ClaimParser.""" +from __future__ import annotations + +import json + +import pytest + +from phase4_grounding.grounding.models import LABEL_VALUES +from phase4_grounding.grounding.parser import ClaimParser + + +def _wrap(claims: list[dict]) -> str: + return json.dumps({"claims": claims}) + + +def test_parses_well_formed_json_with_all_four_labels(): + payload = _wrap( + [ + {"claim": "stated one", "label": "STATED", "evidence_id": 1, "rationale": "directly cited"}, + {"claim": "implied one", "label": "IMPLIED", "evidence_id": 2, "rationale": "one step"}, + {"claim": "structural one", "label": "STRUCTURAL", "evidence_id": None, "rationale": "from SMILES"}, + {"claim": "unsupported one", "label": "UNSUPPORTED", "evidence_id": None, "rationale": "no support"}, + ] + ) + result = ClaimParser.parse(payload, attached_ids={1, 2}) + assert result.ok, result.error + labels = [c.label for c in result.claims] + assert labels == ["STATED", "IMPLIED", "STRUCTURAL", "UNSUPPORTED"] + # all four allowed labels are preserved by name + assert set(labels) <= set(LABEL_VALUES) + + +def test_rejects_malformed_json(): + result = ClaimParser.parse("{not json", attached_ids={1}) + assert not result.ok + assert "invalid JSON" in (result.error or "") + + +def test_rejects_non_object_root(): + result = ClaimParser.parse("[]", attached_ids={1}) + assert not result.ok + assert "object" in (result.error or "") + + +def test_rejects_missing_claims_key(): + result = ClaimParser.parse('{"foo":1}', attached_ids={1}) + assert not result.ok + assert "claims" in (result.error or "") + + +def test_rejects_claims_not_a_list(): + result = ClaimParser.parse('{"claims":"oops"}', attached_ids={1}) + assert not result.ok + assert "list" in (result.error or "") + + +def test_rejects_unknown_label(): + payload = _wrap( + [{"claim": "x", "label": "MAYBE", "evidence_id": 1, "rationale": "r"}] + ) + result = ClaimParser.parse(payload, attached_ids={1}) + assert not result.ok + assert "label" in (result.error or "") + + +def test_rejects_evidence_id_not_in_attached_set(): + payload = _wrap( + [{"claim": "x", "label": "STATED", "evidence_id": 99, "rationale": "r"}] + ) + result = ClaimParser.parse(payload, attached_ids={1, 2}) + assert not result.ok + assert "evidence_id" in (result.error or "") + assert "99" in (result.error or "") + + +def test_rejects_evidence_id_wrong_type(): + payload = _wrap( + [{"claim": "x", "label": "STATED", "evidence_id": "1", "rationale": "r"}] + ) + result = ClaimParser.parse(payload, attached_ids={1}) + assert not result.ok + assert "evidence_id" in (result.error or "") + + +def test_rejects_boolean_evidence_id(): + """JSON `true` would silently become Python True, which `isinstance(_, int)` + accepts. The parser must guard against this corner.""" + payload = _wrap( + [{"claim": "x", "label": "STATED", "evidence_id": True, "rationale": "r"}] + ) + result = ClaimParser.parse(payload, attached_ids={1}) + assert not result.ok + + +def test_rejects_missing_required_field(): + payload = _wrap([{"claim": "x", "label": "STATED", "evidence_id": 1}]) + result = ClaimParser.parse(payload, attached_ids={1}) + assert not result.ok + assert "rationale" in (result.error or "") + + +def test_rejects_non_string_claim(): + payload = _wrap( + [{"claim": 123, "label": "STATED", "evidence_id": 1, "rationale": "r"}] + ) + result = ClaimParser.parse(payload, attached_ids={1}) + assert not result.ok + + +def test_accepts_null_evidence_id_for_unsupported(): + payload = _wrap( + [{"claim": "x", "label": "UNSUPPORTED", "evidence_id": None, "rationale": "r"}] + ) + result = ClaimParser.parse(payload, attached_ids={1, 2}) + assert result.ok + assert result.claims[0].evidence_id is None + + +def test_accepts_null_evidence_id_for_structural(): + payload = _wrap( + [{"claim": "x", "label": "STRUCTURAL", "evidence_id": None, "rationale": "r"}] + ) + result = ClaimParser.parse(payload, attached_ids={1, 2}) + assert result.ok + + +def test_accepts_null_rationale(): + payload = _wrap( + [ + {"claim": "x", "label": "STRUCTURAL", "evidence_id": None, "rationale": None}, + {"claim": "y", "label": "UNSUPPORTED", "evidence_id": None, "rationale": None}, + ] + ) + result = ClaimParser.parse(payload, attached_ids={1, 2}) + assert result.ok + assert result.claims[0].rationale is None + assert result.claims[1].rationale is None + + +def test_handles_none_raw_response(): + # OpenRouter occasionally returns choices[0].message.content = null + # (refusal, truncation, empty tool-call). Must surface as a clean parse + # error so the judge's retry/error path takes over instead of crashing. + result = ClaimParser.parse(None, attached_ids={1}) + assert not result.ok + assert "not a string" in (result.error or "") + + +def test_strips_markdown_json_fence(): + payload = ( + '```json\n' + + _wrap([{"claim": "x", "label": "STATED", "evidence_id": 1, "rationale": "r"}]) + + '\n```' + ) + result = ClaimParser.parse(payload, attached_ids={1}) + assert result.ok + assert len(result.claims) == 1 + + +def test_strips_bare_markdown_fence(): + payload = ( + '```\n' + + _wrap([{"claim": "x", "label": "STATED", "evidence_id": 1, "rationale": "r"}]) + + '\n```' + ) + result = ClaimParser.parse(payload, attached_ids={1}) + assert result.ok + + +def test_strips_fence_with_trailing_whitespace(): + payload = ( + ' ```json\n' + + _wrap([{"claim": "x", "label": "STATED", "evidence_id": 1, "rationale": "r"}]) + + '\n``` \n' + ) + result = ClaimParser.parse(payload, attached_ids={1}) + assert result.ok + + +def test_rejects_non_string_non_null_rationale(): + payload = _wrap( + [{"claim": "x", "label": "STATED", "evidence_id": 1, "rationale": 42}] + ) + result = ClaimParser.parse(payload, attached_ids={1}) + assert not result.ok + assert "rationale" in (result.error or "") + + +def test_empty_claims_list_is_valid(): + """A judge that decomposed nothing returns an empty list — not an error.""" + result = ClaimParser.parse('{"claims":[]}', attached_ids={1}) + assert result.ok + assert result.claims == () + + +def test_attached_ids_can_be_empty_when_all_evidence_id_null(): + """If no evidence was attached and the judge correctly emitted only null + evidence_id values, parsing must still succeed.""" + payload = _wrap( + [ + {"claim": "a", "label": "STRUCTURAL", "evidence_id": None, "rationale": "r"}, + {"claim": "b", "label": "UNSUPPORTED", "evidence_id": None, "rationale": "r"}, + ] + ) + result = ClaimParser.parse(payload, attached_ids=set()) + assert result.ok + assert len(result.claims) == 2 + + +def test_rejects_when_attached_ids_empty_but_evidence_id_set(): + payload = _wrap( + [{"claim": "x", "label": "STATED", "evidence_id": 1, "rationale": "r"}] + ) + result = ClaimParser.parse(payload, attached_ids=set()) + assert not result.ok + + +def test_rejects_claim_item_not_object(): + payload = '{"claims":[1]}' + result = ClaimParser.parse(payload, attached_ids={1}) + assert not result.ok + assert "object" in (result.error or "") diff --git a/phase4_grounding/tests/test_prompt.py b/phase4_grounding/tests/test_prompt.py new file mode 100644 index 0000000..5caae54 --- /dev/null +++ b/phase4_grounding/tests/test_prompt.py @@ -0,0 +1,136 @@ +"""Tests for grounding.prompt.PromptBuilder.""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from phase4_grounding.grounding.evidence import EvidenceAttacher +from phase4_grounding.grounding.models import Compound, EvidenceItem, SampleRow +from phase4_grounding.grounding.prompt import PromptBuilder + + +def _row( + *, + name: str = "Aspirin", + smiles: str = "CC(=O)Oc1ccccc1C(=O)O", + formula: str = "C9H8O4", + question: str = "How does aspirin inhibit COX-1?", + answer: str = "Aspirin acetylates serine 530.", + evidence: tuple[EvidenceItem, ...] = (), +) -> SampleRow: + return SampleRow( + cid=1, + qa_index=1, + topic="mechanism", + split="train", + evidence_ids_nonempty=bool(evidence), + compound=Compound(cid=1, name=name, smiles=smiles, molecular_formula=formula), + question=question, + phase2_answer=answer, + evidence_attached=evidence, + ) + + +def test_prompt_contains_all_required_fields(): + evidence = ( + EvidenceItem(id=1, text="Aspirin acetylates COX-1.", pmid="111", source="abstract"), + EvidenceItem(id=2, text="COX-1 inhibition reduces TXA2.", pmid="111", source="abstract"), + ) + row = _row(evidence=evidence) + out = PromptBuilder().build(row) + + assert "Aspirin" in out + assert "CC(=O)Oc1ccccc1C(=O)O" in out + assert "C9H8O4" in out + assert "How does aspirin inhibit COX-1?" in out + assert "Aspirin acetylates serine 530." in out + + +def test_evidence_is_numbered_E1_E2(): + evidence = ( + EvidenceItem(id=1, text="First sentence.", pmid="p", source="abstract"), + EvidenceItem(id=2, text="Second sentence.", pmid="p", source="abstract"), + EvidenceItem(id=3, text="Third sentence.", pmid="p", source="abstract"), + ) + out = PromptBuilder().build(_row(evidence=evidence)) + assert "[E1] First sentence." in out + assert "[E2] Second sentence." in out + assert "[E3] Third sentence." in out + # ordering is preserved + assert out.index("[E1]") < out.index("[E2]") < out.index("[E3]") + + +def test_no_unfilled_placeholders(): + row = _row(evidence=(EvidenceItem(id=1, text="e", pmid="p", source="abstract"),)) + out = PromptBuilder().build(row) + # any leftover {{FOO}} marker would be a templating bug + assert "{{" not in out + assert "}}" not in out + + +def test_label_definitions_present_in_prompt(): + out = PromptBuilder().build(_row(evidence=(EvidenceItem(id=1, text="e"),))) + for label in ("STATED", "IMPLIED", "UNSUPPORTED", "STRUCTURAL"): + assert label in out + + +def test_strict_json_directive_present(): + out = PromptBuilder().build(_row(evidence=(EvidenceItem(id=1, text="e"),))) + # The judge must be told to emit JSON only — this is load-bearing for the parser. + assert "STRICT JSON" in out or "strict JSON" in out.lower() or "JSON" in out + + +def test_handles_empty_evidence(): + out = PromptBuilder().build(_row(evidence=())) + # When nothing is attached, the prompt should still render a placeholder + # so the judge knows there are no [E#] ids to cite. + assert "no evidence" in out.lower() + + +def test_template_loaded_from_disk_only_once(tmp_path): + """The template should be cached after the first read.""" + template = tmp_path / "tpl.txt" + template.write_text("name: {{COMPOUND_NAME}}\n") + builder = PromptBuilder(template_path=template) + assert "name: Aspirin" in builder.build(_row()) + # Mutate the file on disk; cached template should win on the next call. + template.write_text("name: {{COMPOUND_NAME}} mutated\n") + assert "mutated" not in builder.build(_row()) + + +def test_snapshot_against_aspirin_qa(tiny_dataset_records): + """End-to-end render using EvidenceAttacher + PromptBuilder against a real + fixture row. Compares the full rendered string to a golden snapshot.""" + aspirin = tiny_dataset_records[0] + qa = aspirin["qa_pairs"][0] # mechanism QA, evidence_ids=[1,2] + attached = EvidenceAttacher.attach(qa, aspirin) + row = SampleRow( + cid=int(aspirin["cid"]), + qa_index=int(qa["qa_index"]), + topic=qa["topic"], + split=aspirin["split"], + evidence_ids_nonempty=True, + compound=Compound( + cid=int(aspirin["cid"]), + name=aspirin["name"], + smiles=aspirin["smiles"], + molecular_formula=aspirin["molecular_formula"], + ), + question=qa["question"], + phase2_answer=qa["phase2_answer"], + evidence_attached=attached, + ) + out = PromptBuilder().build(row) + + # Spot-check the rendered evidence block specifically (this is the part the + # parser depends on for [E#] integrity). + expected_block = ( + "[E1] [COMPOUND] irreversibly acetylates serine 530 of cyclooxygenase-1 (COX-1).\n" + "[E2] Inhibition of COX-1 by [COMPOUND] reduces thromboxane A2 production in platelets." + ) + assert expected_block in out + + # Ensure the rendered prompt is stable: re-rendering with the same row gives + # byte-identical output. + assert out == PromptBuilder().build(row) diff --git a/phase4_grounding/tests/test_reporter.py b/phase4_grounding/tests/test_reporter.py new file mode 100644 index 0000000..ea00032 --- /dev/null +++ b/phase4_grounding/tests/test_reporter.py @@ -0,0 +1,113 @@ +"""Tests for grounding.reporter.Reporter.""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from phase4_grounding.grounding.aggregator import Aggregator +from phase4_grounding.grounding.models import Claim, JudgedQA +from phase4_grounding.grounding.reporter import Reporter + + +def _qa(cid: int, labels: list[str], *, topic: str = "mechanism") -> JudgedQA: + return JudgedQA( + cid=cid, + qa_index=1, + topic=topic, + evidence_ids_nonempty=True, + num_evidence_attached=1, + model="m", + claims=tuple( + Claim(claim=f"c{i}", label=lbl, evidence_id=None, rationale="r") + for i, lbl in enumerate(labels) + ), + prompt_tokens=1, + completion_tokens=1, + latency_ms=1, + split="train", + ) + + +@pytest.fixture +def sample_metrics(): + qas = [ + _qa(1, ["STATED", "STATED", "UNSUPPORTED"]), + _qa(2, ["IMPLIED", "STRUCTURAL"]), + _qa(3, ["UNSUPPORTED"]), + ] + agg = Aggregator(qas) + return agg.compute("keep"), agg.compute("drop"), qas + + +def test_writes_both_summary_files(sample_metrics, tmp_out_dir: Path): + keep, drop, qas = sample_metrics + reporter = Reporter(keep, drop, qas) + keep_path, drop_path = reporter.write(tmp_out_dir) + assert keep_path.exists() + assert drop_path.exists() + assert keep_path.name == "grounding_summary_keep_structural.md" + assert drop_path.name == "grounding_summary_drop_structural.md" + + +def test_each_summary_cross_links_the_other(sample_metrics, tmp_out_dir: Path): + keep, drop, qas = sample_metrics + keep_path, drop_path = Reporter(keep, drop, qas).write(tmp_out_dir) + assert "grounding_summary_drop_structural.md" in keep_path.read_text() + assert "grounding_summary_keep_structural.md" in drop_path.read_text() + + +def test_decision_string_above_20pct(tmp_out_dir: Path): + qas = [_qa(1, ["UNSUPPORTED"] * 30 + ["STATED"] * 70)] + agg = Aggregator(qas) + reporter = Reporter(agg.compute("keep"), agg.compute("drop"), qas) + keep_path, _ = reporter.write(tmp_out_dir) + text = keep_path.read_text() + assert "NARROW" in text + assert "30.0%" in text or "30.00%" in text + + +def test_decision_string_below_10pct(tmp_out_dir: Path): + qas = [_qa(1, ["UNSUPPORTED"] * 5 + ["STATED"] * 95)] + agg = Aggregator(qas) + reporter = Reporter(agg.compute("keep"), agg.compute("drop"), qas) + keep_path, _ = reporter.write(tmp_out_dir) + text = keep_path.read_text() + assert "WELL-BEHAVED" in text + + +def test_decision_string_in_caveat_band(tmp_out_dir: Path): + qas = [_qa(1, ["UNSUPPORTED"] * 15 + ["STATED"] * 85)] + agg = Aggregator(qas) + reporter = Reporter(agg.compute("keep"), agg.compute("drop"), qas) + keep_path, _ = reporter.write(tmp_out_dir) + text = keep_path.read_text() + assert "CAVEAT" in text + + +def test_keep_view_mentions_structural_count(sample_metrics, tmp_out_dir: Path): + keep, drop, qas = sample_metrics + keep_path, _ = Reporter(keep, drop, qas).write(tmp_out_dir) + text = keep_path.read_text() + assert "STRUCTURAL claims" in text + + +def test_rejects_swapped_views(sample_metrics): + keep, drop, qas = sample_metrics + # Pass them swapped — should raise. + with pytest.raises(ValueError): + Reporter(metrics_keep=drop, metrics_drop=keep, judged_qas=qas) + + +def test_summary_includes_topic_breakdown(tmp_out_dir: Path): + qas = [ + _qa(1, ["STATED"], topic="mechanism"), + _qa(2, ["UNSUPPORTED"], topic="toxicity"), + ] + agg = Aggregator(qas) + reporter = Reporter(agg.compute("keep"), agg.compute("drop"), qas) + keep_path, _ = reporter.write(tmp_out_dir) + text = keep_path.read_text() + assert "By topic" in text + assert "mechanism" in text + assert "toxicity" in text diff --git a/phase4_grounding/tests/test_sampling.py b/phase4_grounding/tests/test_sampling.py new file mode 100644 index 0000000..aae6324 --- /dev/null +++ b/phase4_grounding/tests/test_sampling.py @@ -0,0 +1,190 @@ +"""Tests for grounding.sampling.Sampler.""" +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from phase4_grounding.grounding.sampling import ( + DEFAULT_TOPIC_WEIGHTS, + HEADLINE_TOPICS, + Sampler, +) + + +def test_allocate_sums_to_n_for_various_sizes(tiny_dataset_path): + s = Sampler(tiny_dataset_path) + for n in (1, 6, 10, 17, 300, 600): + alloc = s._allocate(n) + assert sum(alloc.values()) == n + assert set(alloc) == set(DEFAULT_TOPIC_WEIGHTS) + + +def test_allocate_rejects_nonpositive(tiny_dataset_path): + s = Sampler(tiny_dataset_path) + with pytest.raises(ValueError): + s._allocate(0) + with pytest.raises(ValueError): + s._allocate(-3) + + +def test_weights_must_sum_to_one(tiny_dataset_path): + with pytest.raises(ValueError): + Sampler(tiny_dataset_path, topic_weights={"mechanism": 0.5}) + + +def test_topic_key_collapses_non_headline(tiny_dataset_path): + s = Sampler(tiny_dataset_path) + for t in HEADLINE_TOPICS: + assert s._topic_key(t) == t + assert s._topic_key("adme") == "_other" + assert s._topic_key("design_levers") == "_other" + assert s._topic_key("Therapeutic-Use") == "therapeutic_use" + + +def test_index_qa_only_functional(tiny_dataset_path): + s = Sampler(tiny_dataset_path) + idx = s._index_qa() + # tiny_dataset has mechanism (x2), metabolism (x2), engineering (x2), toxicity (x2) + assert len(idx["mechanism"]) == 2 + assert len(idx["metabolism"]) == 2 + assert len(idx["engineering"]) == 2 + assert len(idx["toxicity"]) == 2 + # therapeutic_use absent in tiny_dataset + assert idx["therapeutic_use"] == [] + + +def test_sample_returns_exact_n_when_pools_sufficient(tmp_path: Path): + """Build a synthetic dataset large enough to satisfy any allocation.""" + records = [] + cid = 1 + topics_to_seed = ["mechanism", "engineering", "metabolism", "toxicity", "therapeutic_use", "adme"] + for t in topics_to_seed: + for evidence_branch in (True, False): + for _ in range(20): # 20 of each (topic, evidence-branch) + evidence_ids = [1] if evidence_branch else [] + records.append({ + "cid": cid, + "split": "train", + "name": f"C{cid}", + "smiles": "C", + "molecular_formula": "CH4", + "evidence_sentences": [{"id": 1, "text": "evidence", "pmid": "1", "source": "abstract"}], + "qa_pairs": [{ + "qa_index": 1, + "topic": t, + "question": f"Q for {t}?", + "phase1_answer": "A1", + "phase2_answer": "A2", + "verdict": "agree", + "judge_reasoning": "r", + "evidence_ids": evidence_ids, + }], + }) + cid += 1 + + path = tmp_path / "ds.jsonl" + path.write_text("\n".join(json.dumps(r) for r in records) + "\n") + s = Sampler(path, seed=42) + rows = s.sample(60) + assert len(rows) == 60 + + +def test_sample_is_deterministic_with_seed(tmp_path: Path): + records = [] + for i in range(50): + records.append({ + "cid": i + 1, + "split": "train", + "name": f"C{i}", + "smiles": "C", + "molecular_formula": "CH4", + "evidence_sentences": [{"id": 1, "text": "e", "pmid": "1", "source": "abstract"}], + "qa_pairs": [{ + "qa_index": 1, + "topic": "mechanism", + "question": "Q?", + "phase1_answer": "a", + "phase2_answer": "a", + "verdict": "agree", + "judge_reasoning": "r", + "evidence_ids": [1] if i % 2 == 0 else [], + }], + }) + path = tmp_path / "ds.jsonl" + path.write_text("\n".join(json.dumps(r) for r in records) + "\n") + + rows_a = Sampler(path, seed=7).sample(8) + rows_b = Sampler(path, seed=7).sample(8) + rows_c = Sampler(path, seed=8).sample(8) + + keys_a = [(r.cid, r.qa_index) for r in rows_a] + keys_b = [(r.cid, r.qa_index) for r in rows_b] + keys_c = [(r.cid, r.qa_index) for r in rows_c] + assert keys_a == keys_b + assert keys_a != keys_c + + +def test_sample_handles_exhausted_strata_gracefully(tiny_dataset_path): + """tiny_dataset has only 4 functional topics with 2 QA each and no + therapeutic_use or _other. Sampling more than the pool size should fall back + instead of raising.""" + s = Sampler(tiny_dataset_path, seed=0) + # request 4 to keep the math simple; allocation will request slots from + # therapeutic_use / _other strata that are empty in tiny_dataset + rows = s.sample(4) + # we should still get at most as many rows as functional QA in the dataset (8) + assert 0 < len(rows) <= 8 + # and each returned row must be a functional topic + from phase4_grounding.grounding.topic_bucket import bucket_topic + for r in rows: + assert bucket_topic(r.topic) == "functional" + + +def test_evidence_branch_split_50_50_when_both_sides_have_supply(tmp_path: Path): + """Within a single topic, half of the slot should come from each evidence branch.""" + records = [] + cid = 1 + for evidence_branch in (True, False): + for _ in range(10): + records.append({ + "cid": cid, + "split": "train", + "name": f"C{cid}", + "smiles": "C", + "molecular_formula": "CH4", + "evidence_sentences": [{"id": 1, "text": "e", "pmid": "1", "source": "abstract"}], + "qa_pairs": [{ + "qa_index": 1, + "topic": "mechanism", + "question": "Q?", + "phase1_answer": "a", + "phase2_answer": "a", + "verdict": "agree", + "judge_reasoning": "r", + "evidence_ids": [1] if evidence_branch else [], + }], + }) + cid += 1 + path = tmp_path / "ds.jsonl" + path.write_text("\n".join(json.dumps(r) for r in records) + "\n") + + # Use a custom weight: 100% mechanism, so allocation == n + s = Sampler( + path, + seed=0, + topic_weights={ + "mechanism": 1.0, + "engineering": 0.0, + "metabolism": 0.0, + "toxicity": 0.0, + "therapeutic_use": 0.0, + "_other": 0.0, + }, + ) + rows = s.sample(8) + nonempty = sum(1 for r in rows if r.evidence_ids_nonempty) + empty = len(rows) - nonempty + assert nonempty == 4 + assert empty == 4