diff --git a/pyproject.toml b/pyproject.toml index dcb945b..c5e34da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ version = "0.1.0" description = "Open Memory Benchmark" requires-python = ">=3.11" dependencies = [ + "anthropic>=0.84.0", "datasets>=2.0", "typer>=0.12", "rich>=13", diff --git a/src/memory_bench/cli.py b/src/memory_bench/cli.py index c3ea8e4..506ffe2 100644 --- a/src/memory_bench/cli.py +++ b/src/memory_bench/cli.py @@ -12,7 +12,7 @@ load_dotenv(dotenv_path=Path(__file__).parents[2] / ".env", override=True) from .dataset import REGISTRY as DATASET_REGISTRY, get_dataset -from .llm import REGISTRY as LLM_REGISTRY, get_llm, get_answer_llm +from .llm import REGISTRY as LLM_REGISTRY, get_answer_llm from .memory import REGISTRY as MEMORY_REGISTRY, get_memory_provider from .modes import REGISTRY as MODE_REGISTRY, get_mode from .runner import EvalRunner @@ -22,12 +22,63 @@ console = Console() -def _resolve_gemini_key() -> None: - key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") - if not key: - typer.echo("Error: GEMINI_API_KEY environment variable is not set.", err=True) +def _ensure_provider_env(provider: str, role: str) -> None: + if provider not in LLM_REGISTRY: + typer.echo( + f"Error: unknown {role.lower()} LLM provider '{provider}'. Available: {', '.join(LLM_REGISTRY)}.", + err=True, + ) raise typer.Exit(1) - os.environ["GOOGLE_API_KEY"] = key + + if provider == "anthropic": + if not os.environ.get("ANTHROPIC_API_KEY"): + typer.echo(f"Error: {role} LLM provider '{provider}' requires ANTHROPIC_API_KEY.", err=True) + raise typer.Exit(1) + return + + if provider == "gemini": + key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") + if not key: + typer.echo(f"Error: {role} LLM provider '{provider}' requires GEMINI_API_KEY.", err=True) + raise typer.Exit(1) + os.environ["GOOGLE_API_KEY"] = key + return + + if provider == "groq": + if not os.environ.get("GROQ_API_KEY"): + typer.echo(f"Error: {role} LLM provider '{provider}' requires GROQ_API_KEY.", err=True) + raise typer.Exit(1) + return + + if provider == "openai": + if not os.environ.get("OPENAI_API_KEY"): + typer.echo(f"Error: {role} LLM provider '{provider}' requires OPENAI_API_KEY.", err=True) + raise typer.Exit(1) + return + + +def _validate_run_env(memory: str, mode: str, answer_provider: str | None = None) -> None: + if answer_provider is not None: + os.environ["OMB_ANSWER_LLM"] = answer_provider + + answer_provider = os.environ.get("OMB_ANSWER_LLM", "groq") + judge_provider = os.environ.get("OMB_JUDGE_LLM", "gemini") + _ensure_provider_env(answer_provider, "Answer") + _ensure_provider_env(judge_provider, "Judge") + + if mode == "agentic-rag" and answer_provider != "gemini": + typer.echo( + f"Error: response mode 'agentic-rag' requires a tool-capable LLM provider; '{answer_provider}' is not supported.", + err=True, + ) + raise typer.Exit(1) + + if memory == "hindsight": + key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") + if not key: + typer.echo("Error: memory provider 'hindsight' requires GEMINI_API_KEY for embedded extraction.", err=True) + raise typer.Exit(1) + os.environ["GOOGLE_API_KEY"] = key @app.command() @@ -36,7 +87,7 @@ def run( dataset: str = typer.Option("tempo", "--dataset", help=f"Dataset. Available: {', '.join(DATASET_REGISTRY)}"), memory: str = typer.Option("bm25", "--memory", "-m", help=f"Memory provider. Available: {', '.join(MEMORY_REGISTRY)}"), mode: str = typer.Option("rag", "--mode", help=f"Response mode. Available: {', '.join(MODE_REGISTRY)}"), - llm: str = typer.Option("gemini", "--llm", help=f"LLM for answer generation. Available: {', '.join(LLM_REGISTRY)}"), + llm: str | None = typer.Option(None, "--llm", help=f"LLM provider for answer generation. Overrides OMB_ANSWER_LLM. Available: {', '.join(LLM_REGISTRY)}"), category: str = typer.Option(None, "--category", "-c", help="Category filter(s), comma-separated (e.g. 'a,b,c'). With --query-limit, runs N queries per category."), query_limit: int = typer.Option(None, "--query-limit", "-q", help="Max queries to evaluate. When combined with multiple --category values, applies per category."), query_id: str = typer.Option(None, "--query-id", help="Run a single specific query by ID"), @@ -53,7 +104,7 @@ def run( description: str = typer.Option(None, "--description", "-d", help="Optional description for this run (stored in the result JSON)"), ) -> None: """Run an evaluation on a single split (optionally filtered to a category).""" - _resolve_gemini_key() + _validate_run_env(memory, mode, llm) ds = get_dataset(dataset) diff --git a/src/memory_bench/llm/__init__.py b/src/memory_bench/llm/__init__.py index 99be8ec..49b2134 100644 --- a/src/memory_bench/llm/__init__.py +++ b/src/memory_bench/llm/__init__.py @@ -1,11 +1,13 @@ import os +from .anthropic import AnthropicLLM from .base import LLM, Schema from .gemini import GeminiLLM from .groq import GroqLLM from .openai import OpenAILLM REGISTRY: dict[str, type[LLM]] = { + "anthropic": AnthropicLLM, "gemini": GeminiLLM, "groq": GroqLLM, "openai": OpenAILLM, diff --git a/src/memory_bench/llm/anthropic.py b/src/memory_bench/llm/anthropic.py new file mode 100644 index 0000000..b5228a1 --- /dev/null +++ b/src/memory_bench/llm/anthropic.py @@ -0,0 +1,183 @@ +import json +import os +import re +import time + +from .base import LLM, Schema + +_MAX_RETRIES = 6 +_RETRY_BASE_DELAY = 5 + + +class _StructuredOutputError(ValueError): + """Raised when the model response does not match the requested schema.""" + + +def _parse_json_payload(text: str) -> dict: + text = text.strip() + + try: + payload = json.loads(text) + except json.JSONDecodeError: + pass + else: + if not isinstance(payload, dict): + raise _StructuredOutputError("Model response must be a JSON object") + return payload + + fenced = re.search(r"```(?:json)?\s*(\{.*\})\s*```", text, flags=re.DOTALL | re.IGNORECASE) + if fenced: + payload = json.loads(fenced.group(1)) + if not isinstance(payload, dict): + raise _StructuredOutputError("Model response must be a JSON object") + return payload + + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1 and end > start: + payload = json.loads(text[start : end + 1]) + if not isinstance(payload, dict): + raise _StructuredOutputError("Model response must be a JSON object") + return payload + + raise json.JSONDecodeError("Could not find JSON object in model response", text, 0) + + +def _coerce_text_payload(text: str, schema: Schema) -> dict | None: + text = text.strip() + if not text: + return None + if len(schema.required) != 1: + return None + + field = schema.required[0] + spec = schema.properties.get(field, {}) + field_type = spec.get("type", "string") + + if field_type == "string": + return {field: text} + + if field_type == "boolean": + lowered = text.lower() + if lowered == "true": + return {field: True} + if lowered == "false": + return {field: False} + + return None + + +def _validate_schema_payload(payload: dict, schema: Schema) -> dict: + extra = sorted(set(payload) - set(schema.properties)) + if extra: + raise _StructuredOutputError(f"Model response included unsupported field(s): {', '.join(extra)}") + + missing = [field for field in schema.required if field not in payload] + if missing: + raise _StructuredOutputError(f"Model response omitted required field(s): {', '.join(missing)}") + + for field, value in payload.items(): + spec = schema.properties.get(field, {}) + expected_type = spec.get("type", "string") + if expected_type == "string": + valid = isinstance(value, str) + elif expected_type == "boolean": + valid = isinstance(value, bool) + elif expected_type == "integer": + valid = isinstance(value, int) and not isinstance(value, bool) + elif expected_type == "number": + valid = isinstance(value, (int, float)) and not isinstance(value, bool) + elif expected_type == "array": + valid = isinstance(value, list) + elif expected_type == "object": + valid = isinstance(value, dict) + else: + valid = True + + if not valid: + raise _StructuredOutputError( + f"Model response field '{field}' must be {expected_type}, got {type(value).__name__}" + ) + + return payload + + +class AnthropicLLM(LLM): + def __init__(self, model: str | None = None): + from anthropic import Anthropic + + api_key = os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + raise RuntimeError("Anthropic provider requires ANTHROPIC_API_KEY") + + base_url = os.environ.get("ANTHROPIC_BASE_URL") + self._client = Anthropic( + api_key=api_key, + base_url=base_url or None, + max_retries=0, + ) + self._model = ( + model + or os.environ.get("ANTHROPIC_MODEL") + or "claude-sonnet-4-5" + ) + + @property + def model_id(self) -> str: + return f"anthropic:{self._model}" + + def generate(self, prompt: str, schema: Schema) -> dict: + from anthropic import APIConnectionError, APIStatusError, RateLimitError + + schema_json = { + "type": "object", + "properties": schema.properties, + "required": schema.required, + "additionalProperties": False, + } + system_prompt = ( + "Return only a valid JSON object matching this schema. " + "Do not wrap JSON in markdown fences.\n\n" + f"{json.dumps(schema_json, ensure_ascii=False)}" + ) + + delay = _RETRY_BASE_DELAY + last_exc = None + + for attempt in range(_MAX_RETRIES): + try: + response = self._client.messages.create( + model=self._model, + max_tokens=4096, + temperature=0.0, + system=system_prompt, + messages=[{"role": "user", "content": prompt}], + ) + text = "".join(block.text for block in response.content if getattr(block, "type", None) == "text") + try: + payload = _parse_json_payload(text) + except json.JSONDecodeError: + coerced = _coerce_text_payload(text, schema) + if coerced is None: + raise _StructuredOutputError("Model response was not valid JSON") from None + payload = coerced + return _validate_schema_payload(payload, schema) + except (RateLimitError, APIConnectionError) as e: + last_exc = e + except APIStatusError as e: + last_exc = e + if e.status_code not in (429, 500, 502, 503, 504): + raise + except _StructuredOutputError as e: + last_exc = e + except Exception as e: + last_exc = e + msg = str(e) + if "429" not in msg and "rate" not in msg.lower(): + raise + + if attempt < _MAX_RETRIES - 1: + time.sleep(delay) + delay *= 2 + + raise RuntimeError(f"Anthropic request failed after {_MAX_RETRIES} retries: {last_exc}") diff --git a/src/memory_bench/modes/agentic_rag.py b/src/memory_bench/modes/agentic_rag.py index 3eaf508..4b94af2 100644 --- a/src/memory_bench/modes/agentic_rag.py +++ b/src/memory_bench/modes/agentic_rag.py @@ -4,7 +4,7 @@ from .base import ResponseMode from .rag import RAGMode, _OPEN_SCHEMA, _MCQ_SCHEMA from ..dataset.base import _DEFAULT_OPEN_PROMPT as _OPEN_PROMPT, _DEFAULT_MCQ_PROMPT as _MCQ_PROMPT -from ..llm.base import ToolDef +from ..llm.base import LLM, ToolDef from ..llm.gemini import GeminiLLM from ..memory.base import MemoryProvider from ..models import AnswerResult @@ -23,8 +23,10 @@ class AgenticRAGMode(ResponseMode): name = "agentic-rag" description = "The LLM acts as an agent with a recall tool and can make multiple retrieval calls with different queries before finalising its answer." - def __init__(self, llm: GeminiLLM | None = None, k: int = 10): + def __init__(self, llm: LLM | None = None, k: int = 10): self._llm = llm or GeminiLLM() + if type(self._llm).tool_loop is LLM.tool_loop: + raise ValueError(f"{self._llm.model_id} does not support agentic-rag tool calling") self._rag = RAGMode(llm=self._llm, k=k) self.k = k diff --git a/uv.lock b/uv.lock index bab77d3..d87b8fc 100644 --- a/uv.lock +++ b/uv.lock @@ -3634,6 +3634,7 @@ name = "omb" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "anthropic" }, { name = "cognee" }, { name = "datasets" }, { name = "fastapi", extra = ["standard"] }, @@ -3646,6 +3647,7 @@ dependencies = [ { name = "qdrant-client" }, { name = "rank-bm25" }, { name = "rich" }, + { name = "scipy" }, { name = "sentence-transformers" }, { name = "supermemory" }, { name = "tiktoken" }, @@ -3655,6 +3657,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "anthropic", specifier = ">=0.84.0" }, { name = "cognee", specifier = ">=0.5.4" }, { name = "datasets", specifier = ">=2.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.135.1" }, @@ -3667,6 +3670,7 @@ requires-dist = [ { name = "qdrant-client", specifier = ">=1.13" }, { name = "rank-bm25", specifier = ">=0.2" }, { name = "rich", specifier = ">=13" }, + { name = "scipy", specifier = ">=1.11" }, { name = "sentence-transformers", specifier = ">=3.0" }, { name = "supermemory", specifier = ">=0.1" }, { name = "tiktoken", specifier = ">=0.12.0" },