diff --git a/environments/careqa/README.md b/environments/careqa/README.md new file mode 100644 index 00000000..65d593e1 --- /dev/null +++ b/environments/careqa/README.md @@ -0,0 +1,89 @@ +# careqa + +Evaluation environment for the [HPAI-BSC/CareQA](https://huggingface.co/datasets/HPAI-BSC/CareQA) dataset. + +### Overview +- **Environment ID**: `careqa` +- **Short description**: CareQA is a healthcare QA dataset with **multiple-choice** and **open-ended clinical reasoning questions**. This environment supports both modes through the `mode` parameter. +- **Tags**: healthcare, medical QA, clinical reasoning, MCQ, single-turn + +### Datasets +- **Primary dataset(s)**: + - `CareQA_en` – multiple-choice clinical questions with 4 options and correct answer labels + - `CareQA_en_open` – open-ended clinical questions with reference answers +- **Source links**: + - [Hugging Face CareQA dataset](https://huggingface.co/datasets/HPAI-BSC/CareQA) + +### Task +- **Type**: single-turn +- **Parser**: + - MCQ mode: `vf.Parser()` or `vf.ThinkParser()` for extracting boxed answers + - Open-ended mode: `XMLParser()` for judge responses +- **Rubric overview**: + - **MCQ mode (`en`)**: `vf.Rubric()` measuring **accuracy** (letter match A–D) + - **Open-ended mode (`open`)**: `vf.JudgeRubric()` using an LLM-as-judge to score free-text answers for correctness and clinical reasoning + +### Quickstart + +**Multiple-choice evaluation:** +```bash +medarc-eval careqa --mode en --model gpt-4.1-mini --num-examples 10 -s +``` + +**Open-ended evaluation:** +```bash +medarc-eval careqa --mode open --model gpt-4.1-mini --num-examples 10 -s +``` + +**With think-mode prompting (MCQ only):** +```bash +medarc-eval careqa --mode en --use-think --model gpt-4.1-mini --num-examples 10 -s +``` + +**With shuffled answer options (MCQ only):** +```bash +medarc-eval careqa --mode en --shuffle-answers --shuffle-seed 42 --model gpt-4.1-mini -n 10 -s +``` + +### Configuration Options + +#### Common Parameters +- `--mode`: Select mode: `en` (multiple-choice) or `open` (open-ended). Default: `open` +- `--split`: Dataset split to use. Default: `test` +- `--system-prompt`: Custom system prompt (uses mode-appropriate default if not specified) + +#### MCQ-Specific Parameters +- `--use-think`: Enable think-style prompting with boxed answers +- `--shuffle-answers`: Randomly shuffle answer options +- `--shuffle-seed`: Seed for answer shuffling (default: 1618) + +#### Open-Ended-Specific Parameters +- `--judge-model`: Model for LLM-as-judge evaluation (default: `gpt-4o-mini`) +- `--judge-base-url`: Base URL for judge API +- `--judge-api-key`: API key for judge (falls back to `OPENAI_API_KEY` env var) + +### Metrics + +#### MCQ Mode +| Metric | Meaning | +|---------------|---------| +| `reward` | Main scalar reward (weighted sum of rubric criteria) | +| `accuracy` | Exact match on target MCQ answer (letter A–D) | + +#### Open-Ended Mode +| Metric | Meaning | +|---------------|---------| +| `reward` | Main scalar reward (weighted sum of rubric criteria) | +| `judge_score` | LLM-assigned score evaluating answer quality, correctness, and clinical reasoning | + +### Example Usage + +```python +import verifiers as vf + +# Load MCQ environment +env_mcq = vf.load_environment("careqa", mode="en", shuffle_answers=True) + +# Load open-ended environment +env_open = vf.load_environment("careqa", mode="open", judge_model="gpt-4o-mini") +``` diff --git a/environments/careqa/careqa.py b/environments/careqa/careqa.py new file mode 100644 index 00000000..e67178af --- /dev/null +++ b/environments/careqa/careqa.py @@ -0,0 +1,251 @@ +import re +from enum import Enum +from typing import Optional + +from datasets import load_dataset +from openai import AsyncOpenAI +import verifiers as vf +from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy +from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice +from medarc_verifiers.parsers.xml_parser import XMLParser +from verifiers.types import Info, State +from verifiers.utils.data_utils import extract_boxed_answer, BOXED_SYSTEM_PROMPT + + +class CareQASplit(Enum): + """Mode selector for CareQA environment.""" + + EN = "en" + OPEN = "open" + + +# --- MCQ Helpers --- + + +def _build_mcq_prompt(question: str, options: dict[str, str]) -> str: + """Create an MCQ prompt.""" + formatted_opts = "\n".join(f"{k}. {v}" for k, v in options.items()) + return f"Question: {question}\nChoices:\n{formatted_opts}\nAnswer:" + + +def accuracy(completion, answer: str, parser: vf.Parser, info: dict | None = None, **kwargs) -> float: + """Reward based on shared multiple-choice accuracy grading.""" + parsed = parser.parse_answer(completion) or "" + answer_text = info.get("answer_text", None) if info else None + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + return 1.0 if is_correct else 0.0 + + +# --- Open-Ended Helpers --- + + +JUDGE_TEMPLATE = """You are grading an AI assistant's answer to a medical/science exam questions. + +Input: +- : The exam question. +- : The correct answer. +- : The AI's response to grade. + +Task: Determine if the assistant's answer is correct or incorrect by comparing it to the reference answer and output your grade in ... tags. + +Grading Rules: +- Assume the reference answer is correct and reflects the expected exam solution. +- Focus on factual content and meaning, not style, length, or confidence. + +Correct if the assistant's answer conveys the same essential fact(s) as the reference, including: +- Synonyms, acronyms (expanded or abbreviated), or rephrasing with equivalent meaning +- Slightly more general/specific phrasing that captures the key concept +- Shorter or longer answers that express the tested fact without contradictions +- Additional supporting details that don't contradict the reference + +Incorrect if any of these apply: +- Different main concept, mechanism, structure, or relationship +- Contradicts the reference on key points (wrong organ, drug, effect, process, etc.) +- Contains clearly incorrect information +- Too vague/incomplete to match the reference +- Merely repeats question words without the core information from the reference + +Be strict: clear mismatches on main concepts or incorrect claims = Incorrect. + +{question} +{answer} +{response} + +Briefly explain whether the assistant's answer matches or conflicts with the reference. Then output your grade as: + +[Correct or Incorrect] +""".strip() + + +def extract_answer_section(completion_text: str) -> str: + """Extract final answer after think tags.""" + if not completion_text: + return "" + if "" in completion_text and "" in completion_text: + return re.sub(r".*?", "", completion_text, flags=re.DOTALL).strip() + return completion_text.strip() + + +def load_environment( + split: str | CareQASplit, + system_prompt: Optional[str] = None, + # MCQ-specific options + shuffle_answers: bool = False, + shuffle_seed: int | None = 1618, + # Open-ended specific options + judge_model: str = "gpt-4o-mini", + judge_base_url: str | None = None, + judge_api_key: str | None = None, + **kwargs, +) -> vf.Environment: + """ + CareQA evaluation environment supporting both MCQ and Open-Ended modes. + + Args: + split: CareQASplit.EN for multiple-choice or CareQASplit.OPEN for open-ended QA. + system_prompt: Custom system prompt (uses mode-appropriate default if None). + shuffle_answers: Shuffle MCQ answer options (MCQ mode only). + shuffle_seed: Seed for answer shuffling (MCQ mode only). + judge_model: Model to use for LLM-as-judge evaluation (Open-ended mode only). + judge_base_url: Base URL for judge API (Open-ended mode only). + judge_api_key: API key for judge (Open-ended mode only). + + Returns: + A vf.Environment configured for the selected mode. + """ + split = CareQASplit(split) if isinstance(split, str) else split + if split == CareQASplit.EN: + return _load_mcq_environment( + system_prompt=system_prompt, + shuffle_answers=shuffle_answers, + shuffle_seed=shuffle_seed, + ) + elif split == CareQASplit.OPEN: + return _load_open_ended_environment( + system_prompt=system_prompt, + judge_model=judge_model, + judge_base_url=judge_base_url, + judge_api_key=judge_api_key, + ) + else: + raise ValueError(f"Invalid mode: {split}") + + +def _load_mcq_environment( + system_prompt: Optional[str], + shuffle_answers: bool, + shuffle_seed: int | None, +) -> vf.Environment: + """Load CareQA multiple-choice environment.""" + eval_dataset = load_dataset("HPAI-BSC/CareQA", "CareQA_en", split="test") + + def _map(ex, idx=None): + options = {"A": ex["op1"], "B": ex["op2"], "C": ex["op3"], "D": ex["op4"]} + gold_letter = ["A", "B", "C", "D"][ex["cop"] - 1] + + if shuffle_answers and gold_letter in options: + options, gold_letter, _ = randomize_multiple_choice( + options=options, + answer_choice=gold_letter, + seed=shuffle_seed, + row_id=ex.get("id", idx), + ) + + return { + "question": _build_mcq_prompt(ex["question"], options), + "answer": gold_letter, + "info": { + "answer_text": options.get(gold_letter, None), + **({"options": options} if shuffle_answers else {}), + }, + } + + load_from_cache_file = not shuffle_answers + eval_dataset = eval_dataset.map( + _map, + with_indices=True, + remove_columns=eval_dataset.column_names, + load_from_cache_file=load_from_cache_file, + ) + + parser = vf.Parser(extract_boxed_answer) + final_system_prompt = BOXED_SYSTEM_PROMPT or system_prompt + + rubric = vf.Rubric(funcs=[accuracy], weights=[1.0], parser=parser) + + return vf.SingleTurnEnv( + eval_dataset=eval_dataset, + rubric=rubric, + parser=parser, + system_prompt=final_system_prompt, + ) + + +def _load_open_ended_environment( + system_prompt: Optional[str], + judge_model: str, + judge_base_url: str | None, + judge_api_key: str | None, +) -> vf.Environment: + """Load CareQA open-ended environment with LLM-as-judge evaluation.""" + eval_dataset = load_dataset("HPAI-BSC/CareQA", "CareQA_en_open", split="test") + + def _map(ex): + info = {} + info["question"] = ex["question"].strip() + return { + "question": ex["question"].strip(), + "answer": ex.get("answer_explanation", ex.get("answer", "")), + "task": "careqa_open", + "info": info, + } + + eval_dataset = eval_dataset.map(_map, remove_columns=eval_dataset.column_names) + + final_system_prompt = system_prompt or ( + "Instructions: The following text is a medical question. Answer it in the most factual, concise, and informative way possible." + ) + + # Judge client setup + judge_client = AsyncOpenAI(base_url=judge_base_url, api_key=judge_api_key) + judge_parser = XMLParser(fields=["grade"], answer_field="grade") + + judge_rubric = vf.JudgeRubric( + parser=judge_parser, + judge_client=judge_client, + judge_model=judge_model, + judge_prompt="{question}", + ) + + async def accuracy(judge, prompt, completion, answer, state: State, info: Info) -> float: + """Evaluate medical equivalence using LLM-as-judge.""" + completion_text = completion if isinstance(completion, str) else str(completion) + response = extract_answer_section(completion_text) + + try: + judge_prompt = JUDGE_TEMPLATE.format(question=info.get("question", ""), answer=answer, response=response) + judge_response = await judge_rubric.judge(judge_prompt, "", "", state) + grade = judge_parser.parse_answer(judge_response).strip().lower() + except AttributeError: + judge_response = await judge_rubric.judge(judge_prompt, "", "", state) + grade = judge_parser.parse_answer(judge_response).strip().lower() + + info.setdefault("judge_feedback", []).append( + { + "grade": grade, + "raw_judge": str(judge_response), + } + ) + + if "correct" in grade and "incorrect" not in grade: + return 1.0 + else: + return 0.0 + + judge_rubric.add_reward_func(accuracy, weight=1.0) + + return vf.SingleTurnEnv( + eval_dataset=eval_dataset, + system_prompt=final_system_prompt, + rubric=judge_rubric, + ) diff --git a/environments/careqa/pyproject.toml b/environments/careqa/pyproject.toml new file mode 100644 index 00000000..a875b5ff --- /dev/null +++ b/environments/careqa/pyproject.toml @@ -0,0 +1,26 @@ +[project] +name = "careqa" +description = "Evaluation environment for the HPAI-BSC/CareQA MCQ dataset" +tags = ["healthcare", "medical-qa", "mcq", "clinical", "single-turn", "open-ended"] +version = "0.1.0" +requires-python = ">=3.11" +dependencies = [ + "verifiers>=0.1.4", + "datasets>=2.13.0", + "medarc_verifiers>=0.1.0", +] + +[tool.prime.environment] +loader = "careqa:load_environment" +display_name = "CareQA" +visibility = "PUBLIC" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["careqa.py"] + +[tool.uv.sources] +medarc_verifiers = { git = "https://github.com/MedARC-AI/med-lm-envs" } \ No newline at end of file