diff --git a/apps/openant-cli/cmd/scan.go b/apps/openant-cli/cmd/scan.go index 2a646b5..deb2390 100644 --- a/apps/openant-cli/cmd/scan.go +++ b/apps/openant-cli/cmd/scan.go @@ -51,6 +51,7 @@ var ( scanDiffBase string scanPR int scanDiffScope string + scanLLMReachability bool ) func init() { @@ -79,6 +80,7 @@ func registerScanFlags(cmd *cobra.Command) { cmd.Flags().StringVar(&scanDiffBase, "diff-base", "", "Incremental mode: filter pipeline to units overlapping diff vs this ref (e.g. origin/main, HEAD~5)") cmd.Flags().IntVar(&scanPR, "pr", 0, "Incremental mode against a GitHub PR number (requires gh; mutex with --diff-base)") cmd.Flags().StringVar(&scanDiffScope, "diff-scope", "changed_functions", "Diff scope: changed_files, changed_functions, callers") + cmd.Flags().BoolVar(&scanLLMReachability, "llm-reachability", false, "Enable the LLM reachability review stage (Opus). Surfaces additional entry points and external-input sites beyond the structural pass. Off by default — enabling this may incur additional LLM cost (one Opus call per ~25 units).") } func runScan(cmd *cobra.Command, args []string) { @@ -197,6 +199,9 @@ func runScan(cmd *cobra.Command, args []string) { if manifestPath != "" { pyArgs = append(pyArgs, "--diff-manifest", manifestPath) } + if scanLLMReachability { + pyArgs = append(pyArgs, "--llm-reachability") + } // Pass repository metadata from project context so reports don't show // [NOT PROVIDED] placeholders. diff --git a/libs/openant-core/core/llm_reachability.py b/libs/openant-core/core/llm_reachability.py new file mode 100644 index 0000000..dccda34 --- /dev/null +++ b/libs/openant-core/core/llm_reachability.py @@ -0,0 +1,435 @@ +""" +LLM-based reachability review stage. + +A complementary, advisory pass over the parsed dataset that uses a strong +LLM (Opus by default) to surface additional reachability signals beyond +what the structural reachability analysis catches: + +- Likely entry points the structural pass missed (framework-specific + handlers, plugin registrations, lambdas, message handlers, etc.). +- External content ingestion sites (HTTP request bodies, file/network + reads, env/argv, IPC channels). +- Cross-process or async data flow indicators. + +Signals are **advisory only** — they may PROMOTE a unit's reachability +(e.g. set ``is_entry_point = True`` for a unit the structural pass didn't +flag), but they never DEMOTE a unit that structural analysis already +kept. This matches the "complements, not replaces" intent in issue #17. + +Output: +- ``analyze_reachability(...)`` returns a list of ``ReachabilitySignal`` + dicts. +- ``apply_signals(dataset, signals)`` mutates the dataset in place so each + unit gains an ``llm_reachability_signals`` field, and high-confidence + ``entry_point`` signals set ``is_entry_point = True`` on the target unit. + +Usage: + from core.llm_reachability import analyze_reachability, apply_signals + + signals = analyze_reachability(dataset, app_context=app_ctx) + apply_signals(dataset, signals) +""" + +from __future__ import annotations + +import json +import re +import sys +from dataclasses import dataclass, field, asdict +from typing import Any, Callable, Dict, List, Optional + + +# Models — matches the convention in core/analyzer.py / utilities/llm_client.py. +MODEL_PRIMARY = "claude-opus-4-20250514" +MODEL_SECONDARY = "claude-sonnet-4-20250514" + + +# Maximum number of units to send in a single LLM call. Larger batches save +# round trips but risk token-limit errors and degraded recall. +DEFAULT_BATCH_SIZE = 25 + +# Maximum bytes of code we send per unit. Trimmed to keep prompts tractable. +MAX_CODE_BYTES = 1500 + + +# --------------------------------------------------------------------------- +# Public dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class ReachabilitySignal: + """A single LLM-emitted reachability signal for one unit. + + ``kind`` is one of: + - ``entry_point`` — unit is itself a likely entry point. + - ``external_input`` — unit receives external/untrusted input. + - ``cross_process`` — unit participates in async / cross-process data flow. + + ``confidence`` is one of ``high``, ``medium``, ``low``. + """ + + unit_id: str + kind: str + confidence: str + reason: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +# --------------------------------------------------------------------------- +# Prompt construction +# --------------------------------------------------------------------------- + + +PROMPT_TEMPLATE = """You are a senior application-security engineer auditing +a codebase for REACHABILITY signals — places where untrusted input can enter +the system. A previous structural pass has already flagged some entry points +and reachable units; your job is to surface ADDITIONAL signals it may have +missed (framework-specific handlers, plugin/CLI registrations, message +queues, async tasks, file/network ingestion, env/argv, IPC, etc.). + +Be conservative. Only emit a signal when the code clearly indicates one of: + + - "entry_point" — this unit is itself a likely entry point reachable + by an external actor (HTTP/CLI/queue/stream handler, + scheduled task, framework lifecycle hook, etc.). + - "external_input" — this unit reads or accepts data from an external + source (request body, file, socket, env, argv, stdin, + child-process output, untrusted message, etc.). + - "cross_process" — this unit dispatches or receives data across async + / process / queue boundaries (so taint may flow in + or out via a path the static call-graph misses). + +Confidence levels: + - "high" — the code unambiguously demonstrates the pattern. + - "medium" — the pattern is present but partially obscured. + - "low" — only suggestive; emit only if you'd want a human reviewer. + +Return STRICT JSON of the form: + + {{ + "signals": [ + {{"unit_id": "", "kind": "entry_point|external_input|cross_process", + "confidence": "high|medium|low", "reason": ""}}, + ... + ] + }} + +If no signals apply, return ``{{"signals": []}}``. Do NOT wrap the JSON in +markdown fences. Do NOT include any prose outside the JSON. + +{app_context_block} + +UNITS TO REVIEW (existing structural flags shown for context — your job is to +ADD signals beyond what those already capture): + +{units_block} +""" + + +def _build_app_context_block(app_context: Optional[Dict[str, Any]]) -> str: + """Render an optional app-context section for the prompt.""" + if not app_context: + return "APPLICATION CONTEXT: (none provided)" + try: + ctx_json = json.dumps(app_context, indent=2, sort_keys=True) + except (TypeError, ValueError): + ctx_json = str(app_context) + return f"APPLICATION CONTEXT:\n{ctx_json}" + + +def _trim_code(code: str) -> str: + """Truncate a code blob so the batch fits in a reasonable prompt window.""" + if not code: + return "" + if len(code) <= MAX_CODE_BYTES: + return code + return code[:MAX_CODE_BYTES] + "\n# ...[truncated]" + + +def _unit_for_prompt(unit: Dict[str, Any]) -> Dict[str, Any]: + """Project a unit into the minimal shape we send to the LLM.""" + code_blob = "" + code = unit.get("code") or {} + if isinstance(code, dict): + code_blob = code.get("primary_code") or code.get("source") or "" + elif isinstance(code, str): + code_blob = code + + return { + "unit_id": unit.get("id", ""), + "unit_type": unit.get("unit_type", "function"), + "is_entry_point": bool(unit.get("is_entry_point", False)), + "reachable_from_entry": unit.get("reachable_from_entry"), + "code": _trim_code(code_blob), + } + + +def build_prompt( + units: List[Dict[str, Any]], + app_context: Optional[Dict[str, Any]] = None, +) -> str: + """Assemble the LLM prompt for a batch of units.""" + app_block = _build_app_context_block(app_context) + payload = [_unit_for_prompt(u) for u in units] + units_block = json.dumps(payload, indent=2) + return PROMPT_TEMPLATE.format( + app_context_block=app_block, + units_block=units_block, + ) + + +# --------------------------------------------------------------------------- +# Response parsing +# --------------------------------------------------------------------------- + + +_VALID_KINDS = {"entry_point", "external_input", "cross_process"} +_VALID_CONFIDENCES = {"high", "medium", "low"} + + +def _extract_json(text: str) -> Optional[Dict[str, Any]]: + """Best-effort JSON extraction from a model response. + + Strips common markdown fences and falls back to the first ``{...}`` + block in the text. Returns ``None`` if nothing valid is found. + """ + if not text: + return None + cleaned = text.strip() + + # Strip ```json ... ``` or ``` ... ``` fences. + fence = re.match( + r"^```(?:json)?\s*(?P.*?)\s*```\s*$", + cleaned, + re.DOTALL | re.IGNORECASE, + ) + if fence: + cleaned = fence.group("body").strip() + + try: + return json.loads(cleaned) + except json.JSONDecodeError: + pass + + # Fall back to the first balanced JSON object in the response. + start = cleaned.find("{") + end = cleaned.rfind("}") + if start != -1 and end > start: + snippet = cleaned[start : end + 1] + try: + return json.loads(snippet) + except json.JSONDecodeError: + return None + return None + + +def parse_response( + response_text: str, + valid_unit_ids: Optional[set] = None, + on_error: Optional[Callable[[str], None]] = None, +) -> List[ReachabilitySignal]: + """Parse a single LLM response into validated ``ReachabilitySignal``s. + + Malformed entries are skipped (not raised); the optional ``on_error`` + callback receives a one-line description per skipped item, useful for + logging. + """ + log = on_error or (lambda msg: print(f"[LLMReach] {msg}", file=sys.stderr)) + + data = _extract_json(response_text) + if not isinstance(data, dict): + log("malformed response: not a JSON object — skipping batch") + return [] + + raw_signals = data.get("signals") + if not isinstance(raw_signals, list): + log("malformed response: 'signals' missing or not a list — skipping batch") + return [] + + out: List[ReachabilitySignal] = [] + for idx, item in enumerate(raw_signals): + if not isinstance(item, dict): + log(f"signal #{idx}: not an object — skipped") + continue + unit_id = item.get("unit_id") + kind = item.get("kind") + confidence = item.get("confidence") + reason = item.get("reason", "") + + if not isinstance(unit_id, str) or not unit_id: + log(f"signal #{idx}: missing unit_id — skipped") + continue + if kind not in _VALID_KINDS: + log(f"signal #{idx}: invalid kind {kind!r} — skipped") + continue + if confidence not in _VALID_CONFIDENCES: + log(f"signal #{idx}: invalid confidence {confidence!r} — skipped") + continue + if valid_unit_ids is not None and unit_id not in valid_unit_ids: + log(f"signal #{idx}: unknown unit_id {unit_id!r} — skipped") + continue + + out.append( + ReachabilitySignal( + unit_id=unit_id, + kind=kind, + confidence=confidence, + reason=str(reason)[:500], + ) + ) + return out + + +# --------------------------------------------------------------------------- +# Main entry points +# --------------------------------------------------------------------------- + + +def _chunk(items: List[Any], size: int) -> List[List[Any]]: + """Split ``items`` into batches of ``size``. + + A non-positive ``size`` is treated as "everything in one batch" so callers + that disable batching never hit a NameError or empty-output surprise. + """ + if size <= 0: + return [list(items)] if items else [] + return [items[i : i + size] for i in range(0, len(items), size)] + + +def analyze_reachability( + dataset: Dict[str, Any], + app_context: Optional[Dict[str, Any]] = None, + client: Any = None, + model: str = MODEL_PRIMARY, + batch_size: int = DEFAULT_BATCH_SIZE, + max_units: Optional[int] = None, + on_error: Optional[Callable[[str], None]] = None, +) -> List[ReachabilitySignal]: + """Run the LLM reachability review stage over a parsed dataset. + + Args: + dataset: Parsed dataset with a ``units`` list, as produced by the + parser stage. Units are expected to expose ``id``, ``code``, and + optionally ``is_entry_point`` / ``reachable_from_entry``. + app_context: Optional application context dict; included in the + prompt to help the model reason about expected entry points + (e.g. ``{"application_type": "web_app"}``). + client: An object exposing ``analyze_sync(prompt, max_tokens=..., + model=...)``. If omitted, an :class:`AnthropicClient` is + instantiated lazily. + model: Model id to use (defaults to Opus). + batch_size: Units per LLM call. + max_units: Optional cap on how many units to review. + on_error: Optional callback for parse/validation issues. + + Returns: + A flat list of :class:`ReachabilitySignal` for every unit the model + flagged. Unknown unit ids and malformed entries are filtered out. + """ + units = dataset.get("units") or [] + if max_units is not None and max_units >= 0: + units = units[:max_units] + if not units: + return [] + + if client is None: + # Lazy import so unit tests can stub this out without an API key. + from utilities.llm_client import AnthropicClient + + client = AnthropicClient(model=model) + + valid_ids = {u.get("id") for u in units if u.get("id")} + + signals: List[ReachabilitySignal] = [] + batches = _chunk(units, batch_size) + for i, batch in enumerate(batches): + prompt = build_prompt(batch, app_context=app_context) + try: + text = client.analyze_sync(prompt, max_tokens=4096, model=model) + except Exception as exc: # noqa: BLE001 — advisory stage; never crash pipeline + msg = f"batch {i + 1}/{len(batches)} failed: {exc}" + if on_error: + on_error(msg) + else: + print(f"[LLMReach] {msg}", file=sys.stderr) + continue + + parsed = parse_response( + text, valid_unit_ids=valid_ids, on_error=on_error + ) + signals.extend(parsed) + + return signals + + +# --------------------------------------------------------------------------- +# Signal application (promote-only) +# --------------------------------------------------------------------------- + + +# Confidences at or above this threshold promote ``entry_point`` signals to +# ``is_entry_point = True`` on the target unit. +_PROMOTE_ENTRY_POINT_AT = {"high"} + + +def apply_signals( + dataset: Dict[str, Any], + signals: List[ReachabilitySignal], +) -> Dict[str, int]: + """Merge LLM signals back into ``dataset`` (in place, promote-only). + + For each unit referenced by a signal: + - The signal is appended to a per-unit ``llm_reachability_signals`` list. + - If the signal kind is ``entry_point`` AND its confidence is in + :data:`_PROMOTE_ENTRY_POINT_AT`, the unit's ``is_entry_point`` field + is set to ``True`` (never set back to ``False``). + + Crucially, this never DEMOTES a unit. ``is_entry_point=True`` set by the + structural pass remains true regardless of what the LLM said. + + Returns a small summary dict:: + + { + "signals_applied": , + "entry_points_promoted": , + "units_touched": , + } + """ + units = dataset.get("units") or [] + by_id = {u.get("id"): u for u in units if u.get("id")} + + promoted = 0 + touched: set = set() + applied = 0 + + for sig in signals: + unit = by_id.get(sig.unit_id) + if unit is None: + continue + + existing = unit.setdefault("llm_reachability_signals", []) + existing.append(sig.to_dict()) + applied += 1 + touched.add(sig.unit_id) + + if ( + sig.kind == "entry_point" + and sig.confidence in _PROMOTE_ENTRY_POINT_AT + and not unit.get("is_entry_point", False) + ): + unit["is_entry_point"] = True + promoted += 1 + + return { + "signals_applied": applied, + "entry_points_promoted": promoted, + "units_touched": len(touched), + } + + +def signals_to_json(signals: List[ReachabilitySignal]) -> List[Dict[str, Any]]: + """Serialize a list of signals for JSON persistence.""" + return [s.to_dict() for s in signals] diff --git a/libs/openant-core/core/scanner.py b/libs/openant-core/core/scanner.py index f081352..f894bc6 100644 --- a/libs/openant-core/core/scanner.py +++ b/libs/openant-core/core/scanner.py @@ -59,6 +59,7 @@ def scan_repository( repo_url: str | None = None, commit_sha: str | None = None, diff_manifest: str | None = None, + llm_reachability: bool = False, ) -> ScanResult: """Scan a repository for vulnerabilities. @@ -106,6 +107,7 @@ def scan_repository( # Count total steps for progress display total_steps = _count_steps( generate_context, enhance, verify, generate_report, dynamic_test, + llm_reachability=llm_reachability, ) step_num = 0 @@ -174,7 +176,7 @@ def _step_label(name: str) -> str: # --------------------------------------------------------------- # Step 2: Application Context (optional) # --------------------------------------------------------------- - app_context_path = None + app_context_path: str | None = None if generate_context and HAS_APP_CONTEXT: print(_step_label("Generating application context..."), file=sys.stderr) @@ -205,6 +207,86 @@ def _step_label(name: str) -> str: result.skipped_steps.append("app-context") print(file=sys.stderr) + # --------------------------------------------------------------- + # Step 2.5: LLM Reachability review (optional, opt-in) + # --------------------------------------------------------------- + # Runs after parse + app-context and before enhance/analyze. Signals are + # advisory and PROMOTE-ONLY: they may flag additional entry points or + # external-input sites the structural pass missed, but never demote a + # unit that structural analysis already kept. Threading app_context into + # the LLM prompt helps the model reason about expected entry points + # (e.g. "this is a web_app, look for HTTP handlers"). + if llm_reachability: + from core.llm_reachability import ( + analyze_reachability, + apply_signals, + signals_to_json, + ) + + print(_step_label("Running LLM reachability review..."), file=sys.stderr) + + with step_context("llm-reachability", output_dir, inputs={ + "dataset_path": active_dataset_path, + "model": "opus", + }) as ctx: + try: + with open(active_dataset_path, encoding="utf-8") as f: + dataset = json.load(f) + except (OSError, json.JSONDecodeError) as exc: + print(f" WARNING: failed to load dataset: {exc}", file=sys.stderr) + ctx.summary = {"skipped": True, "reason": str(exc)} + dataset = None + + if dataset is not None: + app_ctx_payload = None + if app_context_path and os.path.exists(app_context_path): + try: + with open(app_context_path, encoding="utf-8") as f: + app_ctx_payload = json.load(f) + except (OSError, json.JSONDecodeError): + app_ctx_payload = None + + signals = analyze_reachability( + dataset=dataset, + app_context=app_ctx_payload, + max_units=limit, + ) + summary = apply_signals(dataset, signals) + + # Persist mutated dataset (so downstream stages see the + # promoted entry points and the per-unit signals). + with open(active_dataset_path, "w", encoding="utf-8") as f: + json.dump(dataset, f, indent=2) + + signals_path = os.path.join(output_dir, "llm_reachability.json") + with open(signals_path, "w", encoding="utf-8") as f: + json.dump( + {"signals": signals_to_json(signals)}, + f, + indent=2, + ) + + ctx.summary = { + "units_reviewed": len(dataset.get("units", [])), + "signals_added": summary["signals_applied"], + "entry_points_promoted": summary["entry_points_promoted"], + "units_touched": summary["units_touched"], + } + ctx.outputs = {"signals_path": signals_path} + + print( + f" LLM reachability: {summary['signals_applied']} signals, " + f"{summary['entry_points_promoted']} new entry points", + file=sys.stderr, + ) + + collected_step_reports.append( + _load_step_report(output_dir, "llm-reachability") + ) + else: + result.skipped_steps.append("llm-reachability") + print(file=sys.stderr) + # --------------------------------------------------------------- # Step 3: Enhance (optional) # --------------------------------------------------------------- @@ -522,6 +604,7 @@ def _count_steps( verify: bool, generate_report: bool, dynamic_test: bool, + llm_reachability: bool = False, ) -> int: """Count total steps for progress display (always includes parse, detect, build-output).""" count = 3 # parse + detect + build-output (always run) @@ -535,6 +618,8 @@ def _count_steps( count += 1 if dynamic_test: count += 1 + if llm_reachability: + count += 1 return count diff --git a/libs/openant-core/openant/cli.py b/libs/openant-core/openant/cli.py index b0ce345..99a6199 100644 --- a/libs/openant-core/openant/cli.py +++ b/libs/openant-core/openant/cli.py @@ -74,6 +74,7 @@ def cmd_scan(args): repo_url=getattr(args, "repo_url", None), commit_sha=getattr(args, "commit_sha", None), diff_manifest=getattr(args, "diff_manifest", None), + llm_reachability=getattr(args, "llm_reachability", False), ) scan_payload = result.to_dict() @@ -994,6 +995,15 @@ def main(): scan_p.add_argument("--backoff", type=int, default=30, help="Seconds to wait when rate-limited (default: 30)") scan_p.add_argument("--diff-manifest", help="Path to diff_manifest.json for incremental scanning") + scan_p.add_argument( + "--llm-reachability", + action="store_true", + dest="llm_reachability", + help="Enable the LLM reachability review stage (Opus). " + "Surfaces additional entry points and external-input sites " + "beyond the structural pass. Off by default — enabling this " + "may incur additional LLM cost (one Opus call per ~25 units).", + ) scan_p.set_defaults(func=cmd_scan) # --------------------------------------------------------------- diff --git a/libs/openant-core/tests/test_go_cli.py b/libs/openant-core/tests/test_go_cli.py index fc92113..b7d0f3f 100644 --- a/libs/openant-core/tests/test_go_cli.py +++ b/libs/openant-core/tests/test_go_cli.py @@ -79,6 +79,14 @@ def test_scan_help(self): output = result.stdout + result.stderr assert "pipeline" in output.lower() + def test_scan_help_advertises_llm_reachability(self): + """The opt-in --llm-reachability flag (issue #17) should be discoverable + from `openant scan --help`.""" + result = run_cli("scan", "--help") + assert result.returncode == 0 + output = result.stdout + result.stderr + assert "llm-reachability" in output.lower() + class TestParse: def test_parse_python_repo(self, sample_python_repo, tmp_path): diff --git a/libs/openant-core/tests/test_llm_reachability.py b/libs/openant-core/tests/test_llm_reachability.py new file mode 100644 index 0000000..627d084 --- /dev/null +++ b/libs/openant-core/tests/test_llm_reachability.py @@ -0,0 +1,476 @@ +"""Tests for the LLM reachability review stage (issue #17). + +The stage is opt-in and advisory: signals may PROMOTE a unit's +reachability but never demote one that the structural analysis kept. +These tests pin that behavior down with a fully mocked LLM client so they +run without network access or an API key. +""" + +from __future__ import annotations + +import json +from typing import List + +import pytest + +from core.llm_reachability import ( + ReachabilitySignal, + analyze_reachability, + apply_signals, + build_prompt, + parse_response, + signals_to_json, +) + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +class FakeClient: + """Minimal stand-in for AnthropicClient. + + Records calls and replays a fixed sequence of canned responses. + """ + + def __init__(self, responses: List[str]): + self._responses = list(responses) + self.calls: List[dict] = [] + + def analyze_sync(self, prompt: str, max_tokens: int = 4096, model: str = ""): + self.calls.append( + {"prompt": prompt, "max_tokens": max_tokens, "model": model} + ) + if not self._responses: + return '{"signals": []}' + return self._responses.pop(0) + + +def _make_unit(unit_id: str, code: str = "pass", **kw) -> dict: + unit = { + "id": unit_id, + "unit_type": kw.pop("unit_type", "function"), + "code": {"primary_code": code}, + } + unit.update(kw) + return unit + + +# --------------------------------------------------------------------------- +# parse_response +# --------------------------------------------------------------------------- + + +class TestParseResponse: + def test_parses_well_formed_signal(self): + text = json.dumps( + { + "signals": [ + { + "unit_id": "app.py:handler", + "kind": "entry_point", + "confidence": "high", + "reason": "Express handler", + } + ] + } + ) + sigs = parse_response(text, valid_unit_ids={"app.py:handler"}) + assert len(sigs) == 1 + assert sigs[0].unit_id == "app.py:handler" + assert sigs[0].kind == "entry_point" + assert sigs[0].confidence == "high" + assert "Express" in sigs[0].reason + + def test_strips_markdown_fences(self): + text = "```json\n" + json.dumps( + {"signals": [ + {"unit_id": "x.py:f", "kind": "external_input", + "confidence": "medium", "reason": "reads argv"}]} + ) + "\n```" + sigs = parse_response(text, valid_unit_ids={"x.py:f"}) + assert len(sigs) == 1 + assert sigs[0].kind == "external_input" + + def test_falls_back_to_first_object(self): + text = "Sure! Here you go:\n" + json.dumps( + {"signals": [ + {"unit_id": "a.py:g", "kind": "cross_process", + "confidence": "low", "reason": "queue"}]} + ) + "\nEnd." + sigs = parse_response(text, valid_unit_ids={"a.py:g"}) + assert len(sigs) == 1 + + def test_malformed_json_returns_empty(self): + errors: List[str] = [] + sigs = parse_response( + "not json at all", + valid_unit_ids={"x"}, + on_error=errors.append, + ) + assert sigs == [] + assert any("malformed" in e for e in errors) + + def test_invalid_kind_skipped(self): + text = json.dumps( + {"signals": [ + {"unit_id": "x.py:f", "kind": "garbage", + "confidence": "high", "reason": "n/a"}]} + ) + errors: List[str] = [] + sigs = parse_response( + text, valid_unit_ids={"x.py:f"}, on_error=errors.append + ) + assert sigs == [] + assert any("invalid kind" in e for e in errors) + + def test_unknown_unit_id_skipped(self): + text = json.dumps( + {"signals": [ + {"unit_id": "ghost.py:f", "kind": "entry_point", + "confidence": "high", "reason": "n/a"}]} + ) + errors: List[str] = [] + sigs = parse_response( + text, valid_unit_ids={"real.py:f"}, on_error=errors.append + ) + assert sigs == [] + + def test_signals_not_a_list_returns_empty(self): + text = json.dumps({"signals": "nope"}) + errors: List[str] = [] + sigs = parse_response(text, on_error=errors.append) + assert sigs == [] + + +# --------------------------------------------------------------------------- +# build_prompt / app_context threading +# --------------------------------------------------------------------------- + + +class TestBuildPrompt: + def test_includes_unit_ids_and_code(self): + units = [_make_unit("app.py:handler", code="def handler(): ...")] + prompt = build_prompt(units) + assert "app.py:handler" in prompt + assert "def handler()" in prompt + + def test_no_app_context_marker(self): + prompt = build_prompt([_make_unit("a:f")]) + assert "(none provided)" in prompt + + def test_includes_app_context_when_provided(self): + ctx = {"application_type": "web_app", "framework": "Express"} + prompt = build_prompt([_make_unit("a:f")], app_context=ctx) + assert "web_app" in prompt + assert "Express" in prompt + + def test_truncates_overly_long_code(self): + big = "x = 1\n" * 5000 + prompt = build_prompt([_make_unit("a:f", code=big)]) + assert "[truncated]" in prompt + + +# --------------------------------------------------------------------------- +# analyze_reachability — full call with a mocked client +# --------------------------------------------------------------------------- + + +class TestAnalyzeReachability: + def test_parses_signals_from_mocked_llm(self): + dataset = { + "units": [ + _make_unit("app.py:handler"), + _make_unit("util.py:helper"), + ] + } + canned = json.dumps( + { + "signals": [ + { + "unit_id": "app.py:handler", + "kind": "entry_point", + "confidence": "high", + "reason": "Express handler", + }, + { + "unit_id": "util.py:helper", + "kind": "external_input", + "confidence": "medium", + "reason": "reads file", + }, + ] + } + ) + client = FakeClient([canned]) + signals = analyze_reachability(dataset, client=client) + assert len(signals) == 2 + assert {s.kind for s in signals} == {"entry_point", "external_input"} + assert len(client.calls) == 1 + + def test_app_context_threaded_into_prompt(self): + dataset = {"units": [_make_unit("a:f")]} + client = FakeClient(['{"signals": []}']) + ctx = {"application_type": "web_app", "framework": "Flask"} + analyze_reachability(dataset, app_context=ctx, client=client) + assert "Flask" in client.calls[0]["prompt"] + assert "web_app" in client.calls[0]["prompt"] + + def test_malformed_response_handled_gracefully(self): + dataset = {"units": [_make_unit("a:f")]} + errors: List[str] = [] + client = FakeClient(["this is not JSON"]) + sigs = analyze_reachability( + dataset, client=client, on_error=errors.append + ) + assert sigs == [] + assert errors # at least one error logged + + def test_empty_dataset_returns_empty(self): + client = FakeClient([]) + sigs = analyze_reachability({"units": []}, client=client) + assert sigs == [] + assert client.calls == [] # no LLM calls when nothing to review + + def test_batch_size_chunks_units(self): + dataset = {"units": [_make_unit(f"a:{i}") for i in range(7)]} + client = FakeClient(['{"signals": []}'] * 5) + analyze_reachability(dataset, client=client, batch_size=3) + # 7 units / 3 per batch = 3 calls + assert len(client.calls) == 3 + + def test_non_positive_batch_size_uses_single_batch(self): + """``batch_size <= 0`` historically tripped a NameError. Guard the + contract: non-positive size collapses to a single batch covering all + units (and never raises).""" + dataset = {"units": [_make_unit(f"a:{i}") for i in range(4)]} + client = FakeClient(['{"signals": []}']) + analyze_reachability(dataset, client=client, batch_size=0) + assert len(client.calls) == 1 + + def test_client_exception_does_not_crash(self): + class Boom: + def analyze_sync(self, *a, **kw): + raise RuntimeError("api boom") + + errors: List[str] = [] + sigs = analyze_reachability( + {"units": [_make_unit("a:f")]}, + client=Boom(), + on_error=errors.append, + ) + assert sigs == [] + assert any("api boom" in e for e in errors) + + +# --------------------------------------------------------------------------- +# apply_signals — promote-only semantics +# --------------------------------------------------------------------------- + + +class TestApplySignals: + def test_high_confidence_entry_point_promotes(self): + dataset = {"units": [_make_unit("a:f", is_entry_point=False)]} + sigs = [ + ReachabilitySignal("a:f", "entry_point", "high", "framework hook") + ] + summary = apply_signals(dataset, sigs) + assert dataset["units"][0]["is_entry_point"] is True + assert summary["entry_points_promoted"] == 1 + assert summary["signals_applied"] == 1 + assert summary["units_touched"] == 1 + + def test_medium_confidence_does_not_promote(self): + dataset = {"units": [_make_unit("a:f", is_entry_point=False)]} + sigs = [ + ReachabilitySignal("a:f", "entry_point", "medium", "maybe") + ] + summary = apply_signals(dataset, sigs) + assert dataset["units"][0]["is_entry_point"] is False + assert summary["entry_points_promoted"] == 0 + # but the signal is still attached for the reviewer + assert summary["signals_applied"] == 1 + + def test_external_input_does_not_set_entry_point(self): + dataset = {"units": [_make_unit("a:f", is_entry_point=False)]} + sigs = [ + ReachabilitySignal("a:f", "external_input", "high", "argv") + ] + apply_signals(dataset, sigs) + # external_input never sets is_entry_point regardless of confidence + assert dataset["units"][0]["is_entry_point"] is False + + def test_does_not_demote_existing_entry_point(self): + """Crucial promote-only invariant: a unit the structural pass + already marked as an entry point must never be unmarked, even if + the LLM emits no signal (or a low-confidence one) for it.""" + dataset = {"units": [_make_unit("a:f", is_entry_point=True)]} + # Empty signal list — apply_signals must not flip the flag. + apply_signals(dataset, []) + assert dataset["units"][0]["is_entry_point"] is True + + # Even a stray "low" entry_point signal must not flip it back. + sigs = [ReachabilitySignal("a:f", "entry_point", "low", "weak")] + apply_signals(dataset, sigs) + assert dataset["units"][0]["is_entry_point"] is True + + def test_signal_attached_to_unit(self): + dataset = {"units": [_make_unit("a:f")]} + sigs = [ + ReachabilitySignal("a:f", "external_input", "medium", "reads stdin") + ] + apply_signals(dataset, sigs) + unit = dataset["units"][0] + assert "llm_reachability_signals" in unit + assert len(unit["llm_reachability_signals"]) == 1 + attached = unit["llm_reachability_signals"][0] + assert attached["kind"] == "external_input" + assert attached["reason"] == "reads stdin" + + def test_multiple_signals_accumulate_on_same_unit(self): + dataset = {"units": [_make_unit("a:f")]} + sigs = [ + ReachabilitySignal("a:f", "external_input", "medium", "argv"), + ReachabilitySignal("a:f", "cross_process", "low", "queue"), + ] + apply_signals(dataset, sigs) + attached = dataset["units"][0]["llm_reachability_signals"] + assert len(attached) == 2 + + def test_unknown_unit_id_skipped(self): + dataset = {"units": [_make_unit("a:f")]} + sigs = [ReachabilitySignal("ghost:x", "entry_point", "high", "n/a")] + summary = apply_signals(dataset, sigs) + assert summary["signals_applied"] == 0 + assert summary["entry_points_promoted"] == 0 + + +class TestSerialization: + def test_signals_to_json_roundtrip(self): + sigs = [ + ReachabilitySignal("a:f", "entry_point", "high", "r1"), + ReachabilitySignal("b:g", "external_input", "low", "r2"), + ] + out = signals_to_json(sigs) + assert isinstance(out, list) + assert all(isinstance(item, dict) for item in out) + # Round-trips through JSON cleanly. + json.loads(json.dumps(out)) + + +# --------------------------------------------------------------------------- +# CLI flag plumbing — mock scan_repository to confirm wiring without API +# --------------------------------------------------------------------------- + + +class TestCliPlumbing: + """Confirms that the --llm-reachability flag exists in scan --help and + that, by default (no flag), the LLM reachability path is not invoked. + + These tests exercise the Python CLI directly (no Go binary required), so + they always run in the basic pytest suite. + """ + + def test_flag_appears_in_scan_help(self, capsys): + from openant.cli import main + + with pytest.raises(SystemExit): + import sys + old = sys.argv + try: + sys.argv = ["openant", "scan", "--help"] + main() + finally: + sys.argv = old + out = capsys.readouterr().out + capsys.readouterr().err + assert "--llm-reachability" in out + + def test_default_does_not_invoke_llm_reachability(self, monkeypatch, tmp_path): + """When --llm-reachability is NOT passed, ``analyze_reachability`` in + the scanner module must not be called. + + We achieve this by monkey-patching ``scan_repository`` to a stub + that records its kwargs, then driving ``cmd_scan`` through it. + """ + captured = {} + + from openant import cli as cli_mod + + def fake_scan(**kwargs): + captured.update(kwargs) + from core.schemas import ScanResult + r = ScanResult(output_dir=str(tmp_path)) + return r + + monkeypatch.setattr( + "core.scanner.scan_repository", fake_scan, raising=True + ) + + # Drive cmd_scan via argparse + import argparse + ns = argparse.Namespace( + repo=str(tmp_path), + output=str(tmp_path / "out"), + language="auto", + level="reachable", + verify=False, + no_context=True, + no_enhance=True, + enhance_mode="agentic", + no_report=True, + dynamic_test=False, + no_skip_tests=False, + limit=None, + model="opus", + workers=1, + repo_name=None, + repo_url=None, + commit_sha=None, + backoff=30, + diff_manifest=None, + llm_reachability=False, + ) + rc = cli_mod.cmd_scan(ns) + # rc 0 or 1 acceptable; we only care about plumbing. + assert rc in (0, 1) + assert captured.get("llm_reachability") is False + + def test_flag_passes_through_when_set(self, monkeypatch, tmp_path): + captured = {} + from openant import cli as cli_mod + + def fake_scan(**kwargs): + captured.update(kwargs) + from core.schemas import ScanResult + return ScanResult(output_dir=str(tmp_path)) + + monkeypatch.setattr( + "core.scanner.scan_repository", fake_scan, raising=True + ) + + import argparse + ns = argparse.Namespace( + repo=str(tmp_path), + output=str(tmp_path / "out"), + language="auto", + level="reachable", + verify=False, + no_context=True, + no_enhance=True, + enhance_mode="agentic", + no_report=True, + dynamic_test=False, + no_skip_tests=False, + limit=None, + model="opus", + workers=1, + repo_name=None, + repo_url=None, + commit_sha=None, + backoff=30, + diff_manifest=None, + llm_reachability=True, + ) + cli_mod.cmd_scan(ns) + assert captured.get("llm_reachability") is True