Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.0"
description = "Open Memory Benchmark"
requires-python = ">=3.11"
dependencies = [
"anthropic>=0.84.0",
"datasets>=2.0",
"typer>=0.12",
"rich>=13",
Expand Down
67 changes: 59 additions & 8 deletions src/memory_bench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
load_dotenv(dotenv_path=Path(__file__).parents[2] / ".env", override=True)

from .dataset import REGISTRY as DATASET_REGISTRY, get_dataset
from .llm import REGISTRY as LLM_REGISTRY, get_llm, get_answer_llm
from .llm import REGISTRY as LLM_REGISTRY, get_answer_llm
from .memory import REGISTRY as MEMORY_REGISTRY, get_memory_provider
from .modes import REGISTRY as MODE_REGISTRY, get_mode
from .runner import EvalRunner
Expand All @@ -22,12 +22,63 @@
console = Console()


def _resolve_gemini_key() -> None:
key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
if not key:
typer.echo("Error: GEMINI_API_KEY environment variable is not set.", err=True)
def _ensure_provider_env(provider: str, role: str) -> None:
if provider not in LLM_REGISTRY:
typer.echo(
f"Error: unknown {role.lower()} LLM provider '{provider}'. Available: {', '.join(LLM_REGISTRY)}.",
err=True,
)
raise typer.Exit(1)
os.environ["GOOGLE_API_KEY"] = key

if provider == "anthropic":
if not os.environ.get("ANTHROPIC_API_KEY"):
typer.echo(f"Error: {role} LLM provider '{provider}' requires ANTHROPIC_API_KEY.", err=True)
raise typer.Exit(1)
return

if provider == "gemini":
key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
if not key:
typer.echo(f"Error: {role} LLM provider '{provider}' requires GEMINI_API_KEY.", err=True)
raise typer.Exit(1)
os.environ["GOOGLE_API_KEY"] = key
return

if provider == "groq":
if not os.environ.get("GROQ_API_KEY"):
typer.echo(f"Error: {role} LLM provider '{provider}' requires GROQ_API_KEY.", err=True)
raise typer.Exit(1)
return

if provider == "openai":
if not os.environ.get("OPENAI_API_KEY"):
typer.echo(f"Error: {role} LLM provider '{provider}' requires OPENAI_API_KEY.", err=True)
raise typer.Exit(1)
return


def _validate_run_env(memory: str, mode: str, answer_provider: str | None = None) -> None:
if answer_provider is not None:
os.environ["OMB_ANSWER_LLM"] = answer_provider

answer_provider = os.environ.get("OMB_ANSWER_LLM", "groq")
judge_provider = os.environ.get("OMB_JUDGE_LLM", "gemini")
_ensure_provider_env(answer_provider, "Answer")
_ensure_provider_env(judge_provider, "Judge")

if mode == "agentic-rag" and answer_provider != "gemini":
typer.echo(
f"Error: response mode 'agentic-rag' requires a tool-capable LLM provider; '{answer_provider}' is not supported.",
err=True,
)
raise typer.Exit(1)

if memory == "hindsight":
key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
if not key:
typer.echo("Error: memory provider 'hindsight' requires GEMINI_API_KEY for embedded extraction.", err=True)
raise typer.Exit(1)
os.environ["GOOGLE_API_KEY"] = key


@app.command()
Expand All @@ -36,7 +87,7 @@ def run(
dataset: str = typer.Option("tempo", "--dataset", help=f"Dataset. Available: {', '.join(DATASET_REGISTRY)}"),
memory: str = typer.Option("bm25", "--memory", "-m", help=f"Memory provider. Available: {', '.join(MEMORY_REGISTRY)}"),
mode: str = typer.Option("rag", "--mode", help=f"Response mode. Available: {', '.join(MODE_REGISTRY)}"),
llm: str = typer.Option("gemini", "--llm", help=f"LLM for answer generation. Available: {', '.join(LLM_REGISTRY)}"),
llm: str | None = typer.Option(None, "--llm", help=f"LLM provider for answer generation. Overrides OMB_ANSWER_LLM. Available: {', '.join(LLM_REGISTRY)}"),
category: str = typer.Option(None, "--category", "-c", help="Category filter(s), comma-separated (e.g. 'a,b,c'). With --query-limit, runs N queries per category."),
query_limit: int = typer.Option(None, "--query-limit", "-q", help="Max queries to evaluate. When combined with multiple --category values, applies per category."),
query_id: str = typer.Option(None, "--query-id", help="Run a single specific query by ID"),
Expand All @@ -53,7 +104,7 @@ def run(
description: str = typer.Option(None, "--description", "-d", help="Optional description for this run (stored in the result JSON)"),
) -> None:
"""Run an evaluation on a single split (optionally filtered to a category)."""
_resolve_gemini_key()
_validate_run_env(memory, mode, llm)

ds = get_dataset(dataset)

Expand Down
2 changes: 2 additions & 0 deletions src/memory_bench/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os

from .anthropic import AnthropicLLM
from .base import LLM, Schema
from .gemini import GeminiLLM
from .groq import GroqLLM
from .openai import OpenAILLM

REGISTRY: dict[str, type[LLM]] = {
"anthropic": AnthropicLLM,
"gemini": GeminiLLM,
"groq": GroqLLM,
"openai": OpenAILLM,
Expand Down
183 changes: 183 additions & 0 deletions src/memory_bench/llm/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import json
import os
import re
import time

from .base import LLM, Schema

_MAX_RETRIES = 6
_RETRY_BASE_DELAY = 5


class _StructuredOutputError(ValueError):
"""Raised when the model response does not match the requested schema."""


def _parse_json_payload(text: str) -> dict:
text = text.strip()

try:
payload = json.loads(text)
except json.JSONDecodeError:
pass
else:
if not isinstance(payload, dict):
raise _StructuredOutputError("Model response must be a JSON object")
return payload

fenced = re.search(r"```(?:json)?\s*(\{.*\})\s*```", text, flags=re.DOTALL | re.IGNORECASE)
if fenced:
payload = json.loads(fenced.group(1))
if not isinstance(payload, dict):
raise _StructuredOutputError("Model response must be a JSON object")
return payload

start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
payload = json.loads(text[start : end + 1])
if not isinstance(payload, dict):
raise _StructuredOutputError("Model response must be a JSON object")
return payload

raise json.JSONDecodeError("Could not find JSON object in model response", text, 0)


def _coerce_text_payload(text: str, schema: Schema) -> dict | None:
text = text.strip()
if not text:
return None
if len(schema.required) != 1:
return None

field = schema.required[0]
spec = schema.properties.get(field, {})
field_type = spec.get("type", "string")

if field_type == "string":
return {field: text}

if field_type == "boolean":
lowered = text.lower()
if lowered == "true":
return {field: True}
if lowered == "false":
return {field: False}

return None


def _validate_schema_payload(payload: dict, schema: Schema) -> dict:
extra = sorted(set(payload) - set(schema.properties))
if extra:
raise _StructuredOutputError(f"Model response included unsupported field(s): {', '.join(extra)}")

missing = [field for field in schema.required if field not in payload]
if missing:
raise _StructuredOutputError(f"Model response omitted required field(s): {', '.join(missing)}")

for field, value in payload.items():
spec = schema.properties.get(field, {})
expected_type = spec.get("type", "string")
if expected_type == "string":
valid = isinstance(value, str)
elif expected_type == "boolean":
valid = isinstance(value, bool)
elif expected_type == "integer":
valid = isinstance(value, int) and not isinstance(value, bool)
elif expected_type == "number":
valid = isinstance(value, (int, float)) and not isinstance(value, bool)
elif expected_type == "array":
valid = isinstance(value, list)
elif expected_type == "object":
valid = isinstance(value, dict)
else:
valid = True

if not valid:
raise _StructuredOutputError(
f"Model response field '{field}' must be {expected_type}, got {type(value).__name__}"
)

return payload


class AnthropicLLM(LLM):
def __init__(self, model: str | None = None):
from anthropic import Anthropic

api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise RuntimeError("Anthropic provider requires ANTHROPIC_API_KEY")

base_url = os.environ.get("ANTHROPIC_BASE_URL")
self._client = Anthropic(
api_key=api_key,
base_url=base_url or None,
max_retries=0,
)
self._model = (
model
or os.environ.get("ANTHROPIC_MODEL")
or "claude-sonnet-4-5"
)

@property
def model_id(self) -> str:
return f"anthropic:{self._model}"

def generate(self, prompt: str, schema: Schema) -> dict:
from anthropic import APIConnectionError, APIStatusError, RateLimitError

schema_json = {
"type": "object",
"properties": schema.properties,
"required": schema.required,
"additionalProperties": False,
}
system_prompt = (
"Return only a valid JSON object matching this schema. "
"Do not wrap JSON in markdown fences.\n\n"
f"{json.dumps(schema_json, ensure_ascii=False)}"
)

delay = _RETRY_BASE_DELAY
last_exc = None

for attempt in range(_MAX_RETRIES):
try:
response = self._client.messages.create(
model=self._model,
max_tokens=4096,
temperature=0.0,
system=system_prompt,
messages=[{"role": "user", "content": prompt}],
)
text = "".join(block.text for block in response.content if getattr(block, "type", None) == "text")
try:
payload = _parse_json_payload(text)
except json.JSONDecodeError:
coerced = _coerce_text_payload(text, schema)
if coerced is None:
raise _StructuredOutputError("Model response was not valid JSON") from None
payload = coerced
return _validate_schema_payload(payload, schema)
except (RateLimitError, APIConnectionError) as e:
last_exc = e
except APIStatusError as e:
last_exc = e
if e.status_code not in (429, 500, 502, 503, 504):
raise
except _StructuredOutputError as e:
last_exc = e
except Exception as e:
last_exc = e
msg = str(e)
if "429" not in msg and "rate" not in msg.lower():
raise

if attempt < _MAX_RETRIES - 1:
time.sleep(delay)
delay *= 2

raise RuntimeError(f"Anthropic request failed after {_MAX_RETRIES} retries: {last_exc}")
6 changes: 4 additions & 2 deletions src/memory_bench/modes/agentic_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .base import ResponseMode
from .rag import RAGMode, _OPEN_SCHEMA, _MCQ_SCHEMA
from ..dataset.base import _DEFAULT_OPEN_PROMPT as _OPEN_PROMPT, _DEFAULT_MCQ_PROMPT as _MCQ_PROMPT
from ..llm.base import ToolDef
from ..llm.base import LLM, ToolDef
from ..llm.gemini import GeminiLLM
from ..memory.base import MemoryProvider
from ..models import AnswerResult
Expand All @@ -23,8 +23,10 @@ class AgenticRAGMode(ResponseMode):
name = "agentic-rag"
description = "The LLM acts as an agent with a recall tool and can make multiple retrieval calls with different queries before finalising its answer."

def __init__(self, llm: GeminiLLM | None = None, k: int = 10):
def __init__(self, llm: LLM | None = None, k: int = 10):
self._llm = llm or GeminiLLM()
if type(self._llm).tool_loop is LLM.tool_loop:
raise ValueError(f"{self._llm.model_id} does not support agentic-rag tool calling")
self._rag = RAGMode(llm=self._llm, k=k)
self.k = k

Expand Down
4 changes: 4 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.