diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 33c0dd2..38eebdb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -23,7 +23,6 @@ jobs: python-version: "3.11" cache: "pip" cache-dependency-path: | - libs/openant-core/requirements.txt libs/openant-core/pyproject.toml - name: Set up Node.js @@ -35,7 +34,7 @@ jobs: - name: Install Python dependencies working-directory: libs/openant-core - run: pip install -r requirements.txt && pip install ".[dev]" + run: pip install -e ".[dev]" - name: Cache JS parser node_modules id: cache-node-modules @@ -51,7 +50,7 @@ jobs: - name: Run Python and parser tests working-directory: libs/openant-core - run: python -m pytest tests/test_token_tracker.py tests/test_parser_adapter.py tests/test_python_parser.py tests/test_js_parser.py -v + run: python -m pytest tests/test_token_tracker.py tests/test_parser_adapter.py tests/test_python_parser.py tests/test_js_parser.py tests/test_declared_dependencies.py -v go-tests: name: Go build + integration (${{ matrix.os }}) @@ -79,7 +78,6 @@ jobs: python-version: "3.11" cache: "pip" cache-dependency-path: | - libs/openant-core/requirements.txt libs/openant-core/pyproject.toml - name: Set up Node.js @@ -115,7 +113,7 @@ jobs: - name: Install Python dependencies working-directory: libs/openant-core - run: pip install -r requirements.txt && pip install ".[dev]" + run: pip install -e ".[dev]" - name: Cache JS parser node_modules id: cache-node-modules diff --git a/libs/openant-core/README.md b/libs/openant-core/README.md index 9d466ed..77231c7 100644 --- a/libs/openant-core/README.md +++ b/libs/openant-core/README.md @@ -24,7 +24,7 @@ git clone https://github.com/your-org/openant.git cd openant # Install Python dependencies -pip install -r requirements.txt +pip install -e . # Set API key echo "ANTHROPIC_API_KEY=your-key-here" > .env diff --git a/libs/openant-core/context/application_context.py b/libs/openant-core/context/application_context.py index f7fa55d..442f3e5 100644 --- a/libs/openant-core/context/application_context.py +++ b/libs/openant-core/context/application_context.py @@ -29,9 +29,17 @@ from pathlib import Path from typing import Any -from anthropic import Anthropic from dotenv import load_dotenv +# Ensure libs/openant-core is on sys.path so `utilities.*` imports resolve +# regardless of how this module is loaded. +_OPENANT_CORE_ROOT = str(Path(__file__).parent.parent) +if _OPENANT_CORE_ROOT not in sys.path: + sys.path.insert(0, _OPENANT_CORE_ROOT) + +from utilities.model_config import MODEL_AUXILIARY # noqa: E402 +from utilities.llm_client import AnthropicClient # noqa: E402 + # Load environment variables load_dotenv() @@ -462,7 +470,7 @@ def _build_type_descriptions() -> str: def generate_application_context( repo_path: Path, - model: str = "claude-sonnet-4-20250514", + model: str = MODEL_AUXILIARY, force_regenerate: bool = False, ) -> ApplicationContext: """Generate application context using LLM analysis. @@ -503,19 +511,14 @@ def generate_application_context( # Call LLM print(f"Generating context with {model}...", file=sys.stderr) - client = Anthropic() - response = client.messages.create( - model=model, + # AnthropicClient is the SDK-backed wrapper; routes through the Claude + # Agent SDK so this works with both API keys and local Claude Code sessions. + client = AnthropicClient(model=model) + response_text = client.analyze_sync( + CONTEXT_GENERATION_PROMPT.format(sources=sources_text), max_tokens=2000, - messages=[{ - "role": "user", - "content": CONTEXT_GENERATION_PROMPT.format(sources=sources_text) - }] ) - # Parse response - response_text = response.content[0].text - # Extract JSON from response json_match = re.search(r'```json\s*(.*?)\s*```', response_text, re.DOTALL) if json_match: diff --git a/libs/openant-core/context/generate_context.py b/libs/openant-core/context/generate_context.py index 78e21d3..47916f0 100644 --- a/libs/openant-core/context/generate_context.py +++ b/libs/openant-core/context/generate_context.py @@ -14,6 +14,7 @@ # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) +from utilities.model_config import MODEL_AUXILIARY from context.application_context import ( ApplicationType, APPLICATION_TYPE_INFO, @@ -78,8 +79,8 @@ def main(): parser.add_argument( "--model", "-m", - default="claude-sonnet-4-20250514", - help="Anthropic model to use (default: claude-sonnet-4-20250514)", + default=MODEL_AUXILIARY, + help=f"Anthropic model to use (default: {MODEL_AUXILIARY})", ) parser.add_argument( diff --git a/libs/openant-core/core/analyzer.py b/libs/openant-core/core/analyzer.py index 7fb5966..a6564ba 100644 --- a/libs/openant-core/core/analyzer.py +++ b/libs/openant-core/core/analyzer.py @@ -313,7 +313,8 @@ def run_analysis( checkpoint.dir = checkpoint_path # Select model - model_id = "claude-opus-4-6" if model == "opus" else "claude-sonnet-4-20250514" + from utilities.model_config import MODEL_AUXILIARY, MODEL_PRIMARY + model_id = MODEL_PRIMARY if model == "opus" else MODEL_AUXILIARY print(f"[Analyze] Model: {model_id}", file=sys.stderr) # Initialize client diff --git a/libs/openant-core/core/enhancer.py b/libs/openant-core/core/enhancer.py index fef1453..49e5d2d 100644 --- a/libs/openant-core/core/enhancer.py +++ b/libs/openant-core/core/enhancer.py @@ -50,7 +50,8 @@ def enhance_dataset( # Configure global rate limiter configure_rate_limiter(backoff_seconds=float(backoff_seconds)) - model_id = "claude-sonnet-4-20250514" if model == "sonnet" else "claude-opus-4-6" + from utilities.model_config import MODEL_AUXILIARY, MODEL_PRIMARY + model_id = MODEL_AUXILIARY if model == "sonnet" else MODEL_PRIMARY print(f"[Enhance] Mode: {mode}", file=sys.stderr) print(f"[Enhance] Model: {model_id}", file=sys.stderr) diff --git a/libs/openant-core/core/reporter.py b/libs/openant-core/core/reporter.py index 7153dab..6513647 100644 --- a/libs/openant-core/core/reporter.py +++ b/libs/openant-core/core/reporter.py @@ -587,11 +587,12 @@ def _record_usage_in_tracker(usage: dict): """Record usage in the global TokenTracker so step_context captures it.""" try: from utilities.llm_client import get_global_tracker + from utilities.model_config import MODEL_PRIMARY tracker = get_global_tracker() # Record as a single aggregated call if usage.get("total_tokens", 0) > 0: tracker.record_call( - model="claude-opus-4-6", + model=MODEL_PRIMARY, input_tokens=usage["input_tokens"], output_tokens=usage["output_tokens"], ) diff --git a/libs/openant-core/experiment.py b/libs/openant-core/experiment.py index 359d41f..69992ea 100644 --- a/libs/openant-core/experiment.py +++ b/libs/openant-core/experiment.py @@ -474,7 +474,8 @@ def run_experiment( Experiment results with metrics """ # Select model - model_id = "claude-opus-4-20250514" if model == "opus" else "claude-sonnet-4-20250514" + from utilities.model_config import MODEL_AUXILIARY, MODEL_PRIMARY + model_id = MODEL_PRIMARY if model == "opus" else MODEL_AUXILIARY print(f"Using model: {model_id}") print(f"Enhanced context: {enhanced}") print(f"Context correction: {correct_context}") diff --git a/libs/openant-core/generate_report.py b/libs/openant-core/generate_report.py index 633cd9b..83387fe 100644 --- a/libs/openant-core/generate_report.py +++ b/libs/openant-core/generate_report.py @@ -27,16 +27,26 @@ import json import html import os +import sys from datetime import datetime +from pathlib import Path -import anthropic from dotenv import load_dotenv +# Ensure libs/openant-core is on sys.path so `utilities.*` imports resolve +# regardless of the caller's working directory. +_OPENANT_CORE_ROOT = str(Path(__file__).parent) +if _OPENANT_CORE_ROOT not in sys.path: + sys.path.insert(0, _OPENANT_CORE_ROOT) + +from utilities.model_config import MODEL_AUXILIARY # noqa: E402 +from utilities.llm_client import AnthropicClient # noqa: E402 + # Load environment variables from .env file load_dotenv() -REPORT_MODEL = "claude-sonnet-4-20250514" +REPORT_MODEL = MODEL_AUXILIARY MAX_TOKENS = 4096 @@ -198,18 +208,10 @@ def generate_remediation_guidance(findings: list) -> str: {findings_text} """ - api_key = os.getenv("ANTHROPIC_API_KEY") - if not api_key: - raise ValueError("ANTHROPIC_API_KEY not found in environment") - - client = anthropic.Anthropic(api_key=api_key) - response = client.messages.create( - model=REPORT_MODEL, - max_tokens=MAX_TOKENS, - messages=[{"role": "user", "content": prompt}] - ) - - return response.content[0].text + # AnthropicClient is the SDK-backed wrapper; it handles auth (API key + # or local Claude Code session) and surfaces structured SDK errors. + client = AnthropicClient(model=REPORT_MODEL) + return client.analyze_sync(prompt, max_tokens=MAX_TOKENS) def _build_pipeline_costs_html(step_reports: list[dict]) -> str: diff --git a/libs/openant-core/openant/cli.py b/libs/openant-core/openant/cli.py index b0ce345..2cd7476 100644 --- a/libs/openant-core/openant/cli.py +++ b/libs/openant-core/openant/cli.py @@ -587,10 +587,9 @@ def cmd_report_data(args): and step reports — everything display-ready. """ import html as html_mod - import anthropic from core.schemas import success, error from core.step_report import step_context - from utilities.llm_client import get_global_tracker + from utilities.llm_client import AnthropicClient, get_global_tracker results_path = args.results dataset_path = args.dataset @@ -810,13 +809,10 @@ def cmd_report_data(args): {findings_text} """ print("[Report] Generating remediation guidance (LLM)...", file=sys.stderr) - client = anthropic.Anthropic() - response = client.messages.create( - model="claude-sonnet-4-20250514", - max_tokens=4096, - messages=[{"role": "user", "content": prompt}], - ) - remediation_html = response.content[0].text + from utilities.model_config import MODEL_AUXILIARY + # AnthropicClient handles usage tracking via the global TokenTracker. + remediation_client = AnthropicClient(model=MODEL_AUXILIARY) + remediation_html = remediation_client.analyze_sync(prompt, max_tokens=4096) # Post-process: linkify finding references like #4, #12-#14 import re @@ -825,15 +821,11 @@ def _linkify_finding(m): return f'#{num}' remediation_html = re.sub(r'#(\d+)', _linkify_finding, remediation_html) - # Track usage - usage = response.usage - tracker = get_global_tracker() - tracker.record_call( - model="claude-sonnet-4-20250514", - input_tokens=usage.input_tokens, - output_tokens=usage.output_tokens, + last = remediation_client.get_last_call() or {} + print( + f" Remediation cost: ${last.get('cost_usd', 0.0):.4f}", + file=sys.stderr, ) - print(f" Remediation cost: ${(usage.input_tokens / 1e6) * 3.0 + (usage.output_tokens / 1e6) * 15.0:.4f}", file=sys.stderr) # --- Step reports --- step_reports_data = [] diff --git a/libs/openant-core/prompts/verification_prompts.py b/libs/openant-core/prompts/verification_prompts.py index a0b1097..7d3c570 100644 --- a/libs/openant-core/prompts/verification_prompts.py +++ b/libs/openant-core/prompts/verification_prompts.py @@ -200,4 +200,181 @@ def get_phase1_exploitability_prompt(code, finding, attack_vector, files_include def get_phase2_verdict_prompt(exploitability_analysis, original_finding): return "" # Not used in new approach -import json + +# JSON schema for structured output from native Claude Agent SDK verification. +# Used by FindingVerifier.verify_result via run_native_verification, which +# requests structured_output matching this schema. +VERIFICATION_JSON_SCHEMA = { + "type": "object", + "properties": { + "agree": { + "type": "boolean", + "description": "Whether you agree with Stage 1's assessment" + }, + "correct_finding": { + "type": "string", + "enum": ["safe", "protected", "bypassable", "vulnerable", "inconclusive"], + "description": "The correct finding based on exploit path analysis" + }, + "exploit_path": { + "type": "object", + "description": "Analysis of the exploit path from attacker input to sink", + "properties": { + "entry_point": { + "type": ["string", "null"], + "description": "Where attacker input enters (null if none found)" + }, + "data_flow": { + "type": "array", + "items": {"type": "string"}, + "description": "Steps showing how data flows from entry to sink" + }, + "sink_reached": { + "type": "boolean", + "description": "Whether attacker-controlled data reaches the vulnerable operation" + }, + "attacker_control_at_sink": { + "type": "string", + "enum": ["full", "partial", "none"], + "description": "Level of attacker control at the dangerous operation" + }, + "path_broken_at": { + "type": ["string", "null"], + "description": "Where/why the exploit path breaks (null if complete)" + } + } + }, + "explanation": { + "type": "string", + "description": "Detailed explanation of your analysis" + }, + "security_weakness": { + "type": ["string", "null"], + "description": "Any dangerous patterns that exist but aren't currently exploitable" + } + }, + "required": ["agree", "correct_finding", "explanation"] +} + + +def get_native_claude_verification_prompt( + code: str, + finding: str, + attack_vector: str, + reasoning: str, + files_included: list = None, + app_context: "ApplicationContext" = None, +) -> str: + """ + Verification prompt for native Claude Agent SDK multi-turn mode. + + Instead of custom tools (search_usages, etc.), instructs Claude Code + to use its native Read/Grep/Glob tools to explore the codebase. + + Args: + code: The code being verified. + finding: The Stage 1 finding (vulnerable/safe/etc). + attack_vector: The claimed attack vector from Stage 1. + reasoning: The reasoning from Stage 1. + files_included: Optional list of files included in context. + app_context: Optional ApplicationContext for reducing false positives. + + Returns: + The formatted verification prompt. + """ + # Build application context section + app_context_section = "" + if app_context: + app_context_section = format_app_context_for_verification(app_context) + "\n---\n\n" + + # Mark the target function clearly + code_parts = code.split("// ========== File Boundary ==========") + if len(code_parts) > 1: + primary_code = code_parts[0].strip() + context_code = "\n// ========== File Boundary ==========".join(code_parts[1:]) + code_section = f""" +>>> TARGET FUNCTION <<< +``` +{primary_code} +``` + +Context: +``` +{context_code} +```""" + else: + code_section = f""" +>>> TARGET FUNCTION <<< +``` +{code} +```""" + + # Build files hint for exploration + files_hint = "" + if files_included: + files_hint = "\n**Files involved:** " + ", ".join(files_included[:10]) + "\n" + + # Adjust attacker description based on app context + if app_context and not app_context.requires_remote_trigger: + attacker_description = """You are an attacker on the internet. You have a browser and nothing else. +No server access, no admin credentials, no ability to modify files on the server, and NO ABILITY TO RUN CLI COMMANDS. + +You must find a way to trigger this vulnerability REMOTELY. If the only attack path requires: +- Running CLI commands locally +- Having shell access to the server +- Being the user who runs the application + +Then the vulnerability is NOT EXPLOITABLE by you, because local users can already do anything on their own machine.""" + else: + attacker_description = """You are an attacker on the internet. You have a browser and nothing else. No server access, no admin credentials, no ability to modify files on the server.""" + + return f"""{app_context_section}Stage 1 claims this function is **{finding.upper()}**. + +Their reasoning: {reasoning} + +Claimed attack vector: {attack_vector} + +{code_section} +{files_hint} +--- + +{attacker_description} + +## Your Task + +Verify whether this vulnerability is actually exploitable by exploring the codebase. + +1. **Read the target function** and understand what it does. +2. **Trace the data flow** — use Grep and Read to find: + - Where attacker input enters the application (entry points, route handlers) + - How data flows from entry point to the target function + - Whether any validation, sanitization, or auth checks exist along the path +3. **Check for security controls** — search for middleware, guards, validators that might block the attack. +4. **Try multiple attack approaches** — think about different inputs, properties, and entry points. + +For EACH approach, trace through step by step until you succeed or hit a blocker. + +IMPORTANT: +- Only conclude PROTECTED or SAFE if ALL approaches fail. If ANY approach succeeds, conclude VULNERABLE. +- A vulnerability must harm someone OTHER than the attacker. +- If this is a CLI tool/library and the attack requires local access, it is NOT a vulnerability. + +After your analysis, you MUST output your final verdict as a single JSON object and nothing else. Do not include any markdown, explanation, or text outside the JSON. The JSON must match this exact structure: + +```json +{{ + "agree": true/false, + "correct_finding": "safe" | "protected" | "bypassable" | "vulnerable" | "inconclusive", + "exploit_path": {{ + "entry_point": "where attacker input enters, or null", + "data_flow": ["step1", "step2", "..."], + "sink_reached": true/false, + "attacker_control_at_sink": "full" | "partial" | "none", + "path_broken_at": "where/why the exploit path breaks, or null" + }}, + "explanation": "detailed reasoning", + "security_weakness": "any dangerous patterns not currently exploitable, or null" +}} +``` + +Your FINAL output must be ONLY this JSON object — no surrounding text or markdown.""" diff --git a/libs/openant-core/pyproject.toml b/libs/openant-core/pyproject.toml index 266e7db..0ca83a5 100644 --- a/libs/openant-core/pyproject.toml +++ b/libs/openant-core/pyproject.toml @@ -5,7 +5,7 @@ description = "Two-stage SAST tool using Claude for vulnerability analysis" readme = "README.md" requires-python = ">=3.11" dependencies = [ - "anthropic>=0.40.0", + "claude-agent-sdk>=0.1.48", "python-dotenv>=1.0.0", "pydantic>=2.0.0", "httpx>=0.24.0", @@ -16,6 +16,7 @@ dependencies = [ "tree-sitter-cpp>=0.21.0", "tree-sitter-ruby>=0.21.0", "tree-sitter-php>=0.22.0", + "tree-sitter-zig>=0.20.0", ] [project.optional-dependencies] diff --git a/libs/openant-core/report/generator.py b/libs/openant-core/report/generator.py index c996250..54a7dc1 100644 --- a/libs/openant-core/report/generator.py +++ b/libs/openant-core/report/generator.py @@ -8,37 +8,40 @@ import os import re import sys -import anthropic from pathlib import Path from dotenv import load_dotenv +# Ensure libs/openant-core is on sys.path so `utilities.*` imports resolve. +_OPENANT_CORE_ROOT = str(Path(__file__).parent.parent) +if _OPENANT_CORE_ROOT not in sys.path: + sys.path.insert(0, _OPENANT_CORE_ROOT) + +from utilities.model_config import MODEL_PRIMARY, MODEL_AUXILIARY # noqa: E402 + from .schema import validate_pipeline_output, ValidationError +from utilities.llm_client import AnthropicClient load_dotenv() PROMPTS_DIR = Path(__file__).parent / "prompts" -MODEL = "claude-opus-4-6" - -# Pricing per million tokens -_PRICING = { - "claude-opus-4-6": {"input": 15.00, "output": 75.00}, - "claude-opus-4-20250514": {"input": 15.00, "output": 75.00}, - "claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00}, -} -_DEFAULT_PRICING = {"input": 3.00, "output": 15.00} - - -def _extract_usage(response, model: str = MODEL) -> dict: - """Extract usage info from an Anthropic API response.""" - usage = response.usage - pricing = _PRICING.get(model, _DEFAULT_PRICING) - input_cost = (usage.input_tokens / 1_000_000) * pricing["input"] - output_cost = (usage.output_tokens / 1_000_000) * pricing["output"] +MODEL = MODEL_PRIMARY + +def _usage_from_last_call(last_call: dict | None) -> dict: + """Adapt TokenTracker.record_call output to the report's usage shape. + + TokenTracker returns {model, input_tokens, output_tokens, cost_usd}; + callers here want {input_tokens, output_tokens, total_tokens, cost_usd}. + Returns zeros when last_call is None (e.g. if the SDK didn't surface usage). + """ + if not last_call: + return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0, "cost_usd": 0.0} + it = last_call.get("input_tokens", 0) + ot = last_call.get("output_tokens", 0) return { - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, - "total_tokens": usage.input_tokens + usage.output_tokens, - "cost_usd": round(input_cost + output_cost, 6), + "input_tokens": it, + "output_tokens": ot, + "total_tokens": it + ot, + "cost_usd": last_call.get("cost_usd", 0.0), } @@ -54,10 +57,20 @@ def _merge_usage(usages: list[dict]) -> dict: def _check_api_key(): - """Check that ANTHROPIC_API_KEY is set.""" + """Check that we have a viable auth path for the Claude Agent SDK. + + Two modes are supported (matches utilities.llm_client._build_env): + - OPENANT_LOCAL_CLAUDE=true: SDK uses the local Claude Code session + (no API key needed). + - Otherwise: ANTHROPIC_API_KEY must be set. + """ + local_mode = os.environ.get("OPENANT_LOCAL_CLAUDE", "").lower() == "true" + if local_mode: + return if not os.environ.get("ANTHROPIC_API_KEY"): print("Error: ANTHROPIC_API_KEY environment variable not set.", file=sys.stderr) print("Set it with: export ANTHROPIC_API_KEY=sk-ant-...", file=sys.stderr) + print("Or use OPENANT_LOCAL_CLAUDE=true to use the local Claude Code session.", file=sys.stderr) sys.exit(1) @@ -136,20 +149,14 @@ def generate_summary_report(pipeline_data: dict) -> tuple[str, dict]: output_tokens, total_tokens, cost_usd. """ _check_api_key() - client = anthropic.Anthropic() + client = AnthropicClient(model=MODEL) summary_data = _compact_for_summary(pipeline_data) system_prompt = load_prompt("system") user_prompt = load_prompt("summary").replace("{pipeline_data}", json.dumps(summary_data, indent=2)) - response = client.messages.create( - model=MODEL, - max_tokens=4096, - system=system_prompt, - messages=[{"role": "user", "content": user_prompt}] - ) - - return response.content[0].text, _extract_usage(response) + text = client.analyze_sync(user_prompt, system=system_prompt, max_tokens=4096) + return text, _usage_from_last_call(client.get_last_call()) def _splice_code_section(llm_output: str, code_section: str) -> str: @@ -199,7 +206,7 @@ def generate_disclosure(vulnerability_data: dict, product_name: str) -> tuple[st (disclosure_text, usage_dict) """ _check_api_key() - client = anthropic.Anthropic() + client = AnthropicClient(model=MODEL) system_prompt = load_prompt("system") @@ -218,17 +225,10 @@ def generate_disclosure(vulnerability_data: dict, product_name: str) -> tuple[st .replace("{vulnerability_data}", json.dumps(payload, indent=2), 1) ) - response = client.messages.create( - model=MODEL, - max_tokens=4096, - system=system_prompt, - messages=[{"role": "user", "content": user_prompt}] - ) - - llm_output = response.content[0].text + llm_output = client.analyze_sync(user_prompt, system=system_prompt, max_tokens=4096) final_output = _splice_code_section(llm_output, code_section) - return final_output, _extract_usage(response) + return final_output, _usage_from_last_call(client.get_last_call()) def generate_all(pipeline_path: str, output_dir: str) -> None: diff --git a/libs/openant-core/requirements.txt b/libs/openant-core/requirements.txt deleted file mode 100644 index 966904a..0000000 --- a/libs/openant-core/requirements.txt +++ /dev/null @@ -1,24 +0,0 @@ -annotated-types==0.7.0 -anthropic==0.75.0 -anyio==4.12.0 -certifi==2025.11.12 -distro==1.9.0 -docstring_parser==0.17.0 -h11==0.16.0 -httpcore==1.0.9 -httpx==0.28.1 -idna==3.11 -jiter==0.12.0 -pydantic==2.12.5 -pydantic_core==2.41.5 -python-dotenv==1.2.1 -sniffio==1.3.1 -typing-inspection==0.4.2 -typing_extensions==4.15.0 -PyYAML>=6.0 -requests>=2.31.0 -tree-sitter>=0.21.0 -tree-sitter-c>=0.21.0 -tree-sitter-cpp>=0.21.0 -tree-sitter-ruby>=0.21.0 -tree-sitter-php>=0.22.0 diff --git a/libs/openant-core/tests/report/test_disclosure_source_fidelity.py b/libs/openant-core/tests/report/test_disclosure_source_fidelity.py index 462f958..e20790a 100644 --- a/libs/openant-core/tests/report/test_disclosure_source_fidelity.py +++ b/libs/openant-core/tests/report/test_disclosure_source_fidelity.py @@ -21,16 +21,10 @@ _CORE_ROOT = Path(__file__).resolve().parents[2] sys.path.insert(0, str(_CORE_ROOT)) -# The project's venv has a broken `anthropic` install (ErrorObject import fails -# in some sub-dependency). Stub it before `report.generator` is imported so the -# test suite can run without touching the venv. Real API calls are never made -# in this file — all disclosure generation is mocked. -if "anthropic" not in sys.modules: - stub = types.ModuleType("anthropic") - stub.Anthropic = MagicMock() - stub.RateLimitError = type("RateLimitError", (Exception,), {}) - stub.AuthenticationError = type("AuthenticationError", (Exception,), {}) - sys.modules["anthropic"] = stub +# Real API calls are never made in this file — all disclosure generation is +# mocked. Following the SDK migration, generator.py routes through +# AnthropicClient (which wraps the Claude Agent SDK), so we patch +# AnthropicClient.analyze_sync rather than the legacy `anthropic.Anthropic`. from core import reporter # noqa: E402 from report import generator # noqa: E402 @@ -280,36 +274,37 @@ def test_splice_preserves_other_sections(): # the real code, even when the LLM returns fabricated code. # --------------------------------------------------------------------------- -class _FakeAnthropic: - """Replacement for anthropic.Anthropic — returns fabricated code to prove - the post-processor catches it.""" +class _FakeAnthropicClient: + """Replacement for ``utilities.llm_client.AnthropicClient`` — returns + fabricated code to prove the post-processor catches it. - def __init__(self, *args, **kwargs): - self.messages = self - - def create(self, **kwargs): - _FakeAnthropic.last_prompt = kwargs["messages"][0]["content"] - # Return a disclosure WITH fabricated code — the post-processor must fix it. - return _FakeResponse() + Mirrors the new SDK-backed surface (``analyze_sync``, ``get_last_call``) + rather than the legacy ``anthropic.Anthropic`` ``messages.create`` shape. + """ + last_prompt = None -class _FakeResponse: - class _Content: - text = LLM_OUTPUT_WITH_FABRICATED_CODE - - content = [_Content()] + def __init__(self, *args, **kwargs): + pass - class _Usage: - input_tokens = 10 - output_tokens = 50 + def analyze_sync(self, prompt, max_tokens=8192, model=None, system=None): + _FakeAnthropicClient.last_prompt = prompt + # Return a disclosure WITH fabricated code — the post-processor must fix it. + return LLM_OUTPUT_WITH_FABRICATED_CODE - usage = _Usage() + def get_last_call(self): + return { + "model": "fake", + "input_tokens": 10, + "output_tokens": 50, + "cost_usd": 0.0, + } @pytest.fixture def patched_anthropic(monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test-key") - monkeypatch.setattr(generator.anthropic, "Anthropic", _FakeAnthropic) + monkeypatch.setattr(generator, "AnthropicClient", _FakeAnthropicClient) def test_generate_disclosure_output_has_real_code(patched_anthropic, pipeline_output): @@ -341,7 +336,7 @@ def test_generate_disclosure_prompt_has_no_source_code(patched_anthropic, pipeli ) generator.generate_disclosure(ping, product_name="fixture") - prompt = _FakeAnthropic.last_prompt + prompt = _FakeAnthropicClient.last_prompt # The actual source code must not appear in the prompt. assert "subprocess.check_output" not in prompt, ( diff --git a/libs/openant-core/tests/test_declared_dependencies.py b/libs/openant-core/tests/test_declared_dependencies.py new file mode 100644 index 0000000..20ead4e --- /dev/null +++ b/libs/openant-core/tests/test_declared_dependencies.py @@ -0,0 +1,128 @@ +"""Guard against pyproject.toml declared deps drifting from actual imports. + +Regression guard for the Claude Agent SDK migration (#25), which dropped +`anthropic` from pyproject.toml while leaving `import anthropic` live in +four files. Every clean install of openant broke at `openant parse`. +""" +import ast +import sys +import tomllib +from pathlib import Path + +import pytest + +PROJECT_ROOT = Path(__file__).parent.parent +PACKAGED_DIRS = ["openant", "core", "utilities", "parsers", "prompts", "context", "report"] + +# Maps PyPI distribution names to their top-level import names when they differ. +# Extend only when adding a new dependency whose import name diverges from its +# PyPI name; the test will tell you which direction it's failing. +DIST_TO_IMPORT = { + "python-dotenv": "dotenv", + "pyyaml": "yaml", + "claude-agent-sdk": "claude_agent_sdk", + "tree-sitter": "tree_sitter", + "tree-sitter-c": "tree_sitter_c", + "tree-sitter-cpp": "tree_sitter_cpp", + "tree-sitter-ruby": "tree_sitter_ruby", + "tree-sitter-php": "tree_sitter_php", + "tree-sitter-zig": "tree_sitter_zig", +} + + +def _dist_name_to_import(dist: str) -> str: + key = dist.lower().replace("_", "-") + return DIST_TO_IMPORT.get(key, dist.replace("-", "_").lower()) + + +def _declared_imports() -> set[str]: + with open(PROJECT_ROOT / "pyproject.toml", "rb") as f: + data = tomllib.load(f) + deps = data["project"]["dependencies"] + names = [] + for dep in deps: + for sep in ("[", ">=", "<=", "==", "!=", ">", "<", "~=", ";", " "): + dep = dep.split(sep, 1)[0] + names.append(dep.strip()) + return {_dist_name_to_import(n) for n in names if n} + + +def _collect_top_level_imports(root: Path) -> set[str]: + """Return the set of top-level module names imported anywhere under `root`.""" + imports: set[str] = set() + for py in root.rglob("*.py"): + try: + tree = ast.parse(py.read_text(encoding="utf-8")) + except (SyntaxError, UnicodeDecodeError): + continue + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + imports.add(alias.name.split(".", 1)[0]) + elif isinstance(node, ast.ImportFrom): + # Relative imports (level > 0) have module=None or point at a + # sibling — they can't be third-party by definition. + if node.level == 0 and node.module: + imports.add(node.module.split(".", 1)[0]) + return imports + + +def _first_party_names() -> set[str]: + """Every module/package name reachable in the repo — treated as first-party. + + Parsers use sys.path hackery to import siblings as top-level names + (e.g. `from call_graph_builder import ...`), so the set of first-party + names isn't just the packaged top-level dirs. + """ + names: set[str] = set(PACKAGED_DIRS) + for path in PROJECT_ROOT.rglob("*.py"): + # Skip the managed dev venv and any other nested virtualenvs. + if ".venv" in path.parts or "site-packages" in path.parts: + continue + names.add(path.stem) + for parent in path.parents: + if parent == PROJECT_ROOT: + break + names.add(parent.name) + return names + + +def test_every_third_party_import_is_declared(): + first_party = _first_party_names() + stdlib = set(sys.stdlib_module_names) + declared = _declared_imports() + + all_imports: set[str] = set() + for pkg in PACKAGED_DIRS: + pkg_dir = PROJECT_ROOT / pkg + if pkg_dir.is_dir(): + all_imports |= _collect_top_level_imports(pkg_dir) + + # Deps pulled in transitively that some callsites import by name. These + # aren't direct deps of openant but are guaranteed present by something + # we *do* declare, so it's safe to treat them as allowed. + transitive_allowed = { + # pulled in by claude-agent-sdk + "mcp", + } + + third_party = all_imports - first_party - stdlib - transitive_allowed + missing = sorted(third_party - declared) + assert not missing, ( + f"Imports not declared in pyproject.toml dependencies: {missing}. " + "Either add the distribution to `dependencies`, or remove the import. " + "If a distribution's import name differs from its PyPI name, add it to " + "DIST_TO_IMPORT in this test." + ) + + +@pytest.mark.parametrize("pkg", PACKAGED_DIRS) +def test_package_imports_cleanly(pkg): + """Smoke-test: every packaged top-level module can be imported. + + This catches the specific failure mode from #25 — where a dropped dep + only manifested at `import utilities` time, not at `import openant`. + """ + if not (PROJECT_ROOT / pkg).is_dir(): + pytest.skip(f"{pkg} not present") + __import__(pkg) diff --git a/libs/openant-core/tests/test_local_claude.py b/libs/openant-core/tests/test_local_claude.py new file mode 100644 index 0000000..1ca9155 --- /dev/null +++ b/libs/openant-core/tests/test_local_claude.py @@ -0,0 +1,457 @@ +"""Tests for SDK-backed LLM client.""" +import json +import os +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + +from utilities.llm_client import ( + _build_env, + _build_options, + _run_query_sync, + run_native_verification, + AnthropicClient, + TokenTracker, +) +from utilities.model_config import MODEL_PRIMARY, MODEL_AUXILIARY + + +# --------------------------------------------------------------------------- +# Helpers to build mock SDK messages +# --------------------------------------------------------------------------- + +def _make_result_message( + result="The answer is 42.", + input_tokens=150, + output_tokens=25, + cost=0.003, + structured_output=None, +): + """Create a mock ResultMessage matching the SDK's interface.""" + msg = MagicMock() + msg.result = result + msg.usage = {"input_tokens": input_tokens, "output_tokens": output_tokens} + msg.total_cost_usd = cost + msg.structured_output = structured_output + return msg + + +# --------------------------------------------------------------------------- +# _build_env() +# --------------------------------------------------------------------------- + +class TestBuildEnv: + def test_includes_api_key_when_set(self): + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-test", "OPENANT_LOCAL_CLAUDE": ""}, clear=False): + env = _build_env() + assert env["ANTHROPIC_API_KEY"] == "sk-test" + + def test_no_api_key_in_local_mode(self): + """Local mode should not pass API key even if set.""" + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-test", "OPENANT_LOCAL_CLAUDE": "true"}, clear=False): + env = _build_env() + assert "ANTHROPIC_API_KEY" not in env + + def test_no_api_key_when_not_set(self): + env_without_key = {k: v for k, v in os.environ.items() if k != "ANTHROPIC_API_KEY"} + with patch.dict(os.environ, env_without_key, clear=True): + env = _build_env() + assert "ANTHROPIC_API_KEY" not in env + + def test_includes_config_dir_when_set(self): + with patch.dict(os.environ, {"CLAUDE_CONFIG_DIR": "/home/.claude-k"}, clear=False): + env = _build_env() + assert env["CLAUDE_CONFIG_DIR"] == "/home/.claude-k" + + def test_clears_claudecode(self): + with patch.dict(os.environ, {"CLAUDECODE": "1"}, clear=False): + env = _build_env() + assert env["CLAUDECODE"] == "" + + +# --------------------------------------------------------------------------- +# _build_options() +# --------------------------------------------------------------------------- + +class TestBuildOptions: + def test_sets_model(self): + opts = _build_options(MODEL_PRIMARY) + assert opts.model == MODEL_PRIMARY + + def test_sets_system_prompt(self): + opts = _build_options("m", system="You are a security analyst.") + assert opts.system_prompt == "You are a security analyst." + + def test_sets_max_turns(self): + opts = _build_options("m", max_turns=5) + assert opts.max_turns == 5 + + def test_sets_allowed_tools(self): + opts = _build_options("m", allowed_tools=["Read", "Grep"]) + assert opts.allowed_tools == ["Read", "Grep"] + + def test_default_allowed_tools_empty(self): + opts = _build_options("m") + assert opts.allowed_tools == [] + + def test_sets_permission_mode(self): + opts = _build_options("m") + assert opts.permission_mode == "bypassPermissions" + + def test_passes_extra_kwargs(self): + opts = _build_options("m", add_dirs=["/tmp/repo"], max_budget_usd=1.0) + assert opts.add_dirs == ["/tmp/repo"] + assert opts.max_budget_usd == 1.0 + + +# --------------------------------------------------------------------------- +# AnthropicClient — token tracking +# --------------------------------------------------------------------------- + +class TestAnthropicClientTokenTracking: + @patch("utilities.llm_client._run_query_sync") + def test_tracks_tokens(self, mock_query): + mock_query.return_value = ( + _make_result_message(result="ok", input_tokens=200, output_tokens=100, cost=0.01), + "ok", + ) + + tracker = TokenTracker() + client = AnthropicClient(model=MODEL_AUXILIARY, tracker=tracker) + result = client.analyze_sync("test") + + assert result == "ok" + assert tracker.total_input_tokens == 200 + assert tracker.total_output_tokens == 100 + assert tracker.total_cost_usd == 0.01 + assert len(tracker.calls) == 1 + + @patch("utilities.llm_client._run_query_sync") + def test_tracks_sdk_cost(self, mock_query): + """Uses SDK-reported cost, not pricing table estimate.""" + mock_query.return_value = ( + _make_result_message(result="ok", input_tokens=100, output_tokens=50, cost=0.0042), + "ok", + ) + + tracker = TokenTracker() + client = AnthropicClient(model=MODEL_AUXILIARY, tracker=tracker) + client.analyze_sync("test") + + assert tracker.total_cost_usd == 0.0042 + + @patch("utilities.llm_client._run_query_sync") + def test_last_call_populated(self, mock_query): + mock_query.return_value = ( + _make_result_message(result="ok", input_tokens=50, output_tokens=20, cost=0.001), + "ok", + ) + + client = AnthropicClient(model=MODEL_AUXILIARY) + client.analyze_sync("test") + + last = client.get_last_call() + assert last is not None + assert last["input_tokens"] == 50 + assert last["output_tokens"] == 20 + + @patch("utilities.llm_client._run_query_sync") + def test_falls_back_to_assistant_text(self, mock_query): + """When ResultMessage.result is None, use last AssistantMessage text.""" + mock_query.return_value = ( + _make_result_message(result=None), + "Fallback text", + ) + + client = AnthropicClient(model=MODEL_AUXILIARY) + result = client.analyze_sync("test") + + assert result == "Fallback text" + + +# --------------------------------------------------------------------------- +# run_native_verification() — SDK-backed +# --------------------------------------------------------------------------- + +class TestRunNativeVerification: + @patch("utilities.llm_client._run_query_sync") + def test_returns_result_text(self, mock_query): + mock_query.return_value = ( + _make_result_message( + result='{"agree": false, "correct_finding": "safe"}', + input_tokens=5000, + output_tokens=2000, + cost=1.25, + ), + "", + ) + + result = run_native_verification( + prompt="Verify this", system="sys", model=MODEL_PRIMARY, + repo_path="/tmp/repo", + ) + + assert result["text"] == '{"agree": false, "correct_finding": "safe"}' + assert result["input_tokens"] == 5000 + assert result["output_tokens"] == 2000 + assert result["cost_usd"] == 1.25 + + @patch("utilities.llm_client._run_query_sync") + def test_uses_structured_output_when_available(self, mock_query): + structured = {"agree": False, "correct_finding": "safe", "explanation": "sanitized"} + mock_query.return_value = ( + _make_result_message( + result="some text", + structured_output=structured, + ), + "", + ) + + result = run_native_verification( + prompt="test", system="sys", model="m", repo_path="/tmp/repo", + json_schema={"type": "object"}, + ) + + assert json.loads(result["text"]) == structured + + @patch("utilities.llm_client._run_query_sync") + def test_passes_correct_options(self, mock_query): + mock_query.return_value = (_make_result_message(), "") + + run_native_verification( + prompt="Verify this finding", + system="You are a pentester.", + model=MODEL_PRIMARY, + repo_path="/tmp/target-repo", + max_budget_usd=5.0, + ) + + prompt, options = mock_query.call_args[0] + assert prompt == "Verify this finding" + assert options.model == MODEL_PRIMARY + assert options.system_prompt == "You are a pentester." + assert options.max_turns is None # Multi-turn + assert "Read" in options.allowed_tools + assert "Grep" in options.allowed_tools + assert options.max_budget_usd == 5.0 + assert options.permission_mode == "bypassPermissions" + + @patch("utilities.llm_client._run_query_sync") + def test_sets_output_format_with_json_schema(self, mock_query): + mock_query.return_value = (_make_result_message(), "") + schema = {"type": "object", "properties": {"agree": {"type": "boolean"}}} + + run_native_verification( + prompt="test", system="sys", model="m", repo_path="/tmp/repo", + json_schema=schema, + ) + + options = mock_query.call_args[0][1] + assert options.output_format == {"type": "json_schema", "schema": schema} + + @patch("utilities.llm_client._run_query_sync") + def test_no_output_format_without_schema(self, mock_query): + mock_query.return_value = (_make_result_message(), "") + + run_native_verification( + prompt="test", system="sys", model="m", repo_path="/tmp/repo", + ) + + options = mock_query.call_args[0][1] + assert options.output_format is None + + @patch("utilities.llm_client._run_query_sync") + def test_raises_when_no_result_message(self, mock_query): + mock_query.return_value = (None, "some text") + + with pytest.raises(RuntimeError, match="no ResultMessage"): + run_native_verification( + prompt="test", system="sys", model="m", repo_path="/tmp/repo", + ) + + @patch("utilities.llm_client._run_query_sync") + def test_falls_back_to_last_text_when_result_none(self, mock_query): + mock_query.return_value = ( + _make_result_message(result=None), + "fallback text from assistant", + ) + + result = run_native_verification( + prompt="test", system="sys", model="m", repo_path="/tmp/repo", + ) + + assert result["text"] == "fallback text from assistant" + + +# --------------------------------------------------------------------------- +# FindingVerifier with SDK-backed verification +# --------------------------------------------------------------------------- + +class TestVerifyWithNativeClaude: + """Tests for FindingVerifier.verify_result against the SDK-native path. + + Mocks `utilities.finding_verifier.run_native_verification` (the wrapper + around the Claude Agent SDK) so we can exercise the parse/verdict paths + without spawning a real SDK subprocess. + """ + + def _make_verifier(self, repo_path="/tmp/repo"): + from utilities.finding_verifier import FindingVerifier + + index = MagicMock() + index.repo_path = Path(repo_path) if repo_path else None + tracker = TokenTracker() + verifier = FindingVerifier(index=index, tracker=tracker) + return verifier, tracker + + @patch("utilities.finding_verifier.run_native_verification") + def test_structured_json_returns_verdict(self, mock_run): + verifier, tracker = self._make_verifier() + mock_run.return_value = { + "text": json.dumps({ + "agree": False, + "correct_finding": "safe", + "explanation": "Sanitized.", + }), + "input_tokens": 5000, + "output_tokens": 2000, + "cost_usd": 0.12, + } + + result = verifier.verify_result( + code="function upload(req) { fs.writeFile(req.body.name, data); }", + finding="vulnerable", + attack_vector="path traversal via filename", + reasoning="unsanitized filename in writeFile", + ) + + assert result.correct_finding == "safe" + assert result.agree is False + assert tracker.total_input_tokens == 5000 + assert tracker.total_output_tokens == 2000 + + @patch("utilities.finding_verifier.run_native_verification") + def test_json_in_code_block_is_parsed(self, mock_run): + verifier, _tracker = self._make_verifier() + verdict_json = json.dumps({ + "agree": False, + "correct_finding": "protected", + "explanation": "Auth check prevents exploitation.", + }) + mock_run.return_value = { + "text": f"Here's my analysis:\n\n```json\n{verdict_json}\n```", + "input_tokens": 500, + "output_tokens": 200, + "cost_usd": 0.01, + } + + result = verifier.verify_result( + code="code", finding="vulnerable", + attack_vector="test", reasoning="test", + ) + + assert result.correct_finding == "protected" + assert result.agree is False + + @patch("utilities.finding_verifier.run_native_verification") + def test_freetext_verdict_fallback(self, mock_run): + verifier, _tracker = self._make_verifier() + mock_run.return_value = { + "text": "After tracing every input path the verdict is: **PROTECTED** " + "because the middleware validates the token.", + "input_tokens": 200, + "output_tokens": 50, + "cost_usd": 0.005, + } + + result = verifier.verify_result( + code="code", finding="vulnerable", + attack_vector="test", reasoning="test", + ) + + assert result.correct_finding == "protected" + + @patch("utilities.finding_verifier.run_native_verification") + def test_unparseable_text_returns_agree(self, mock_run): + verifier, _tracker = self._make_verifier() + mock_run.return_value = { + "text": "not json at all", + "input_tokens": 10, + "output_tokens": 5, + "cost_usd": 0.0001, + } + + result = verifier.verify_result( + code="code", finding="vulnerable", + attack_vector="test", reasoning="test", + ) + + assert result.agree is True + assert result.correct_finding == "vulnerable" + + @patch("utilities.finding_verifier.run_native_verification") + def test_sdk_failure_returns_conservative_agree(self, mock_run): + """Process-level SDK failures (CLI missing, subprocess died) are + absorbed into a conservative 'agree' verdict so the pipeline doesn't + abort on a single bad finding.""" + verifier, _tracker = self._make_verifier() + mock_run.side_effect = RuntimeError("SDK subprocess died") + + result = verifier.verify_result( + code="code", finding="bypassable", + attack_vector="test", reasoning="test", + ) + + assert result.agree is True + assert result.correct_finding == "bypassable" + assert "SDK subprocess died" in result.explanation + + def test_missing_repo_path_returns_agree(self): + """Without a repo path we can't drive the native SDK call — return a + conservative 'agree' verdict and log a warning.""" + from utilities.finding_verifier import FindingVerifier + + index = MagicMock() + index.repo_path = None + verifier = FindingVerifier(index=index, tracker=TokenTracker()) + + result = verifier.verify_result( + code="code", finding="vulnerable", + attack_vector="test", reasoning="test", + ) + + assert result.agree is True + assert result.correct_finding == "vulnerable" + assert "repo_path" in result.explanation.lower() + + +# --------------------------------------------------------------------------- +# TokenTracker +# --------------------------------------------------------------------------- + +class TestTokenTrackerRestoreFrom: + def test_restore_from_checkpoint(self): + tracker = TokenTracker() + tracker.restore_from({ + "total_input_tokens": 1000, + "total_output_tokens": 500, + "total_cost_usd": 0.05, + }) + + assert tracker.total_input_tokens == 1000 + assert tracker.total_output_tokens == 500 + assert tracker.total_cost_usd == 0.05 + + def test_restore_then_record_accumulates(self): + tracker = TokenTracker() + tracker.restore_from({ + "total_input_tokens": 1000, + "total_output_tokens": 500, + "total_cost_usd": 0.05, + }) + tracker.record_call(MODEL_AUXILIARY, 200, 100, cost_usd=0.01) + + assert tracker.total_input_tokens == 1200 + assert tracker.total_output_tokens == 600 + assert abs(tracker.total_cost_usd - 0.06) < 1e-9 diff --git a/libs/openant-core/tests/test_model_config.py b/libs/openant-core/tests/test_model_config.py new file mode 100644 index 0000000..bab1970 --- /dev/null +++ b/libs/openant-core/tests/test_model_config.py @@ -0,0 +1,93 @@ +"""Tests for the central model configuration module.""" +import re +from pathlib import Path + +import pytest + +from utilities import model_config +from utilities.model_config import MODEL_AUXILIARY, MODEL_DEFAULT, MODEL_PRIMARY + + +# Regex for a valid Claude model identifier (e.g. claude-opus-4-20250514, +# claude-sonnet-4-6, claude-haiku-4-5). +_MODEL_ID_RE = re.compile(r"^claude-(opus|sonnet|haiku)-[0-9A-Za-z-]+$") + +# Regex used by the regression test to detect any hardcoded +# claude-opus-* / claude-sonnet-* string literal. +_HARDCODED_LITERAL_RE = re.compile(r"claude-(?:opus|sonnet)-[0-9][0-9A-Za-z-]*") + + +class TestModelConstants: + """Constants must exist, be non-empty strings, and match Claude model id format.""" + + def test_model_primary_is_valid_string(self): + assert isinstance(MODEL_PRIMARY, str) + assert MODEL_PRIMARY, "MODEL_PRIMARY must be non-empty" + assert _MODEL_ID_RE.match(MODEL_PRIMARY), ( + f"MODEL_PRIMARY={MODEL_PRIMARY!r} does not match expected " + f"claude-(opus|sonnet|haiku)-... format" + ) + + def test_model_auxiliary_is_valid_string(self): + assert isinstance(MODEL_AUXILIARY, str) + assert MODEL_AUXILIARY, "MODEL_AUXILIARY must be non-empty" + assert _MODEL_ID_RE.match(MODEL_AUXILIARY), ( + f"MODEL_AUXILIARY={MODEL_AUXILIARY!r} does not match expected " + f"claude-(opus|sonnet|haiku)-... format" + ) + + def test_model_default_is_valid_string(self): + assert isinstance(MODEL_DEFAULT, str) + assert MODEL_DEFAULT, "MODEL_DEFAULT must be non-empty" + assert _MODEL_ID_RE.match(MODEL_DEFAULT) + + def test_module_exposes_all_three_constants(self): + for name in ("MODEL_PRIMARY", "MODEL_AUXILIARY", "MODEL_DEFAULT"): + assert hasattr(model_config, name), f"model_config missing {name}" + + +class TestNoHardcodedModelLiterals: + """Regression test: no hardcoded claude-opus-*/claude-sonnet-* literals + may reappear in libs/openant-core/*.py outside of model_config.py. + + If this test fails, replace the offending literal with an import of + MODEL_PRIMARY / MODEL_AUXILIARY / MODEL_DEFAULT from utilities.model_config. + """ + + # Path to libs/openant-core (this file is at libs/openant-core/tests/...) + _CORE_ROOT = Path(__file__).parent.parent + + # Files exempt from the scan (the constants live here, by design) + _EXEMPT = { + _CORE_ROOT / "utilities" / "model_config.py", + # The regression test itself contains the regex pattern source. + Path(__file__).resolve(), + } + + def test_no_hardcoded_model_strings_outside_model_config(self): + offenders: list[tuple[Path, int, str]] = [] + + for py_path in self._CORE_ROOT.rglob("*.py"): + resolved = py_path.resolve() + if resolved in {p.resolve() for p in self._EXEMPT}: + continue + + try: + text = py_path.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError): + continue + + for lineno, line in enumerate(text.splitlines(), start=1): + if _HARDCODED_LITERAL_RE.search(line): + offenders.append((py_path, lineno, line.strip())) + + if offenders: + details = "\n".join( + f" {path.relative_to(self._CORE_ROOT)}:{lineno}: {snippet}" + for path, lineno, snippet in offenders + ) + pytest.fail( + "Found hardcoded claude-opus-*/claude-sonnet-* literals outside " + "utilities/model_config.py. Replace them with imports from " + "utilities.model_config:\n" + details + ) diff --git a/libs/openant-core/tests/test_sdk_error_surfacing.py b/libs/openant-core/tests/test_sdk_error_surfacing.py new file mode 100644 index 0000000..62d691e --- /dev/null +++ b/libs/openant-core/tests/test_sdk_error_surfacing.py @@ -0,0 +1,159 @@ +"""Tests for AssistantMessage.error -> sdk_errors mapping in _run_query. + +The hot path (_run_query) requires an actual ClaudeSDKClient subprocess, so we +can't unit-test it end-to-end without live API access. What we CAN test is the +error-classification and rate-limiter-notification logic in isolation: mock out +the message stream and assert the right exception propagates and the right +rate-limiter call fires. +""" +import asyncio +from unittest.mock import MagicMock, patch + +import pytest +from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock + + +@pytest.fixture(autouse=True) +def reset_rate_limiter(): + """Each test gets a clean rate-limiter singleton. + + The module-level singleton is held in `_rate_limiter` (not `_instance`, + which lives on the GlobalRateLimiter class). We use the public + `reset_rate_limiter()` API plus null out the module-level handle so the + next get_rate_limiter() call constructs a fresh instance — important for + tests that mock report_rate_limit on the singleton. + """ + from utilities import rate_limiter + rate_limiter.reset_rate_limiter() + rate_limiter._rate_limiter = None + rate_limiter.GlobalRateLimiter._instance = None + yield + rate_limiter.reset_rate_limiter() + rate_limiter._rate_limiter = None + rate_limiter.GlobalRateLimiter._instance = None + + +def _assistant_msg(error=None, text=None): + """Build a real AssistantMessage — isinstance checks in _run_query need the real class.""" + content = [TextBlock(text=text)] if text else [] + return AssistantMessage(content=content, model="test-model", error=error) + + +def _result_msg(cost=0.0): + return ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=50, + is_error=False, + num_turns=1, + session_id="test", + total_cost_usd=cost, + usage={}, + ) + + +class _FakeClient: + """Drop-in stand-in for ClaudeSDKClient async context manager.""" + + def __init__(self, messages): + self._messages = messages + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return False + + async def query(self, prompt): + return None + + async def receive_response(self): + for msg in self._messages: + yield msg + + +def _run(messages): + """Invoke _run_query with a scripted message sequence. Returns (result, text) or raises.""" + from utilities.llm_client import _run_query + + options = MagicMock() + options.model = "test-model" + options.max_turns = 1 + + # _run_query imports ClaudeSDKClient from claude_agent_sdk on each call; + # patching the source module swaps it in place. + with patch("claude_agent_sdk.ClaudeSDKClient", lambda options: _FakeClient(messages)): + return asyncio.run(_run_query("prompt", options)) + + +class TestAssistantMessageErrorMapping: + def test_rate_limit_raises_rate_limit_error(self): + from utilities.sdk_errors import RateLimitError + msgs = [_assistant_msg(error="rate_limit", text="too many")] + with pytest.raises(RateLimitError): + _run(msgs) + + def test_rate_limit_notifies_global_limiter(self): + from utilities.sdk_errors import RateLimitError + from utilities.rate_limiter import get_rate_limiter + + limiter = get_rate_limiter() + limiter.report_rate_limit = MagicMock() + + msgs = [_assistant_msg(error="rate_limit", text="slow down")] + with pytest.raises(RateLimitError): + _run(msgs) + + limiter.report_rate_limit.assert_called_once_with(0) + + def test_auth_error(self): + from utilities.sdk_errors import AuthError + msgs = [_assistant_msg(error="authentication_failed", text="bad key")] + with pytest.raises(AuthError): + _run(msgs) + + def test_billing_error(self): + from utilities.sdk_errors import BillingError + msgs = [_assistant_msg(error="billing_error")] + with pytest.raises(BillingError): + _run(msgs) + + def test_server_error(self): + from utilities.sdk_errors import ServerError + msgs = [_assistant_msg(error="server_error")] + with pytest.raises(ServerError): + _run(msgs) + + def test_invalid_request(self): + from utilities.sdk_errors import InvalidRequestError + msgs = [_assistant_msg(error="invalid_request")] + with pytest.raises(InvalidRequestError): + _run(msgs) + + def test_unknown_error_falls_back(self): + from utilities.sdk_errors import UnknownLLMError + msgs = [_assistant_msg(error="unknown")] + with pytest.raises(UnknownLLMError): + _run(msgs) + + def test_non_rate_limit_does_not_notify_limiter(self): + from utilities.sdk_errors import AuthError + from utilities.rate_limiter import get_rate_limiter + + limiter = get_rate_limiter() + limiter.report_rate_limit = MagicMock() + + msgs = [_assistant_msg(error="authentication_failed")] + with pytest.raises(AuthError): + _run(msgs) + + limiter.report_rate_limit.assert_not_called() + + def test_clean_message_does_not_raise(self): + msgs = [ + _assistant_msg(error=None, text="hello world"), + _result_msg(cost=0.001), + ] + result_msg, text = _run(msgs) + assert text == "hello world" + assert result_msg is not None diff --git a/libs/openant-core/tests/test_sdk_errors.py b/libs/openant-core/tests/test_sdk_errors.py new file mode 100644 index 0000000..75de23a --- /dev/null +++ b/libs/openant-core/tests/test_sdk_errors.py @@ -0,0 +1,87 @@ +"""Tests for utilities.sdk_errors.""" +import pytest + +from utilities.sdk_errors import ( + OpenAntLLMError, + AuthError, + BillingError, + RateLimitError, + InvalidRequestError, + ServerError, + UnknownLLMError, + error_from_kind, + classify_error, +) + + +class TestErrorFromKind: + def test_each_known_kind_maps_to_correct_class(self): + cases = [ + ("authentication_failed", AuthError), + ("billing_error", BillingError), + ("rate_limit", RateLimitError), + ("invalid_request", InvalidRequestError), + ("server_error", ServerError), + ("unknown", UnknownLLMError), + ] + for kind, expected_cls in cases: + exc = error_from_kind(kind) + assert isinstance(exc, expected_cls), f"{kind} -> {type(exc).__name__}" + assert exc.error_kind == kind + + def test_unrecognized_kind_falls_back_to_unknown(self): + exc = error_from_kind("not_a_real_kind") + assert isinstance(exc, UnknownLLMError) + + def test_custom_message_used_when_provided(self): + exc = error_from_kind("rate_limit", "hit the limit") + assert "hit the limit" in str(exc) + + def test_default_message_mentions_kind(self): + exc = error_from_kind("server_error") + assert "server_error" in str(exc) + + +class TestClassifyError: + def test_rate_limit(self): + info = classify_error(RateLimitError("rate limit hit")) + assert info["type"] == "rate_limit" + assert info["exception_class"] == "RateLimitError" + assert "rate limit hit" in info["message"] + + def test_auth_billing_server_invalid_request(self): + assert classify_error(AuthError())["type"] == "auth" + assert classify_error(BillingError())["type"] == "billing" + assert classify_error(ServerError())["type"] == "server" + assert classify_error(InvalidRequestError())["type"] == "invalid_request" + + def test_non_llm_error_type_is_unknown_but_class_name_preserved(self): + info = classify_error(ValueError("bad")) + assert info["type"] == "unknown" + assert info["exception_class"] == "ValueError" + + def test_agent_state_propagated(self): + exc = RateLimitError("hit limit") + exc.agent_state = {"iteration": 3, "tokens_used": 1234} + info = classify_error(exc) + assert info["agent_state"] == {"iteration": 3, "tokens_used": 1234} + + def test_missing_agent_state_not_in_dict(self): + info = classify_error(RateLimitError()) + assert "agent_state" not in info + + +class TestOpenAntLLMError: + def test_kwargs_become_attributes(self): + exc = RateLimitError("msg", request_id="req-123", status_code=429) + assert exc.request_id == "req-123" + assert exc.status_code == 429 + + def test_message_attribute_mirrors_str(self): + exc = RateLimitError("hello") + assert exc.message == "hello" + assert str(exc) == "hello" + + def test_can_be_raised_and_caught_as_base(self): + with pytest.raises(OpenAntLLMError): + raise RateLimitError("x") diff --git a/libs/openant-core/tests/test_silent_401.py b/libs/openant-core/tests/test_silent_401.py index bbb9fe7..49d15fd 100644 --- a/libs/openant-core/tests/test_silent_401.py +++ b/libs/openant-core/tests/test_silent_401.py @@ -7,25 +7,14 @@ import io import sys -import types from pathlib import Path -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest _CORE_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(_CORE_ROOT)) -if "anthropic" not in sys.modules: - _stub = types.ModuleType("anthropic") - _stub.Anthropic = MagicMock() - sys.modules["anthropic"] = _stub -_anth = sys.modules["anthropic"] -if not hasattr(_anth, "RateLimitError"): - _anth.RateLimitError = type("RateLimitError", (Exception,), {}) -if not hasattr(_anth, "AuthenticationError"): - _anth.AuthenticationError = type("AuthenticationError", (Exception,), {}) - from core.schemas import ScanResult, AnalysisMetrics, UsageInfo # noqa: E402 @@ -98,20 +87,27 @@ def test_print_summary_no_warning_on_normal_scan(normal_result): # --------------------------------------------------------------------------- def test_analyze_sync_raises_on_auth_error(): - """When the Anthropic API returns 401, analyze_sync must not swallow it.""" + """When the Claude Agent SDK reports auth failure, analyze_sync must not swallow it. + + The SDK signals authentication failures via ``AssistantMessage.error == + "authentication_failed"``, which ``llm_client._run_query`` re-raises as + ``utilities.sdk_errors.AuthError``. ``AnthropicClient.analyze_sync`` must + propagate that exception unmodified rather than returning an empty string. + """ import os os.environ["ANTHROPIC_API_KEY"] = "sk-test-bad-key" from utilities.llm_client import AnthropicClient - - AuthError = sys.modules["anthropic"].AuthenticationError + from utilities.sdk_errors import AuthError client = AnthropicClient.__new__(AnthropicClient) - client.client = MagicMock() - client.client.messages.create.side_effect = AuthError("invalid x-api-key") client.model = "claude-haiku-4-5-20251001" client.tracker = MagicMock() client.last_call = None - with pytest.raises(AuthError): - client.analyze_sync("test prompt") + with patch( + "utilities.llm_client._run_query_sync", + side_effect=AuthError("invalid x-api-key"), + ): + with pytest.raises(AuthError): + client.analyze_sync("test prompt") diff --git a/libs/openant-core/tests/test_token_tracker.py b/libs/openant-core/tests/test_token_tracker.py index 08fdc9c..410bdbf 100644 --- a/libs/openant-core/tests/test_token_tracker.py +++ b/libs/openant-core/tests/test_token_tracker.py @@ -1,5 +1,6 @@ """Tests for TokenTracker.""" from utilities.llm_client import TokenTracker, MODEL_PRICING +from utilities.model_config import MODEL_PRIMARY, MODEL_AUXILIARY class TestTokenTracker: @@ -13,9 +14,9 @@ def test_initial_state(self): def test_record_call_known_model(self): tracker = TokenTracker() - result = tracker.record_call("claude-sonnet-4-20250514", 1000, 500) + result = tracker.record_call(MODEL_AUXILIARY, 1000, 500) - assert result["model"] == "claude-sonnet-4-20250514" + assert result["model"] == MODEL_AUXILIARY assert result["input_tokens"] == 1000 assert result["output_tokens"] == 500 # Sonnet: $3/M input, $15/M output @@ -31,8 +32,8 @@ def test_record_call_unknown_model_uses_default(self): def test_cumulative_tracking(self): tracker = TokenTracker() - tracker.record_call("claude-sonnet-4-20250514", 1000, 500) - tracker.record_call("claude-sonnet-4-20250514", 2000, 1000) + tracker.record_call(MODEL_AUXILIARY, 1000, 500) + tracker.record_call(MODEL_AUXILIARY, 2000, 1000) assert tracker.total_input_tokens == 3000 assert tracker.total_output_tokens == 1500 @@ -41,7 +42,7 @@ def test_cumulative_tracking(self): def test_reset(self): tracker = TokenTracker() - tracker.record_call("claude-sonnet-4-20250514", 1000, 500) + tracker.record_call(MODEL_AUXILIARY, 1000, 500) tracker.reset() assert tracker.total_input_tokens == 0 @@ -51,7 +52,7 @@ def test_reset(self): def test_get_summary_includes_calls(self): tracker = TokenTracker() - tracker.record_call("claude-sonnet-4-20250514", 100, 50) + tracker.record_call(MODEL_AUXILIARY, 100, 50) summary = tracker.get_summary() assert summary["total_calls"] == 1 @@ -60,7 +61,7 @@ def test_get_summary_includes_calls(self): def test_get_totals_excludes_calls(self): tracker = TokenTracker() - tracker.record_call("claude-sonnet-4-20250514", 100, 50) + tracker.record_call(MODEL_AUXILIARY, 100, 50) totals = tracker.get_totals() assert totals["total_calls"] == 1 @@ -68,6 +69,6 @@ def test_get_totals_excludes_calls(self): def test_opus_pricing(self): tracker = TokenTracker() - result = tracker.record_call("claude-opus-4-20250514", 1_000_000, 1_000_000) + result = tracker.record_call(MODEL_PRIMARY, 1_000_000, 1_000_000) # Opus: $15/M input, $75/M output assert result["cost_usd"] == 90.0 diff --git a/libs/openant-core/utilities/agentic_enhancer/agent.py b/libs/openant-core/utilities/agentic_enhancer/agent.py index 62061b7..44b2354 100644 --- a/libs/openant-core/utilities/agentic_enhancer/agent.py +++ b/libs/openant-core/utilities/agentic_enhancer/agent.py @@ -1,8 +1,10 @@ """ Agentic Context Enhancer -Main agent loop that iteratively explores the codebase to gather context. -Uses Claude Sonnet with tool use to search and read code. +Uses the Claude Agent SDK with native tools (Read, Grep, Glob, Bash) to +explore the codebase and gather context for security analysis. Static +dependencies are pre-resolved from the RepositoryIndex and included in +the prompt so the model has a head start before exploring. Supports reachability-aware classification to distinguish: - EXPLOITABLE: Vulnerable + reachable from user input @@ -12,25 +14,65 @@ """ import json +import re +import sys from typing import Optional, Set, List -import anthropic - from ..llm_client import TokenTracker, get_global_tracker +from ..model_config import MODEL_AUXILIARY from ..rate_limiter import get_rate_limiter +from ..sdk_errors import RateLimitError from .repository_index import RepositoryIndex -from .tools import TOOL_DEFINITIONS, ToolExecutor +from .tools import ToolExecutor from .prompts import SYSTEM_PROMPT, get_user_prompt from .entry_point_detector import EntryPointDetector from .reachability_analyzer import ReachabilityAnalyzer # Use Sonnet for exploration (cost-effective) -AGENT_MODEL = "claude-sonnet-4-20250514" +AGENT_MODEL = MODEL_AUXILIARY # Safety limits -MAX_ITERATIONS = 20 -MAX_TOKENS_PER_RESPONSE = 4096 +MAX_TURNS = 20 + + +# JSON schema for structured agent output +AGENT_OUTPUT_SCHEMA = { + "type": "object", + "properties": { + "include_functions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Function identifier in format 'relative/path.ext:functionName'", + }, + "reason": { + "type": "string", + "description": "Why this function is needed for context", + }, + }, + "required": ["id", "reason"], + }, + }, + "usage_context": {"type": "string"}, + "security_classification": { + "type": "string", + "enum": ["exploitable", "vulnerable_internal", "security_control", "neutral"], + }, + "classification_reasoning": {"type": "string"}, + "confidence": {"type": "number"}, + }, + "required": [ + "include_functions", + "usage_context", + "security_classification", + "classification_reasoning", + "confidence", + ], +} class AgentResult: @@ -68,7 +110,7 @@ def __init__( def to_dict(self) -> dict: """Convert to dictionary for JSON serialization.""" - result = { + return { "include_functions": self.include_functions, "usage_context": self.usage_context, "security_classification": self.security_classification, @@ -84,16 +126,49 @@ def to_dict(self) -> dict: "reachability": { "is_entry_point": self.is_entry_point, "reachable_from_entry": self.reachable_from_entry, - "entry_point_path": self.entry_point_path - } + "entry_point_path": self.entry_point_path, + }, } - return result + + +def _pre_resolve_deps(tool_executor: ToolExecutor, static_deps: list, static_callers: list) -> str: + """Pre-resolve static dependencies via ToolExecutor and format for prompt.""" + tool_executor.set_unit_context(static_deps, static_callers) + resolved = tool_executor.execute("get_static_dependencies", {}) + parts = [] + for label, key in [("Resolved Dependencies", "dependencies"), ("Resolved Callers", "callers")]: + items = resolved.get(key, {}).get("resolved", []) + if items: + parts.append(f"### {label} (from parsed index)") + for item in items[:15]: + if isinstance(item, dict): + parts.append(f"- `{item.get('id', item.get('name', '?'))}` ({item.get('file', '')})") + else: + parts.append(f"- `{item}`") + return "\n".join(parts) if parts else "" + + +def _try_parse_json(text: str) -> Optional[dict]: + """Try to extract JSON from text, handling code blocks.""" + try: + return json.loads(text) + except (json.JSONDecodeError, ValueError): + pass + match = re.search(r"```(?:json)?\s*\n(.*?)\n\s*```", text, re.DOTALL) + if match: + try: + return json.loads(match.group(1)) + except (json.JSONDecodeError, ValueError): + pass + return None class ContextAgent: """ Agent that explores codebase to gather context for security analysis. - Uses iterative tool use to trace call paths and understand code intent. + Uses Claude Agent SDK with native tools (Read, Grep, Glob, Bash) to + trace call paths and understand code intent. Static dependencies are + pre-resolved from the RepositoryIndex and included in the prompt. Supports reachability-aware classification when entry_points and reachability analyzer are provided. @@ -106,7 +181,6 @@ def __init__( verbose: bool = False, entry_points: Optional[Set[str]] = None, reachability: Optional[ReachabilityAnalyzer] = None, - client: Optional[anthropic.Anthropic] = None, ): """ Initialize the agent. @@ -117,8 +191,6 @@ def __init__( verbose: If True, print debug information entry_points: Set of func_ids that are entry points (optional) reachability: ReachabilityAnalyzer for checking user input paths (optional) - client: Shared Anthropic client (reuse across workers to avoid FD exhaustion). - If not provided, creates a new one (only for standalone/test use). """ self.index = index self.tracker = tracker or get_global_tracker() @@ -126,7 +198,6 @@ def __init__( self.tool_executor = ToolExecutor(index) self.entry_points = entry_points or set() self.reachability = reachability - self.client = client or anthropic.Anthropic(max_retries=5) def analyze_unit( self, @@ -134,10 +205,10 @@ def analyze_unit( unit_type: str, primary_code: str, static_deps: list[str], - static_callers: list[str] + static_callers: list[str], ) -> AgentResult: """ - Analyze a code unit to gather context. + Analyze a code unit using Claude Agent SDK with native tools. Args: unit_id: Function identifier @@ -149,6 +220,9 @@ def analyze_unit( Returns: AgentResult with gathered context """ + # Lazy import to avoid breaking non-LLM commands at import time. + from ..llm_client import _run_query_sync, _build_options + # Compute reachability info is_entry_point = unit_id in self.entry_points reachable_from_entry: Optional[bool] = None @@ -161,7 +235,9 @@ def analyze_unit( entry_point_path = self.reachability.get_entry_point_path(unit_id) reaching_entry_point = self.reachability.get_reaching_entry_point(unit_id) - # Build initial prompt with reachability info + # Pre-resolve static deps from the index so the model has a head start + resolved_context = _pre_resolve_deps(self.tool_executor, static_deps, static_callers) + user_prompt = get_user_prompt( unit_id=unit_id, unit_type=unit_type, @@ -171,204 +247,123 @@ def analyze_unit( is_entry_point=is_entry_point, reachable_from_entry=reachable_from_entry, entry_point_path=entry_point_path, - reaching_entry_point=reaching_entry_point + reaching_entry_point=reaching_entry_point, ) + if resolved_context: + user_prompt += f"\n\n{resolved_context}" - # Initialize conversation - messages = [{"role": "user", "content": user_prompt}] - - iterations = 0 - total_input_tokens = 0 - total_output_tokens = 0 + repo_path = str(self.index.repo_path) if self.index.repo_path else None - while iterations < MAX_ITERATIONS: - iterations += 1 - - if self.verbose: - print(f" Iteration {iterations}...") - - # Call Claude with rate limiting - try: - # Wait if we're in a global backoff period - rate_limiter = get_rate_limiter() - rate_limiter.wait_if_needed() - - response = self.client.messages.create( - model=AGENT_MODEL, - max_tokens=MAX_TOKENS_PER_RESPONSE, - system=SYSTEM_PROMPT, - tools=TOOL_DEFINITIONS, - messages=messages - ) - except anthropic.RateLimitError as exc: - # Report to global rate limiter so all workers back off - retry_after = float(exc.response.headers.get("retry-after", 0)) - get_rate_limiter().report_rate_limit(retry_after) - # Attach agent state so the caller knows how far we got - exc.agent_state = { - "iteration": iterations, - "max_iterations": MAX_ITERATIONS, - "tokens_used": total_input_tokens + total_output_tokens, - "input_tokens": total_input_tokens, - "output_tokens": total_output_tokens, - } - raise - except Exception as exc: - # Attach agent state so the caller knows how far we got - exc.agent_state = { - "iteration": iterations, - "max_iterations": MAX_ITERATIONS, - "tokens_used": total_input_tokens + total_output_tokens, - "input_tokens": total_input_tokens, - "output_tokens": total_output_tokens, - } - raise - - # Track tokens - total_input_tokens += response.usage.input_tokens - total_output_tokens += response.usage.output_tokens - - # Process response - assistant_content = response.content - stop_reason = response.stop_reason - - if self.verbose: - # Print text blocks - for block in assistant_content: - if hasattr(block, 'text'): - print(f" Agent: {block.text[:200]}...") - - # Check if we're done (finish tool called or no more tool use) - if stop_reason == "end_turn": - # Model finished without calling finish tool - # Return default result - if self.verbose: - print(" Agent ended without calling finish tool") - - return AgentResult( - include_functions=[], - usage_context="Agent did not complete analysis", - security_classification="neutral", - classification_reasoning="Analysis incomplete", - confidence=0.3, - iterations=iterations, - total_tokens=total_input_tokens + total_output_tokens, - is_entry_point=is_entry_point, - reachable_from_entry=reachable_from_entry, - entry_point_path=entry_point_path - ) - - # Process tool calls - tool_results = [] - finish_result = None - - for block in assistant_content: - if block.type == "tool_use": - tool_name = block.name - tool_input = block.input - tool_use_id = block.id - - if self.verbose: - print(f" Tool: {tool_name}({json.dumps(tool_input)[:100]}...)") - - # Execute tool - result = self.tool_executor.execute(tool_name, tool_input) - - if self.verbose: - result_preview = str(result)[:200] - print(f" Result: {result_preview}...") - - # Check for finish - if tool_name == "finish" and result.get("status") == "complete": - finish_result = result.get("result", {}) - # Still add to tool_results for the message - tool_results.append({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": json.dumps(result) - }) - break - else: - tool_results.append({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": json.dumps(result) - }) - - # If finish was called, return result - if finish_result: - # Record token usage - call_record = self.tracker.record_call( - model=AGENT_MODEL, - input_tokens=total_input_tokens, - output_tokens=total_output_tokens - ) - - return AgentResult( - include_functions=finish_result.get("include_functions", []), - usage_context=finish_result.get("usage_context", ""), - security_classification=finish_result.get("security_classification", "neutral"), - classification_reasoning=finish_result.get("classification_reasoning", ""), - confidence=finish_result.get("confidence", 0.5), - iterations=iterations, - total_tokens=total_input_tokens + total_output_tokens, - is_entry_point=is_entry_point, - reachable_from_entry=reachable_from_entry, - entry_point_path=entry_point_path, - input_tokens=total_input_tokens, - output_tokens=total_output_tokens, - cost_usd=call_record.get("cost_usd", 0.0), - ) - - # Add assistant message and tool results to conversation - messages.append({"role": "assistant", "content": assistant_content}) - - # Only add user message with tool results if there are results - # (empty content triggers API error: "user messages must have non-empty content") - if tool_results: - messages.append({"role": "user", "content": tool_results}) - else: - # No tool calls but model didn't end — treat as incomplete - if self.verbose: - print(" No tool calls in response, treating as incomplete") - return AgentResult( - include_functions=[], - usage_context="Agent response had no tool calls", - security_classification="neutral", - classification_reasoning="Analysis incomplete - no tool calls", - confidence=0.3, - iterations=iterations, - total_tokens=total_input_tokens + total_output_tokens, - is_entry_point=is_entry_point, - reachable_from_entry=reachable_from_entry, - entry_point_path=entry_point_path - ) + options = _build_options( + model=AGENT_MODEL, + system=SYSTEM_PROMPT, + max_turns=MAX_TURNS, + allowed_tools=["Read", "Grep", "Glob", "Bash"], + add_dirs=[repo_path] if repo_path else [], + cwd=repo_path, + output_format={"type": "json_schema", "schema": AGENT_OUTPUT_SCHEMA}, + max_budget_usd=0.50, + ) - # Max iterations reached if self.verbose: - print(f" Max iterations ({MAX_ITERATIONS}) reached") + print(f" Analyzing {unit_id} via SDK...", file=sys.stderr, flush=True) + + try: + result_message, last_text = _run_query_sync(user_prompt, options, label=unit_id) + except RateLimitError as exc: + # Rate limit is already reported to GlobalRateLimiter inside _run_query. + # Attach agent state so the caller (checkpoint/resume) knows how far we got. + exc.agent_state = { + "unit_id": unit_id, + "tokens_used": 0, + "input_tokens": 0, + "output_tokens": 0, + } + raise + except Exception as exc: + # Attach agent state for non-rate-limit errors too (checkpoint telemetry). + exc.agent_state = { + "unit_id": unit_id, + "tokens_used": 0, + "input_tokens": 0, + "output_tokens": 0, + } + raise + + # Track usage via the SDK-reported totals + usage = (result_message.usage or {}) if result_message else {} + input_tokens = usage.get("input_tokens", 0) + output_tokens = usage.get("output_tokens", 0) + total_tokens = input_tokens + output_tokens + cost_usd_sdk = (result_message.total_cost_usd or 0.0) if result_message else None - # Record token usage call_record = self.tracker.record_call( model=AGENT_MODEL, - input_tokens=total_input_tokens, - output_tokens=total_output_tokens + input_tokens=input_tokens, + output_tokens=output_tokens, + cost_usd=cost_usd_sdk, ) + call_cost = call_record.get("cost_usd", 0.0) + + # Parse structured output (preferred) or fall back to text + JSON extraction + parsed: Optional[dict] = None + structured = getattr(result_message, "structured_output", None) if result_message else None + if structured and isinstance(structured, dict): + parsed = structured + if not parsed: + raw_text = ( + result_message.result + if result_message and result_message.result + else last_text or "" + ) + if raw_text: + parsed = _try_parse_json(raw_text) + + # Num turns from SDK (for telemetry parity with the old iterations field) + num_turns = ( + getattr(result_message, "num_turns", 0) if result_message else 0 + ) or 0 + + if parsed and "security_classification" in parsed: + if self.verbose: + print( + f" Classification: {parsed['security_classification']} " + f"(confidence: {parsed.get('confidence', '?')})", + file=sys.stderr, + flush=True, + ) + return AgentResult( + include_functions=parsed.get("include_functions", []), + usage_context=parsed.get("usage_context", ""), + security_classification=parsed.get("security_classification", "neutral"), + classification_reasoning=parsed.get("classification_reasoning", ""), + confidence=parsed.get("confidence", 0.5), + iterations=num_turns, + total_tokens=total_tokens, + is_entry_point=is_entry_point, + reachable_from_entry=reachable_from_entry, + entry_point_path=entry_point_path, + input_tokens=input_tokens, + output_tokens=output_tokens, + cost_usd=call_cost, + ) + if self.verbose: + print(" Could not parse agent response", file=sys.stderr, flush=True) return AgentResult( include_functions=[], - usage_context="Analysis terminated - max iterations reached", + usage_context="Could not parse agent response", security_classification="neutral", - classification_reasoning="Could not complete analysis within iteration limit", + classification_reasoning="Analysis response unparseable", confidence=0.2, - iterations=iterations, - total_tokens=total_input_tokens + total_output_tokens, + iterations=num_turns, + total_tokens=total_tokens, is_entry_point=is_entry_point, reachable_from_entry=reachable_from_entry, entry_point_path=entry_point_path, - input_tokens=total_input_tokens, - output_tokens=total_output_tokens, - cost_usd=call_record.get("cost_usd", 0.0), + input_tokens=input_tokens, + output_tokens=output_tokens, + cost_usd=call_cost, ) @@ -379,22 +374,24 @@ def enhance_unit_with_agent( verbose: bool = False, entry_points: Optional[Set[str]] = None, reachability: Optional[ReachabilityAnalyzer] = None, - client: Optional[anthropic.Anthropic] = None, ) -> dict: """ Enhance a single unit using the agentic approach. + Mutates ``unit`` in place: attaches ``agent_context`` and, if the agent + identified additional functions, extends ``unit["code"]["primary_code"]`` + with those definitions. + Args: - unit: Unit from dataset + unit: Unit from dataset (mutated in place) index: Repository index for searching tracker: Token tracker verbose: Print debug info entry_points: Set of func_ids that are entry points (optional) reachability: ReachabilityAnalyzer for checking user input paths (optional) - client: Shared Anthropic client (reuse across workers to avoid FD exhaustion). Returns: - Enhanced unit with agent_context field including reachability info + The same unit dict, with agent_context field including reachability info """ agent = ContextAgent( index=index, @@ -402,16 +399,23 @@ def enhance_unit_with_agent( verbose=verbose, entry_points=entry_points, reachability=reachability, - client=client, ) # Extract unit info unit_id = unit.get("id", "unknown") unit_type = unit.get("unit_type", "function") code_section = unit.get("code", {}) - primary_code = code_section.get("primary_code", "") - static_deps = unit.get("metadata", {}).get("direct_calls", []) - static_callers = unit.get("metadata", {}).get("direct_callers", []) + primary_code = ( + code_section + if isinstance(code_section, str) + else code_section.get("primary_code", "") + ) + metadata = unit.get("metadata", {}) + if isinstance(metadata, str): + static_deps, static_callers = [], [] + else: + static_deps = metadata.get("direct_calls", []) + static_callers = metadata.get("direct_callers", []) # Run agent result = agent.analyze_unit( @@ -419,21 +423,34 @@ def enhance_unit_with_agent( unit_type=unit_type, primary_code=primary_code, static_deps=static_deps, - static_callers=static_callers + static_callers=static_callers, ) # Add result to unit unit["agent_context"] = result.to_dict() # Assemble additional code if functions were identified - if result.include_functions: + if result.include_functions and isinstance(unit.get("code"), dict): additional_code = [] additional_files = set() for func_info in result.include_functions: - func_id = func_info.get("id", "") + func_id = func_info if isinstance(func_info, str) else func_info.get("id", "") func_data = index.get_function(func_id) + # Fuzzy match: SDK native tools may return IDs that don't match the + # index exactly (e.g. "file.py:Class.method" vs index's "file.py:method"). + if not func_data and func_id: + name_part = func_id.rsplit(":", 1)[-1] if ":" in func_id else func_id + if "." in name_part: + name_part = name_part.rsplit(".", 1)[-1] + matches = index.search_by_name(name_part, exact=True) + if not matches: + matches = index.search_by_name(name_part, exact=False) + if matches: + func_data = matches[0] + func_id = func_data.get("id", func_id) + if func_data and func_data.get("code"): additional_code.append(func_data["code"]) @@ -445,7 +462,7 @@ def enhance_unit_with_agent( # Append to primary_code with file boundaries if additional_code: FILE_BOUNDARY = "\n\n// ========== File Boundary ==========\n\n" - current_code = unit["code"]["primary_code"] + current_code = unit["code"].get("primary_code", "") assembled = current_code + FILE_BOUNDARY + FILE_BOUNDARY.join(additional_code) unit["code"]["primary_code"] = assembled @@ -463,7 +480,7 @@ def enhance_unit_with_agent( def create_reachability_context( functions: dict, call_graph: dict, - reverse_call_graph: dict + reverse_call_graph: dict, ) -> tuple[Set[str], ReachabilityAnalyzer]: """ Create entry points and reachability analyzer from call graph data. @@ -502,7 +519,7 @@ def create_reachability_context( reachability = ReachabilityAnalyzer( functions=functions, reverse_call_graph=reverse_call_graph, - entry_points=entry_points + entry_points=entry_points, ) return entry_points, reachability diff --git a/libs/openant-core/utilities/context_corrector.py b/libs/openant-core/utilities/context_corrector.py index 918dda6..d451cad 100644 --- a/libs/openant-core/utilities/context_corrector.py +++ b/libs/openant-core/utilities/context_corrector.py @@ -17,6 +17,7 @@ from typing import Optional from .llm_client import AnthropicClient, TokenTracker, get_global_tracker +from .model_config import MODEL_AUXILIARY # Maximum characters per batch (leaving room for prompt overhead) @@ -102,7 +103,7 @@ def parse_missing_context_with_llm( prompt = get_missing_context_prompt(reasoning) try: - llm_response = client.analyze_sync(prompt, model="claude-sonnet-4-20250514") + llm_response = client.analyze_sync(prompt, model=MODEL_AUXILIARY) parsed = _parse_json_response(llm_response) if parsed and "missing_context" in parsed: @@ -254,7 +255,7 @@ def search_files_for_context( prompt = get_file_search_prompt(missing_context, files_content, batch_info) try: - response = client.analyze_sync(prompt, model="claude-sonnet-4-20250514") + response = client.analyze_sync(prompt, model=MODEL_AUXILIARY) result = _parse_json_response(response) if result and result.get("found_files"): diff --git a/libs/openant-core/utilities/context_enhancer.py b/libs/openant-core/utilities/context_enhancer.py index 2ffbfe6..dd3adfe 100644 --- a/libs/openant-core/utilities/context_enhancer.py +++ b/libs/openant-core/utilities/context_enhancer.py @@ -23,11 +23,18 @@ from pathlib import Path from typing import Callable, Optional -import anthropic - from .llm_client import AnthropicClient, TokenTracker, get_global_tracker, reset_global_tracker from .agentic_enhancer import RepositoryIndex, enhance_unit_with_agent, load_index_from_file +from .model_config import MODEL_AUXILIARY from .rate_limiter import get_rate_limiter, is_rate_limit_error, is_retryable_error +from .sdk_errors import ( + AuthError, + BillingError, + InvalidRequestError, + RateLimitError, + ServerError, + UnknownLLMError, +) # Avoid circular import — import checkpoint at usage site _StepCheckpoint = None @@ -45,14 +52,28 @@ def _get_step_checkpoint(): # Use Sonnet for context enhancement (cost-effective auxiliary task) -CONTEXT_ENHANCEMENT_MODEL = "claude-sonnet-4-20250514" +CONTEXT_ENHANCEMENT_MODEL = MODEL_AUXILIARY def _build_error_info(exc: Exception) -> dict: """Build a structured error dict from an exception. - Captures exception type, message, HTTP status, request ID, and - any agent iteration state attached by agent.py. + Maps utilities.sdk_errors classes (raised by llm_client._run_query) + AND claude_agent_sdk process-level errors onto the diagnostic shape + is_rate_limit_error() and is_retryable_error() expect: + + {"type": "rate_limit" | "connection" | "timeout" | + "api_status" | "auth" | "billing" | "invalid_request" | + "unknown", + "exception_class": "...", + "message": "...", + ["agent_state": {...}]} + + The pre-migration shape also carried request_id and retry_after from + anthropic response headers; the Claude Agent SDK doesn't surface those. + Callers that only check "type" (is_rate_limit_error, is_retryable_error) + keep working unchanged. Callers that read request_id/retry_after won't + find them any more — none in this codebase do. """ info = { "type": "unknown", @@ -60,24 +81,52 @@ def _build_error_info(exc: Exception) -> dict: "message": str(exc), } - # Anthropic SDK specific exceptions - if isinstance(exc, anthropic.APIConnectionError): - info["type"] = "connection" - elif isinstance(exc, anthropic.APITimeoutError): - info["type"] = "timeout" - elif isinstance(exc, anthropic.RateLimitError): + # SDK-layer (API-reported) errors raised by utilities.llm_client._run_query. + if isinstance(exc, RateLimitError): info["type"] = "rate_limit" - info["status_code"] = exc.status_code - if hasattr(exc, "response") and exc.response is not None: - info["request_id"] = exc.response.headers.get("request-id") - retry_after = exc.response.headers.get("retry-after") - if retry_after: - info["retry_after"] = retry_after - elif isinstance(exc, anthropic.APIStatusError): + elif isinstance(exc, AuthError): + info["type"] = "auth" + elif isinstance(exc, BillingError): + info["type"] = "billing" + elif isinstance(exc, ServerError): + # Treat as a 500-class error for is_retryable_error(). + info["type"] = "api_status" + info["status_code"] = 500 + elif isinstance(exc, InvalidRequestError): + # 400-class — caller bug, not retryable. info["type"] = "api_status" - info["status_code"] = exc.status_code - if hasattr(exc, "response") and exc.response is not None: - info["request_id"] = exc.response.headers.get("request-id") + info["status_code"] = 400 + elif isinstance(exc, UnknownLLMError): + info["type"] = "unknown" + else: + # Process-level errors from claude_agent_sdk: classify so the + # retry policy treats transient subprocess/IO failures as retryable. + try: + from claude_agent_sdk import ( + CLIConnectionError, + CLIJSONDecodeError, + CLINotFoundError, + ProcessError, + ) + if isinstance(exc, CLIConnectionError): + info["type"] = "connection" + elif isinstance(exc, ProcessError): + # Subprocess died unexpectedly — usually transient. + info["type"] = "connection" + elif isinstance(exc, CLIJSONDecodeError): + # Malformed SDK response — treat as transient. + info["type"] = "connection" + elif isinstance(exc, CLINotFoundError): + # `claude` binary missing — environmental/config issue, not + # retryable. Leave type as "unknown" (default); the + # exception_class field carries the precise diagnostic so + # operators can distinguish from API-side auth failures. + pass + except ImportError: # pragma: no cover + pass + # Plain TimeoutError still matches the "timeout" retry path. + if isinstance(exc, TimeoutError): + info["type"] = "timeout" # Agent iteration state (attached by agent.py) agent_state = getattr(exc, "agent_state", None) @@ -568,7 +617,7 @@ def enhance_dataset_agentic( remaining = total - len(processed_ids) self._log("info", f"Enhancing {remaining} units with agentic analysis ({len(processed_ids)} already done)", units=remaining) self._log("info", "Mode: Iterative tool use (traces call paths)") - self._log("info", "Model: claude-sonnet-4-20250514") + self._log("info", f"Model: {CONTEXT_ENHANCEMENT_MODEL}") mode = "sequential" if workers <= 1 else f"parallel ({workers} workers)" self._log("info", f"Workers: {mode}") if checkpoint_dir: @@ -580,13 +629,6 @@ def enhance_dataset_agentic( stats = index.get_statistics() self._log("info", f"Indexed {stats['total_functions']} functions from {stats['total_files']} files") - # Create a single shared Anthropic client for all workers. - # Each ContextAgent previously created its own anthropic.Anthropic() instance, - # which spawns a new httpx connection pool. With 1000+ units and 8 workers, - # this exhausted file descriptors (macOS limit ~256). The httpx.Client - # underlying anthropic.Anthropic is thread-safe, so sharing is correct. - shared_client = anthropic.Anthropic(max_retries=5) - # Filter to unprocessed units units_to_process = [(i, unit) for i, unit in enumerate(units) if unit.get("id") not in processed_ids] @@ -596,7 +638,7 @@ def _enhance_one(unit): unit_start = time.monotonic() classification = "neutral" try: - enhance_unit_with_agent(unit, index, self.tracker, verbose, client=shared_client) + enhance_unit_with_agent(unit, index, self.tracker, verbose) agent_ctx = unit.get("agent_context", {}) classification = agent_ctx.get("security_classification", "neutral") diff --git a/libs/openant-core/utilities/context_reviewer.py b/libs/openant-core/utilities/context_reviewer.py index b17107d..5b9aa3c 100644 --- a/libs/openant-core/utilities/context_reviewer.py +++ b/libs/openant-core/utilities/context_reviewer.py @@ -13,6 +13,7 @@ from typing import Optional from .llm_client import AnthropicClient +from .model_config import MODEL_AUXILIARY from .context_corrector import gather_source_files, search_files_for_context @@ -176,7 +177,7 @@ def review_context( prompt = get_context_review_prompt(code, route, handler, files_included) try: - response = self.client.analyze_sync(prompt, model="claude-sonnet-4-20250514") + response = self.client.analyze_sync(prompt, model=MODEL_AUXILIARY) review = self._parse_json_response(response) if not review: diff --git a/libs/openant-core/utilities/dynamic_tester/test_generator.py b/libs/openant-core/utilities/dynamic_tester/test_generator.py index c95b88a..9d76869 100644 --- a/libs/openant-core/utilities/dynamic_tester/test_generator.py +++ b/libs/openant-core/utilities/dynamic_tester/test_generator.py @@ -15,8 +15,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from utilities.llm_client import AnthropicClient, TokenTracker +from utilities.model_config import MODEL_AUXILIARY -SONNET_MODEL = "claude-sonnet-4-20250514" +SONNET_MODEL = MODEL_AUXILIARY # Map language strings to Dockerfile template names LANGUAGE_MAP = { diff --git a/libs/openant-core/utilities/finding_verifier.py b/libs/openant-core/utilities/finding_verifier.py index 2e66b7c..d4059b0 100644 --- a/libs/openant-core/utilities/finding_verifier.py +++ b/libs/openant-core/utilities/finding_verifier.py @@ -1,35 +1,28 @@ """ -Stage 2 Finding Verifier (Enhanced) - -Stage 2 of the two-stage vulnerability analysis pipeline. -Uses Opus with tool access to validate Stage 1 assessments by exploring -the codebase - searching function usages, reading definitions, and -tracing call paths. - -Key Improvements: - 1. Explicit vulnerability definitions (exploitable NOW vs dangerous design) - 2. Required exploit path tracing (entry point -> sink) - 3. Consistency cross-check for similar code patterns - 4. Structured output with exploit_path field - 5. Batch verification with consistency validation - -The verifier asks: "Can an attacker exploit this NOW in the current codebase?" -It validates by tracing the complete exploit path from attacker input to sink. - -Available Tools: - - search_usages: Find where a function is called - - search_definitions: Find where a function is defined - - read_function: Get full function code by ID - - list_functions: List all functions in a file - - finish: Complete verification with verdict and exploit path +Stage 2 Finding Verifier (Enhanced, SDK-native) + +Stage 2 of the two-stage vulnerability analysis pipeline. Validates Stage 1 +assessments by letting Claude Code explore the codebase with its native +Read/Grep/Glob/Bash tools to trace exploit paths from attacker input to sink. + +This module used to drive a manual tool-dispatch loop against the `anthropic` +SDK (search_usages / search_definitions / read_function / list_functions / +finish). That loop has been replaced with a single SDK-native call to +`run_native_verification` from `utilities.llm_client`, which delegates to the +Claude Agent SDK. Rate-limit handling is centralised in +`utilities.llm_client._run_query` and surfaces via +`utilities.sdk_errors.RateLimitError`. Classes: - VerificationResult: Dataclass containing verdict, exploit path, explanation - FindingVerifier: Main verifier class with verify_result() and verify_batch() methods + ExploitPath: Structured exploit path analysis. + VerificationResult: Dataclass containing verdict, exploit path, explanation. + ConsistencyCheckResult: Result from cross-pattern consistency check. + FindingVerifier: Main verifier class with verify_result() and verify_batch() methods. """ import json import logging +import os import re import sys import threading @@ -38,21 +31,27 @@ from dataclasses import dataclass, field from typing import Callable, Optional -import anthropic - -from .llm_client import TokenTracker, get_global_tracker +from .llm_client import ( + AnthropicClient, + TokenTracker, + get_global_tracker, + run_native_verification, +) from .rate_limiter import get_rate_limiter # Null logger that discards all messages (used when no logger provided) _null_logger = logging.getLogger("null_verifier") _null_logger.addHandler(logging.NullHandler()) + from .agentic_enhancer.repository_index import RepositoryIndex -from .agentic_enhancer.tools import ToolExecutor +from .model_config import MODEL_PRIMARY from prompts.verification_prompts import ( + VERIFICATION_JSON_SCHEMA, VERIFICATION_SYSTEM_PROMPT, + get_consistency_check_prompt, + get_native_claude_verification_prompt, get_verification_prompt, get_verification_system_prompt, - get_consistency_check_prompt ) # Import application context type for type hints @@ -62,125 +61,13 @@ ApplicationContext = None -VERIFIER_MODEL = "claude-opus-4-6" -MAX_ITERATIONS = 20 -MAX_TOKENS_PER_RESPONSE = 4096 - - -# Enhanced finish tool with exploit_path structure -VERIFICATION_TOOLS = [ - { - "name": "search_usages", - "description": "Search for all places where a function is called/used in the codebase. Use this to trace how attacker input flows through the code.", - "input_schema": { - "type": "object", - "properties": { - "function_name": { - "type": "string", - "description": "Name of the function to find usages of" - } - }, - "required": ["function_name"] - } - }, - { - "name": "search_definitions", - "description": "Search for where a function is defined. Use this to understand what a function does.", - "input_schema": { - "type": "object", - "properties": { - "function_name": { - "type": "string", - "description": "Name of the function to find definition of" - } - }, - "required": ["function_name"] - } - }, - { - "name": "read_function", - "description": "Read the full source code of a function by its ID. Use this to analyze function behavior.", - "input_schema": { - "type": "object", - "properties": { - "function_id": { - "type": "string", - "description": "Function identifier in format 'file/path.ts:functionName'" - } - }, - "required": ["function_id"] - } - }, - { - "name": "list_functions", - "description": "List all functions defined in a specific file.", - "input_schema": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Path to the file relative to repository root" - } - }, - "required": ["file_path"] - } - }, - { - "name": "finish", - "description": "Complete the verification with your verdict and exploit path analysis.", - "input_schema": { - "type": "object", - "properties": { - "agree": { - "type": "boolean", - "description": "Whether you agree with Stage 1's assessment" - }, - "correct_finding": { - "type": "string", - "enum": ["safe", "protected", "bypassable", "vulnerable", "inconclusive"], - "description": "The correct finding based on exploit path analysis" - }, - "exploit_path": { - "type": "object", - "description": "Analysis of the exploit path from attacker input to sink", - "properties": { - "entry_point": { - "type": ["string", "null"], - "description": "Where attacker input enters (null if none found)" - }, - "data_flow": { - "type": "array", - "items": {"type": "string"}, - "description": "Steps showing how data flows from entry to sink" - }, - "sink_reached": { - "type": "boolean", - "description": "Whether attacker-controlled data reaches the vulnerable operation" - }, - "attacker_control_at_sink": { - "type": "string", - "enum": ["full", "partial", "none"], - "description": "Level of attacker control at the dangerous operation" - }, - "path_broken_at": { - "type": ["string", "null"], - "description": "Where/why the exploit path breaks (null if complete)" - } - } - }, - "explanation": { - "type": "string", - "description": "Detailed explanation of your analysis" - }, - "security_weakness": { - "type": ["string", "null"], - "description": "Any dangerous patterns that exist but aren't currently exploitable (optional)" - } - }, - "required": ["agree", "correct_finding", "explanation"] - } - } -] +VERIFIER_MODEL = MODEL_PRIMARY +# Budget ceiling per-finding for the native SDK verification call. The SDK +# will halt multi-turn exploration if cumulative cost exceeds this. +MAX_BUDGET_USD_PER_FINDING = 0.30 +# Hard timeout per-finding (seconds). Passed through for API compat; the +# SDK's own message loop governs actual wall-clock behaviour. +MAX_VERIFICATION_TIMEOUT_S = 600 @dataclass @@ -221,6 +108,7 @@ class VerificationResult: total_tokens: int exploit_path: Optional[ExploitPath] = None security_weakness: Optional[str] = None + raw_response: Optional[str] = None # Full SDK response text (not serialized) def to_dict(self) -> dict: result = { @@ -255,7 +143,7 @@ def to_dict(self) -> dict: class FindingVerifier: - """Validates Stage 1 assessments using Opus with tool access.""" + """Validates Stage 1 assessments using Claude Code's native tools via the Agent SDK.""" def __init__( self, @@ -264,16 +152,24 @@ def __init__( verbose: bool = False, app_context: "ApplicationContext" = None, logger: logging.Logger = None, - client: "anthropic.Anthropic | None" = None, + output_dir: str = None, ): self.index = index self.tracker = tracker or get_global_tracker() self.verbose = verbose self.app_context = app_context - self.tool_executor = ToolExecutor(index) - self.client = client or anthropic.Anthropic(max_retries=5) self.logger = logger or _null_logger self._use_logger = logger is not None + self.output_dir = output_dir + # Single-turn client used for the consistency cross-check step + # (which does not need multi-turn native tool use). + self._consistency_client = AnthropicClient( + model=VERIFIER_MODEL, tracker=self.tracker, + ) + + # ------------------------------------------------------------------ + # Logging + # ------------------------------------------------------------------ def _log(self, level: str, msg: str, **extras): """Log a message, using logger if available, otherwise print if verbose.""" @@ -281,9 +177,60 @@ def _log(self, level: str, msg: str, **extras): log_func = getattr(self.logger, level, self.logger.info) log_func(msg, extra=extras) elif self.verbose: - # Fallback to print for CLI usage + # Fallback to print for CLI usage (stderr to avoid corrupting JSON stdout) suffix = " ".join(f"{k}={v}" for k, v in extras.items() if v is not None) - print(f" {msg} {suffix}" if suffix else f" {msg}") + print(f" {msg} {suffix}" if suffix else f" {msg}", + file=sys.stderr, flush=True) + + def _save_explanation(self, route_key: str, verification: "VerificationResult"): + """Save verification explanation to a file in the output directory.""" + if not self.output_dir: + return + verify_dir = os.path.join(self.output_dir, "verify_explanations") + try: + os.makedirs(verify_dir, exist_ok=True) + except OSError as e: + print(f"[Verify] Could not create explanation dir for {route_key}: {e}", + file=sys.stderr, flush=True) + return + + # Sanitize route_key for use as filename + safe_name = re.sub(r'[\\/:*?"<>|]', '_', route_key) + filepath = os.path.join(verify_dir, f"{safe_name}.md") + + lines = [f"# {route_key}\n"] + lines.append(f"**Verdict:** {verification.correct_finding}") + lines.append(f"**Agrees with Stage 1:** {verification.agree}\n") + if verification.exploit_path: + ep = verification.exploit_path + lines.append("## Exploit Path\n") + if ep.entry_point: + lines.append(f"**Entry point:** {ep.entry_point}\n") + if ep.data_flow: + lines.append("**Data flow:**") + for step in ep.data_flow: + lines.append(f"1. {step}") + lines.append("") + lines.append(f"**Sink reached:** {ep.sink_reached}") + lines.append(f"**Attacker control at sink:** {ep.attacker_control_at_sink}") + if ep.path_broken_at: + lines.append(f"**Path broken at:** {ep.path_broken_at}") + lines.append("") + lines.append("## Explanation\n") + lines.append(verification.explanation) + if verification.security_weakness: + lines.append(f"\n## Security Weakness\n\n{verification.security_weakness}") + + try: + with open(filepath, "w", encoding="utf-8") as f: + f.write("\n".join(lines) + "\n") + except OSError as e: + print(f"[Verify] Could not save explanation for {route_key}: {e}", + file=sys.stderr, flush=True) + + # ------------------------------------------------------------------ + # Single-finding verification (native SDK call) + # ------------------------------------------------------------------ def verify_result( self, @@ -296,136 +243,126 @@ def verify_result( """ Validate a Stage 1 assessment with exploit path tracing. + Delegates to the Claude Agent SDK with native tools (Read, Grep, + Glob, Bash). Rate-limit handling is centralised in + `utilities.llm_client._run_query`. + Args: - code: The code that was assessed - finding: Stage 1's finding - attack_vector: Stage 1's attack vector - reasoning: Stage 1's reasoning - files_included: Optional list of files in context + code: The code that was assessed. + finding: Stage 1's finding. + attack_vector: Stage 1's attack vector. + reasoning: Stage 1's reasoning. + files_included: Optional list of files in context. Returns: - VerificationResult with verdict, exploit path, and explanation + VerificationResult with verdict, exploit path, and explanation. """ - user_prompt = get_verification_prompt( + repo_path = str(self.index.repo_path) if getattr(self.index, "repo_path", None) else None + if not repo_path: + self._log("warning", "No repo_path available for native SDK verification") + return VerificationResult( + agree=True, + correct_finding=finding, + explanation="No repo_path available for native SDK verification", + iterations=0, + total_tokens=0, + ) + + user_prompt = get_native_claude_verification_prompt( code=code, finding=finding, attack_vector=attack_vector, reasoning=reasoning, files_included=files_included, - app_context=self.app_context + app_context=self.app_context, ) - - # Get system prompt with app context if available system_prompt = get_verification_system_prompt(self.app_context) - messages = [{"role": "user", "content": user_prompt}] - iterations = 0 - total_input_tokens = 0 - total_output_tokens = 0 - - while iterations < MAX_ITERATIONS: - iterations += 1 + # Respect the global rate limiter. _run_query_sync will also raise + # utilities.sdk_errors.RateLimitError (and notify the limiter) if the + # SDK reports a rate-limit mid-flight. + get_rate_limiter().wait_if_needed() - self._log("debug", f"Iteration {iterations}", iterations=iterations) + # Lazy import to keep the module loadable when claude_agent_sdk is + # not installed (e.g. parse-only commands). + try: + from claude_agent_sdk import ClaudeSDKError + except ImportError: # pragma: no cover + ClaudeSDKError = () # type: ignore[assignment] - # Wait if we're in a global backoff period - rate_limiter = get_rate_limiter() - rate_limiter.wait_if_needed() + try: + result = run_native_verification( + prompt=user_prompt, + system=system_prompt, + model=VERIFIER_MODEL, + repo_path=repo_path, + json_schema=VERIFICATION_JSON_SCHEMA, + max_budget_usd=MAX_BUDGET_USD_PER_FINDING, + timeout=MAX_VERIFICATION_TIMEOUT_S, + ) + except (RuntimeError, FileNotFoundError, TimeoutError, ClaudeSDKError) as exc: + # Process-level failures (SDK subprocess died, CLI missing, + # connection error, JSON decode failure, etc.) fall through to a + # conservative "agree" verdict so the pipeline does not abort on + # a single bad finding. ClaudeSDKError is the SDK's umbrella base + # class — catches CLINotFoundError, CLIConnectionError, + # ProcessError, CLIJSONDecodeError. Rate-limit errors come + # through as utilities.sdk_errors.RateLimitError (not a subclass + # of ClaudeSDKError) and propagate unmodified so caller + # backoff/retry logic can see them. + print(f"[Verify] Native SDK verification failed: {exc}", + file=sys.stderr, flush=True) + return VerificationResult( + agree=True, + correct_finding=finding, + explanation=f"Verification failed: {exc}", + iterations=0, + total_tokens=0, + ) - try: - response = self.client.messages.create( - model=VERIFIER_MODEL, - max_tokens=MAX_TOKENS_PER_RESPONSE, - system=system_prompt, - tools=VERIFICATION_TOOLS, - messages=messages - ) - except anthropic.RateLimitError as exc: - # Report to global rate limiter so all workers back off - retry_after = float(exc.response.headers.get("retry-after", 0)) - get_rate_limiter().report_rate_limit(retry_after) - raise - - total_input_tokens += response.usage.input_tokens - total_output_tokens += response.usage.output_tokens - - assistant_content = response.content - stop_reason = response.stop_reason - - # If model finished without calling finish tool, try to parse response - if stop_reason == "end_turn": - result = self._try_parse_text_response( - assistant_content, finding, iterations, - total_input_tokens, total_output_tokens - ) - if result: - return result - - # Default: agree with Stage 1 - return VerificationResult( - agree=True, - correct_finding=finding, - explanation="Verification incomplete", - iterations=iterations, - total_tokens=total_input_tokens + total_output_tokens - ) - - # Process tool calls - tool_results = [] - finish_result = None - - for block in assistant_content: - if block.type == "tool_use": - tool_name = block.name - tool_input = block.input - tool_use_id = block.id - - self._log("debug", f"Tool call: {tool_name}") - - if tool_name == "finish": - finish_result = tool_input - tool_results.append({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": json.dumps({"status": "complete"}) - }) - break - else: - result = self.tool_executor.execute(tool_name, tool_input) - tool_results.append({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": json.dumps(result) - }) - - if finish_result: - self.tracker.record_call( - model=VERIFIER_MODEL, - input_tokens=total_input_tokens, - output_tokens=total_output_tokens - ) - return self._parse_finish_result( - finish_result, finding, iterations, - total_input_tokens + total_output_tokens - ) - - messages.append({"role": "assistant", "content": assistant_content}) - messages.append({"role": "user", "content": tool_results}) - - # Max iterations reached + input_tokens = result.get("input_tokens", 0) + output_tokens = result.get("output_tokens", 0) + total_tokens = input_tokens + output_tokens self.tracker.record_call( model=VERIFIER_MODEL, - input_tokens=total_input_tokens, - output_tokens=total_output_tokens + input_tokens=input_tokens, + output_tokens=output_tokens, + cost_usd=result.get("cost_usd"), ) + + raw_text = result.get("text", "") or "" + + # Preferred path: structured JSON (either from SDK structured_output + # or embedded in the message text). + parsed = self._extract_json(raw_text) + if parsed and "correct_finding" in parsed: + vr = self._parse_finish_result(parsed, finding, 0, total_tokens) + vr.raw_response = raw_text + return vr + + # Fallback: try to extract a verdict from free-form text. + fallback = self._parse_freetext_verdict(raw_text, finding) + if fallback: + vr = self._parse_finish_result(fallback, finding, 0, total_tokens) + vr.raw_response = raw_text + return vr + + # Final fallback: conservative agree. + print(f"[Verify] Could not parse SDK response: {raw_text[:500]}", + file=sys.stderr, flush=True) return VerificationResult( agree=True, correct_finding=finding, - explanation="Max iterations reached", - iterations=iterations, - total_tokens=total_input_tokens + total_output_tokens + explanation="Could not parse verification response", + iterations=0, + total_tokens=total_tokens, + raw_response=raw_text, ) + # ------------------------------------------------------------------ + # Batch verification (parallel, checkpoint-aware) — upstream API + # ------------------------------------------------------------------ + def verify_batch( self, results: list, @@ -442,8 +379,8 @@ def verify_batch( Supports checkpoint/resume via the checkpoint parameter. Args: - results: List of Stage 1 results to verify - code_by_route: Dict mapping route_key to code + results: List of Stage 1 results to verify. + code_by_route: Dict mapping route_key to code. progress_callback: Optional callback(unit_id, detail, unit_elapsed) called after each finding is verified. workers: Number of parallel workers (default: 10). @@ -452,7 +389,7 @@ def verify_batch( loading with the number of restored units. Returns: - Updated results with verification and consistency check + Updated results with verification and consistency check. """ total = len(results) @@ -611,9 +548,14 @@ def _verify_one(self, result, code_by_route): unit_id=route_key, total_tokens=verification.total_tokens, iterations=verification.iterations) + # Optionally save explanation to disk (for run-level debugging) + if self.output_dir and verification.explanation: + self._save_explanation(route_key, verification) + except Exception as e: detail = "error" - print(f"[Verify] ERROR {route_key}: {type(e).__name__}: {e}", file=sys.stderr, flush=True) + print(f"[Verify] ERROR {route_key}: {type(e).__name__}: {e}", + file=sys.stderr, flush=True) unit_elapsed = time.monotonic() - unit_start usage = self.tracker.get_unit_usage() @@ -684,6 +626,10 @@ def _verify_batch_parallel(self, results, code_by_route, progress_callback, work return executor.shutdown(wait=False) + # ------------------------------------------------------------------ + # Consistency cross-check + # ------------------------------------------------------------------ + def _check_consistency( self, results: list, @@ -763,9 +709,6 @@ def _has_conclusive_exploit_path(self, result: dict) -> bool: - sink_reached = false (attacker data doesn't reach the sink) - attacker_control_at_sink = "none" (no control at sink) - path_broken_at is set (explicit explanation of where path breaks) - - These findings are based on detailed code analysis and should not be - overridden by superficial pattern matching. """ verification = result.get("verification", {}) @@ -825,49 +768,43 @@ def _resolve_inconsistency( code_by_route: dict ) -> Optional[ConsistencyCheckResult]: """ - Use LLM to resolve inconsistent verdicts for similar code patterns. + Use a single-turn LLM call to resolve inconsistent verdicts for + similar code patterns. + + Rate-limit handling: `_run_query` inside `AnthropicClient.analyze_sync` + raises `utilities.sdk_errors.RateLimitError` and notifies the global + rate limiter automatically. We don't need to catch it here — callers + higher up retry as appropriate. """ prompt = get_consistency_check_prompt(group, code_by_route) - try: - # Wait if we're in a global backoff period - rate_limiter = get_rate_limiter() - rate_limiter.wait_if_needed() + # Respect the global rate limiter before dispatching. + get_rate_limiter().wait_if_needed() - response = self.client.messages.create( - model=VERIFIER_MODEL, - max_tokens=MAX_TOKENS_PER_RESPONSE, + try: + text = self._consistency_client.analyze_sync( + prompt, system="You are checking verdict consistency across similar code patterns.", - messages=[{"role": "user", "content": prompt}] ) - - self.tracker.record_call( - model=VERIFIER_MODEL, - input_tokens=response.usage.input_tokens, - output_tokens=response.usage.output_tokens - ) - - # Parse response - text = response.content[0].text if response.content else "" - result = self._parse_json_from_text(text) - - if result: - return ConsistencyCheckResult( - pattern_identified=result.get("pattern_identified", "unknown"), - consistent_verdict=result.get("consistent_verdict", "inconclusive"), - findings_updated=result.get("findings_to_update", []), - explanation=result.get("explanation", "") - ) - - except anthropic.RateLimitError as e: - # Report to global rate limiter so all workers back off - retry_after = float(e.response.headers.get("retry-after", 0)) - get_rate_limiter().report_rate_limit(retry_after) - self._log("error", f"Consistency resolution rate limited", error=str(e)) except Exception as e: + # Non-rate-limit failure — log and leave the group unresolved. self._log("error", f"Consistency resolution failed", error=str(e)) + return None - return None + result = self._parse_json_from_text(text) + if not result: + return None + + return ConsistencyCheckResult( + pattern_identified=result.get("pattern_identified", "unknown"), + consistent_verdict=result.get("consistent_verdict", "inconclusive"), + findings_updated=result.get("findings_to_update", []), + explanation=result.get("explanation", ""), + ) + + # ------------------------------------------------------------------ + # Result / response parsing + # ------------------------------------------------------------------ def _parse_finish_result( self, @@ -876,8 +813,7 @@ def _parse_finish_result( iterations: int, total_tokens: int ) -> VerificationResult: - """Parse the finish tool result into VerificationResult.""" - # Parse exploit path if present + """Parse the finish-tool-style dict (or structured JSON) into VerificationResult.""" exploit_path = None if "exploit_path" in finish_result and finish_result["exploit_path"]: ep = finish_result["exploit_path"] @@ -899,32 +835,8 @@ def _parse_finish_result( security_weakness=finish_result.get("security_weakness") ) - def _try_parse_text_response( - self, - assistant_content: list, - original_finding: str, - iterations: int, - total_input_tokens: int, - total_output_tokens: int - ) -> Optional[VerificationResult]: - """Try to parse a text response as JSON.""" - for block in assistant_content: - if hasattr(block, 'text'): - result = self._parse_json_from_text(block.text) - if result: - self.tracker.record_call( - model=VERIFIER_MODEL, - input_tokens=total_input_tokens, - output_tokens=total_output_tokens - ) - return self._parse_finish_result( - result, original_finding, iterations, - total_input_tokens + total_output_tokens - ) - return None - def _parse_json_from_text(self, text: str) -> Optional[dict]: - """Extract JSON object from text, with LLM correction fallback.""" + """Extract a JSON object from text, with LLM correction fallback.""" try: start = text.find('{') end = text.rfind('}') + 1 @@ -937,7 +849,7 @@ def _parse_json_from_text(self, text: str) -> Optional[dict]: if text.strip(): try: from utilities.json_corrector import JSONCorrector - corrector = JSONCorrector(self.client) + corrector = JSONCorrector(self._consistency_client) corrected = corrector.attempt_correction(text) if corrected.get("verdict") != "ERROR": corrected["json_corrected"] = True @@ -945,3 +857,89 @@ def _parse_json_from_text(self, text: str) -> Optional[dict]: except Exception: pass return None + + @staticmethod + def _extract_json(text: str) -> Optional[dict]: + """Extract a JSON object from text without LLM fallback. + + Tries: + 1. Parse the entire text as JSON + 2. Extract JSON from a ```json code block + 3. Find the outermost { ... } pair + """ + if not text: + return None + text = text.strip() + + # Try parsing the whole thing + try: + return json.loads(text) + except (json.JSONDecodeError, ValueError): + pass + + # Try extracting from ```json ... ``` code block + json_block = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL) + if json_block: + try: + return json.loads(json_block.group(1)) + except json.JSONDecodeError: + pass + + # Try finding outermost braces + start = text.find('{') + end = text.rfind('}') + 1 + if start >= 0 and end > start: + try: + return json.loads(text[start:end]) + except json.JSONDecodeError: + pass + + return None + + @staticmethod + def _parse_freetext_verdict(text: str, original_finding: str) -> Optional[dict]: + """Extract a verdict from free-text response when JSON parsing fails. + + Looks for keywords like PROTECTED, SAFE, VULNERABLE in the response + and constructs a result dict. + """ + if not text: + return None + text_lower = text.lower() + + # Determine verdict from common patterns + correct_finding = None + agree = None + + if "disagree" in text_lower: + agree = False + elif "agree" in text_lower: + agree = True + + for verdict in ["vulnerable", "bypassable", "protected", "safe", "inconclusive"]: + patterns = [ + rf'(?:verdict|finding|correct_finding|conclusion)[:\s]*\**{verdict}\**', + rf'\*\*{verdict.upper()}\*\*', + ] + for pattern in patterns: + if re.search(pattern, text, re.IGNORECASE): + correct_finding = verdict + break + if correct_finding: + break + + if not correct_finding: + return None + + if agree is None: + agree = correct_finding == original_finding + + explanation = text[:500].strip() + if len(text) > 500: + explanation += "..." + + return { + "agree": agree, + "correct_finding": correct_finding, + "explanation": explanation, + } diff --git a/libs/openant-core/utilities/ground_truth_challenger.py b/libs/openant-core/utilities/ground_truth_challenger.py index b0ad1db..bb808ed 100644 --- a/libs/openant-core/utilities/ground_truth_challenger.py +++ b/libs/openant-core/utilities/ground_truth_challenger.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from .llm_client import AnthropicClient +from .model_config import MODEL_AUXILIARY @dataclass @@ -209,7 +210,7 @@ class GroundTruthChallenger: 2. Validate false negatives - did the model miss something, or is the ground truth wrong? """ - def __init__(self, client: AnthropicClient, model: str = "claude-sonnet-4-20250514"): + def __init__(self, client: AnthropicClient, model: str = MODEL_AUXILIARY): """ Initialize the challenger. diff --git a/libs/openant-core/utilities/json_corrector.py b/libs/openant-core/utilities/json_corrector.py index dd35cda..30002ca 100644 --- a/libs/openant-core/utilities/json_corrector.py +++ b/libs/openant-core/utilities/json_corrector.py @@ -16,6 +16,7 @@ from typing import Optional from .llm_client import AnthropicClient +from .model_config import MODEL_AUXILIARY def get_json_extraction_prompt(raw_response: str) -> str: @@ -83,7 +84,7 @@ def extract_json_with_llm( # Use Sonnet for extraction (faster/cheaper) llm_response = client.analyze_sync( prompt, - model="claude-sonnet-4-20250514", + model=MODEL_AUXILIARY, max_tokens=2048 ) return _parse_json_response(llm_response) diff --git a/libs/openant-core/utilities/llm_client.py b/libs/openant-core/utilities/llm_client.py index ea356bf..5f90015 100644 --- a/libs/openant-core/utilities/llm_client.py +++ b/libs/openant-core/utilities/llm_client.py @@ -10,31 +10,259 @@ Usage: from utilities.llm_client import AnthropicClient, get_global_tracker - client = AnthropicClient(model="claude-opus-4-20250514") + client = AnthropicClient() # uses MODEL_DEFAULT response = client.analyze_sync("Analyze this code...") tracker = get_global_tracker() print(f"Total cost: ${tracker.total_cost_usd:.4f}") """ +import json import os +import sys import threading from typing import Optional -import anthropic from dotenv import load_dotenv +# Load .env once at module level. load_dotenv() mutates os.environ which is +# not thread-safe under concurrent client construction; runs at import time so +# tests that need to control env should mock os.environ or patch before import. +load_dotenv() + from .rate_limiter import get_rate_limiter +from .model_config import MODEL_PRIMARY, MODEL_AUXILIARY, MODEL_DEFAULT # Pricing per million tokens (as of December 2024) MODEL_PRICING = { - "claude-opus-4-20250514": {"input": 15.00, "output": 75.00}, - "claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00}, + MODEL_PRIMARY: {"input": 15.00, "output": 75.00}, + MODEL_AUXILIARY: {"input": 3.00, "output": 15.00}, # Fallback for unknown models (use Sonnet pricing as conservative estimate) "default": {"input": 3.00, "output": 15.00} } +# --------------------------------------------------------------------------- +# SDK helpers (lazy imports to avoid breaking non-LLM commands like parse) +# --------------------------------------------------------------------------- + +def _build_env() -> dict: + """Build the env dict for ClaudeAgentOptions.""" + env = {} + local_mode = os.getenv("OPENANT_LOCAL_CLAUDE", "").lower() == "true" + + # Pass API key only if not in local mode (local session auth takes precedence) + if not local_mode: + api_key = os.getenv("ANTHROPIC_API_KEY") + if api_key: + env["ANTHROPIC_API_KEY"] = api_key + + # Always forward CLAUDE_CONFIG_DIR — selects which Claude Code + # profile/session to use (e.g. .claude-k) + config_dir = os.getenv("CLAUDE_CONFIG_DIR") + if config_dir: + env["CLAUDE_CONFIG_DIR"] = config_dir + # Prevent "nested session" error when running from within Claude Code + env["CLAUDECODE"] = "" + return env + + +def _build_options(model, system=None, max_turns=1, allowed_tools=None, **kwargs): + """Build ClaudeAgentOptions with auth and config.""" + from claude_agent_sdk import ClaudeAgentOptions + return ClaudeAgentOptions( + model=model, + system_prompt=system, + max_turns=max_turns, + allowed_tools=allowed_tools or [], + permission_mode="bypassPermissions", + env=_build_env(), + **kwargs, + ) + + +async def _run_query(prompt, options, label=""): + """Run a query via ClaudeSDKClient and return the response text. + + Uses the async context manager approach (NOT the query() generator) + to avoid anyio cancel-scope errors. + + If the SDK reports an API error via AssistantMessage.error, this raises + the corresponding openant.utilities.sdk_errors exception class. Rate + limits additionally notify the GlobalRateLimiter so all workers back off. + + Args: + prompt: The prompt text to send. + options: ClaudeAgentOptions instance. + label: Optional label for verbose log lines (e.g. unit ID). + + Returns: + Tuple of (ResultMessage, last_assistant_text). + + Raises: + utilities.sdk_errors.OpenAntLLMError: subclass matching the SDK's + reported AssistantMessageError (RateLimitError, AuthError, + BillingError, InvalidRequestError, ServerError, UnknownLLMError). + """ + import time + from claude_agent_sdk import ClaudeSDKClient, AssistantMessage, ResultMessage + + from .sdk_errors import error_from_kind, RateLimitError + + _verbose = os.getenv("OPENANT_VERBOSE", "").lower() == "true" + tag = f"[SDK:{label}]" if label else "[SDK]" + t0 = time.monotonic() + + if _verbose: + print(f" {tag} Connecting (model={options.model}, max_turns={options.max_turns})...", + file=sys.stderr, flush=True) + + async with ClaudeSDKClient(options=options) as client: + t_connect = time.monotonic() + if _verbose: + print(f" {tag} Connected ({t_connect - t0:.1f}s). Sending query...", + file=sys.stderr, flush=True) + + await client.query(prompt) + t_query = time.monotonic() + if _verbose: + print(f" {tag} Query sent ({t_query - t_connect:.1f}s). Receiving messages...", + file=sys.stderr, flush=True) + + last_assistant_text = "" + result_message = None + turn_count = 0 + async for message in client.receive_response(): + if isinstance(message, AssistantMessage): + # SDK signals API-level errors on AssistantMessage.error. + # Raise the typed openant exception; rate-limit errors also + # notify the global limiter so all threads back off. + msg_error = getattr(message, "error", None) + if msg_error: + text_for_context = "" + for block in getattr(message, "content", []): + if type(block).__name__ == "TextBlock": + text_for_context = block.text + break + exc = error_from_kind(msg_error, text_for_context) + if isinstance(exc, RateLimitError): + # No retry-after from SDK; default backoff applies. + get_rate_limiter().report_rate_limit(0) + raise exc + + turn_count += 1 + parts = [] + tool_names = [] + for block in getattr(message, "content", []): + if type(block).__name__ == "TextBlock": + parts.append(block.text) + elif type(block).__name__ == "ToolUseBlock": + tool_names.append(getattr(block, "name", "?")) + if parts: + last_assistant_text = "\n".join(parts) + if _verbose: + elapsed = time.monotonic() - t0 + tools_str = f" tools=[{','.join(tool_names)}]" if tool_names else "" + text_preview = parts[0][:80] if parts else "" + print(f" {tag} Turn {turn_count} ({elapsed:.1f}s){tools_str} {text_preview}", + file=sys.stderr, flush=True) + elif isinstance(message, ResultMessage): + result_message = message + if _verbose: + elapsed = time.monotonic() - t0 + cost = getattr(message, "total_cost_usd", None) or 0 + turns = getattr(message, "num_turns", "?") + print(f" {tag} Done ({elapsed:.1f}s total, {turns} turns, ${cost:.4f})", + file=sys.stderr, flush=True) + + return result_message, last_assistant_text + + +def _run_query_sync(prompt, options, label=""): + """Synchronous wrapper around _run_query. + + Creates a fresh event loop per call so concurrent threads don't + interfere with each other. + """ + import asyncio + return asyncio.run(_run_query(prompt, options, label=label)) + + +# --------------------------------------------------------------------------- +# Native verification (SDK-backed, multi-turn with native tools) +# --------------------------------------------------------------------------- + +def run_native_verification( + prompt: str, + system: str, + model: str, + repo_path: str, + json_schema: dict = None, + max_budget_usd: float = 0.30, + timeout: int = 600, +) -> dict: + """ + Run Claude Code in native multi-turn mode for verification. + + Lets Claude Code use its own native tools (Read, Grep, Glob, Bash) + to explore the codebase and produce a structured JSON verdict. + + Args: + prompt: The verification prompt. + system: System prompt for the session. + model: Model to use (typically MODEL_PRIMARY from utilities.model_config). + repo_path: Path to the repository root. + json_schema: Optional JSON schema for structured output. + max_budget_usd: Maximum dollar budget per finding. + timeout: Not used with SDK (kept for API compat). + + Returns: + Dict with 'text', 'input_tokens', 'output_tokens', 'cost_usd'. + """ + extra = {} + if json_schema: + extra["output_format"] = {"type": "json_schema", "schema": json_schema} + + options = _build_options( + model=model, + system=system, + max_turns=None, # Let SDK decide (multi-turn) + allowed_tools=["Read", "Grep", "Glob", "Bash"], + add_dirs=[repo_path], + cwd=repo_path, + max_budget_usd=max_budget_usd, + **extra, + ) + + result_message, last_text = _run_query_sync(prompt, options) + + if result_message is None: + raise RuntimeError("SDK returned no ResultMessage") + + usage = result_message.usage or {} + cost = result_message.total_cost_usd or 0.0 + + # With json_schema, the verdict is in structured_output + structured = getattr(result_message, "structured_output", None) + if structured and isinstance(structured, dict): + text = json.dumps(structured) + elif result_message.result: + text = result_message.result + else: + text = last_text + + return { + "text": text, + "input_tokens": usage.get("input_tokens", 0), + "output_tokens": usage.get("output_tokens", 0), + "cost_usd": cost, + } + + +# --------------------------------------------------------------------------- +# Token tracking +# --------------------------------------------------------------------------- + class TokenTracker: """ Tracks token usage and costs across LLM calls. @@ -53,12 +281,24 @@ def reset(self): self.total_output_tokens = 0 self.total_cost_usd = 0.0 + def restore_from(self, totals: dict): + """Restore counters from a previously saved totals dict (e.g. checkpoint). + + Args: + totals: Dict with total_input_tokens, total_output_tokens, total_cost_usd. + """ + with self._lock: + self.total_input_tokens = totals.get("total_input_tokens", 0) + self.total_output_tokens = totals.get("total_output_tokens", 0) + self.total_cost_usd = totals.get("total_cost_usd", 0.0) + @property def total_tokens(self) -> int: """Total tokens (input + output).""" return self.total_input_tokens + self.total_output_tokens - def record_call(self, model: str, input_tokens: int, output_tokens: int) -> dict: + def record_call(self, model: str, input_tokens: int, output_tokens: int, + cost_usd: float = None) -> dict: """ Record a single LLM call. @@ -66,17 +306,23 @@ def record_call(self, model: str, input_tokens: int, output_tokens: int) -> dict model: Model identifier input_tokens: Number of input tokens output_tokens: Number of output tokens + cost_usd: Actual cost from the SDK (`ResultMessage.total_cost_usd`). + If provided, used directly. Otherwise we fall back to the + static pricing table — kept only for legacy code paths. Returns: Dict with call details including cost """ - # Get pricing for model - pricing = MODEL_PRICING.get(model, MODEL_PRICING["default"]) + if cost_usd is not None: + total_cost = cost_usd + else: + # Get pricing for model + pricing = MODEL_PRICING.get(model, MODEL_PRICING["default"]) - # Calculate cost (pricing is per million tokens) - input_cost = (input_tokens / 1_000_000) * pricing["input"] - output_cost = (output_tokens / 1_000_000) * pricing["output"] - total_cost = input_cost + output_cost + # Calculate cost (pricing is per million tokens) + input_cost = (input_tokens / 1_000_000) * pricing["input"] + output_cost = (output_tokens / 1_000_000) * pricing["output"] + total_cost = input_cost + output_cost call_record = { "model": model, @@ -181,34 +427,112 @@ def reset_global_tracker(): _global_tracker.reset() +_auth_mode_logged = False +_auth_mode_lock = threading.Lock() + + +def _log_auth_mode(): + """Print which auth mode the SDK will use (once per process). + + Without a guard, every AnthropicClient construction re-prints — under + parallel workers (10+ threads) this floods stderr. The lock+flag pair + makes this exactly-once even when many threads race AnthropicClient(). + """ + global _auth_mode_logged + if _auth_mode_logged: + return + with _auth_mode_lock: + if _auth_mode_logged: + return + local_mode = os.getenv("OPENANT_LOCAL_CLAUDE", "").lower() == "true" + api_key = os.getenv("ANTHROPIC_API_KEY") + config_dir = os.getenv("CLAUDE_CONFIG_DIR") + if local_mode: + if config_dir: + print(f"Using Claude Agent SDK (local session, config: {config_dir})", file=sys.stderr) + else: + print("Using Claude Agent SDK (local session)", file=sys.stderr) + elif api_key: + print("Using Claude Agent SDK (API key mode)", file=sys.stderr) + else: + # No auth path explicitly configured. The SDK will fall back to + # whatever the local `claude` CLI finds; warn the user so an + # eventual auth failure is not a surprise. + print("Warning: neither ANTHROPIC_API_KEY nor OPENANT_LOCAL_CLAUDE=true set; " + "Claude Agent SDK will rely on the local `claude` CLI's own auth state.", + file=sys.stderr) + _auth_mode_logged = True + + class AnthropicClient: """ - Client for Anthropic Claude API. + Client for Claude API with automatic token tracking. - Uses Claude Opus 4 for vulnerability analysis. - Tracks token usage and costs for all calls. + Routes all calls through the Claude Agent SDK. Authentication is + automatic: uses ANTHROPIC_API_KEY if set, otherwise the local + Claude Code session. + + The class name is kept as ``AnthropicClient`` to avoid churn in callers; + historically this wrapped the ``anthropic`` package but now wraps the + Claude Agent SDK helpers in this module. """ - def __init__(self, model: str = "claude-opus-4-20250514", tracker: TokenTracker = None): + def __init__(self, model: str = MODEL_DEFAULT, tracker: TokenTracker = None): """ - Initialize the Anthropic client. + Initialize the client. Args: - model: Model identifier. Default is Claude Opus 4 (highest capability). - Use "claude-sonnet-4-20250514" for cost-effective option. + model: Model identifier. Defaults to MODEL_DEFAULT from model_config + (Claude Opus 4, highest capability). Use MODEL_AUXILIARY for + the cost-effective option. tracker: Optional TokenTracker instance. Uses global tracker if not provided. """ - load_dotenv() - - api_key = os.getenv("ANTHROPIC_API_KEY") - if not api_key: - raise ValueError("ANTHROPIC_API_KEY not found in environment") - - self.client = anthropic.Anthropic(api_key=api_key, max_retries=5) + _log_auth_mode() self.model = model self.tracker = tracker or _global_tracker self.last_call = None # Store last call details + def _call(self, model: str, prompt: str, system: str = None, + max_tokens: int = 8192) -> str: + """Make an SDK call, track usage, and return the response text. + + Rate-limit handling: ``_run_query`` raises ``sdk_errors.RateLimitError`` + and notifies the GlobalRateLimiter; we don't need to catch it here. + + Note on ``max_tokens``: ``ClaudeAgentOptions`` (Claude Agent SDK 0.1.x) + does not expose a top-level max-output-tokens setting — the SDK lets + the model decide. The parameter is accepted for backwards compatibility + with the pre-migration anthropic-SDK signature but is currently a no-op. + See PR #51 for context. + """ + # Wait if we're in a global backoff period + get_rate_limiter().wait_if_needed() + + options = _build_options( + model=model, + system=system, + max_turns=1, + allowed_tools=[], + ) + result_message, last_text = _run_query_sync(prompt, options) + + # Extract usage and cost from SDK + usage = (result_message.usage or {}) if result_message else {} + input_tokens = usage.get("input_tokens", 0) + output_tokens = usage.get("output_tokens", 0) + cost_usd = (result_message.total_cost_usd or 0.0) if result_message else None + + self.last_call = self.tracker.record_call( + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + cost_usd=cost_usd, + ) + + if result_message and result_message.result: + return result_message.result + return last_text + async def analyze(self, prompt: str, max_tokens: int = 8192) -> str: """ Send a prompt to Claude and get a response. @@ -220,34 +544,10 @@ async def analyze(self, prompt: str, max_tokens: int = 8192) -> str: Returns: Response text from Claude """ - # Wait if we're in a global backoff period - rate_limiter = get_rate_limiter() - rate_limiter.wait_if_needed() - - try: - message = self.client.messages.create( - model=self.model, - max_tokens=max_tokens, - messages=[ - {"role": "user", "content": prompt} - ] - ) - except anthropic.RateLimitError as exc: - # Report to global rate limiter so all workers back off - retry_after = float(exc.response.headers.get("retry-after", 0)) - get_rate_limiter().report_rate_limit(retry_after) - raise - - # Track token usage - self.last_call = self.tracker.record_call( - model=self.model, - input_tokens=message.usage.input_tokens, - output_tokens=message.usage.output_tokens - ) - - return message.content[0].text + return self._call(self.model, prompt, max_tokens=max_tokens) - def analyze_sync(self, prompt: str, max_tokens: int = 8192, model: str = None, system: str = None) -> str: + def analyze_sync(self, prompt: str, max_tokens: int = 8192, + model: str = None, system: str = None) -> str: """ Synchronous version of analyze. @@ -260,38 +560,8 @@ def analyze_sync(self, prompt: str, max_tokens: int = 8192, model: str = None, s Returns: Response text from Claude """ - used_model = model or self.model - - kwargs = { - "model": used_model, - "max_tokens": max_tokens, - "messages": [ - {"role": "user", "content": prompt} - ] - } - if system: - kwargs["system"] = system - - # Wait if we're in a global backoff period - rate_limiter = get_rate_limiter() - rate_limiter.wait_if_needed() - - try: - message = self.client.messages.create(**kwargs) - except anthropic.RateLimitError as exc: - # Report to global rate limiter so all workers back off - retry_after = float(exc.response.headers.get("retry-after", 0)) - get_rate_limiter().report_rate_limit(retry_after) - raise - - # Track token usage - self.last_call = self.tracker.record_call( - model=used_model, - input_tokens=message.usage.input_tokens, - output_tokens=message.usage.output_tokens - ) - - return message.content[0].text + return self._call(model or self.model, prompt, system=system, + max_tokens=max_tokens) def get_last_call(self) -> Optional[dict]: """ @@ -319,18 +589,3 @@ def get_session_summary(self) -> dict: Dict with totals and calls list """ return self.tracker.get_summary() - - def get_usage(self, message) -> dict: - """ - Extract token usage from a message response. - - Args: - message: Response from messages.create() - - Returns: - Dict with input_tokens, output_tokens - """ - return { - "input_tokens": message.usage.input_tokens, - "output_tokens": message.usage.output_tokens - } diff --git a/libs/openant-core/utilities/model_config.py b/libs/openant-core/utilities/model_config.py new file mode 100644 index 0000000..ebef3b7 --- /dev/null +++ b/libs/openant-core/utilities/model_config.py @@ -0,0 +1,15 @@ +""" +Central model configuration. + +All Claude model IDs are defined here. Change these to update which models +are used across the entire pipeline. +""" + +# Primary model — high capability, used for critical analysis and verification +MODEL_PRIMARY = "claude-opus-4-20250514" + +# Auxiliary model — cost-effective, used for enhancement, consistency, context +MODEL_AUXILIARY = "claude-sonnet-4-20250514" + +# Default fallback when no model is specified +MODEL_DEFAULT = MODEL_PRIMARY diff --git a/libs/openant-core/utilities/rate_limiter.py b/libs/openant-core/utilities/rate_limiter.py index 3416f1b..c7facd1 100644 --- a/libs/openant-core/utilities/rate_limiter.py +++ b/libs/openant-core/utilities/rate_limiter.py @@ -15,10 +15,15 @@ rate_limiter = get_rate_limiter() rate_limiter.wait_if_needed() - # When catching RateLimitError - except anthropic.RateLimitError as e: - retry_after = float(e.response.headers.get("retry-after", 0)) - rate_limiter.report_rate_limit(retry_after) + # Rate-limit detection happens centrally in llm_client._run_query, which + # raises utilities.sdk_errors.RateLimitError and calls + # rate_limiter.report_rate_limit(0) on every rate-limit event. Callers + # that need to attach state before re-raising: + from utilities.sdk_errors import RateLimitError + try: + ... + except RateLimitError: + # report_rate_limit already fired in _run_query raise """ diff --git a/libs/openant-core/utilities/sdk_errors.py b/libs/openant-core/utilities/sdk_errors.py new file mode 100644 index 0000000..9b58e07 --- /dev/null +++ b/libs/openant-core/utilities/sdk_errors.py @@ -0,0 +1,141 @@ +"""SDK error taxonomy. + +Replaces the `anthropic.*Error` hierarchy that scattered through the codebase +pre-migration. The Claude Agent SDK surfaces API-level errors through +`AssistantMessage.error: AssistantMessageError | None`, a Literal of +"authentication_failed", "billing_error", "rate_limit", "invalid_request", +"server_error", "unknown". We map each one onto an exception class so callers +can `except RateLimitError` instead of inspecting message fields. + +Process-level errors (CLI not found, connection issues, JSON decode failures) +come through `claude_agent_sdk`'s own exception types and propagate up +unwrapped — they're not API errors. +""" + +from typing import Any + + +class OpenAntLLMError(Exception): + """Base class for LLM-layer errors that originated inside the model turn. + + Subclasses correspond 1:1 to `AssistantMessageError` literal values. Raised + from `utilities.llm_client._run_query` when an AssistantMessage carries an + `error` field. + """ + + error_kind: str = "unknown" + + def __init__(self, message: str = "", **kwargs: Any): + super().__init__(message) + self.message = message + # Allow callers to attach arbitrary state (e.g. agent iteration counts). + for key, value in kwargs.items(): + setattr(self, key, value) + + +class AuthError(OpenAntLLMError): + """Maps to AssistantMessageError == "authentication_failed".""" + + error_kind = "authentication_failed" + + +class BillingError(OpenAntLLMError): + """Maps to AssistantMessageError == "billing_error".""" + + error_kind = "billing_error" + + +class RateLimitError(OpenAntLLMError): + """Maps to AssistantMessageError == "rate_limit". + + The SDK does not surface a `retry-after` value; callers that feed + GlobalRateLimiter should pass 0 and let the default backoff apply. + """ + + error_kind = "rate_limit" + + +class InvalidRequestError(OpenAntLLMError): + """Maps to AssistantMessageError == "invalid_request".""" + + error_kind = "invalid_request" + + +class ServerError(OpenAntLLMError): + """Maps to AssistantMessageError == "server_error". Transient; safe to retry.""" + + error_kind = "server_error" + + +class UnknownLLMError(OpenAntLLMError): + """Maps to AssistantMessageError == "unknown". Fallback for unrecognized values.""" + + error_kind = "unknown" + + +# Dispatch table for AssistantMessageError literal -> exception class. +_ERROR_KIND_TO_CLASS: dict[str, type[OpenAntLLMError]] = { + "authentication_failed": AuthError, + "billing_error": BillingError, + "rate_limit": RateLimitError, + "invalid_request": InvalidRequestError, + "server_error": ServerError, + "unknown": UnknownLLMError, +} + + +def error_from_kind(kind: str, message: str = "") -> OpenAntLLMError: + """Construct the right exception subclass for an AssistantMessageError value. + + Unknown kinds fall back to UnknownLLMError. + """ + cls = _ERROR_KIND_TO_CLASS.get(kind, UnknownLLMError) + return cls(message or f"SDK reported error: {kind}") + + +def classify_error(exc: BaseException) -> dict: + """Return a diagnostic dict for an OpenAntLLMError (or any exception). + + Shape matches what `utilities.context_enhancer` logged pre-migration, so + existing callers can swap to this function without reshaping downstream + consumers: + + { + "type": "rate_limit" | "connection" | "timeout" | "api_status" | "unknown", + "exception_class": "RateLimitError", + "message": "...", + ... + } + + The pre-migration shape also carried `status_code`, `request_id`, and + `retry_after` extracted from anthropic response headers. The SDK does not + surface any of those, so we drop them. + """ + info: dict = { + "type": "unknown", + "exception_class": type(exc).__name__, + "message": str(exc), + } + + if isinstance(exc, RateLimitError): + info["type"] = "rate_limit" + elif isinstance(exc, AuthError): + info["type"] = "auth" + elif isinstance(exc, BillingError): + info["type"] = "billing" + elif isinstance(exc, ServerError): + info["type"] = "server" + elif isinstance(exc, InvalidRequestError): + info["type"] = "invalid_request" + elif isinstance(exc, UnknownLLMError): + info["type"] = "unknown" + else: + # Non-LLM error (process-level, SDK framework, caller bug). Leave + # "type" as "unknown" but the class name still identifies it. + pass + + agent_state = getattr(exc, "agent_state", None) + if agent_state: + info["agent_state"] = agent_state + + return info diff --git a/libs/openant-core/utilities/stage1_consistency.py b/libs/openant-core/utilities/stage1_consistency.py index 96b54b3..0aabeab 100644 --- a/libs/openant-core/utilities/stage1_consistency.py +++ b/libs/openant-core/utilities/stage1_consistency.py @@ -15,10 +15,11 @@ from dataclasses import dataclass from utilities.llm_client import AnthropicClient, TokenTracker +from utilities.model_config import MODEL_AUXILIARY # Use a cheaper/faster model for consistency checks -CONSISTENCY_MODEL = "claude-sonnet-4-20250514" +CONSISTENCY_MODEL = MODEL_AUXILIARY MAX_TOKENS = 4096 @@ -267,22 +268,15 @@ def _resolve_stage1_inconsistency( prompt = get_stage1_consistency_prompt(group, code_by_route) try: - response = client.messages.create( + # AnthropicClient now wraps the Claude Agent SDK; analyze_sync handles + # token + cost tracking against the shared TokenTracker. + text = client.analyze_sync( + prompt, model=CONSISTENCY_MODEL, max_tokens=MAX_TOKENS, system="You are checking verdict consistency across similar code patterns in a security analysis.", - messages=[{"role": "user", "content": prompt}] ) - tracker.record_call( - model=CONSISTENCY_MODEL, - input_tokens=response.usage.input_tokens, - output_tokens=response.usage.output_tokens - ) - - # Parse response - text = response.content[0].text if response.content else "" - # Extract JSON from response json_match = re.search(r'\{[\s\S]*\}', text) if json_match: