diff --git a/.AUDIT/copilot-local-review.sh b/.AUDIT/copilot-local-review.sh new file mode 100755 index 000000000..46613fbac --- /dev/null +++ b/.AUDIT/copilot-local-review.sh @@ -0,0 +1,252 @@ +#!/usr/bin/env bash +set -uo pipefail +# Owner: Codex +# +# Private local gate: run GitHub Copilot CLI against the staged diff before a +# local commit. This is an early-warning review, not a replacement for the +# origin/upstream PR review gate. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +cd "$ROOT" || exit 2 + +if [[ -f "$ROOT/.venv/bin/activate" ]]; then + # shellcheck disable=SC1091 + source "$ROOT/.venv/bin/activate" +fi + +if [[ "${AUDIT_PRIVATE_GUARD:-1}" != "0" && -x "$SCRIPT_DIR/private-guard.sh" ]]; then + "$SCRIPT_DIR/private-guard.sh" --quiet || exit "$?" +fi + +usage() { + cat <<'EOF' +Usage: .AUDIT/copilot-local-review.sh [--cached|--worktree|--base ] [--advisory] [--max-diff-bytes ] + +Runs GitHub Copilot CLI against a local diff and blocks when Copilot reports +actionable findings. + +Modes: + --cached review staged changes; default and intended for pre-commit + --worktree review unstaged worktree changes + --base review changes from merge-base(, HEAD) to HEAD + --advisory always exit 0 after saving/reporting review output + +Environment: + AUDIT_COPILOT_MAX_DIFF_BYTES default 120000; local review blocks above this + AUDIT_COPILOT_REPORT_DIR default .AUDIT/reports + +Output contract: + Copilot must emit exactly one decision line: + LOCAL_COPILOT_REVIEW_DECISION: PASS + LOCAL_COPILOT_REVIEW_DECISION: BLOCK + +PASS means no actionable correctness/security/regression/test issue was found. +BLOCK or ambiguous output stops the commit gate. +EOF +} + +mode="cached" +base_ref="" +advisory=0 +max_diff_bytes="${AUDIT_COPILOT_MAX_DIFF_BYTES:-2000000}" + +while [[ $# -gt 0 ]]; do + case "$1" in + --cached) + mode="cached" + ;; + --worktree) + mode="worktree" + ;; + --base) + if [[ $# -lt 2 ]]; then + echo "[copilot-local-review] USAGE_ERROR: --base requires a ref" >&2 + exit 2 + fi + mode="base" + base_ref="$2" + shift + ;; + --base=*) + mode="base" + base_ref="${1#--base=}" + ;; + --advisory) + advisory=1 + ;; + --max-diff-bytes) + if [[ $# -lt 2 ]]; then + echo "[copilot-local-review] USAGE_ERROR: --max-diff-bytes requires a number" >&2 + exit 2 + fi + max_diff_bytes="$2" + shift + ;; + --max-diff-bytes=*) + max_diff_bytes="${1#--max-diff-bytes=}" + ;; + --help|-h) + usage + exit 0 + ;; + *) + echo "[copilot-local-review] USAGE_ERROR: unknown option '$1'" >&2 + usage >&2 + exit 2 + ;; + esac + shift +done + +case "$max_diff_bytes" in + ''|*[!0-9]*) + echo "[copilot-local-review] USAGE_ERROR: max diff bytes must be a non-negative integer" >&2 + exit 2 + ;; +esac + +if ! command -v copilot >/dev/null 2>&1; then + echo "[copilot-local-review] BLOCKED: GitHub Copilot CLI is not installed or not on PATH" >&2 + echo "[copilot-local-review] Install/authenticate Copilot CLI or set AUDIT_SKIP_LOCAL_COPILOT=1 for an explicit bypass." >&2 + exit 1 +fi + +diff_file="$(mktemp)" +stat_file="$(mktemp)" +trap 'rm -f "$diff_file" "$stat_file"' EXIT + +if [[ "$mode" == "cached" ]]; then + git diff --cached --stat >"$stat_file" + git diff --cached --no-ext-diff --binary --unified=80 >"$diff_file" +elif [[ "$mode" == "worktree" ]]; then + git diff --stat >"$stat_file" + git diff --no-ext-diff --binary --unified=80 >"$diff_file" +else + merge_base="$(git merge-base HEAD "$base_ref")" || { + echo "[copilot-local-review] GIT_CONTEXT_ERROR: could not merge-base HEAD and $base_ref" >&2 + exit 2 + } + git diff --stat "$merge_base..HEAD" >"$stat_file" + git diff --no-ext-diff --binary --unified=80 "$merge_base..HEAD" >"$diff_file" +fi + +if [[ ! -s "$diff_file" ]]; then + echo "[copilot-local-review] clean: no diff to review for mode=$mode" + exit 0 +fi + +if grep -Eq '^(Binary files |GIT binary patch)' "$diff_file"; then + echo "[copilot-local-review] BLOCKED: binary diff present; Copilot local text review cannot inspect it reliably" >&2 + exit 1 +fi + +diff_bytes="$(wc -c <"$diff_file" | tr -d '[:space:]')" +if (( max_diff_bytes > 0 && diff_bytes > max_diff_bytes )); then + echo "[copilot-local-review] BLOCKED: diff is ${diff_bytes} bytes, above local review limit ${max_diff_bytes}" >&2 + echo "[copilot-local-review] Split the commit or use origin PR review as the authoritative review surface." >&2 + exit 1 +fi + +report_dir="${AUDIT_COPILOT_REPORT_DIR:-$SCRIPT_DIR/reports}" +mkdir -p "$report_dir" +report_file="$report_dir/$(date +%Y%m%d-%H%M%S)-copilot-local-review.md" + +prompt_file="$(mktemp)" +trap 'rm -f "$diff_file" "$stat_file" "$prompt_file"' EXIT + +{ + cat <<'EOF' +You are GitHub Copilot reviewing a local staged diff before commit. + +Review only the supplied diff. Do not edit files. Do not run tools. Do not ask +questions. Focus on correctness, security, data loss, regression risk, broken +tests, missing tests for changed behavior, and user-visible behavior. Ignore +pure style unless it can create functional risk. + +Your response MUST include exactly one decision line: + +LOCAL_COPILOT_REVIEW_DECISION: PASS + +or: + +LOCAL_COPILOT_REVIEW_DECISION: BLOCK + +Use PASS only if you find no actionable issue. Use BLOCK if there is any +actionable issue or if the diff is too incomplete to review safely. + +After the decision line, provide concise findings with file/path references +when blocking. If passing, provide a brief explanation of the risk areas you +checked. + +Diff stat: +EOF + cat "$stat_file" + printf '\nDiff:\n```diff\n' + cat "$diff_file" + printf '\n```\n' +} >"$prompt_file" + +echo "[copilot-local-review] invoking Copilot CLI mode=$mode diff_bytes=$diff_bytes" +set +e +copilot_output="$( + copilot \ + -p "$(cat "$prompt_file")" \ + --disable-builtin-mcps \ + --disallow-temp-dir \ + --no-color \ + --output-format text 2>&1 +)" +copilot_rc=$? +set -e + +{ + echo "# Local Copilot Review" + echo + echo "- Mode: \`$mode\`" + [[ -n "$base_ref" ]] && echo "- Base: \`$base_ref\`" + echo "- Diff bytes: \`$diff_bytes\`" + echo "- Copilot exit code: \`$copilot_rc\`" + echo "- Started: \`$(date -u +%Y-%m-%dT%H:%M:%SZ)\`" + echo + echo "## Diff Stat" + echo + echo '```text' + cat "$stat_file" + echo '```' + echo + echo "## Copilot Output" + echo + echo '```text' + printf '%s\n' "$copilot_output" + echo '```' +} >"$report_file" + +printf '%s\n' "$copilot_output" +echo "[copilot-local-review] report=$report_file" + +if (( advisory == 1 )); then + echo "[copilot-local-review] advisory mode: not blocking" + exit 0 +fi + +if (( copilot_rc != 0 )); then + echo "[copilot-local-review] BLOCKED: Copilot CLI exited $copilot_rc" >&2 + exit 1 +fi + +pass_count="$(grep -c '^LOCAL_COPILOT_REVIEW_DECISION: PASS$' "$report_file" || true)" +block_count="$(grep -c '^LOCAL_COPILOT_REVIEW_DECISION: BLOCK$' "$report_file" || true)" + +if [[ "$pass_count" == "1" && "$block_count" == "0" ]]; then + echo "[copilot-local-review] clean: Copilot returned PASS" + exit 0 +fi + +if [[ "$block_count" != "0" ]]; then + echo "[copilot-local-review] BLOCKED: Copilot returned BLOCK" >&2 + exit 1 +fi + +echo "[copilot-local-review] BLOCKED: Copilot output did not contain the required PASS decision line" >&2 +exit 1 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..6f0ce17fe --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,32 @@ +## graphify + +This project has a graphify knowledge graph at graphify-out/. + +Rules: +- Before answering architecture or codebase questions, read graphify-out/GRAPH_REPORT.md for god nodes and community structure +- If graphify-out/wiki/index.md exists, navigate it instead of reading raw files +- After modifying code files in this session, run `graphify update .` to keep the graph current (AST-only, no API cost) + +## Remote and Rebase Policy + +Rules: +- `origin` is the active development target unless the user explicitly says otherwise. +- `upstream` is read-only/reference by default. Do not open, update, or describe work as ready for upstream contribution unless the user explicitly reopens that path. +- Still follow upstream Graphify closely: before planning implementation slices or rebases, fetch `origin` and `upstream`, verify current branch/HEAD/ahead-behind state, and compare live upstream changes that touch the same files. +- Prefer rebasing/pulling useful upstream changes into the local/origin development branch when doing so preserves the Graphify direction and reduces future drift. +- Upstream synchronization is preauthorized only for this Graphify/vampyre checkout (`/Users/jonathandirks/Development/vampyre`). Do not generalize this permission to any other repo or project. In this checkout, do not ask for human approval before stashing local work, fetching `origin`/`upstream`, comparing upstream changes, rebasing onto useful upstream updates, resolving conflicts by comparison, reapplying the stash, and rerunning verification. +- If upstream changes are harmful, irrelevant, or incompatible with the local Graphify direction, skip them only after comparing the change and recording the reason in the handoff. +- Helper scripts for stash/fetch/rebase workflows may be created only in local-only paths that are excluded from git, such as `.agent-local/`; do not add those helper scripts to GitHub-bound history. +- Absolute conflict rule: never resolve merge, rebase, cherry-pick, or generated-artifact conflicts by blindly choosing ours/theirs or by assuming the local branch is always correct. Always inspect and compare both sides, identify the intended behavior on each side, preserve useful upstream behavior unless it is deliberately incompatible with the local plan, and document the resolution in the handoff or commit notes. +- If conflict behavior is unclear after comparing both sides and the relevant tests/docs, stop and ask before resolving. +- For this checkout only, publishing a verified local/origin development branch is part of closing an upstream sync. Before pushing, run the full local verification stack, including local Copilot review, tests, lint/type/security/warning gates, pre-commit/pre-push gates, and graph refresh when applicable. +- After verification passes, fetch `origin` and compare. If local `HEAD` contains `origin` or deliberately supersedes it because origin-only commits are already integrated, duplicated by rebase, or intentionally rejected with the reason recorded, update `origin` with a normal push for fast-forwards or `--force-with-lease` for rewritten history. +- Do not leave `origin` behind a verified local stack merely because the local branch was rebased. If the lease fails, or if origin contains new valuable or unclear work that is not in local, stop, fetch, compare, and integrate or ask before publishing. +- This does not authorize publishing to `upstream`, changing remotes, or generalizing the Graphify/vampyre publication rule to any other repo. + +## Verification and Failure Policy + +Rules: +- Do not waive test failures, skips, warnings, linter findings, gate failures, graphify update failures, or audit findings as "pre-existing." If an issue is reproducible in the current workspace, it becomes the current agent's active task until resolved or until the user explicitly redirects the work. +- Do not mark a slice, PR, branch, or handoff as ready while any reproduced failure, skip, warning, or gate finding remains unresolved. +- If another agent reports an issue as pre-existing, independently reproduce it, root-cause it, fix it when it is in scope, and record the invalid waiver in the handoff or audit notes. diff --git a/README.md b/README.md index a25195af2..525b61415 100644 --- a/README.md +++ b/README.md @@ -496,6 +496,8 @@ graphify extract ./docs --mode deep # richer semantic extraction via graphify extract ./docs --no-cluster # raw extraction only, skip clustering graphify extract ./docs --force # overwrite graph.json even if new graph has fewer nodes (use after refactors or to clear ghost duplicates) graphify extract ./docs --dedup-llm # LLM tiebreaker for ambiguous entity pairs (uses same API key) +graphify extract ./docs --multigraph # build a MultiDiGraph that preserves parallel edges (e.g. A calls B AND A imports B) +graphify extract ./docs --simple # force a simple graph even over an existing multigraph (lossy — warns on collapse) graphify extract ./docs --global --as myrepo # extract and register into the cross-project global graph GRAPHIFY_MAX_OUTPUT_TOKENS=32768 graphify extract ./docs --backend claude # raise output cap for dense corpora diff --git a/graphify/__init__.py b/graphify/__init__.py index e34c938ef..e0f698b6e 100644 --- a/graphify/__init__.py +++ b/graphify/__init__.py @@ -22,6 +22,7 @@ def __getattr__(name): } if name in _map: import importlib + mod_name, attr = _map[name] mod = importlib.import_module(mod_name) return getattr(mod, attr) diff --git a/graphify/__main__.py b/graphify/__main__.py index 380572890..b19355518 100644 --- a/graphify/__main__.py +++ b/graphify/__main__.py @@ -1,4 +1,5 @@ """graphify CLI - `graphify install` sets up the Claude Code skill.""" + from __future__ import annotations import json import os @@ -10,6 +11,7 @@ try: from importlib.metadata import version as _pkg_version + __version__ = _pkg_version("graphifyy") except Exception: __version__ = "unknown" @@ -35,6 +37,7 @@ def _enforce_graph_size_cap_or_exit(gp: Path) -> None: and let the ``ValueError`` propagate. """ from graphify.security import check_graph_file_size_cap + try: check_graph_file_size_cap(gp) except ValueError as exc: @@ -42,17 +45,66 @@ def _enforce_graph_size_cap_or_exit(gp: Path) -> None: sys.exit(1) +def _backup_merge_target(target: Path) -> "Path | None": + """Snapshot an existing merge target to a dated ``.bak`` sibling before overwrite. + + Mirrors :func:`graphify.global_graph.backup_global_graph`'s dated-snapshot + contract, parameterized for an arbitrary merge output path (the merge-driver + writes to ``current``; ``merge-graphs`` writes to ``--out``). Neither of the + existing backup helpers fits here: ``export.backup_if_protected`` keys off a + ``graphify-out`` layout (semantic marker / curated labels) and + ``backup_global_graph`` is hard-wired to the global-graph path. + + The backup is written next to *target* as ``..bak``. + Idempotent within a day — identical content is not re-copied; a changed + target refreshes the same-day snapshot in place (one backup per day, always + the latest pre-overwrite state). Returns the backup path, or None when there + is nothing to back up (target absent) or backups are disabled via + ``GRAPHIFY_NO_BACKUP``. Never raises: a backup failure prints a warning and + returns None so it can never block the write it protects. + """ + import hashlib + from datetime import date + + if os.environ.get("GRAPHIFY_NO_BACKUP"): + return None + if not target.exists(): + return None + + today = date.today().isoformat() + backup_path = target.with_name(f"{target.stem}.{today}.bak") + try: + if backup_path.exists(): + src_hash = hashlib.sha256(target.read_bytes()).hexdigest() + bak_hash = hashlib.sha256(backup_path.read_bytes()).hexdigest() + if src_hash == bak_hash: + return backup_path # identical content, nothing to do + shutil.copy2(target, backup_path) + return backup_path + except Exception as exc: + print( + f"[graphify merge] warning: backup failed ({exc}) — continuing with overwrite", + file=sys.stderr, + ) + return None + + def _check_skill_version(skill_dst: Path) -> None: """Warn if the installed skill is from an older graphify version.""" version_file = skill_dst.parent / ".graphify_version" if not version_file.exists(): return if not skill_dst.exists(): - print(" warning: skill dir exists but SKILL.md is missing. Run 'graphify install' to repair.") + print( + " warning: skill dir exists but SKILL.md is missing. Run 'graphify install' to repair." + ) return installed = version_file.read_text(encoding="utf-8").strip() if installed != __version__: - print(f" warning: skill is from graphify {installed}, package is {__version__}. Run 'graphify install' to update.", file=sys.stderr) + print( + f" warning: skill is from graphify {installed}, package is {__version__}. Run 'graphify install' to update.", + file=sys.stderr, + ) def _refresh_all_version_stamps() -> None: @@ -68,7 +120,9 @@ def _refresh_all_version_stamps() -> None: vf.write_text(__version__, encoding="utf-8") -def _platform_skill_destination(platform_name: str, *, project: bool = False, project_dir: Path | None = None) -> Path: +def _platform_skill_destination( + platform_name: str, *, project: bool = False, project_dir: Path | None = None +) -> Path: """Return the skill destination for a platform and scope.""" if platform_name == "gemini": if project: @@ -102,9 +156,13 @@ def _platform_skill_destination(platform_name: str, *, project: bool = False, pr return Path.home() / cfg["skill_dst"] -def _copy_skill_file(platform_name: str, *, project: bool = False, project_dir: Path | None = None) -> Path: +def _copy_skill_file( + platform_name: str, *, project: bool = False, project_dir: Path | None = None +) -> Path: """Copy a packaged skill file and write its version stamp.""" - skill_file = "skill.md" if platform_name == "gemini" else _PLATFORM_CONFIG[platform_name]["skill_file"] + skill_file = ( + "skill.md" if platform_name == "gemini" else _PLATFORM_CONFIG[platform_name]["skill_file"] + ) skill_src = Path(__file__).parent / skill_file if not skill_src.exists(): print(f"error: {skill_file} not found in package - reinstall graphify", file=sys.stderr) @@ -127,7 +185,9 @@ def _copy_skill_file(platform_name: str, *, project: bool = False, project_dir: return skill_dst -def _remove_skill_file(platform_name: str, *, project: bool = False, project_dir: Path | None = None) -> bool: +def _remove_skill_file( + platform_name: str, *, project: bool = False, project_dir: Path | None = None +) -> bool: """Remove a platform skill file and its version stamp without touching other scopes.""" skill_dst = _platform_skill_destination(platform_name, project=project, project_dir=project_dir) removed = False @@ -187,6 +247,7 @@ def _print_project_git_add_hint(paths: list[Path]) -> None: print("Project-scoped install. Add to version control:") print(f" git add {' '.join(unique)}") + _SETTINGS_HOOK = { # Claude Code v2.1.117+ removed dedicated Grep/Glob tools; searches now go through Bash. # We match on Bash and inspect the command string to avoid firing on every shell call. @@ -195,10 +256,10 @@ def _print_project_git_add_hint(paths: list[Path]) -> None: { "type": "command", "command": ( - "CMD=$(python3 -c \"" + 'CMD=$(python3 -c "' "import json,sys; d=json.load(sys.stdin); " "print(d.get('tool_input',d).get('command',''))\" 2>/dev/null || true); " - "case \"$CMD\" in " + 'case "$CMD" in ' r"*grep*|*rg\ *|*ripgrep*|*find\ *|*fd\ *|*ack\ *|*ag\ *) " " [ -f graphify-out/graph.json ] && " r""" echo '{"hookSpecificOutput":{"hookEventName":"PreToolUse","additionalContext":"graphify: knowledge graph at graphify-out/. For focused questions, run `graphify query \"\"` (scoped subgraph, usually much smaller than GRAPH_REPORT.md) instead of grepping raw files. Read GRAPH_REPORT.md only for broad architecture context."}}' """ @@ -209,13 +270,14 @@ def _print_project_git_add_hint(paths: list[Path]) -> None: ], } + def _skill_registration(skill_path: str = "~/.claude/skills/graphify/SKILL.md") -> str: return ( "\n# graphify\n" f"- **graphify** (`{skill_path}`) " "- any input to knowledge graph. Trigger: `/graphify`\n" "When the user types `/graphify`, invoke the Skill tool " - "with `skill: \"graphify\"` before doing anything else.\n" + 'with `skill: "graphify"` before doing anything else.\n' ) @@ -360,7 +422,9 @@ def _replace_or_append_section(content: str, marker: str, new_section: str) -> s return out -def install(platform: str = "claude", *, project: bool = False, project_dir: Path | None = None) -> None: +def install( + platform: str = "claude", *, project: bool = False, project_dir: Path | None = None +) -> None: if platform == "gemini": gemini_install(project_dir=project_dir, project=project) return @@ -383,12 +447,18 @@ def install(platform: str = "claude", *, project: bool = False, project_dir: Pat if cfg["claude_md"]: # Register in the matching Claude Code scope. - claude_md = (project_dir / ".claude" / "CLAUDE.md") if project else Path.home() / ".claude" / "CLAUDE.md" - registration = _skill_registration(".claude/skills/graphify/SKILL.md" if project else "~/.claude/skills/graphify/SKILL.md") + claude_md = ( + (project_dir / ".claude" / "CLAUDE.md") + if project + else Path.home() / ".claude" / "CLAUDE.md" + ) + registration = _skill_registration( + ".claude/skills/graphify/SKILL.md" if project else "~/.claude/skills/graphify/SKILL.md" + ) if claude_md.exists(): content = claude_md.read_text(encoding="utf-8") if "graphify" in content: - print(f" CLAUDE.md -> already registered (no change)") + print(" CLAUDE.md -> already registered (no change)") else: claude_md.write_text(content.rstrip() + registration, encoding="utf-8") print(f" CLAUDE.md -> skill registered in {claude_md}") @@ -495,9 +565,7 @@ def gemini_install(project_dir: Path | None = None, *, project: bool = False) -> if target.exists(): content = target.read_text(encoding="utf-8") - new_content = _replace_or_append_section( - content, _GEMINI_MD_MARKER, _GEMINI_MD_SECTION - ) + new_content = _replace_or_append_section(content, _GEMINI_MD_MARKER, _GEMINI_MD_SECTION) else: new_content = _GEMINI_MD_SECTION @@ -511,7 +579,13 @@ def gemini_install(project_dir: Path | None = None, *, project: bool = False) -> # wording) is replaced on upgrade. _install_gemini_hook(project_dir) if project: - _print_project_git_add_hint([_project_scope_root(skill_dst, project_dir), project_dir / "GEMINI.md", project_dir / ".gemini"]) + _print_project_git_add_hint( + [ + _project_scope_root(skill_dst, project_dir), + project_dir / "GEMINI.md", + project_dir / ".gemini", + ] + ) print() print("Gemini CLI will now check the knowledge graph before answering") print("codebase questions and rebuild it after code changes.") @@ -521,7 +595,9 @@ def _install_gemini_hook(project_dir: Path) -> None: settings_path = project_dir / ".gemini" / "settings.json" settings_path.parent.mkdir(parents=True, exist_ok=True) try: - settings = json.loads(settings_path.read_text(encoding="utf-8")) if settings_path.exists() else {} + settings = ( + json.loads(settings_path.read_text(encoding="utf-8")) if settings_path.exists() else {} + ) except json.JSONDecodeError: settings = {} before_tool = settings.setdefault("hooks", {}).setdefault("BeforeTool", []) @@ -615,7 +691,9 @@ def vscode_install(project_dir: Path | None = None) -> None: print(f" {instructions} -> already configured (no change)") else: instructions.write_text(new_content, encoding="utf-8") - print(f" {instructions} -> graphify section {'updated' if _VSCODE_INSTRUCTIONS_MARKER in content else 'added'}") + print( + f" {instructions} -> graphify section {'updated' if _VSCODE_INSTRUCTIONS_MARKER in content else 'added'}" + ) else: instructions.write_text(_VSCODE_INSTRUCTIONS_SECTION, encoding="utf-8") print(f" {instructions} -> created") @@ -720,7 +798,7 @@ def _kiro_install(project_dir: Path) -> None: steering_dir.mkdir(parents=True, exist_ok=True) steering_dst = steering_dir / "graphify.md" if steering_dst.exists() and steering_dst.read_text(encoding="utf-8") == _KIRO_STEERING: - print(f" .kiro/steering/graphify.md -> already configured (no change)") + print(" .kiro/steering/graphify.md -> already configured (no change)") else: # File is wholly graphify-owned. Overwrite on upgrade so older # report-first wording does not silently linger (issue #580). @@ -801,11 +879,15 @@ def _antigravity_install(project_dir: Path) -> None: print("Antigravity will now check the knowledge graph before answering") print("codebase questions. Run /graphify first to build the graph.") print() - print("To enable full MCP architecture navigation, add this to ~/.gemini/antigravity/mcp_config.json:") + print( + "To enable full MCP architecture navigation, add this to ~/.gemini/antigravity/mcp_config.json:" + ) print(' "graphify": {') print(' "command": "uv",') - print(' "args": ["run", "--with", "graphifyy", "--with", "mcp", "-m", "graphify.serve", "${workspace.path}/graphify-out/graph.json"]') - print(' }') + print( + ' "args": ["run", "--with", "graphifyy", "--with", "mcp", "-m", "graphify.serve", "${workspace.path}/graphify-out/graph.json"]' + ) + print(" }") def _antigravity_uninstall(project_dir: Path, *, project: bool = False) -> None: @@ -1029,6 +1111,7 @@ def _resolve_graphify_exe() -> str: not on PATH (e.g. VS Code Codex extension on Windows). """ import shutil + found = shutil.which("graphify") if found: return found @@ -1086,7 +1169,7 @@ def _uninstall_codex_hook(project_dir: Path) -> None: filtered = [h for h in pre_tool if "graphify" not in str(h)] existing["hooks"]["PreToolUse"] = filtered hooks_path.write_text(json.dumps(existing, indent=2), encoding="utf-8") - print(f" .codex/hooks.json -> PreToolUse hook removed") + print(" .codex/hooks.json -> PreToolUse hook removed") def _agents_install(project_dir: Path, platform: str) -> None: @@ -1095,9 +1178,7 @@ def _agents_install(project_dir: Path, platform: str) -> None: if target.exists(): content = target.read_text(encoding="utf-8") - new_content = _replace_or_append_section( - content, _AGENTS_MD_MARKER, _AGENTS_MD_SECTION - ) + new_content = _replace_or_append_section(content, _AGENTS_MD_MARKER, _AGENTS_MD_SECTION) else: new_content = _AGENTS_MD_SECTION @@ -1136,7 +1217,17 @@ def _project_install(platform_name: str, project_dir: Path | None = None) -> Non elif platform_name == "kiro": _kiro_install(project_dir) _print_project_git_add_hint([project_dir / ".kiro"]) - elif platform_name in ("aider", "amp", "codex", "opencode", "claw", "droid", "trae", "trae-cn", "hermes"): + elif platform_name in ( + "aider", + "amp", + "codex", + "opencode", + "claw", + "droid", + "trae", + "trae-cn", + "hermes", + ): skill_dst = _copy_skill_file(platform_name, project=True, project_dir=project_dir) _agents_install(project_dir, platform_name) hint_paths = [_project_scope_root(skill_dst, project_dir), project_dir / "AGENTS.md"] @@ -1148,7 +1239,9 @@ def _project_install(platform_name: str, project_dir: Path | None = None) -> Non elif platform_name == "devin": skill_dst = _copy_skill_file("devin", project=True, project_dir=project_dir) _devin_rules_install(project_dir) - _print_project_git_add_hint([_project_scope_root(skill_dst, project_dir), project_dir / ".windsurf"]) + _print_project_git_add_hint( + [_project_scope_root(skill_dst, project_dir), project_dir / ".windsurf"] + ) elif platform_name in ("copilot", "pi", "antigravity", "kimi"): skill_dst = _copy_skill_file(platform_name, project=True, project_dir=project_dir) _print_project_git_add_hint([_project_scope_root(skill_dst, project_dir)]) @@ -1169,7 +1262,17 @@ def _project_uninstall(platform_name: str, project_dir: Path | None = None) -> N _cursor_uninstall(project_dir) elif platform_name == "kiro": _kiro_uninstall(project_dir) - elif platform_name in ("aider", "amp", "codex", "opencode", "claw", "droid", "trae", "trae-cn", "hermes"): + elif platform_name in ( + "aider", + "amp", + "codex", + "opencode", + "claw", + "droid", + "trae", + "trae-cn", + "hermes", + ): _remove_skill_file(platform_name, project=True, project_dir=project_dir) _agents_uninstall(project_dir, platform=platform_name) if platform_name == "codex": @@ -1236,9 +1339,7 @@ def claude_install(project_dir: Path | None = None) -> None: if target.exists(): content = target.read_text(encoding="utf-8") - new_content = _replace_or_append_section( - content, _CLAUDE_MD_MARKER, _CLAUDE_MD_SECTION - ) + new_content = _replace_or_append_section(content, _CLAUDE_MD_MARKER, _CLAUDE_MD_SECTION) else: new_content = _CLAUDE_MD_SECTION @@ -1273,10 +1374,14 @@ def _install_claude_hook(project_dir: Path) -> None: hooks = settings.setdefault("hooks", {}) pre_tool = hooks.setdefault("PreToolUse", []) - hooks["PreToolUse"] = [h for h in pre_tool if not (h.get("matcher") in ("Glob|Grep", "Bash") and "graphify" in str(h))] + hooks["PreToolUse"] = [ + h + for h in pre_tool + if not (h.get("matcher") in ("Glob|Grep", "Bash") and "graphify" in str(h)) + ] hooks["PreToolUse"].append(_SETTINGS_HOOK) settings_path.write_text(json.dumps(settings, indent=2), encoding="utf-8") - print(f" .claude/settings.json -> PreToolUse hook registered") + print(" .claude/settings.json -> PreToolUse hook registered") def _uninstall_claude_hook(project_dir: Path) -> None: @@ -1289,12 +1394,16 @@ def _uninstall_claude_hook(project_dir: Path) -> None: except json.JSONDecodeError: return pre_tool = settings.get("hooks", {}).get("PreToolUse", []) - filtered = [h for h in pre_tool if not (h.get("matcher") in ("Glob|Grep", "Bash") and "graphify" in str(h))] + filtered = [ + h + for h in pre_tool + if not (h.get("matcher") in ("Glob|Grep", "Bash") and "graphify" in str(h)) + ] if len(filtered) == len(pre_tool): return settings["hooks"]["PreToolUse"] = filtered settings_path.write_text(json.dumps(settings, indent=2), encoding="utf-8") - print(f" .claude/settings.json -> PreToolUse hook removed") + print(" .claude/settings.json -> PreToolUse hook removed") def uninstall_all(project_dir: Path | None = None, purge: bool = False) -> None: @@ -1317,18 +1426,20 @@ def uninstall_all(project_dir: Path | None = None, purge: bool = False) -> None: # Git hook try: from graphify.hooks import uninstall as hook_uninstall + result = hook_uninstall(pd) if result: print(result) - except Exception: - pass + except Exception as exc: + print(f"[graphify] warning: could not uninstall git hook: {exc}", file=sys.stderr) if purge: import shutil as _shutil + out = pd / "graphify-out" if out.exists(): _shutil.rmtree(out) - print(f"\n graphify-out/ -> deleted (--purge)") + print("\n graphify-out/ -> deleted (--purge)") else: print("\n graphify-out/ -> not found (nothing to purge)") @@ -1403,7 +1514,7 @@ def _clone_repo(url: str, branch: str | None = None, out_dir: Path | None = None cmd = ["git", "-C", str(dest), "pull"] if branch: cmd += ["origin", "--", branch] - result = _sp.run(cmd, capture_output=True, text=True) + result = _sp.run(cmd, capture_output=True, text=True) # nosec B603 if result.returncode != 0: print(f"warning: git pull failed:\n{result.stderr}", file=sys.stderr) else: @@ -1413,7 +1524,7 @@ def _clone_repo(url: str, branch: str | None = None, out_dir: Path | None = None if branch: cmd += ["--branch", branch] cmd += ["--", git_url, str(dest)] - result = _sp.run(cmd, capture_output=True, text=True) + result = _sp.run(cmd, capture_output=True, text=True) # nosec B603 if result.returncode != 0: print(f"error: git clone failed:\n{result.stderr}", file=sys.stderr) sys.exit(1) @@ -1422,13 +1533,138 @@ def _clone_repo(url: str, branch: str | None = None, out_dir: Path | None = None return dest +def _read_provider_registry(path: Path) -> dict: + if not path.is_file(): + return {} + try: + data = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return {} + if isinstance(data, dict): + return data + return {} + + +def _provider_cmd(args: list[str]) -> None: + from graphify.llm import BACKENDS, _custom_providers_path + + subcmd = args[0] if args else "" + global_path = _custom_providers_path(global_=True) + + if subcmd == "list": + global_path.parent.mkdir(parents=True, exist_ok=True) + existing = _read_provider_registry(global_path) + if not existing: + print("No custom providers registered.") + else: + for name in existing: + print(f" {name} ({existing[name].get('base_url', '')})") + + elif subcmd == "show": + name = args[1] if len(args) > 1 else "" + if not name: + print("Usage: graphify provider show ", file=sys.stderr) + sys.exit(1) + existing = _read_provider_registry(global_path) + if name not in existing: + print(f"Provider '{name}' not found.", file=sys.stderr) + sys.exit(1) + print(json.dumps({name: existing[name]}, indent=2)) + + elif subcmd == "add": + add_args = args[1:] + name = add_args[0] if add_args and not add_args[0].startswith("-") else "" + if not name: + print( + "Usage: graphify provider add --base-url URL --default-model MODEL --env-key KEY", + file=sys.stderr, + ) + sys.exit(1) + if name in BACKENDS: + print( + f"Error: '{name}' is a built-in provider and cannot be overridden.", + file=sys.stderr, + ) + sys.exit(1) + base_url = "" + default_model = "" + env_key = "" + pricing_input = 0.0 + pricing_output = 0.0 + i = 1 + while i < len(add_args): + a = add_args[i] + if a == "--base-url" and i + 1 < len(add_args): + base_url = add_args[i + 1] + i += 2 + elif a.startswith("--base-url="): + base_url = a.split("=", 1)[1] + i += 1 + elif a == "--default-model" and i + 1 < len(add_args): + default_model = add_args[i + 1] + i += 2 + elif a.startswith("--default-model="): + default_model = a.split("=", 1)[1] + i += 1 + elif a == "--env-key" and i + 1 < len(add_args): + env_key = add_args[i + 1] + i += 2 + elif a.startswith("--env-key="): + env_key = a.split("=", 1)[1] + i += 1 + elif a == "--pricing-input" and i + 1 < len(add_args): + pricing_input = float(add_args[i + 1]) + i += 2 + elif a == "--pricing-output" and i + 1 < len(add_args): + pricing_output = float(add_args[i + 1]) + i += 2 + else: + i += 1 + if not base_url or not default_model or not env_key: + print( + "Error: --base-url, --default-model, and --env-key are required.", + file=sys.stderr, + ) + sys.exit(1) + global_path.parent.mkdir(parents=True, exist_ok=True) + existing = _read_provider_registry(global_path) + existing[name] = { + "base_url": base_url, + "default_model": default_model, + "env_key": env_key, + "pricing": {"input": pricing_input, "output": pricing_output}, + "temperature": 0, + } + global_path.write_text(json.dumps(existing, indent=2) + "\n", encoding="utf-8") + print(f"Provider '{name}' added. Use with: graphify extract . --backend {name}") + + elif subcmd == "remove": + name = args[1] if len(args) > 1 else "" + if not name: + print("Usage: graphify provider remove ", file=sys.stderr) + sys.exit(1) + existing = _read_provider_registry(global_path) + if name not in existing: + print(f"Provider '{name}' not found.", file=sys.stderr) + sys.exit(1) + del existing[name] + global_path.write_text(json.dumps(existing, indent=2) + "\n", encoding="utf-8") + print(f"Provider '{name}' removed.") + + else: + print("Usage: graphify provider [add|list|show|remove]", file=sys.stderr) + if subcmd: + sys.exit(1) + + def main() -> None: for _stream in (sys.stdout, sys.stderr): - if _stream is not None and hasattr(_stream, "reconfigure"): + reconfigure = getattr(_stream, "reconfigure", None) if _stream is not None else None + if callable(reconfigure): try: - _stream.reconfigure(encoding="utf-8", errors="replace") - except Exception: - pass + reconfigure(encoding="utf-8", errors="replace") + except Exception as exc: + _ = exc # Check all known skill install locations for a stale version stamp. # Skip during install/uninstall (hook writes trigger a fresh check anyway). # Skip during hook-check — it runs on every editor tool use and must be silent. @@ -1446,12 +1682,14 @@ def main() -> None: print("Usage: graphify ") print() print("Commands:") - print(" install [--platform P] copy skill to platform config dir (claude|windows|codex|opencode|aider|claw|droid|trae|trae-cn|gemini|cursor|antigravity|hermes|kiro|pi|devin)") + print( + " install [--platform P] copy skill to platform config dir (claude|windows|codex|opencode|aider|claw|droid|trae|trae-cn|gemini|cursor|antigravity|hermes|kiro|pi|devin)" + ) print(" uninstall remove graphify from all detected platforms in one shot") print(" --purge also delete graphify-out/ directory") - print(" path \"A\" \"B\" shortest path between two nodes in graph.json") + print(' path "A" "B" shortest path between two nodes in graph.json') print(" --graph path to graph.json (default graphify-out/graph.json)") - print(" explain \"X\" plain-language explanation of a node and its neighbors") + print(' explain "X" plain-language explanation of a node and its neighbors') print(" --graph path to graph.json (default graphify-out/graph.json)") print(" diagnose multigraph report same-endpoint edge collapse risk in graph.json") print(" --graph path to graph/extraction JSON") @@ -1463,40 +1701,67 @@ def main() -> None: print(" (default follows JSON directed flag;") print(" raw extraction with no flag defaults directed)") print(" --extract-path PATH extractor source for suppression scan") - print(" clone clone a GitHub repo locally and print its path for /graphify") - print(" merge-driver git merge driver: union-merge two graph.json files (set up via hook install)") - print(" merge-graphs merge two or more graph.json files into one cross-repo graph") + print( + " clone clone a GitHub repo locally and print its path for /graphify" + ) + print( + " merge-driver git merge driver: union-merge two graph.json files (set up via hook install)" + ) + print( + " merge-graphs [--multigraph|--simple] merge two or more graph.json files into one cross-repo graph" + ) print(" --out output path (default: graphify-out/merged-graph.json)") print(" --branch checkout a specific branch (default: repo default)") - print(" --out clone to a custom directory (default: ~/.graphify/repos//)") + print( + " --out clone to a custom directory (default: ~/.graphify/repos//)" + ) print(" add fetch a URL and save it to ./raw, then update the graph") - print(" --author \"Name\" tag the author of the content") - print(" --contributor \"Name\" tag who added it to the corpus") + print(' --author "Name" tag the author of the content') + print(' --contributor "Name" tag who added it to the corpus') print(" --dir target directory (default: ./raw)") print(" watch watch a folder and rebuild the graph on code changes") - print(" update re-extract code files and update the graph (no LLM needed)") - print(" --force overwrite graph.json even if the rebuild has fewer nodes") - print(" (also: GRAPHIFY_FORCE=1 env var; use after refactors that delete code)") + print( + " update re-extract code files and update the graph (no LLM needed)" + ) + print( + " (inherits the existing graph.json profile — a multigraph stays a multigraph)" + ) + print( + " --force overwrite graph.json even if the rebuild has fewer nodes" + ) + print( + " (also: GRAPHIFY_FORCE=1 env var; use after refactors that delete code)" + ) print(" --no-cluster skip clustering, write raw extraction only") - print(" cluster-only rerun clustering on an existing graph.json and regenerate report") - print(" --no-viz skip graph.html generation (useful for >5000 node graphs / CI)") - print(" --graph path to graph.json (default /graphify-out/graph.json)") - print(" query \"\" BFS traversal of graph.json for a question") + print( + " cluster-only rerun clustering on an existing graph.json and regenerate report" + ) + print( + " --no-viz skip graph.html generation (useful for >5000 node graphs / CI)" + ) + print( + " --graph path to graph.json (default /graphify-out/graph.json)" + ) + print(' query "" BFS traversal of graph.json for a question') print(" --dfs use depth-first instead of breadth-first") print(" --context C explicit edge-context filter (repeatable)") print(" --budget N cap output at N tokens (default 2000)") print(" --graph path to graph.json (default graphify-out/graph.json)") - print(" affected \"X\" reverse traversal to find nodes impacted by X") + print(' affected "X" reverse traversal to find nodes impacted by X') print(" --relation R edge relation to traverse in reverse (repeatable)") print(" --depth N reverse traversal depth (default 2)") print(" --graph path to graph.json (default graphify-out/graph.json)") - print(" save-result save a Q&A result to graphify-out/memory/ for graph feedback loop") + print( + " save-result save a Q&A result to graphify-out/memory/ for graph feedback loop" + ) print(" --question Q the question asked") print(" --answer A the answer to save") print(" --type T query type: query|path_query|explain (default: query)") print(" --nodes N1 N2 ... source node labels cited in the answer") print(" --memory-dir DIR memory directory (default: graphify-out/memory)") - print(" check-update check needs_update flag and notify if semantic re-extraction is pending (cron-safe)") + print( + " check-update check needs_update flag and notify if semantic re-extraction is pending (cron-safe)" + ) print(" tree emit a D3 v7 collapsible-tree HTML for graph.json") print(" --graph PATH path to graph.json (default graphify-out/graph.json)") print(" --output HTML output path (default graphify-out/GRAPH_TREE.html)") @@ -1504,44 +1769,79 @@ def main() -> None: print(" --max-children N cap children per node (default 200)") print(" --top-k-edges N per-symbol outbound edges in inspector (default 12)") print(" --label NAME project label in header") - print(" extract headless full extraction (AST + semantic LLM) for CI/scripts") - print(" --backend B gemini|kimi|claude|openai|deepseek|ollama (default: whichever API key is set)") + print( + " extract headless full extraction (AST + semantic LLM) for CI/scripts" + ) + print( + " --backend B gemini|kimi|claude|openai|deepseek|ollama (default: whichever API key is set)" + ) print(" --model M override backend default model") print(" --mode deep aggressive INFERRED-edge semantic extraction") print(" --max-workers N AST extraction subprocess count (default: cpu_count)") - print(" --token-budget N per-chunk token cap for semantic extraction (default: 60000)") - print(" --max-concurrency N parallel semantic chunks in flight (default: 4; set 1 for local LLMs)") - print(" --api-timeout S per-request timeout in seconds for the LLM client (default: 600)") - print(" --out DIR output dir (default: ); writes /graphify-out/") - print(" --google-workspace export .gdoc/.gsheet/.gslides shortcuts via gws before extraction") + print( + " --token-budget N per-chunk token cap for semantic extraction (default: 60000)" + ) + print( + " --max-concurrency N parallel semantic chunks in flight (default: 4; set 1 for local LLMs)" + ) + print( + " --api-timeout S per-request timeout in seconds for the LLM client (default: 600)" + ) + print( + " --out DIR output dir (default: ); writes /graphify-out/" + ) + print( + " --google-workspace export .gdoc/.gsheet/.gslides shortcuts via gws before extraction" + ) print(" --no-cluster skip clustering, write raw extraction only") + print( + " --multigraph build a keyed MultiDiGraph (preserves parallel edges between the same pair)" + ) + print( + " --simple force a simple graph even over an existing multigraph (lossy downgrade)" + ) + print( + " (default: STICKY — inherit the existing graph.json profile)" + ) print(" --global also merge the resulting graph into the global graph") print(" --as repo tag for --global (default: target directory name)") - print(" global add add/update a project graph in the global graph (~/.graphify/global-graph.json)") + print( + " global add add/update a project graph in the global graph (~/.graphify/global-graph.json)" + ) print(" --as repo tag (default: parent directory name)") print(" global remove remove a repo's nodes from the global graph") print(" global list list repos in the global graph") print(" global path print path to the global graph file") print(" benchmark [graph.json] measure token reduction vs naive full-corpus approach") print(" export callflow-html emit Mermaid-based architecture/call-flow HTML") - print(" hook install install post-commit/post-checkout git hooks (all platforms)") + print( + " hook install install post-commit/post-checkout git hooks (all platforms)" + ) print(" hook uninstall remove git hooks") print(" hook status check if git hooks are installed") print(" gemini install write GEMINI.md section + BeforeTool hook (Gemini CLI)") print(" gemini uninstall remove GEMINI.md section + BeforeTool hook") print(" cursor install write .cursor/rules/graphify.mdc (Cursor)") print(" cursor uninstall remove .cursor/rules/graphify.mdc") - print(" claude install write graphify section to CLAUDE.md + PreToolUse hook (Claude Code)") + print( + " claude install write graphify section to CLAUDE.md + PreToolUse hook (Claude Code)" + ) print(" claude uninstall remove graphify section from CLAUDE.md + PreToolUse hook") print(" codex install write graphify section to AGENTS.md (Codex)") print(" codex uninstall remove graphify section from AGENTS.md") - print(" opencode install write graphify section to AGENTS.md + tool.execute.before plugin (OpenCode)") + print( + " opencode install write graphify section to AGENTS.md + tool.execute.before plugin (OpenCode)" + ) print(" opencode uninstall remove graphify section from AGENTS.md + plugin") print(" aider install write graphify section to AGENTS.md (Aider)") print(" aider uninstall remove graphify section from AGENTS.md") - print(" copilot install copy graphify skill to ~/.copilot/skills (GitHub Copilot CLI)") + print( + " copilot install copy graphify skill to ~/.copilot/skills (GitHub Copilot CLI)" + ) print(" copilot uninstall remove graphify skill from ~/.copilot/skills") - print(" vscode install configure VS Code Copilot Chat (skill + .github/copilot-instructions.md)") + print( + " vscode install configure VS Code Copilot Chat (skill + .github/copilot-instructions.md)" + ) print(" vscode uninstall remove VS Code Copilot Chat configuration") print(" claw install write graphify section to AGENTS.md (OpenClaw)") print(" claw uninstall remove graphify section from AGENTS.md") @@ -1551,15 +1851,23 @@ def main() -> None: print(" trae uninstall remove graphify section from AGENTS.md") print(" trae-cn install write graphify section to AGENTS.md (Trae CN)") print(" trae-cn uninstall remove graphify section from AGENTS.md") - print(" antigravity install write .agents/rules + .agents/workflows + skill (Google Antigravity)") + print( + " antigravity install write .agents/rules + .agents/workflows + skill (Google Antigravity)" + ) print(" antigravity uninstall remove .agents/rules, .agents/workflows, and skill") print(" hermes install write skill to ~/.hermes/skills/graphify/ (Hermes)") print(" hermes uninstall remove skill from ~/.hermes/skills/graphify/") - print(" kiro install write skill to .kiro/skills/graphify/ + steering file (Kiro IDE/CLI)") + print( + " kiro install write skill to .kiro/skills/graphify/ + steering file (Kiro IDE/CLI)" + ) print(" kiro uninstall remove skill + steering file") - print(" pi install write skill to ~/.pi/agent/skills/graphify/ (Pi coding agent)") + print( + " pi install write skill to ~/.pi/agent/skills/graphify/ (Pi coding agent)" + ) print(" pi uninstall remove skill from ~/.pi/agent/skills/graphify/") - print(" devin install write skill to ~/.config/devin/skills/graphify/ (Devin CLI)") + print( + " devin install write skill to ~/.config/devin/skills/graphify/ (Devin CLI)" + ) print(" devin uninstall remove skill from ~/.config/devin/skills/graphify/") print() return @@ -1573,7 +1881,7 @@ def main() -> None: # "install"/"uninstall" which have their own per-subcommand help handlers. _FREE_TEXT_CMDS = {"query", "explain", "path", "save-result", "install", "uninstall"} if cmd not in _FREE_TEXT_CMDS and any(a in {"-h", "--help", "-?"} for a in sys.argv[2:]): - print(f"Run 'graphify --help' for full usage.") + print("Run 'graphify --help' for full usage.") return if cmd == "install": @@ -1785,123 +2093,18 @@ def main() -> None: print("Usage: graphify antigravity [install|uninstall]", file=sys.stderr) sys.exit(1) elif cmd == "provider": - from graphify.llm import _custom_providers_path, BACKENDS - import json as _json - subcmd = sys.argv[2] if len(sys.argv) > 2 else "" - global_path = _custom_providers_path(global_=True) - - if subcmd == "list": - global_path.parent.mkdir(parents=True, exist_ok=True) - existing: dict = {} - if global_path.is_file(): - try: - existing = _json.loads(global_path.read_text(encoding="utf-8")) - except Exception: - pass - if not existing: - print("No custom providers registered.") - else: - for name in existing: - print(f" {name} ({existing[name].get('base_url', '')})") - - elif subcmd == "show": - name = sys.argv[3] if len(sys.argv) > 3 else "" - if not name: - print("Usage: graphify provider show ", file=sys.stderr) - sys.exit(1) - existing = {} - if global_path.is_file(): - try: - existing = _json.loads(global_path.read_text(encoding="utf-8")) - except Exception: - pass - if name not in existing: - print(f"Provider '{name}' not found.", file=sys.stderr) - sys.exit(1) - print(_json.dumps({name: existing[name]}, indent=2)) - - elif subcmd == "add": - args = sys.argv[3:] - name = args[0] if args and not args[0].startswith("-") else "" - if not name: - print("Usage: graphify provider add --base-url URL --default-model MODEL --env-key KEY", file=sys.stderr) - sys.exit(1) - if name in BACKENDS: - print(f"Error: '{name}' is a built-in provider and cannot be overridden.", file=sys.stderr) - sys.exit(1) - base_url = "" - default_model = "" - env_key = "" - pricing_input = 0.0 - pricing_output = 0.0 - i = 1 - while i < len(args): - a = args[i] - if a == "--base-url" and i + 1 < len(args): - base_url = args[i + 1]; i += 2 - elif a.startswith("--base-url="): - base_url = a.split("=", 1)[1]; i += 1 - elif a == "--default-model" and i + 1 < len(args): - default_model = args[i + 1]; i += 2 - elif a.startswith("--default-model="): - default_model = a.split("=", 1)[1]; i += 1 - elif a == "--env-key" and i + 1 < len(args): - env_key = args[i + 1]; i += 2 - elif a.startswith("--env-key="): - env_key = a.split("=", 1)[1]; i += 1 - elif a == "--pricing-input" and i + 1 < len(args): - pricing_input = float(args[i + 1]); i += 2 - elif a == "--pricing-output" and i + 1 < len(args): - pricing_output = float(args[i + 1]); i += 2 - else: - i += 1 - if not base_url or not default_model or not env_key: - print("Error: --base-url, --default-model, and --env-key are required.", file=sys.stderr) - sys.exit(1) - global_path.parent.mkdir(parents=True, exist_ok=True) - existing = {} - if global_path.is_file(): - try: - existing = _json.loads(global_path.read_text(encoding="utf-8")) - except Exception: - pass - existing[name] = { - "base_url": base_url, - "default_model": default_model, - "env_key": env_key, - "pricing": {"input": pricing_input, "output": pricing_output}, - "temperature": 0, - } - global_path.write_text(_json.dumps(existing, indent=2) + "\n", encoding="utf-8") - print(f"Provider '{name}' added. Use with: graphify extract . --backend {name}") - - elif subcmd == "remove": - name = sys.argv[3] if len(sys.argv) > 3 else "" - if not name: - print("Usage: graphify provider remove ", file=sys.stderr) - sys.exit(1) - existing = {} - if global_path.is_file(): - try: - existing = _json.loads(global_path.read_text(encoding="utf-8")) - except Exception: - pass - if name not in existing: - print(f"Provider '{name}' not found.", file=sys.stderr) - sys.exit(1) - del existing[name] - global_path.write_text(_json.dumps(existing, indent=2) + "\n", encoding="utf-8") - print(f"Provider '{name}' removed.") - - else: - print("Usage: graphify provider [add|list|show|remove]", file=sys.stderr) - if subcmd: - sys.exit(1) + _provider_cmd(sys.argv[2:]) elif cmd == "prs": from graphify.prs import cmd_prs + cmd_prs(sys.argv[2:]) elif cmd == "hook": - from graphify.hooks import install as hook_install, uninstall as hook_uninstall, status as hook_status + from graphify.hooks import ( + install as hook_install, + uninstall as hook_uninstall, + status as hook_status, + ) + subcmd = sys.argv[2] if len(sys.argv) > 2 else "" if subcmd == "install": print(hook_install(Path("."))) @@ -1914,11 +2117,14 @@ def main() -> None: sys.exit(1) elif cmd == "query": if len(sys.argv) < 3: - print("Usage: graphify query \"\" [--dfs] [--context C] [--budget N] [--graph path]", file=sys.stderr) + print( + 'Usage: graphify query "" [--dfs] [--context C] [--budget N] [--graph path]', + file=sys.stderr, + ) sys.exit(1) from graphify.serve import _query_graph_text - from graphify.security import sanitize_label from networkx.readwrite import json_graph + question = sys.argv[2] use_dfs = "--dfs" in sys.argv budget = 2000 @@ -1931,14 +2137,14 @@ def main() -> None: try: budget = int(args[i + 1]) except ValueError: - print(f"error: --budget must be an integer", file=sys.stderr) + print("error: --budget must be an integer", file=sys.stderr) sys.exit(1) i += 2 elif args[i].startswith("--budget="): try: budget = int(args[i].split("=", 1)[1]) except ValueError: - print(f"error: --budget must be an integer", file=sys.stderr) + print("error: --budget must be an integer", file=sys.stderr) sys.exit(1) i += 1 elif args[i] == "--context" and i + 1 < len(args): @@ -1948,7 +2154,8 @@ def main() -> None: context_filters.append(args[i].split("=", 1)[1]) i += 1 elif args[i] == "--graph" and i + 1 < len(args): - graph_path = args[i + 1]; i += 2 + graph_path = args[i + 1] + i += 2 else: i += 1 gp = Path(graph_path).resolve() @@ -1956,12 +2163,13 @@ def main() -> None: print(f"error: graph file not found: {gp}", file=sys.stderr) sys.exit(1) if not gp.suffix == ".json": - print(f"error: graph file must be a .json file", file=sys.stderr) + print("error: graph file must be a .json file", file=sys.stderr) sys.exit(1) _enforce_graph_size_cap_or_exit(gp) try: import json as _json import networkx as _nx + _raw = _json.loads(gp.read_text(encoding="utf-8")) if "links" not in _raw and "edges" in _raw: _raw = dict(_raw, links=_raw["edges"]) @@ -1984,9 +2192,13 @@ def main() -> None: ) elif cmd == "affected": if len(sys.argv) < 3: - print("Usage: graphify affected \"\" [--relation R] [--depth N] [--graph path]", file=sys.stderr) + print( + 'Usage: graphify affected "" [--relation R] [--depth N] [--graph path]', + file=sys.stderr, + ) sys.exit(1) from graphify.affected import DEFAULT_AFFECTED_RELATIONS, format_affected, load_graph + query = sys.argv[2] graph_path = "graphify-out/graph.json" depth = 2 @@ -2045,6 +2257,7 @@ def main() -> None: elif cmd == "save-result": # graphify save-result --question Q --answer A --type T [--nodes N1 N2 ...] import argparse as _ap + p = _ap.ArgumentParser(prog="graphify save-result") p.add_argument("--question", required=True) p.add_argument("--answer", required=True) @@ -2053,6 +2266,7 @@ def main() -> None: p.add_argument("--memory-dir", default="graphify-out/memory") opts = p.parse_args(sys.argv[2:]) from graphify.ingest import save_query_result as _sqr + out = _sqr( question=opts.question, answer=opts.answer, @@ -2063,11 +2277,12 @@ def main() -> None: print(f"Saved to {out}") elif cmd == "path": if len(sys.argv) < 4: - print("Usage: graphify path \"\" \"\" [--graph path]", file=sys.stderr) + print('Usage: graphify path "" "" [--graph path]', file=sys.stderr) sys.exit(1) from graphify.serve import _score_nodes from networkx.readwrite import json_graph import networkx as _nx + source_label = sys.argv[2] target_label = sys.argv[3] graph_path = _default_graph_path() @@ -2125,32 +2340,51 @@ def main() -> None: hops = len(path_nodes) - 1 segments = [] from graphify.build import edge_data + from graphify.projections import ( + format_relationship_envelope, + relationship_envelope, + ) + for i in range(len(path_nodes) - 1): u, v = path_nodes[i], path_nodes[i + 1] # Check which direction the stored edge points. if G.has_edge(u, v): - edata = edge_data(G, u, v) forward = True + src, tgt = u, v else: - edata = edge_data(G, v, u) forward = False - rel = edata.get("relation", "") - conf = edata.get("confidence", "") - conf_str = f" [{conf}]" if conf else "" + src, tgt = v, u + # Bundle every parallel relationship on this hop so a MultiDiGraph + # never silently shows only the first edge (#PR5 go/no-go gate). + # directed_only=True isolates this hop's stored direction (src->tgt) + # so a reverse edge (tgt->src) never bleeds into the arrow's bundle. + env = relationship_envelope(G, src, tgt, directed_only=True) + if len(env["relations"]) > 1: + # Multiple parallel relations: render the capped bundle; the + # envelope omits per-relation confidence for stability. + rel_str = format_relationship_envelope(G, src, tgt, directed_only=True) + else: + # Single relation (always true for simple DiGraph/Graph): keep + # the historical "rel [CONFIDENCE]" form byte-for-byte stable. + edata = edge_data(G, src, tgt) + rel = edata.get("relation", "") + conf = edata.get("confidence", "") + rel_str = f"{rel} [{conf}]" if conf else rel if i == 0: segments.append(G.nodes[u].get("label", u)) if forward: - segments.append(f"--{rel}{conf_str}--> {G.nodes[v].get('label', v)}") + segments.append(f"--{rel_str}--> {G.nodes[v].get('label', v)}") else: - segments.append(f"<--{rel}{conf_str}-- {G.nodes[v].get('label', v)}") + segments.append(f"<--{rel_str}-- {G.nodes[v].get('label', v)}") print(f"Shortest path ({hops} hops):\n " + " ".join(segments)) elif cmd == "explain": if len(sys.argv) < 3: - print("Usage: graphify explain \"\" [--graph path]", file=sys.stderr) + print('Usage: graphify explain "" [--graph path]', file=sys.stderr) sys.exit(1) from graphify.serve import _find_node from networkx.readwrite import json_graph + label = sys.argv[2] graph_path = _default_graph_path() args = sys.argv[3:] @@ -2184,19 +2418,42 @@ def main() -> None: print(f" Community: {d.get('community', '')}") print(f" Degree: {G.degree(nid)}") from graphify.build import edge_data - connections: list[tuple[str, str, dict]] = [] # (direction, neighbor_id, edge_data) + from graphify.projections import ( + format_relationship_envelope, + relationship_envelope, + ) + + # (direction, neighbor_id, edge_src, edge_tgt) — src/tgt preserve the + # stored edge direction so the relationship envelope reads the correct + # parallel-edge bundle for this neighbor. + connections: list[tuple[str, str, str, str]] = [] for nb in G.successors(nid): - connections.append(("out", nb, edge_data(G, nid, nb))) + connections.append(("out", nb, nid, nb)) for nb in G.predecessors(nid): - connections.append(("in", nb, edge_data(G, nb, nid))) + connections.append(("in", nb, nb, nid)) if connections: print(f"\nConnections ({len(connections)}):") connections.sort(key=lambda c: G.degree(c[1]), reverse=True) - for direction, nb, edata in connections[:20]: - rel = edata.get("relation", "") - conf = edata.get("confidence", "") + for direction, nb, e_src, e_tgt in connections[:20]: arrow = "-->" if direction == "out" else "<--" - print(f" {arrow} {G.nodes[nb].get('label', nb)} [{rel}] [{conf}]") + # Bundle every parallel relationship to this neighbor so a + # MultiDiGraph never shows only the first edge (#PR5 gate). + # directed_only=True isolates this connection's stored direction + # (e_src->e_tgt) so an "out" arrow never merges the reverse "in" + # relations and vice versa. + env = relationship_envelope(G, e_src, e_tgt, directed_only=True) + if len(env["relations"]) > 1: + rel_block = ( + f"[{format_relationship_envelope(G, e_src, e_tgt, directed_only=True)}]" + ) + else: + # Single relation (always true for simple DiGraph/Graph): + # keep the historical "[rel] [conf]" form byte-stable. + edata = edge_data(G, e_src, e_tgt) + rel = edata.get("relation", "") + conf = edata.get("confidence", "") + rel_block = f"[{rel}] [{conf}]" + print(f" {arrow} {G.nodes[nb].get('label', nb)} {rel_block}") if len(connections) > 20: print(f" ... and {len(connections) - 20} more") @@ -2296,9 +2553,13 @@ def main() -> None: elif cmd == "add": if len(sys.argv) < 3: - print("Usage: graphify add [--author Name] [--contributor Name] [--dir ./raw]", file=sys.stderr) + print( + "Usage: graphify add [--author Name] [--contributor Name] [--dir ./raw]", + file=sys.stderr, + ) sys.exit(1) from graphify.ingest import ingest as _ingest + url = sys.argv[2] author: str | None = None contributor: str | None = None @@ -2307,11 +2568,14 @@ def main() -> None: i = 0 while i < len(args): if args[i] == "--author" and i + 1 < len(args): - author = args[i + 1]; i += 2 + author = args[i + 1] + i += 2 elif args[i] == "--contributor" and i + 1 < len(args): - contributor = args[i + 1]; i += 2 + contributor = args[i + 1] + i += 2 elif args[i] == "--dir" and i + 1 < len(args): - target_dir = Path(args[i + 1]); i += 2 + target_dir = Path(args[i + 1]) + i += 2 else: i += 1 try: @@ -2328,6 +2592,7 @@ def main() -> None: print(f"error: path not found: {watch_path}", file=sys.stderr) sys.exit(1) from graphify.watch import watch as _watch + try: _watch(watch_path) except ImportError as exc: @@ -2349,26 +2614,36 @@ def main() -> None: while i_arg < len(args): a = args[i_arg] if a == "--graph" and i_arg + 1 < len(args): - graph_override = Path(args[i_arg + 1]); i_arg += 2 + graph_override = Path(args[i_arg + 1]) + i_arg += 2 elif a == "--resolution" and i_arg + 1 < len(args): - co_resolution = float(args[i_arg + 1]); i_arg += 2 + co_resolution = float(args[i_arg + 1]) + i_arg += 2 elif a.startswith("--resolution="): - co_resolution = float(a.split("=", 1)[1]); i_arg += 1 + co_resolution = float(a.split("=", 1)[1]) + i_arg += 1 elif a == "--exclude-hubs" and i_arg + 1 < len(args): - co_exclude_hubs = float(args[i_arg + 1]); i_arg += 2 + co_exclude_hubs = float(args[i_arg + 1]) + i_arg += 2 elif a.startswith("--exclude-hubs="): - co_exclude_hubs = float(a.split("=", 1)[1]); i_arg += 1 + co_exclude_hubs = float(a.split("=", 1)[1]) + i_arg += 1 elif a == "--no-viz" or a.startswith("--min-community-size="): i_arg += 1 elif a.startswith("--"): i_arg += 1 elif watch_path is None: - watch_path = Path(a); i_arg += 1 + watch_path = Path(a) + i_arg += 1 else: i_arg += 1 if watch_path is None: watch_path = Path(".") - graph_json = graph_override if graph_override is not None else watch_path / "graphify-out" / "graph.json" + graph_json = ( + graph_override + if graph_override is not None + else watch_path / "graphify-out" / "graph.json" + ) if not graph_json.exists(): print(f"error: no graph found at {graph_json} - run /graphify first", file=sys.stderr) sys.exit(1) @@ -2378,6 +2653,7 @@ def main() -> None: from graphify.analyze import god_nodes, surprising_connections, suggest_questions from graphify.report import generate from graphify.export import to_json, to_html + print("Loading existing graph...") _enforce_graph_size_cap_or_exit(graph_json) _raw = json.loads(graph_json.read_text(encoding="utf-8")) @@ -2406,7 +2682,10 @@ def main() -> None: labels_path = out / ".graphify_labels.json" if labels_path.exists(): try: - labels = {int(k): v for k, v in json.loads(labels_path.read_text(encoding="utf-8")).items()} + labels = { + int(k): v + for k, v in json.loads(labels_path.read_text(encoding="utf-8")).items() + } except Exception: labels = {cid: f"Community {cid}" for cid in communities} else: @@ -2414,16 +2693,30 @@ def main() -> None: questions = suggest_questions(G, communities, labels) tokens = {"input": 0, "output": 0} from graphify.export import _git_head as _gh + _commit = _gh() - report = generate(G, communities, cohesion, labels, gods, surprises, - {"warning": "cluster-only mode — file stats not available"}, - tokens, str(watch_path), suggested_questions=questions, - min_community_size=min_community_size, built_at_commit=_commit) + report = generate( + G, + communities, + cohesion, + labels, + gods, + surprises, + {"warning": "cluster-only mode — file stats not available"}, + tokens, + str(watch_path), + suggested_questions=questions, + min_community_size=min_community_size, + built_at_commit=_commit, + ) (out / "GRAPH_REPORT.md").write_text(report, encoding="utf-8") from graphify.export import backup_if_protected as _backup + _backup(out) to_json(G, communities, str(out / "graph.json")) - labels_path.write_text(json.dumps({str(k): v for k, v in labels.items()}, ensure_ascii=False), encoding="utf-8") + labels_path.write_text( + json.dumps({str(k): v for k, v in labels.items()}, ensure_ascii=False), encoding="utf-8" + ) # Mirror watch.py pattern: gate to_html so core outputs (graph.json + # GRAPH_REPORT.md) always land. Honor --no-viz explicitly; otherwise @@ -2433,20 +2726,27 @@ def main() -> None: if no_viz: if html_target.exists(): html_target.unlink() - print(f"Done - {len(communities)} communities. GRAPH_REPORT.md and graph.json updated (--no-viz; graph.html removed).") + print( + f"Done - {len(communities)} communities. GRAPH_REPORT.md and graph.json updated (--no-viz; graph.html removed)." + ) else: try: to_html(G, communities, str(html_target), community_labels=labels or None) - print(f"Done - {len(communities)} communities. GRAPH_REPORT.md, graph.json and graph.html updated.") + print( + f"Done - {len(communities)} communities. GRAPH_REPORT.md, graph.json and graph.html updated." + ) except ValueError as viz_err: if html_target.exists(): html_target.unlink() print(f"Skipped graph.html: {viz_err}") - print(f"Done - {len(communities)} communities. GRAPH_REPORT.md and graph.json updated.") + print( + f"Done - {len(communities)} communities. GRAPH_REPORT.md and graph.json updated." + ) elif cmd == "update": force = os.environ.get("GRAPHIFY_FORCE", "").lower() in ("1", "true", "yes") no_cluster = False + no_viz = False args = sys.argv[2:] watch_arg: str | None = None for a in args: @@ -2456,6 +2756,9 @@ def main() -> None: if a == "--no-cluster": no_cluster = True continue + if a == "--no-viz": + no_viz = True + continue if a.startswith("-"): print(f"error: unknown update option: {a}", file=sys.stderr) sys.exit(2) @@ -2476,14 +2779,33 @@ def main() -> None: if not watch_path.exists(): print(f"error: path not found: {watch_path}", file=sys.stderr) sys.exit(1) + + # PR 7 go/no-go gate: "no silent fallback to simple graph behavior." + # No special handling is needed here: watch._rebuild_code now inherits + # the saved graph.json profile (it reads the on-disk `multigraph` flag + # and rebuilds via build_from_json(multigraph=...), re-stamping + # multigraph/directed + graphify_profile on write). A multidigraph + # graph.json therefore round-trips through `graphify update` as a + # MultiDiGraph with its keyed parallel edges intact — never silently + # collapsed to a simple graph — and simple/digraph graphs update exactly + # as before. from graphify.watch import _rebuild_code + print(f"Re-extracting code files in {watch_path} (no LLM needed)...") # Interactive CLI: block on the per-repo lock rather than skip, so the # user sees their explicit `graphify update` complete instead of # exiting silently when a hook-driven rebuild happens to be running. - ok = _rebuild_code(watch_path, force=force, no_cluster=no_cluster, block_on_lock=True) + ok = _rebuild_code( + watch_path, + force=force, + no_cluster=no_cluster, + no_viz=no_viz, + block_on_lock=True, + ) if ok: - print("Code graph updated. For doc/paper/image changes run /graphify --update in your AI assistant.") + print( + "Code graph updated. For doc/paper/image changes run /graphify --update in your AI assistant." + ) if not ( os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") @@ -2491,7 +2813,9 @@ def main() -> None: or os.environ.get("DEEPSEEK_API_KEY") or os.environ.get("GRAPHIFY_NO_TIPS") ): - print("Tip: set GEMINI_API_KEY or GOOGLE_API_KEY to use Gemini for semantic extraction.") + print( + "Tip: set GEMINI_API_KEY or GOOGLE_API_KEY to use Gemini for semantic extraction." + ) else: print("Nothing to update or rebuild failed - check output above.", file=sys.stderr) sys.exit(1) @@ -2506,6 +2830,7 @@ def main() -> None: print("Usage: graphify check-update ", file=sys.stderr) sys.exit(1) from graphify.watch import check_update + check_update(Path(sys.argv[2]).resolve()) sys.exit(0) elif cmd == "tree": @@ -2516,6 +2841,7 @@ def main() -> None: # showing top-K outbound edges per symbol. from typing import Optional as _Opt from graphify.tree_html import write_tree_html, DEFAULT_MAX_CHILDREN + graph_path = Path(_GRAPHIFY_OUT) / "graph.json" output_path: "_Opt[Path]" = None root: "_Opt[str]" = None @@ -2527,24 +2853,34 @@ def main() -> None: while i_arg < len(args): a = args[i_arg] if a == "--graph" and i_arg + 1 < len(args): - graph_path = Path(args[i_arg + 1]); i_arg += 2 + graph_path = Path(args[i_arg + 1]) + i_arg += 2 elif a == "--output" and i_arg + 1 < len(args): - output_path = Path(args[i_arg + 1]); i_arg += 2 + output_path = Path(args[i_arg + 1]) + i_arg += 2 elif a == "--root" and i_arg + 1 < len(args): - root = args[i_arg + 1]; i_arg += 2 + root = args[i_arg + 1] + i_arg += 2 elif a == "--max-children" and i_arg + 1 < len(args): - max_children = int(args[i_arg + 1]); i_arg += 2 + max_children = int(args[i_arg + 1]) + i_arg += 2 elif a == "--top-k-edges" and i_arg + 1 < len(args): - top_k_edges = int(args[i_arg + 1]); i_arg += 2 + top_k_edges = int(args[i_arg + 1]) + i_arg += 2 elif a == "--label" and i_arg + 1 < len(args): - project_label = args[i_arg + 1]; i_arg += 2 + project_label = args[i_arg + 1] + i_arg += 2 elif a in ("-h", "--help"): print("Usage: graphify tree [--graph PATH] [--output HTML]") print(" --graph PATH path to graph.json (default graphify-out/graph.json)") print(" --output HTML output path (default graphify-out/GRAPH_TREE.html)") - print(" --root PATH filesystem root (default: longest common dir of all source_files)") + print( + " --root PATH filesystem root (default: longest common dir of all source_files)" + ) print(" --max-children N cap visible children per node (default 200)") - print(" --top-k-edges N pre-compute top-K outbound edges per symbol (default 12)") + print( + " --top-k-edges N pre-compute top-K outbound edges per symbol (default 12)" + ) print(" --label NAME project label shown in the page header") return else: @@ -2556,9 +2892,12 @@ def main() -> None: if output_path is None: output_path = graph_path.parent / "GRAPH_TREE.html" out = write_tree_html( - graph_path=graph_path, output_path=output_path, - root=root, max_children=max_children, - top_k_edges=top_k_edges, project_label=project_label, + graph_path=graph_path, + output_path=output_path, + root=root, + max_children=max_children, + top_k_edges=top_k_edges, + project_label=project_label, ) size_kb = out.stat().st_size / 1024 print(f"wrote {out} ({size_kb:.1f} KB)") @@ -2583,6 +2922,7 @@ def main() -> None: _MERGE_MAX_NODES = 100_000 import networkx as _nx from networkx.readwrite import json_graph as _jg + def _load_graph(p: str): path_obj = Path(p) try: @@ -2598,45 +2938,148 @@ def _load_graph(p: str): return _jg.node_link_graph(data, edges="links"), data except TypeError: return _jg.node_link_graph(data), data + try: - G_cur, _ = _load_graph(_current_path) - G_oth, _ = _load_graph(_other_path) + current_graph, current_data = _load_graph(_current_path) + other_graph, other_data = _load_graph(_other_path) except Exception as exc: print(f"[graphify merge-driver] error loading graphs: {exc}", file=sys.stderr) sys.exit(1) # surface the conflict so git doesn't accept a corrupt merge - merged = _nx.compose(G_cur, G_oth) - if merged.number_of_nodes() > _MERGE_MAX_NODES: + + # Class-safe union. Reuse the Phase A normalizer so a class mismatch + # (e.g. simple current + MultiDiGraph other) can never reach + # ``nx.compose`` raising NetworkXError. ``normalize_graphs_for_global`` + # with no explicit target infers the target (multidigraph if either + # side is multi; else digraph if either directed; else simple) and + # returns BOTH inputs realized as that one class, in order. + from graphify.global_graph import ( + GlobalGraphRecoveryError, + detect_pre_profile, + normalize_graphs_for_global, + refuse_pre_profile_upgrade, + ) + + try: + (current_norm, other_norm), target_type = normalize_graphs_for_global( + [current_graph, other_graph] + ) + except Exception as exc: + print(f"[graphify merge-driver] error normalizing graphs: {exc}", file=sys.stderr) + sys.exit(1) + + # Recovery refusal: refuse to UPGRADE a pre-profile input (one with no + # graphify_profile / multigraph / directed markers — possibly already a + # silently-collapsed simple graph) to a multidigraph target, since the + # lost parallel edges cannot be reconstructed in place. Leave both files + # unmutated. Only the irreversible multidigraph upgrade is refused; + # simple/digraph merges of a pre-profile graph proceed normally. + # Only the OVERWRITTEN file (current) is protected: an in-place + # multidigraph upgrade of a pre-profile current graph is irreversible. + # `other` is read-only (merged in, never rewritten), so its pre-profile + # status implies no unreconstructable in-place loss and must not block. + if target_type == "multidigraph": + try: + refuse_pre_profile_upgrade( + current_data, + target_type, + graph_label="merge target graph", + graph_path=_current_path, + recovery_hint=( + "Regenerate or recreate this graph.json from source before retrying " + "the merge, or resolve the file manually from source-backed inputs" + ), + ) + except GlobalGraphRecoveryError as exc: + print(f"[graphify merge-driver] {exc}", file=sys.stderr) + sys.exit(1) + if detect_pre_profile(current_data): + # Defensive: detect_pre_profile is the same predicate + # refuse_pre_profile_upgrade uses, so the branch above + # already exited; this guards against future divergence. + print( + f"[graphify merge-driver] refusing to upgrade pre-profile graph " + f"{_current_path} to multidigraph", + file=sys.stderr, + ) + sys.exit(1) + + # KEY-AWARE compose, mirroring graphify.global_graph.global_add: start + # from the normalized current graph and replay the normalized other + # graph. For a multidigraph target iterate ``edges(keys=True)`` and + # ``add_edge(u, v, key=key, ...)`` so parallel edges survive distinctly + # AND a repeated merge of the same inputs overwrites the same + # ``(u, v, key)`` slots instead of accumulating fresh auto-int keys — + # that keyless drift is exactly what makes a naive ``nx.compose`` merge + # non-idempotent. Simple/digraph targets keep one edge per pair. + merged_graph = current_norm + for _node, _ndata in other_norm.nodes(data=True): + merged_graph.add_node(_node, **_ndata) + if isinstance(other_norm, (_nx.MultiGraph, _nx.MultiDiGraph)): + for _u, _v, _key, _edata in other_norm.edges(keys=True, data=True): + merged_graph.add_edge(_u, _v, key=_key, **_edata) + else: + for _u, _v, _edata in other_norm.edges(data=True): + merged_graph.add_edge(_u, _v, **_edata) + + if merged_graph.number_of_nodes() > _MERGE_MAX_NODES: print( - f"[graphify merge-driver] merged graph has {merged.number_of_nodes()} nodes, " + f"[graphify merge-driver] merged graph has {merged_graph.number_of_nodes()} nodes, " f"exceeds {_MERGE_MAX_NODES}-node cap; aborting merge.", file=sys.stderr, ) sys.exit(1) - try: - out_data = _jg.node_link_data(merged, edges="links") - except TypeError: - out_data = _jg.node_link_data(merged) - Path(_current_path).write_text(json.dumps(out_data, indent=2), encoding="utf-8") + + # Back up the existing target before the irreversible overwrite, then + # write via ``to_json`` so the graphify_profile (graph_type) is persisted + # and round-trips on the next load. ``force=True`` because a merge + # legitimately overwrites ``current`` (the shrink-guard would otherwise + # block a merge that drops nodes via dedup); ``communities={}`` because a + # merged graph is not (re)clustered here. + from graphify.export import to_json as _to_json + + _backup_merge_target(Path(_current_path)) + _to_json(merged_graph, {}, _current_path, force=True) sys.exit(0) elif cmd == "merge-graphs": # graphify merge-graphs graph1.json graph2.json ... --out merged.json + # [--multigraph | --simple] + # Optional target flag controls the merged graph class. By DEFAULT the + # target is INFERRED (multidigraph if any input is multi; else digraph if + # any directed; else simple) so multigraph inputs never silently + # collapse. --simple forces a simple projection (the normalizer warns and + # collapses parallel edges — an explicit, audible choice); --multigraph + # forces a keyed multidigraph. args = sys.argv[2:] graph_paths: list[Path] = [] out_path = Path(_GRAPHIFY_OUT) / "merged-graph.json" + explicit_target: str | None = None i = 0 while i < len(args): if args[i] == "--out" and i + 1 < len(args): - out_path = Path(args[i + 1]); i += 2 + out_path = Path(args[i + 1]) + i += 2 + elif args[i] == "--multigraph": + explicit_target = "multidigraph" + i += 1 + elif args[i] == "--simple": + explicit_target = "simple" + i += 1 else: - graph_paths.append(Path(args[i])); i += 1 + graph_paths.append(Path(args[i])) + i += 1 if len(graph_paths) < 2: - print("Usage: graphify merge-graphs [...] [--out merged.json]", file=sys.stderr) + print( + "Usage: graphify merge-graphs [...] " + "[--out merged.json] [--multigraph | --simple]", + file=sys.stderr, + ) sys.exit(1) import networkx as _nx from networkx.readwrite import json_graph as _jg from graphify.build import prefix_graph_for_global as _prefix - graphs = [] + + loaded_graphs = [] for gp in graph_paths: if not gp.exists(): print(f"error: not found: {gp}", file=sys.stderr) @@ -2648,27 +3091,70 @@ def _load_graph(p: str): if "links" not in data and "edges" in data: data = dict(data, links=data["edges"]) try: - G = _jg.node_link_graph(data, edges="links") + input_graph = _jg.node_link_graph(data, edges="links") except TypeError: - G = _jg.node_link_graph(data) - graphs.append(G) - merged = _nx.Graph() - for G, gp in zip(graphs, graph_paths): + input_graph = _jg.node_link_graph(data) + loaded_graphs.append(input_graph) + # Prefix every input for cross-repo isolation FIRST (relabel preserves + # multigraph keys), then normalize the whole batch to one common class + # via the Phase A helper. Replacing the old hard-coded ``_nx.Graph()`` + # start removes the silent collapse: the resolved class is multidigraph + # if any input is multi (unless --simple is given, which warns + projects). + from graphify.global_graph import normalize_graphs_for_global + + prefixed_graphs = [] + for input_graph, gp in zip(loaded_graphs, graph_paths): repo_tag = gp.parent.parent.name # graphify-out/../ → repo dir name - prefixed = _prefix(G, repo_tag) - merged = _nx.compose(merged, prefixed) + prefixed_graphs.append(_prefix(input_graph, repo_tag)) + try: - out_data = _jg.node_link_data(merged, edges="links") - except TypeError: - out_data = _jg.node_link_data(merged) + normalized_graphs, target_type = normalize_graphs_for_global( + prefixed_graphs, target_type=explicit_target + ) + except Exception as exc: + print(f"[graphify merge-graphs] error normalizing graphs: {exc}", file=sys.stderr) + sys.exit(1) + + # KEY-AWARE compose into the resolved class (mirrors global_add). For a + # multidigraph target replay ``edges(keys=True)`` so parallel edges keep + # distinct keys and a repeated merge of the same inputs overwrites the + # same ``(u, v, key)`` slots (idempotent) instead of drifting on fresh + # auto-int keys. Prefixing already isolates repos, so cross-graph node + # collisions cannot occur; same-class composition is safe. + target_cls = type(normalized_graphs[0]) if normalized_graphs else _nx.Graph + merged_graph = target_cls() + for ng in normalized_graphs: + merged_graph.graph.update(ng.graph) + for _node, _ndata in ng.nodes(data=True): + merged_graph.add_node(_node, **_ndata) + if isinstance(ng, (_nx.MultiGraph, _nx.MultiDiGraph)): + for _u, _v, _key, _edata in ng.edges(keys=True, data=True): + merged_graph.add_edge(_u, _v, key=_key, **_edata) + else: + for _u, _v, _edata in ng.edges(data=True): + merged_graph.add_edge(_u, _v, **_edata) + + # Back up an existing target before overwrite, then write via ``to_json`` + # so the graphify_profile (graph_type=target_type) persists and the class + # round-trips on reload. ``force=True``: merge-graphs deliberately + # (re)writes the merged output; ``communities={}``: not clustered here. + from graphify.export import to_json as _to_json + out_path.parent.mkdir(parents=True, exist_ok=True) - out_path.write_text(json.dumps(out_data, indent=2), encoding="utf-8") - print(f"Merged {len(graphs)} graphs -> {merged.number_of_nodes()} nodes, {merged.number_of_edges()} edges") + _backup_merge_target(out_path) + _to_json(merged_graph, {}, str(out_path), force=True) + print( + f"Merged {len(loaded_graphs)} graphs -> {merged_graph.number_of_nodes()} nodes, " + f"{merged_graph.number_of_edges()} edges ({target_type})" + ) print(f"Written to: {out_path}") elif cmd == "clone": if len(sys.argv) < 3: - print("Usage: graphify clone [--branch ] [--out ]", file=sys.stderr) + print( + "Usage: graphify clone [--branch ] [--out ]", + file=sys.stderr, + ) sys.exit(1) url = sys.argv[2] branch: str | None = None @@ -2677,9 +3163,11 @@ def _load_graph(p: str): i = 0 while i < len(args): if args[i] == "--branch" and i + 1 < len(args): - branch = args[i + 1]; i += 2 + branch = args[i + 1] + i += 2 elif args[i] == "--out" and i + 1 < len(args): - out_dir = Path(args[i + 1]); i += 2 + out_dir = Path(args[i + 1]) + i += 2 else: i += 1 local_path = _clone_repo(url, branch=branch, out_dir=out_dir) @@ -2689,15 +3177,29 @@ def _load_graph(p: str): subcmd = sys.argv[2] if len(sys.argv) > 2 else "" if subcmd not in ("html", "callflow-html", "obsidian", "wiki", "svg", "graphml", "neo4j"): print("Usage: graphify export ", file=sys.stderr) - print(" html [--graph PATH] [--labels PATH] [--node-limit N] [--no-viz]", file=sys.stderr) - print(" callflow-html [GRAPH|DIR] [--graph PATH] [--labels PATH] [--report PATH] [--sections PATH] [--output HTML]", file=sys.stderr) - print(" [--lang auto|zh-CN|en] [--max-sections N] [--diagram-scale N]", file=sys.stderr) + print( + " html [--graph PATH] [--labels PATH] [--node-limit N] [--no-viz]", + file=sys.stderr, + ) + print( + " callflow-html [GRAPH|DIR] [--graph PATH] [--labels PATH] [--report PATH] [--sections PATH] [--output HTML]", + file=sys.stderr, + ) + print( + " [--lang auto|zh-CN|en] [--max-sections N] [--diagram-scale N]", + file=sys.stderr, + ) print(" obsidian [--graph PATH] [--labels PATH] [--dir PATH]", file=sys.stderr) print(" wiki [--graph PATH] [--labels PATH]", file=sys.stderr) print(" svg [--graph PATH] [--labels PATH]", file=sys.stderr) print(" graphml [--graph PATH]", file=sys.stderr) - print(" neo4j [--graph PATH] [--push URI] [--user U] [--password P]", file=sys.stderr) - print(" (or set NEO4J_PASSWORD instead of --password to keep it off argv)", file=sys.stderr) + print( + " neo4j [--graph PATH] [--push URI] [--user U] [--password P]", file=sys.stderr + ) + print( + " (or set NEO4J_PASSWORD instead of --password to keep it off argv)", + file=sys.stderr, + ) sys.exit(1) # Parse shared args @@ -2741,27 +3243,37 @@ def _load_graph(p: str): report_path_explicit = True i += 2 elif a == "--sections" and i + 1 < len(args): - sections_path = Path(args[i + 1]); i += 2 + sections_path = Path(args[i + 1]) + i += 2 elif a == "--output" and i + 1 < len(args): callflow_output = Path(args[i + 1]).expanduser() if not callflow_output.is_absolute(): callflow_output = Path.cwd() / callflow_output i += 2 elif a == "--lang" and i + 1 < len(args): - callflow_lang = args[i + 1]; i += 2 + callflow_lang = args[i + 1] + i += 2 elif a == "--max-sections" and i + 1 < len(args): - callflow_max_sections = int(args[i + 1]); i += 2 + callflow_max_sections = int(args[i + 1]) + i += 2 elif a == "--diagram-scale" and i + 1 < len(args): - callflow_diagram_scale = float(args[i + 1]); i += 2 + callflow_diagram_scale = float(args[i + 1]) + i += 2 elif a == "--max-diagram-nodes" and i + 1 < len(args): - callflow_max_diagram_nodes = int(args[i + 1]); i += 2 + callflow_max_diagram_nodes = int(args[i + 1]) + i += 2 elif a == "--max-diagram-edges" and i + 1 < len(args): - callflow_max_diagram_edges = int(args[i + 1]); i += 2 + callflow_max_diagram_edges = int(args[i + 1]) + i += 2 elif a in ("-h", "--help") and subcmd == "callflow-html": - print("Usage: graphify export callflow-html [GRAPH|DIR] [--graph PATH] [--labels PATH]") + print( + "Usage: graphify export callflow-html [GRAPH|DIR] [--graph PATH] [--labels PATH]" + ) print(" --report PATH path to GRAPH_REPORT.md") print(" --sections PATH JSON section definitions") - print(" --output HTML output path (default graphify-out/-callflow.html)") + print( + " --output HTML output path (default graphify-out/-callflow.html)" + ) print(" --lang LANG auto, zh-CN, en, etc. (default auto)") print(" --max-sections N maximum auto-derived sections (default 15)") print(" --diagram-scale N Mermaid diagram scale (default 1.0)") @@ -2769,17 +3281,23 @@ def _load_graph(p: str): print(" --max-diagram-edges N representative edges per section (default 24)") sys.exit(0) elif a == "--node-limit" and i + 1 < len(args): - node_limit = int(args[i + 1]); i += 2 + node_limit = int(args[i + 1]) + i += 2 elif a == "--no-viz": - no_viz = True; i += 1 + no_viz = True + i += 1 elif a == "--dir" and i + 1 < len(args): - obsidian_dir = Path(args[i + 1]); i += 2 + obsidian_dir = Path(args[i + 1]) + i += 2 elif a == "--push" and i + 1 < len(args): - neo4j_uri = args[i + 1]; i += 2 + neo4j_uri = args[i + 1] + i += 2 elif a == "--user" and i + 1 < len(args): - neo4j_user = args[i + 1]; i += 2 + neo4j_user = args[i + 1] + i += 2 elif a == "--password" and i + 1 < len(args): - neo4j_password = args[i + 1]; i += 2 + neo4j_password = args[i + 1] + i += 2 elif subcmd == "callflow-html" and not a.startswith("-") and not graph_path_explicit: candidate = Path(a) if candidate.name == "graph.json" or candidate.suffix.lower() == ".json": @@ -2804,11 +3322,15 @@ def _load_graph(p: str): report_path = report_path.expanduser() if not graph_path.exists(): - print(f"error: graph not found: {graph_path}. Run /graphify first.", file=sys.stderr) + print( + f"error: graph not found: {graph_path}. Run /graphify first.", + file=sys.stderr, + ) sys.exit(1) if subcmd == "callflow-html": from graphify.callflow_html import write_callflow_html as _write_callflow_html + out = _write_callflow_html( graph=graph_path, report=report_path, @@ -2826,7 +3348,6 @@ def _load_graph(p: str): sys.exit(0) from networkx.readwrite import json_graph as _jg - from graphify.build import build_from_json as _bfj _enforce_graph_size_cap_or_exit(graph_path) _raw = json.loads(graph_path.read_text(encoding="utf-8")) @@ -2873,36 +3394,52 @@ def _load_graph(p: str): labels: dict[int, str] = {} if labels_path.exists(): - labels = {int(k): v for k, v in json.loads(labels_path.read_text(encoding="utf-8")).items()} + labels = { + int(k): v for k, v in json.loads(labels_path.read_text(encoding="utf-8")).items() + } out_dir = graph_path.parent if subcmd == "html": from graphify.export import to_html as _to_html + if no_viz: html_target = out_dir / "graph.html" if html_target.exists(): html_target.unlink() print("--no-viz: skipped graph.html") else: - _to_html(G, communities, str(out_dir / "graph.html"), - community_labels=labels or None, node_limit=node_limit) + _to_html( + G, + communities, + str(out_dir / "graph.html"), + community_labels=labels or None, + node_limit=node_limit, + ) if G.number_of_nodes() <= node_limit: - print(f"graph.html written - open in any browser, no server needed") + print("graph.html written - open in any browser, no server needed") elif subcmd == "obsidian": from graphify.export import to_obsidian as _to_obsidian, to_canvas as _to_canvas - n = _to_obsidian(G, communities, str(obsidian_dir), - community_labels=labels or None, cohesion=cohesion or None) + + n = _to_obsidian( + G, + communities, + str(obsidian_dir), + community_labels=labels or None, + cohesion=cohesion or None, + ) print(f"Obsidian vault: {n} notes in {obsidian_dir}/") - _to_canvas(G, communities, str(obsidian_dir / "graph.canvas"), - community_labels=labels or None) + _to_canvas( + G, communities, str(obsidian_dir / "graph.canvas"), community_labels=labels or None + ) print(f"Canvas: {obsidian_dir}/graph.canvas") print(f"Open {obsidian_dir}/ as a vault in Obsidian.") elif subcmd == "wiki": from graphify.wiki import to_wiki as _to_wiki from graphify.analyze import god_nodes as _god_nodes + if not communities: print( "error: .graphify_analysis.json is missing or empty — refusing to export wiki to prevent data loss.\n" @@ -2912,39 +3449,53 @@ def _load_graph(p: str): sys.exit(1) if not gods_data: gods_data = _god_nodes(G) - n = _to_wiki(G, communities, str(out_dir / "wiki"), - community_labels=labels or None, cohesion=cohesion or None, - god_nodes_data=gods_data) + n = _to_wiki( + G, + communities, + str(out_dir / "wiki"), + community_labels=labels or None, + cohesion=cohesion or None, + god_nodes_data=gods_data, + ) print(f"Wiki: {n} articles written to {out_dir}/wiki/") print(f" {out_dir}/wiki/index.md -> agent entry point") elif subcmd == "svg": from graphify.export import to_svg as _to_svg - _to_svg(G, communities, str(out_dir / "graph.svg"), - community_labels=labels or None) - print(f"graph.svg written - embeds in Obsidian, Notion, GitHub READMEs") + + _to_svg(G, communities, str(out_dir / "graph.svg"), community_labels=labels or None) + print("graph.svg written - embeds in Obsidian, Notion, GitHub READMEs") elif subcmd == "graphml": from graphify.export import to_graphml as _to_graphml + _to_graphml(G, communities, str(out_dir / "graph.graphml")) - print(f"graph.graphml written - open in Gephi, yEd, or any GraphML tool") + print("graph.graphml written - open in Gephi, yEd, or any GraphML tool") elif subcmd == "neo4j": if neo4j_uri: from graphify.export import push_to_neo4j as _push + if neo4j_password is None: print("error: --password required for --push", file=sys.stderr) sys.exit(1) - result = _push(G, uri=neo4j_uri, user=neo4j_user, - password=neo4j_password, communities=communities) + result = _push( + G, + uri=neo4j_uri, + user=neo4j_user, + password=neo4j_password, + communities=communities, + ) print(f"Pushed to Neo4j: {result['nodes']} nodes, {result['edges']} edges") else: from graphify.export import to_cypher as _to_cypher + _to_cypher(G, str(out_dir / "cypher.txt")) print(f"cypher.txt written - import with: cypher-shell < {out_dir}/cypher.txt") elif cmd == "benchmark": from graphify.benchmark import run_benchmark, print_benchmark + graph_path = sys.argv[2] if len(sys.argv) > 2 else "graphify-out/graph.json" _enforce_graph_size_cap_or_exit(Path(graph_path)) # Try to load corpus_words from detect output @@ -2954,8 +3505,11 @@ def _load_graph(p: str): try: detect_data = json.loads(detect_path.read_text(encoding="utf-8")) corpus_words = detect_data.get("total_words") - except Exception: - pass + except Exception as exc: + print( + f"[graphify] warning: could not read .graphify_detect.json: {exc}", + file=sys.stderr, + ) result = run_benchmark(graph_path, corpus_words=corpus_words) print_benchmark(result) @@ -2967,6 +3521,7 @@ def _load_graph(p: str): global_list as _global_list, global_path as _global_path, ) + if subcmd == "add": # graphify global add [--as ] args = sys.argv[3:] @@ -2975,9 +3530,11 @@ def _load_graph(p: str): i = 0 while i < len(args): if args[i] == "--as" and i + 1 < len(args): - tag = args[i + 1]; i += 2 + tag = args[i + 1] + i += 2 elif not source: - source = Path(args[i]); i += 1 + source = Path(args[i]) + i += 1 else: i += 1 if not source: @@ -2989,19 +3546,24 @@ def _load_graph(p: str): if result["skipped"]: print(f"'{tag}' unchanged since last add - global graph not modified.") else: - print(f"Added '{tag}' to global graph: +{result['nodes_added']} nodes, " - f"-{result['nodes_removed']} pruned. Global: {_global_path()}") + print( + f"Added '{tag}' to global graph: +{result['nodes_added']} nodes, " + f"-{result['nodes_removed']} pruned. Global: {_global_path()}" + ) except Exception as exc: - print(f"error: {exc}", file=sys.stderr); sys.exit(1) + print(f"error: {exc}", file=sys.stderr) + sys.exit(1) elif subcmd == "remove": tag = sys.argv[3] if len(sys.argv) > 3 else "" if not tag: - print("Usage: graphify global remove ", file=sys.stderr); sys.exit(1) + print("Usage: graphify global remove ", file=sys.stderr) + sys.exit(1) try: removed = _global_remove(tag) print(f"Removed '{tag}' from global graph ({removed} nodes pruned).") except KeyError as exc: - print(f"error: {exc}", file=sys.stderr); sys.exit(1) + print(f"error: {exc}", file=sys.stderr) + sys.exit(1) elif subcmd == "list": repos = _global_list() if not repos: @@ -3009,12 +3571,14 @@ def _load_graph(p: str): else: print(f"Global graph: {_global_path()}") for tag, info in repos.items(): - print(f" {tag}: {info.get('node_count', '?')} nodes, added {info.get('added_at', '?')[:10]}") + print( + f" {tag}: {info.get('node_count', '?')} nodes, added {info.get('added_at', '?')[:10]}" + ) elif subcmd == "path": print(_global_path()) else: - print("Usage: graphify global [add|remove|list|path]", file=sys.stderr); sys.exit(1) - + print("Usage: graphify global [add|remove|list|path]", file=sys.stderr) + sys.exit(1) elif cmd == "extract": # Headless full-pipeline extraction for CI / scripts (#698). # Runs detect -> AST extraction on code -> semantic LLM extraction on @@ -3026,6 +3590,7 @@ def _load_graph(p: str): print( "Usage: graphify extract [--backend gemini|kimi|claude|openai|deepseek|ollama] " "[--model M] [--mode deep] [--out DIR] [--google-workspace] [--no-cluster] " + "[--multigraph|--simple] " "[--max-workers N] [--token-budget N] [--max-concurrency N] " "[--api-timeout S]", file=sys.stderr, @@ -3046,6 +3611,13 @@ def _load_graph(p: str): google_workspace = False global_merge = False global_repo_tag: str | None = None + # Graph class selection (PR 9). None = STICKY: inherit the existing + # graphify-out/graph.json profile (multidigraph stays multidigraph, + # otherwise the historical simple/directed default). True = force a + # keyed MultiDiGraph (parallel edges). False = explicit downgrade to a + # simple graph even when the existing graph.json is a multigraph. + # --multigraph and --simple are mutually exclusive. + multigraph_flag: bool | None = None # Performance/tuning knobs (issue #792). None means "use library default". cli_max_workers: int | None = None cli_token_budget: int | None = None @@ -3083,59 +3655,104 @@ def _parse_float(name: str, raw: str) -> float: while i < len(args): a = args[i] if a == "--backend" and i + 1 < len(args): - backend = args[i + 1]; i += 2 + backend = args[i + 1] + i += 2 elif a.startswith("--backend="): - backend = a.split("=", 1)[1]; i += 1 + backend = a.split("=", 1)[1] + i += 1 elif a == "--model" and i + 1 < len(args): - model = args[i + 1]; i += 2 + model = args[i + 1] + i += 2 elif a.startswith("--model="): - model = a.split("=", 1)[1]; i += 1 + model = a.split("=", 1)[1] + i += 1 elif a == "--mode" and i + 1 < len(args): - extract_mode = args[i + 1]; i += 2 + extract_mode = args[i + 1] + i += 2 elif a.startswith("--mode="): - extract_mode = a.split("=", 1)[1]; i += 1 + extract_mode = a.split("=", 1)[1] + i += 1 elif a == "--out" and i + 1 < len(args): - out_dir = Path(args[i + 1]); i += 2 + out_dir = Path(args[i + 1]) + i += 2 elif a.startswith("--out="): - out_dir = Path(a.split("=", 1)[1]); i += 1 + out_dir = Path(a.split("=", 1)[1]) + i += 1 elif a == "--no-cluster": - no_cluster = True; i += 1 + no_cluster = True + i += 1 elif a == "--dedup-llm": - dedup_llm = True; i += 1 + dedup_llm = True + i += 1 elif a == "--google-workspace": - google_workspace = True; i += 1 + google_workspace = True + i += 1 elif a == "--global": - global_merge = True; i += 1 + global_merge = True + i += 1 + elif a == "--multigraph": + if multigraph_flag is False: + print( + "error: --multigraph and --simple are mutually exclusive", + file=sys.stderr, + ) + sys.exit(2) + multigraph_flag = True + i += 1 + elif a == "--simple": + if multigraph_flag is True: + print( + "error: --multigraph and --simple are mutually exclusive", + file=sys.stderr, + ) + sys.exit(2) + multigraph_flag = False + i += 1 elif a == "--as" and i + 1 < len(args): - global_repo_tag = args[i + 1]; i += 2 + global_repo_tag = args[i + 1] + i += 2 elif a == "--max-workers" and i + 1 < len(args): - cli_max_workers = _parse_int("--max-workers", args[i + 1]); i += 2 + cli_max_workers = _parse_int("--max-workers", args[i + 1]) + i += 2 elif a.startswith("--max-workers="): - cli_max_workers = _parse_int("--max-workers", a.split("=", 1)[1]); i += 1 + cli_max_workers = _parse_int("--max-workers", a.split("=", 1)[1]) + i += 1 elif a == "--token-budget" and i + 1 < len(args): - cli_token_budget = _parse_int("--token-budget", args[i + 1]); i += 2 + cli_token_budget = _parse_int("--token-budget", args[i + 1]) + i += 2 elif a.startswith("--token-budget="): - cli_token_budget = _parse_int("--token-budget", a.split("=", 1)[1]); i += 1 + cli_token_budget = _parse_int("--token-budget", a.split("=", 1)[1]) + i += 1 elif a == "--max-concurrency" and i + 1 < len(args): - cli_max_concurrency = _parse_int("--max-concurrency", args[i + 1]); i += 2 + cli_max_concurrency = _parse_int("--max-concurrency", args[i + 1]) + i += 2 elif a.startswith("--max-concurrency="): - cli_max_concurrency = _parse_int("--max-concurrency", a.split("=", 1)[1]); i += 1 + cli_max_concurrency = _parse_int("--max-concurrency", a.split("=", 1)[1]) + i += 1 elif a == "--api-timeout" and i + 1 < len(args): - cli_api_timeout = _parse_float("--api-timeout", args[i + 1]); i += 2 + cli_api_timeout = _parse_float("--api-timeout", args[i + 1]) + i += 2 elif a.startswith("--api-timeout="): - cli_api_timeout = _parse_float("--api-timeout", a.split("=", 1)[1]); i += 1 + cli_api_timeout = _parse_float("--api-timeout", a.split("=", 1)[1]) + i += 1 elif a == "--resolution" and i + 1 < len(args): - cli_resolution = _parse_float("--resolution", args[i + 1]); i += 2 + cli_resolution = _parse_float("--resolution", args[i + 1]) + i += 2 elif a.startswith("--resolution="): - cli_resolution = _parse_float("--resolution", a.split("=", 1)[1]); i += 1 + cli_resolution = _parse_float("--resolution", a.split("=", 1)[1]) + i += 1 elif a == "--exclude-hubs" and i + 1 < len(args): - cli_exclude_hubs = float(args[i + 1]); i += 2 + cli_exclude_hubs = float(args[i + 1]) + i += 2 elif a.startswith("--exclude-hubs="): - cli_exclude_hubs = float(a.split("=", 1)[1]); i += 1 + cli_exclude_hubs = float(a.split("=", 1)[1]) + i += 1 elif a == "--exclude" and i + 1 < len(args): - cli_excludes.append(args[i + 1]); i += 2 + cli_excludes.append(args[i + 1]) + i += 2 elif a.startswith("--exclude="): - cli_excludes.append(a.split("=", 1)[1]); i += 1 + cli_excludes.append(a.split("=", 1)[1]) + i += 1 else: i += 1 @@ -3170,6 +3787,7 @@ def _parse_float(name: str, raw: str) -> float: _format_backend_env_keys, _get_backend_api_key, ) + if backend is None: backend = _detect_backend() if backend is None: @@ -3183,8 +3801,7 @@ def _parse_float(name: str, raw: str) -> float: sys.exit(1) if backend not in _BACKENDS: print( - f"error: unknown backend '{backend}'. " - f"Available: {', '.join(sorted(_BACKENDS))}", + f"error: unknown backend '{backend}'. Available: {', '.join(sorted(_BACKENDS))}", file=sys.stderr, ) sys.exit(1) @@ -3196,18 +3813,18 @@ def _parse_float(name: str, raw: str) -> float: allow_no_key = False if backend == "ollama": from urllib.parse import urlparse - ollama_url = os.environ.get( - "OLLAMA_BASE_URL", - _BACKENDS["ollama"].get("base_url", ""), + + ollama_url = str( + os.environ.get( + "OLLAMA_BASE_URL", + _BACKENDS["ollama"].get("base_url", ""), + ) ) try: host = (urlparse(ollama_url).hostname or "").lower() except Exception: host = "" - allow_no_key = ( - host in ("localhost", "127.0.0.1", "::1") - or host.startswith("127.") - ) + allow_no_key = host in ("localhost", "127.0.0.1", "::1") or host.startswith("127.") elif backend == "bedrock": allow_no_key = bool( os.environ.get("AWS_PROFILE") @@ -3217,6 +3834,7 @@ def _parse_float(name: str, raw: str) -> float: ) elif backend == "claude-cli": import shutil as _shutil + allow_no_key = _shutil.which("claude") is not None if not allow_no_key: print( @@ -3235,7 +3853,7 @@ def _parse_float(name: str, raw: str) -> float: # Resolve output dir. The user-facing contract is "/graphify-out/" # so a fresh checkout writes graphify-out/ at the project root, matching # the skill.md pipeline. - out_root = (out_dir.resolve() if out_dir else target) + out_root = out_dir.resolve() if out_dir else target graphify_out = out_root / "graphify-out" graphify_out.mkdir(parents=True, exist_ok=True) @@ -3244,6 +3862,7 @@ def _parse_float(name: str, raw: str) -> float: detect_incremental as _detect_incremental, save_manifest as _save_manifest, ) + manifest_path = graphify_out / "manifest.json" existing_graph_path = graphify_out / "graph.json" incremental_mode = manifest_path.exists() and existing_graph_path.exists() @@ -3258,7 +3877,11 @@ def _parse_float(name: str, raw: str) -> float: ) else: print(f"[graphify extract] scanning {target}") - detection = _detect(target, google_workspace=google_workspace or None, extra_excludes=cli_excludes or None) + detection = _detect( + target, + google_workspace=google_workspace or None, + extra_excludes=cli_excludes or None, + ) files_by_type = detection.get("files", {}) if incremental_mode: @@ -3296,6 +3919,7 @@ def _parse_float(name: str, raw: str) -> float: ast_result: dict = {"nodes": [], "edges": [], "input_tokens": 0, "output_tokens": 0} if code_files: from graphify.extract import extract as _ast_extract + ast_kwargs: dict = {"cache_root": target} if cli_max_workers is not None: ast_kwargs["max_workers"] = cli_max_workers @@ -3311,16 +3935,20 @@ def _parse_float(name: str, raw: str) -> float: check_semantic_cache as _check_semantic_cache, save_semantic_cache as _save_semantic_cache, ) + sem_result: dict = { - "nodes": [], "edges": [], "hyperedges": [], - "input_tokens": 0, "output_tokens": 0, + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, } sem_cache_hits = 0 sem_cache_misses = 0 if semantic_files: sem_paths_str = [str(p) for p in semantic_files] - cached_nodes, cached_edges, cached_hyperedges, uncached_paths = ( - _check_semantic_cache(sem_paths_str, root=target) + cached_nodes, cached_edges, cached_hyperedges, uncached_paths = _check_semantic_cache( + sem_paths_str, root=target ) sem_cache_hits = len(semantic_files) - len(uncached_paths) sem_cache_misses = len(uncached_paths) @@ -3328,10 +3956,14 @@ def _parse_float(name: str, raw: str) -> float: sem_result["edges"].extend(cached_edges) sem_result["hyperedges"].extend(cached_hyperedges) if sem_cache_hits: - print(f"[graphify extract] semantic cache: {sem_cache_hits} hit / {sem_cache_misses} miss") + print( + f"[graphify extract] semantic cache: {sem_cache_hits} hit / {sem_cache_misses} miss" + ) if uncached_paths: - print(f"[graphify extract] semantic extraction on {len(uncached_paths)} files via {backend}...") + print( + f"[graphify extract] semantic extraction on {len(uncached_paths)} files via {backend}..." + ) corpus_kwargs: dict = { "backend": backend, "model": model, @@ -3349,6 +3981,7 @@ def _parse_float(name: str, raw: str) -> float: # Also track per-chunk success so we can fail loudly when # every chunk errors (e.g. missing backend SDK package). _chunk_stats = {"total": 0, "succeeded": 0} + def _progress(idx: int, total: int, _result: dict) -> None: _chunk_stats["total"] = total _chunk_stats["succeeded"] += 1 @@ -3356,6 +3989,7 @@ def _progress(idx: int, total: int, _result: dict) -> None: f"[graphify extract] chunk {idx + 1}/{total} done", flush=True, ) + corpus_kwargs["on_chunk_done"] = _progress try: @@ -3371,7 +4005,13 @@ def _progress(idx: int, total: int, _result: dict) -> None: f"[graphify extract] semantic extraction failed: {exc}", file=sys.stderr, ) - fresh = {"nodes": [], "edges": [], "hyperedges": [], "input_tokens": 0, "output_tokens": 0} + fresh = { + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } # on_chunk_done only fires after a chunk succeeds. If fresh # semantic extraction was requested and no chunks completed, @@ -3393,7 +4033,10 @@ def _progress(idx: int, total: int, _result: dict) -> None: root=target, ) except Exception as exc: - print(f"[graphify extract] warning: could not write semantic cache: {exc}", file=sys.stderr) + print( + f"[graphify extract] warning: could not write semantic cache: {exc}", + file=sys.stderr, + ) sem_result["nodes"].extend(fresh.get("nodes", [])) sem_result["edges"].extend(fresh.get("edges", [])) sem_result["hyperedges"].extend(fresh.get("hyperedges", [])) @@ -3409,7 +4052,8 @@ def _progress(idx: int, total: int, _result: dict) -> None: "edges": list(ast_result.get("edges", [])) + list(sem_result.get("edges", [])), "hyperedges": list(sem_result.get("hyperedges", [])), "input_tokens": ast_result.get("input_tokens", 0) + sem_result.get("input_tokens", 0), - "output_tokens": ast_result.get("output_tokens", 0) + sem_result.get("output_tokens", 0), + "output_tokens": ast_result.get("output_tokens", 0) + + sem_result.get("output_tokens", 0), } graph_json_path = graphify_out / "graph.json" @@ -3421,9 +4065,7 @@ def _progress(idx: int, total: int, _result: dict) -> None: # their semantic_hash empty so detect_incremental re-queues them (#933). _sem_extracted: set[str] = { n.get("source_file", "") for n in sem_result.get("nodes", []) - } | { - e.get("source_file", "") for e in sem_result.get("edges", []) - } + } | {e.get("source_file", "") for e in sem_result.get("edges", [])} _sem_extracted.discard("") _sem_types = {"document", "paper", "image"} _manifest_files = { @@ -3431,20 +4073,159 @@ def _progress(idx: int, total: int, _result: dict) -> None: for ftype, flist in files_by_type.items() } + # Resolve the effective graph class (PR 9 sticky profile). When neither + # --multigraph nor --simple is given the build must INHERIT the existing + # graph.json profile so a multigraph never silently downgrades to a + # simple graph on a default re-extract (mirrors watch._rebuild_code and + # build_merge's inherit-on-None contract). --multigraph forces multi, + # --simple forces a simple downgrade. + from graphify.watch import _existing_is_multigraph as _detect_multigraph + + _existing_multigraph = False + if existing_graph_path.exists(): + try: + _existing_data = json.loads(existing_graph_path.read_text(encoding="utf-8")) + _existing_multigraph = _detect_multigraph(_existing_data) + except Exception as exc: + print( + f"[graphify extract] warning: could not inspect existing graph.json " + f"profile ({exc}); treating as simple.", + file=sys.stderr, + ) + if multigraph_flag is None: + resolved_multigraph = _existing_multigraph + else: + resolved_multigraph = multigraph_flag + + # Lossy-projection warning: an EXPLICIT --simple over an existing + # multigraph graph.json collapses keyed parallel edges. This is an + # intentional downgrade, so warn loudly (not silently) and proceed. + if multigraph_flag is False and _existing_multigraph: + print( + "[graphify extract] WARNING: --simple requested over an existing " + "multigraph graph.json; parallel edges between the same pair will be " + "collapsed onto a single edge (lossy downgrade). Omit --simple to " + "preserve them, or re-extract with --multigraph.", + file=sys.stderr, + ) + + # Capability gate: surface the MultiDiGraph capability probe failure as a + # clean CLI error (exit 1) instead of letting the RuntimeError raised deep + # inside build_from_json escape as a traceback. The probe is cheap and + # lru_cached, so running it up front costs nothing on the happy path. + if resolved_multigraph: + from graphify.multigraph_compat import require_multigraph_capabilities + + try: + require_multigraph_capabilities() + except RuntimeError as exc: + print(str(exc), file=sys.stderr) + sys.exit(1) + if no_cluster: # --no-cluster: dump the raw merged extraction as graph.json. # No NetworkX, no community detection, no analysis sidecar. from graphify.export import backup_if_protected as _backup + _backup(graphify_out) - graph_json_path.write_text( - json.dumps(merged, indent=2), encoding="utf-8" - ) - cost = _estimate_cost( - backend, merged["input_tokens"], merged["output_tokens"] - ) + if incremental_mode: + # Incremental no-cluster scans are still deltas. If no files + # changed, ``merged`` is empty; writing it directly would erase + # the saved graph. Reuse build_merge so unchanged nodes/edges, + # deleted-file pruning, and sticky multigraph profile handling + # match the clustered path while still skipping clustering. + from graphify.build import build_merge as _nc_build_merge + from graphify.export import to_json as _nc_to_json + + _nc_graph = _nc_build_merge( + [merged], + graph_path=existing_graph_path, + prune_sources=deleted_files or None, + dedup=True, + dedup_llm_backend=backend if dedup_llm else None, + root=target, + multigraph=multigraph_flag, + ) + # RISK 4 — Guard 1 signaling: to_json's empty-merge floor returns + # False (and PRESERVES the populated graph.json) when the merged + # graph has 0 nodes over a populated file. force=True bypasses the + # shrink guard (Guard 2), so under force the ONLY False return is + # that 0-node floor — never a legitimate non-zero shrink. Honor the + # refusal: do NOT fall through to the success line. A 0-node merge + # over a populated graph is an aborted extraction, so signal it + # (exit 1) instead of falsely reporting "wrote ... 0 nodes". + if not _nc_to_json(_nc_graph, {}, str(graph_json_path), force=True): + print( + "[graphify extract] extraction aborted: the merge produced an " + "empty (0-node) graph; the previous graph.json was preserved.", + file=sys.stderr, + ) + sys.exit(1) + n_nodes = _nc_graph.number_of_nodes() + n_edges = _nc_graph.number_of_edges() + elif resolved_multigraph: + # A multigraph profile (sticky-inherited or explicit --multigraph) + # cannot be expressed by the raw-merged dump: parallel edges would + # be written without keys and the file would lack the multigraph + # flag + graphify_profile, silently collapsing on the next load. + # Build a keyed MultiDiGraph and serialize it via to_json (with no + # communities) so the no-cluster file still round-trips losslessly. + from graphify.build import build_from_json as _build_from_json + from graphify.export import to_json as _nc_to_json + + _nc_graph = _build_from_json(merged, multigraph=True, root=target) + # RISK 4 — Guard 1 signaling (multigraph sibling): identical to the + # incremental site above. A non-incremental run can still see a + # populated graph.json on disk (graph.json present, manifest.json + # absent), so to_json's 0-node floor can refuse and preserve it. + # Honor the False return — exit 1 rather than print the misleading + # "wrote ... 0 nodes" success line. Under force=True the only False + # return is the 0-node floor, so a legitimate non-zero multigraph + # build (True) is completely unaffected. + if not _nc_to_json(_nc_graph, {}, str(graph_json_path), force=True): + print( + "[graphify extract] extraction aborted: the merge produced an " + "empty (0-node) graph; the previous graph.json was preserved.", + file=sys.stderr, + ) + sys.exit(1) + n_nodes = _nc_graph.number_of_nodes() + n_edges = _nc_graph.number_of_edges() + else: + # Empty-merge floor (RISK 4 — Guard 3): this raw write is the one + # no-cluster path that does NOT route through to_json (Guard 1) or + # _check_shrink (Guard 2), so a 0-node ``merged`` here would silently + # overwrite a populated graph.json — the exact failed/aborted- + # extraction wipe the clustered sibling already blocks via its + # ``if G.number_of_nodes() == 0`` exit. Refuse the overwrite when the + # merged extraction is empty AND an existing graph.json on disk is + # populated. Read the existing node count defensively: any error + # (missing/corrupt file) is treated as 0 nodes so a fresh or + # unreadable target leaves the floor inert and the write proceeds + # exactly as before (no new exit on a legitimately-empty fresh run). + if len(merged.get("nodes", [])) == 0 and graph_json_path.exists(): + try: + _existing_n = len( + json.loads(graph_json_path.read_text(encoding="utf-8")).get("nodes", []) + ) + except Exception: + _existing_n = 0 + if _existing_n > 0: + print( + f"[graphify] ERROR: refusing to overwrite a populated " + f"graph.json ({_existing_n} nodes) with an EMPTY (0-node) " + f"graph - this is a failed/aborted extraction, not a real " + f"result. The previous graph is preserved.", + file=sys.stderr, + ) + sys.exit(1) + graph_json_path.write_text(json.dumps(merged, indent=2), encoding="utf-8") + n_nodes = len(merged["nodes"]) + n_edges = len(merged["edges"]) + cost = _estimate_cost(backend, merged["input_tokens"], merged["output_tokens"]) print( f"[graphify extract] wrote {graph_json_path} — " - f"{len(merged['nodes'])} nodes, {len(merged['edges'])} edges " + f"{n_nodes} nodes, {n_edges} edges " f"(no clustering)" ) if merged["input_tokens"] or merged["output_tokens"]: @@ -3457,32 +4238,44 @@ def _progress(idx: int, total: int, _result: dict) -> None: try: _save_manifest(_manifest_files, manifest_path=str(manifest_path), kind="both") except Exception as exc: - print(f"[graphify extract] warning: could not write manifest: {exc}", file=sys.stderr) + print( + f"[graphify extract] warning: could not write manifest: {exc}", file=sys.stderr + ) if global_merge: from graphify.global_graph import global_add as _global_add + _tag = global_repo_tag or target.name try: result = _global_add(graphify_out / "graph.json", _tag) if result["skipped"]: print(f"[graphify global] '{_tag}' unchanged since last add - skipped.") else: - print(f"[graphify global] '{_tag}' merged into global graph " - f"(+{result['nodes_added']} nodes, -{result['nodes_removed']} pruned).") + print( + f"[graphify global] '{_tag}' merged into global graph " + f"(+{result['nodes_added']} nodes, -{result['nodes_removed']} pruned)." + ) except Exception as exc: - print(f"[graphify global] warning: failed to merge into global graph: {exc}", file=sys.stderr) + print( + f"[graphify global] warning: failed to merge into global graph: {exc}", + file=sys.stderr, + ) sys.exit(0) # Build graph + cluster + score + write. from graphify.build import ( build as _build, - build_from_json as _build_from_json, build_merge as _build_merge, ) from graphify.cluster import cluster as _cluster, score_all as _score_all from graphify.export import to_json as _to_json from graphify.analyze import god_nodes as _god_nodes, surprising_connections as _surprising + dedup_backend = backend if dedup_llm else None if incremental_mode: + # Pass multigraph_flag straight through: None lets build_merge + # INHERIT the saved graph.json profile (the sticky default), while an + # explicit --multigraph/--simple overrides it (build_merge warns on + # an explicit override of the saved flag). G = _build_merge( [merged], graph_path=existing_graph_path, @@ -3490,9 +4283,19 @@ def _progress(idx: int, total: int, _result: dict) -> None: dedup=True, dedup_llm_backend=dedup_backend, root=target, + multigraph=multigraph_flag, ) else: - G = _build([merged], dedup=True, dedup_llm_backend=dedup_backend, root=target) + # Fresh build: no saved graph.json to inherit from, so the resolved + # value already collapses to the requested flag (or the historical + # simple default when no flag is given). + G = _build( + [merged], + dedup=True, + dedup_llm_backend=dedup_backend, + root=target, + multigraph=resolved_multigraph, + ) if G.number_of_nodes() == 0: print( "[graphify extract] graph is empty — extraction produced no nodes. " @@ -3502,7 +4305,9 @@ def _progress(idx: int, total: int, _result: dict) -> None: ) sys.exit(1) - communities = _cluster(G, resolution=cli_resolution, exclude_hubs_percentile=cli_exclude_hubs) + communities = _cluster( + G, resolution=cli_resolution, exclude_hubs_percentile=cli_exclude_hubs + ) cohesion = _score_all(G, communities) try: gods = _god_nodes(G) @@ -3514,6 +4319,7 @@ def _progress(idx: int, total: int, _result: dict) -> None: surprises = [] from graphify.export import backup_if_protected as _backup + _backup(graphify_out) _to_json(G, communities, str(graph_json_path), force=True) if merged.get("output_tokens", 0) > 0: @@ -3522,16 +4328,22 @@ def _progress(idx: int, total: int, _result: dict) -> None: ) if global_merge: from graphify.global_graph import global_add as _global_add + _tag = global_repo_tag or target.name try: result = _global_add(graphify_out / "graph.json", _tag) if result["skipped"]: print(f"[graphify global] '{_tag}' unchanged since last add - skipped.") else: - print(f"[graphify global] '{_tag}' merged into global graph " - f"(+{result['nodes_added']} nodes, -{result['nodes_removed']} pruned).") + print( + f"[graphify global] '{_tag}' merged into global graph " + f"(+{result['nodes_added']} nodes, -{result['nodes_removed']} pruned)." + ) except Exception as exc: - print(f"[graphify global] warning: failed to merge into global graph: {exc}", file=sys.stderr) + print( + f"[graphify global] warning: failed to merge into global graph: {exc}", + file=sys.stderr, + ) analysis = { "communities": {str(k): v for k, v in communities.items()}, "cohesion": {str(k): v for k, v in cohesion.items()}, @@ -3563,7 +4375,9 @@ def _progress(idx: int, total: int, _result: dict) -> None: f"{len(deleted_files)} deleted" ) elif sem_cache_hits: - print(f"[graphify extract] semantic cache: {sem_cache_hits} cached, {sem_cache_misses} re-extracted") + print( + f"[graphify extract] semantic cache: {sem_cache_hits} cached, {sem_cache_misses} re-extracted" + ) if merged["input_tokens"] or merged["output_tokens"]: print( f"[graphify extract] tokens: " @@ -3580,26 +4394,31 @@ def _progress(idx: int, total: int, _result: dict) -> None: # graphify-out/.graphify_uncached.txt — paths that need extraction # Stdout: "Cache: N hit, M miss" from graphify.cache import check_semantic_cache + if len(sys.argv) < 3: print("Usage: graphify cache-check [--root ]", file=sys.stderr) sys.exit(1) files_from = Path(sys.argv[2]) - root = Path(".") + cache_root = Path(".") i = 3 while i < len(sys.argv): if sys.argv[i] == "--root" and i + 1 < len(sys.argv): - root = Path(sys.argv[i + 1]) + cache_root = Path(sys.argv[i + 1]) i += 2 else: i += 1 files = [f for f in files_from.read_text(encoding="utf-8").splitlines() if f.strip()] - cached_nodes, cached_edges, cached_hyperedges, uncached = check_semantic_cache(files, root) - out = root / "graphify-out" + cached_nodes, cached_edges, cached_hyperedges, uncached = check_semantic_cache( + files, cache_root + ) + out = cache_root / "graphify-out" out.mkdir(parents=True, exist_ok=True) if cached_nodes or cached_edges or cached_hyperedges: (out / ".graphify_cached.json").write_text( - json.dumps({"nodes": cached_nodes, "edges": cached_edges, "hyperedges": cached_hyperedges}, - ensure_ascii=False), + json.dumps( + {"nodes": cached_nodes, "edges": cached_edges, "hyperedges": cached_hyperedges}, + ensure_ascii=False, + ), encoding="utf-8", ) (out / ".graphify_uncached.txt").write_text("\n".join(uncached), encoding="utf-8") @@ -3610,6 +4429,7 @@ def _progress(idx: int, total: int, _result: dict) -> None: # Concatenates .graphify_chunk_*.json files written by semantic subagents. # Deduplicates nodes by id (first writer wins). Sums token counts. import glob as _glob + if len(sys.argv) < 3: print("Usage: graphify merge-chunks --out ", file=sys.stderr) sys.exit(1) @@ -3630,7 +4450,13 @@ def _progress(idx: int, total: int, _result: dict) -> None: for arg in chunk_args: expanded = _glob.glob(arg) chunk_files.extend(sorted(expanded) if expanded else [arg]) - merged: dict = {"nodes": [], "edges": [], "hyperedges": [], "input_tokens": 0, "output_tokens": 0} + merged: dict = { + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } seen_ids: set[str] = set() for cf in chunk_files: try: @@ -3658,7 +4484,10 @@ def _progress(idx: int, total: int, _result: dict) -> None: # Merges cached semantic results with freshly-extracted chunk results. # Deduplicates nodes by id (cached entries take priority over new ones). if len(sys.argv) < 3: - print("Usage: graphify merge-semantic --cached --new --out ", file=sys.stderr) + print( + "Usage: graphify merge-semantic --cached --new --out ", + file=sys.stderr, + ) sys.exit(1) cached_path: Path | None = None new_path: Path | None = None @@ -3666,19 +4495,30 @@ def _progress(idx: int, total: int, _result: dict) -> None: i = 2 while i < len(sys.argv): if sys.argv[i] == "--cached" and i + 1 < len(sys.argv): - cached_path = Path(sys.argv[i + 1]); i += 2 + cached_path = Path(sys.argv[i + 1]) + i += 2 elif sys.argv[i] == "--new" and i + 1 < len(sys.argv): - new_path = Path(sys.argv[i + 1]); i += 2 + new_path = Path(sys.argv[i + 1]) + i += 2 elif sys.argv[i] == "--out" and i + 1 < len(sys.argv): - out_path2 = Path(sys.argv[i + 1]); i += 2 + out_path2 = Path(sys.argv[i + 1]) + i += 2 else: i += 1 if not out_path2: print("error: --out required", file=sys.stderr) sys.exit(1) empty: dict = {"nodes": [], "edges": [], "hyperedges": []} - cached_data = json.loads(cached_path.read_text(encoding="utf-8")) if cached_path and cached_path.exists() else empty - new_data = json.loads(new_path.read_text(encoding="utf-8")) if new_path and new_path.exists() else empty + cached_data = ( + json.loads(cached_path.read_text(encoding="utf-8")) + if cached_path and cached_path.exists() + else empty + ) + new_data = ( + json.loads(new_path.read_text(encoding="utf-8")) + if new_path and new_path.exists() + else empty + ) seen_ids2: set[str] = set() all_nodes: list[dict] = [] for n in cached_data.get("nodes", []) + new_data.get("nodes", []): diff --git a/graphify/affected.py b/graphify/affected.py index 109eaa95e..0d81e6eda 100644 --- a/graphify/affected.py +++ b/graphify/affected.py @@ -3,7 +3,7 @@ from collections import deque from dataclasses import dataclass from pathlib import Path -from typing import Iterable +from typing import Any, Iterable, cast import networkx as nx @@ -87,8 +87,9 @@ def affected_nodes( current, current_depth = queue.popleft() if current_depth >= depth: continue - if hasattr(graph, "in_edges"): - incoming = graph.in_edges(current, data=True) + graph_any = cast(Any, graph) + if hasattr(graph_any, "in_edges"): + incoming = graph_any.in_edges(current, data=True) else: incoming = ( (source, target, data) diff --git a/graphify/analyze.py b/graphify/analyze.py index f3e08103d..5004dc412 100644 --- a/graphify/analyze.py +++ b/graphify/analyze.py @@ -1,9 +1,12 @@ """Graph analysis: god nodes (most connected), surprising connections (cross-community), suggested questions.""" + from __future__ import annotations from pathlib import Path import networkx as nx from graphify.build import edge_data +from graphify.detect import CODE_EXTENSIONS, IMAGE_EXTENSIONS, PAPER_EXTENSIONS +from graphify.projections import distinct_neighbor_degree, project_for_community # Language families — extensions sharing a runtime can legitimately call each other _LANG_FAMILY: dict[str, str] = { @@ -53,6 +56,7 @@ def _is_file_node(G: nx.Graph, node_id: str) -> bool: source_file = attrs.get("source_file", "") if source_file: from pathlib import Path as _Path + if label == _Path(source_file).name: return True # Method stub: AST extractor labels methods as '.method_name()' @@ -60,17 +64,34 @@ def _is_file_node(G: nx.Graph, node_id: str) -> bool: return True # Module-level function stub: labeled 'function_name()' - only has a contains edge # These are real functions but structurally isolated by definition; not a gap worth flagging - if label.endswith("()") and G.degree(node_id) <= 1: + if label.endswith("()") and distinct_neighbor_degree(G, node_id) <= 1: return True return False -_JSON_NOISE_LABELS: frozenset[str] = frozenset({ - "start", "end", "name", "id", "type", "properties", - "value", "key", "data", "items", "title", "description", "version", - "dependencies", "devdependencies", "peerdependencies", - "optionaldependencies", "bundleddependencies", "bundledependencies", -}) +_JSON_NOISE_LABELS: frozenset[str] = frozenset( + { + "start", + "end", + "name", + "id", + "type", + "properties", + "value", + "key", + "data", + "items", + "title", + "description", + "version", + "dependencies", + "devdependencies", + "peerdependencies", + "optionaldependencies", + "bundleddependencies", + "bundledependencies", + } +) def _is_json_key_node(G: nx.Graph, node_id: str) -> bool: @@ -88,17 +109,23 @@ def god_nodes(G: nx.Graph, top_n: int = 10) -> list[dict]: File-level hub nodes are excluded: they accumulate import/contains edges mechanically and don't represent meaningful architectural abstractions. """ - degree = dict(G.degree()) + degree = {n: distinct_neighbor_degree(G, n) for n in G.nodes()} sorted_nodes = sorted(degree.items(), key=lambda x: x[1], reverse=True) result = [] for node_id, deg in sorted_nodes: - if _is_file_node(G, node_id) or _is_concept_node(G, node_id) or _is_json_key_node(G, node_id): + if ( + _is_file_node(G, node_id) + or _is_concept_node(G, node_id) + or _is_json_key_node(G, node_id) + ): continue - result.append({ - "id": node_id, - "label": G.nodes[node_id].get("label", node_id), - "degree": deg, - }) + result.append( + { + "id": node_id, + "label": G.nodes[node_id].get("label", node_id), + "degree": deg, + } + ) if len(result) >= top_n: break return result @@ -124,9 +151,7 @@ def surprising_connections( """ # Identify unique source files (ignore empty/null source_file) source_files = { - data.get("source_file", "") - for _, data in G.nodes(data=True) - if data.get("source_file", "") + data.get("source_file", "") for _, data in G.nodes(data=True) if data.get("source_file", "") } is_multi_source = len(source_files) > 1 @@ -155,9 +180,6 @@ def _is_concept_node(G: nx.Graph, node_id: str) -> bool: return False -from graphify.detect import CODE_EXTENSIONS, DOC_EXTENSIONS, PAPER_EXTENSIONS, IMAGE_EXTENSIONS - - def _file_category(path: str) -> str: ext = ("." + path.rsplit(".", 1)[-1].lower()) if "." in path else "" if ext in CODE_EXTENSIONS: @@ -237,8 +259,8 @@ def _surprise_score( reasons.append("semantically similar concepts with no structural link") # 5. Peripheral→hub: a low-degree node connecting to a high-degree one - deg_u = degrees[u] if degrees is not None else G.degree(u) - deg_v = degrees[v] if degrees is not None else G.degree(v) + deg_u = degrees[u] if degrees is not None else distinct_neighbor_degree(G, u) + deg_v = degrees[v] if degrees is not None else distinct_neighbor_degree(G, v) if min(deg_u, deg_v) <= 2 and max(deg_u, deg_v) >= 5: score += 1 peripheral = G.nodes[u].get("label", u) if deg_u <= 2 else G.nodes[v].get("label", v) @@ -263,7 +285,7 @@ def _cross_file_surprises(G: nx.Graph, communities: dict[int, list[str]], top_n: Each result includes a 'why' field explaining what makes it non-obvious. """ node_community = _node_community_map(communities) - degrees = dict(G.degree()) + degrees = {n: distinct_neighbor_degree(G, n) for n in G.nodes()} candidates = [] for u, v, data in G.edges(data=True): @@ -288,18 +310,20 @@ def _cross_file_surprises(G: nx.Graph, communities: dict[int, list[str]], top_n: tgt_id = data.get("_tgt", v) if tgt_id not in G.nodes: tgt_id = v - candidates.append({ - "_score": score, - "source": G.nodes[src_id].get("label", src_id), - "target": G.nodes[tgt_id].get("label", tgt_id), - "source_files": [ - G.nodes[src_id].get("source_file", ""), - G.nodes[tgt_id].get("source_file", ""), - ], - "confidence": data.get("confidence", "EXTRACTED"), - "relation": relation, - "why": "; ".join(reasons) if reasons else "cross-file semantic connection", - }) + candidates.append( + { + "_score": score, + "source": G.nodes[src_id].get("label", src_id), + "target": G.nodes[tgt_id].get("label", tgt_id), + "source_files": [ + G.nodes[src_id].get("source_file", ""), + G.nodes[tgt_id].get("source_file", ""), + ], + "confidence": data.get("confidence", "EXTRACTED"), + "relation": relation, + "why": "; ".join(reasons) if reasons else "cross-file semantic connection", + } + ) candidates.sort(key=lambda x: x["_score"], reverse=True) for c in candidates: @@ -329,22 +353,27 @@ def _cross_community_surprises( return [] if G.number_of_nodes() > 5000: return [] - betweenness = nx.edge_betweenness_centrality(G) + # Project to simple graph so betweenness returns 2-tuple keys + # and parallel edges don't inflate centrality scores. + simple = project_for_community(G) if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)) else G + betweenness = nx.edge_betweenness_centrality(simple) top_edges = sorted(betweenness.items(), key=lambda x: x[1], reverse=True)[:top_n] result = [] for (u, v), score in top_edges: data = edge_data(G, u, v) - result.append({ - "source": G.nodes[u].get("label", u), - "target": G.nodes[v].get("label", v), - "source_files": [ - G.nodes[u].get("source_file", ""), - G.nodes[v].get("source_file", ""), - ], - "confidence": data.get("confidence", "EXTRACTED"), - "relation": data.get("relation", ""), - "note": f"Bridges graph structure (betweenness={score:.3f})", - }) + result.append( + { + "source": G.nodes[u].get("label", u), + "target": G.nodes[v].get("label", v), + "source_files": [ + G.nodes[u].get("source_file", ""), + G.nodes[v].get("source_file", ""), + ], + "confidence": data.get("confidence", "EXTRACTED"), + "relation": data.get("relation", ""), + "note": f"Bridges graph structure (betweenness={score:.3f})", + } + ) return result # Build node → community map @@ -370,18 +399,20 @@ def _cross_community_surprises( tgt_id = data.get("_tgt", v) if tgt_id not in G.nodes: tgt_id = v - surprises.append({ - "source": G.nodes[src_id].get("label", src_id), - "target": G.nodes[tgt_id].get("label", tgt_id), - "source_files": [ - G.nodes[src_id].get("source_file", ""), - G.nodes[tgt_id].get("source_file", ""), - ], - "confidence": confidence, - "relation": relation, - "note": f"Bridges community {cid_u} → community {cid_v}", - "_pair": tuple(sorted([cid_u, cid_v])), - }) + surprises.append( + { + "source": G.nodes[src_id].get("label", src_id), + "target": G.nodes[tgt_id].get("label", tgt_id), + "source_files": [ + G.nodes[src_id].get("source_file", ""), + G.nodes[tgt_id].get("source_file", ""), + ], + "confidence": confidence, + "relation": relation, + "note": f"Bridges community {cid_u} → community {cid_v}", + "_pair": tuple(sorted([cid_u, cid_v])), + } + ) # Sort: AMBIGUOUS first, then INFERRED, then EXTRACTED order = {"AMBIGUOUS": 0, "INFERRED": 1, "EXTRACTED": 2} @@ -411,7 +442,9 @@ def suggest_questions( Each question has a 'type', 'question', and 'why' field. """ if community_labels: - community_labels = {int(k) if isinstance(k, str) else k: v for k, v in community_labels.items()} + community_labels = { + int(k) if isinstance(k, str) else k: v for k, v in community_labels.items() + } questions = [] node_community = _node_community_map(communities) @@ -422,11 +455,13 @@ def suggest_questions( ul = G.nodes[u].get("label", u) vl = G.nodes[v].get("label", v) relation = data.get("relation", "related to") - questions.append({ - "type": "ambiguous_edge", - "question": f"What is the exact relationship between `{ul}` and `{vl}`?", - "why": f"Edge tagged AMBIGUOUS (relation: {relation}) - confidence is low.", - }) + questions.append( + { + "type": "ambiguous_edge", + "question": f"What is the exact relationship between `{ul}` and `{vl}`?", + "why": f"Edge tagged AMBIGUOUS (relation: {relation}) - confidence is low.", + } + ) # 2. Bridge nodes (high betweenness) → cross-cutting concern questions if G.number_of_edges() > 0: @@ -434,27 +469,38 @@ def suggest_questions( betweenness = nx.betweenness_centrality(G, k=k, seed=42) # Top bridge nodes that are NOT file-level hubs bridges = sorted( - [(n, s) for n, s in betweenness.items() - if not _is_file_node(G, n) and not _is_concept_node(G, n) and s > 0], + [ + (n, s) + for n, s in betweenness.items() + if not _is_file_node(G, n) and not _is_concept_node(G, n) and s > 0 + ], key=lambda x: x[1], reverse=True, )[:3] for node_id, score in bridges: label = G.nodes[node_id].get("label", node_id) cid = node_community.get(node_id) - comm_label = community_labels.get(cid, f"Community {cid}") if cid is not None else "unknown" + comm_label = ( + community_labels.get(cid, f"Community {cid}") if cid is not None else "unknown" + ) neighbors = list(G.neighbors(node_id)) - neighbor_comms = {node_community.get(n) for n in neighbors if node_community.get(n) != cid} + neighbor_comms = { + other_cid + for n in neighbors + if (other_cid := node_community.get(n)) is not None and other_cid != cid + } if neighbor_comms: other_labels = [community_labels.get(c, f"Community {c}") for c in neighbor_comms] - questions.append({ - "type": "bridge_node", - "question": f"Why does `{label}` connect `{comm_label}` to {', '.join(f'`{l}`' for l in other_labels)}?", - "why": f"High betweenness centrality ({score:.3f}) - this node is a cross-community bridge.", - }) + questions.append( + { + "type": "bridge_node", + "question": f"Why does `{label}` connect `{comm_label}` to {', '.join(f'`{label}`' for label in other_labels)}?", + "why": f"High betweenness centrality ({score:.3f}) - this node is a cross-community bridge.", + } + ) # 3. God nodes with many INFERRED edges → verification questions - degree = dict(G.degree()) + degree = {n: distinct_neighbor_degree(G, n) for n in G.nodes()} top_nodes = sorted( [(n, d) for n, d in degree.items() if not _is_file_node(G, n)], key=lambda x: x[1], @@ -462,7 +508,8 @@ def suggest_questions( )[:5] for node_id, _ in top_nodes: inferred = [ - (u, v, d) for u, v, d in G.edges(node_id, data=True) + (u, v, d) + for u, v, d in G.edges(node_id, data=True) if d.get("confidence") == "INFERRED" ] if len(inferred) >= 2: @@ -478,48 +525,60 @@ def suggest_questions( tgt_id = v other_id = tgt_id if src_id == node_id else src_id others.append(G.nodes[other_id].get("label", other_id)) - questions.append({ - "type": "verify_inferred", - "question": f"Are the {len(inferred)} inferred relationships involving `{label}` (e.g. with `{others[0]}` and `{others[1]}`) actually correct?", - "why": f"`{label}` has {len(inferred)} INFERRED edges - model-reasoned connections that need verification.", - }) + questions.append( + { + "type": "verify_inferred", + "question": f"Are the {len(inferred)} inferred relationships involving `{label}` (e.g. with `{others[0]}` and `{others[1]}`) actually correct?", + "why": f"`{label}` has {len(inferred)} INFERRED edges - model-reasoned connections that need verification.", + } + ) # 4. Isolated or weakly-connected nodes → exploration questions isolated = [ - n for n in G.nodes() - if G.degree(n) <= 1 and not _is_file_node(G, n) and not _is_concept_node(G, n) + n + for n in G.nodes() + if distinct_neighbor_degree(G, n) <= 1 + and not _is_file_node(G, n) + and not _is_concept_node(G, n) ] if isolated: labels = [G.nodes[n].get("label", n) for n in isolated[:3]] - questions.append({ - "type": "isolated_nodes", - "question": f"What connects {', '.join(f'`{l}`' for l in labels)} to the rest of the system?", - "why": f"{len(isolated)} weakly-connected nodes found - possible documentation gaps or missing edges.", - }) + questions.append( + { + "type": "isolated_nodes", + "question": f"What connects {', '.join(f'`{label}`' for label in labels)} to the rest of the system?", + "why": f"{len(isolated)} weakly-connected nodes found - possible documentation gaps or missing edges.", + } + ) # 5. Low-cohesion communities → structural questions from .cluster import cohesion_score + for cid, nodes in communities.items(): score = cohesion_score(G, nodes) if score < 0.15 and len(nodes) >= 5: label = community_labels.get(cid, f"Community {cid}") - questions.append({ - "type": "low_cohesion", - "question": f"Should `{label}` be split into smaller, more focused modules?", - "why": f"Cohesion score {score} - nodes in this community are weakly interconnected.", - }) + questions.append( + { + "type": "low_cohesion", + "question": f"Should `{label}` be split into smaller, more focused modules?", + "why": f"Cohesion score {score} - nodes in this community are weakly interconnected.", + } + ) if not questions: - return [{ - "type": "no_signal", - "question": None, - "why": ( - "Not enough signal to generate questions. " - "This usually means the corpus has no AMBIGUOUS edges, no bridge nodes, " - "no INFERRED relationships, and all communities are tightly cohesive. " - "Add more files or run with --mode deep to extract richer edges." - ), - }] + return [ + { + "type": "no_signal", + "question": None, + "why": ( + "Not enough signal to generate questions. " + "This usually means the corpus has no AMBIGUOUS edges, no bridge nodes, " + "no INFERRED relationships, and all communities are tightly cohesive. " + "Add more files or run with --mode deep to extract richer edges." + ), + } + ] return questions[:top_n] @@ -542,13 +601,9 @@ def graph_diff(G_old: nx.Graph, G_new: nx.Graph) -> dict: added_node_ids = new_nodes - old_nodes removed_node_ids = old_nodes - new_nodes - new_nodes_list = [ - {"id": n, "label": G_new.nodes[n].get("label", n)} - for n in added_node_ids - ] + new_nodes_list = [{"id": n, "label": G_new.nodes[n].get("label", n)} for n in added_node_ids] removed_nodes_list = [ - {"id": n, "label": G_old.nodes[n].get("label", n)} - for n in removed_node_ids + {"id": n, "label": G_old.nodes[n].get("label", n)} for n in removed_node_ids ] def edge_key(G: nx.Graph, u: str, v: str, data: dict) -> tuple: @@ -556,14 +611,8 @@ def edge_key(G: nx.Graph, u: str, v: str, data: dict) -> tuple: return (u, v, data.get("relation", "")) return (min(u, v), max(u, v), data.get("relation", "")) - old_edge_keys = { - edge_key(G_old, u, v, d) - for u, v, d in G_old.edges(data=True) - } - new_edge_keys = { - edge_key(G_new, u, v, d) - for u, v, d in G_new.edges(data=True) - } + old_edge_keys = {edge_key(G_old, u, v, d) for u, v, d in G_old.edges(data=True)} + new_edge_keys = {edge_key(G_new, u, v, d) for u, v, d in G_new.edges(data=True)} added_edge_keys = new_edge_keys - old_edge_keys removed_edge_keys = old_edge_keys - new_edge_keys @@ -571,22 +620,26 @@ def edge_key(G: nx.Graph, u: str, v: str, data: dict) -> tuple: new_edges_list = [] for u, v, d in G_new.edges(data=True): if edge_key(G_new, u, v, d) in added_edge_keys: - new_edges_list.append({ - "source": u, - "target": v, - "relation": d.get("relation", ""), - "confidence": d.get("confidence", ""), - }) + new_edges_list.append( + { + "source": u, + "target": v, + "relation": d.get("relation", ""), + "confidence": d.get("confidence", ""), + } + ) removed_edges_list = [] for u, v, d in G_old.edges(data=True): if edge_key(G_old, u, v, d) in removed_edge_keys: - removed_edges_list.append({ - "source": u, - "target": v, - "relation": d.get("relation", ""), - "confidence": d.get("confidence", ""), - }) + removed_edges_list.append( + { + "source": u, + "target": v, + "relation": d.get("relation", ""), + "confidence": d.get("confidence", ""), + } + ) parts = [] if new_nodes_list: @@ -594,9 +647,13 @@ def edge_key(G: nx.Graph, u: str, v: str, data: dict) -> tuple: if new_edges_list: parts.append(f"{len(new_edges_list)} new edge{'s' if len(new_edges_list) != 1 else ''}") if removed_nodes_list: - parts.append(f"{len(removed_nodes_list)} node{'s' if len(removed_nodes_list) != 1 else ''} removed") + parts.append( + f"{len(removed_nodes_list)} node{'s' if len(removed_nodes_list) != 1 else ''} removed" + ) if removed_edges_list: - parts.append(f"{len(removed_edges_list)} edge{'s' if len(removed_edges_list) != 1 else ''} removed") + parts.append( + f"{len(removed_edges_list)} edge{'s' if len(removed_edges_list) != 1 else ''} removed" + ) summary = ", ".join(parts) if parts else "no changes" return { @@ -632,6 +689,7 @@ def find_import_cycles( "why": "circular dependency" } """ + def _endpoint_source_file(node_id: str) -> str: attrs = G.nodes.get(node_id, {}) src_file = attrs.get("source_file", "") @@ -703,10 +761,12 @@ def _endpoint_source_file(node_id: str) -> str: result: list[dict] = [] for cycle in unique_cycles: - result.append({ - "cycle": cycle, - "length": len(cycle), - "why": "circular dependency", - }) + result.append( + { + "cycle": cycle, + "length": len(cycle), + "why": "circular dependency", + } + ) return result diff --git a/graphify/benchmark.py b/graphify/benchmark.py index eabade292..cf2ab4666 100644 --- a/graphify/benchmark.py +++ b/graphify/benchmark.py @@ -1,4 +1,5 @@ """Token-reduction benchmark - measures how much context graphify saves vs naive full-corpus approach.""" + from __future__ import annotations import json import sys @@ -66,11 +67,15 @@ def _query_subgraph_tokens(G: nx.Graph, question: str, depth: int = 3) -> int: lines = [] for nid in visited: d = G.nodes[nid] - lines.append(f"NODE {d.get('label', nid)} src={d.get('source_file', '')} loc={d.get('source_location', '')}") + lines.append( + f"NODE {d.get('label', nid)} src={d.get('source_file', '')} loc={d.get('source_location', '')}" + ) for u, v in edges_seen: if u in visited and v in visited: d = edge_data(G, u, v) - lines.append(f"EDGE {G.nodes[u].get('label', u)} --{d.get('relation', '')}--> {G.nodes[v].get('label', v)}") + lines.append( + f"EDGE {G.nodes[u].get('label', u)} --{d.get('relation', '')}--> {G.nodes[v].get('label', v)}" + ) return _estimate_tokens("\n".join(lines)) @@ -99,6 +104,7 @@ def run_benchmark( Returns dict with: corpus_tokens, avg_query_tokens, reduction_ratio, per_question """ from graphify.security import check_graph_file_size_cap + check_graph_file_size_cap(Path(graph_path)) data = json.loads(Path(graph_path).read_text(encoding="utf-8")) try: @@ -108,16 +114,20 @@ def run_benchmark( if corpus_words is None: # Rough estimate: each node label is ~3 words, plus source context - corpus_words = G.number_of_nodes() * 50 + estimated_corpus_words = G.number_of_nodes() * 50 + else: + estimated_corpus_words = corpus_words - corpus_tokens = corpus_words * 100 // 75 # words → tokens (100 words ≈ 133 tokens) + corpus_tokens = estimated_corpus_words * 100 // 75 # words → tokens (100 words ≈ 133 tokens) qs = questions or _SAMPLE_QUESTIONS per_question = [] for q in qs: qt = _query_subgraph_tokens(G, q) if qt > 0: - per_question.append({"question": q, "query_tokens": qt, "reduction": round(corpus_tokens / qt, 1)}) + per_question.append( + {"question": q, "query_tokens": qt, "reduction": round(corpus_tokens / qt, 1)} + ) if not per_question: return {"error": "No matching nodes found for sample questions. Build the graph first."} @@ -127,7 +137,7 @@ def run_benchmark( return { "corpus_tokens": corpus_tokens, - "corpus_words": corpus_words, + "corpus_words": estimated_corpus_words, "nodes": G.number_of_nodes(), "edges": G.number_of_edges(), "avg_query_tokens": avg_query_tokens, @@ -142,14 +152,16 @@ def print_benchmark(result: dict) -> None: print(f"Benchmark error: {result['error']}") return - print(f"\ngraphify token reduction benchmark") + print("\ngraphify token reduction benchmark") print(_hr(50)) arrow = _safe("→", "->") - print(f" Corpus: {result['corpus_words']:,} words {arrow} ~{result['corpus_tokens']:,} tokens (naive)") + print( + f" Corpus: {result['corpus_words']:,} words {arrow} ~{result['corpus_tokens']:,} tokens (naive)" + ) print(f" Graph: {result['nodes']:,} nodes, {result['edges']:,} edges") print(f" Avg query cost: ~{result['avg_query_tokens']:,} tokens") print(f" Reduction: {result['reduction_ratio']}x fewer tokens per query") - print(f"\n Per question:") + print("\n Per question:") for p in result["per_question"]: print(f" [{p['reduction']}x] {p['question'][:55]}") print() diff --git a/graphify/build.py b/graphify/build.py index 07fbb0340..0fe4a9695 100644 --- a/graphify/build.py +++ b/graphify/build.py @@ -22,18 +22,49 @@ # from __future__ import annotations import json +import hashlib import os import re import sys import unicodedata +from collections.abc import Hashable from pathlib import Path import networkx as nx -from .validate import validate_extraction +from .edge_identity import make_stable_key, strip_schema_key +from .validate import is_hashable, validate_extraction # Synonym mapper for known invalid file_type values that LLM subagents commonly # emit. Keeps semantic intent close (markdown→document, tool→code) and falls # back to "concept" for any other invalid value (see #840). +_LANG_FAMILY: dict[str, str] = { + ".py": "py", + ".pyi": "py", + ".js": "js", + ".mjs": "js", + ".cjs": "js", + ".jsx": "js", + ".ts": "js", + ".tsx": "js", + ".go": "go", + ".rs": "rs", + ".java": "jvm", + ".kt": "jvm", + ".scala": "jvm", + ".groovy": "jvm", + ".c": "c", + ".h": "c", + ".cc": "cpp", + ".cpp": "cpp", + ".hpp": "cpp", + ".rb": "rb", + ".php": "php", + ".cs": "cs", + ".swift": "swift", + ".lua": "lua", +} + + _FILE_TYPE_SYNONYMS = { "markdown": "document", "text": "document", @@ -83,7 +114,58 @@ def _norm_source_file(p: str | None, root: str | None = None) -> str | None: return p -def edge_data(G: nx.Graph, u: str, v: str) -> dict: +def _stable_identity_component(value: object) -> str | None: + """Normalize malformed edge identity values before stable-key hashing.""" + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, os.PathLike): + # os.fspath can return bytes for bytes-flavored PathLike; coerce to str + # so downstream json.dumps / hashing always sees text. + fs_value = os.fspath(value) + return ( + fs_value.decode("utf-8", errors="replace") if isinstance(fs_value, bytes) else fs_value + ) + if isinstance(value, (set, frozenset)): + return json.dumps(sorted(str(item) for item in value), ensure_ascii=False) + try: + return json.dumps(value, sort_keys=True, ensure_ascii=False, default=str) + except (TypeError, ValueError): + return str(value) + + +def _make_collision_key(base_key: str, attrs: dict, *, salt: int = 0) -> str: + payload = { + "base_key": base_key, + "attrs": attrs, + } + if salt: + payload["salt"] = salt + repair_payload = json.dumps(payload, sort_keys=True, ensure_ascii=False, default=str) + repair_digest = hashlib.sha256(repair_payload.encode()).hexdigest() + return f"{base_key}:alt:{repair_digest}" + + +def _list_field(data: dict, key: str) -> list: + """Return ``data[key]`` if it is a list; otherwise warn to stderr and return ``[]``. + + Extraction dicts come from LLM subagents and can contain malformed shapes; + matching the rest of build_from_json's skip+warn policy keeps a single bad + field from crashing the whole build. + """ + value = data.get(key, []) + if isinstance(value, list): + return value + print( + f"[graphify] WARNING: extraction field '{key}' must be a list, " + f"got {type(value).__name__}; treating as empty.", + file=sys.stderr, + ) + return [] + + +def edge_data(G: nx.Graph, u: Hashable, v: Hashable) -> dict: """Return one edge attribute dict for (u, v), tolerating MultiGraph. For MultiGraph/MultiDiGraph there can be multiple parallel edges; @@ -96,7 +178,7 @@ def edge_data(G: nx.Graph, u: str, v: str) -> dict: return raw -def edge_datas(G: nx.Graph, u: str, v: str) -> list[dict]: +def edge_datas(G: nx.Graph, u: Hashable, v: Hashable) -> list[dict]: """Return every edge attribute dict for (u, v); always a list.""" raw = G[u][v] if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)): @@ -104,29 +186,47 @@ def edge_datas(G: nx.Graph, u: str, v: str) -> list[dict]: return [raw] -def build_from_json(extraction: dict, *, directed: bool = False, root: str | Path | None = None) -> nx.Graph: +def build_from_json( + extraction: dict, + *, + directed: bool = False, + root: str | Path | None = None, + multigraph: bool = False, +) -> nx.Graph | nx.DiGraph | nx.MultiDiGraph: """Build a NetworkX graph from an extraction dict. directed=True produces a DiGraph that preserves edge direction (source→target). directed=False (default) produces an undirected Graph for backward compatibility. + multigraph=True produces a directed MultiDiGraph with keyed parallel edges for + internal tests/callers; public CLI exposure is intentionally deferred. + In this mode, directed is ignored because MultiDiGraph is always directed. root: if given, absolute source_file paths from semantic subagents are made relative to root so all nodes share a consistent path key (#932). """ + if not isinstance(extraction, dict): + raise TypeError("extraction must be a JSON object") + _root = str(Path(root).resolve()) if root else None # NetworkX <= 3.1 serialised edges as "links"; remap to "edges" for compatibility. if "edges" not in extraction and "links" in extraction: extraction = dict(extraction, edges=extraction["links"]) + nodes = _list_field(extraction, "nodes") + edges = _list_field(extraction, "edges") + extraction = dict(extraction, nodes=nodes, edges=edges) + # Canonicalize legacy node/edge schema before validation. - for node in extraction.get("nodes", []): + for node in nodes: if not isinstance(node, dict): continue if "source" in node and "source_file" not in node: # Count edges that reference this node so the warning is actionable (#479) node_id = node.get("id", "?") affected_edges = sum( - 1 for e in extraction.get("edges", []) - if e.get("source") == node_id or e.get("target") == node_id + 1 + for e in edges + if isinstance(e, dict) + and (e.get("source") == node_id or e.get("target") == node_id) ) print( f"[graphify] WARNING: node '{node_id}' uses field 'source' instead of " @@ -149,29 +249,79 @@ def build_from_json(extraction: dict, *, directed: bool = False, root: str | Pat # Dangling edges (stdlib/external imports) are expected - only warn about real schema errors. real_errors = [e for e in errors if "does not match any node id" not in e] if real_errors: - print(f"[graphify] Extraction warning ({len(real_errors)} issues): {real_errors[0]}", file=sys.stderr) - G: nx.Graph = nx.DiGraph() if directed else nx.Graph() - for node in extraction.get("nodes", []): + print( + f"[graphify] Extraction warning ({len(real_errors)} issues): {real_errors[0]}", + file=sys.stderr, + ) + if multigraph: + from .multigraph_compat import require_multigraph_capabilities + + require_multigraph_capabilities() + G: nx.Graph = nx.MultiDiGraph() if multigraph else nx.DiGraph() if directed else nx.Graph() + for node in nodes: + if not isinstance(node, dict) or "id" not in node: + continue + node_id = node["id"] + if not is_hashable(node_id): + continue if "source_file" in node: - node["source_file"] = _norm_source_file(node["source_file"], _root) - G.add_node(node["id"], **{k: v for k, v in node.items() if k != "id"}) + node["source_file"] = _norm_source_file( + _stable_identity_component(node["source_file"]), _root + ) + node_attrs = {k: v for k, v in node.items() if k != "id"} + # Reject node ids that JSON-serialize but won't round-trip to the same + # hashable type. Tuples serialize as JSON arrays and come back as lists + # (unhashable), so they cannot be used as NetworkX node ids after a + # save/load cycle even though json.dumps would accept them. + if isinstance(node_id, (list, tuple, set, frozenset, dict)): + print( + f"[graphify] WARNING: node id {node_id!r} ({type(node_id).__name__}) " + f"would not round-trip through JSON as the same hashable type; skipping.", + file=sys.stderr, + ) + continue + # Check id AND attrs are JSON-serializable. NetworkX allows hashable but + # non-JSON-safe ids (e.g., custom objects); accepting them here would + # break later node_link_data + json.dump. + try: + json.dumps({"id": node_id, **node_attrs}, ensure_ascii=False) + except (TypeError, ValueError) as exc: + print( + f"[graphify] WARNING: node {node_id!r} has non-JSON-serializable " + f"id or attrs ({exc}); skipping.", + file=sys.stderr, + ) + continue + G.add_node(node_id, **node_attrs) node_set = set(G.nodes()) # Normalized ID map: lets edges survive when the LLM generates IDs with # slightly different casing or punctuation than the AST extractor. # e.g. "Session_ValidateToken" maps to "session_validatetoken". - norm_to_id: dict[str, str] = {_normalize_id(nid): nid for nid in node_set} + norm_to_id: dict[str, Hashable] = { + _normalize_id(nid): nid for nid in node_set if isinstance(nid, str) + } + multigraph_groups: dict[tuple[Hashable, Hashable, str], list[dict]] = {} + multigraph_explicit_keys: set[tuple[Hashable, Hashable, str]] = set() + multigraph_diagnostics = {"exact_duplicate_edges": 0, "key_collision_edges": 0} + # Iterate edges in a deterministic order. The graph is undirected and stores # direction in _src/_tgt; when two edges collapse onto the same node pair the # last write wins, so an unstable iteration order flips _src/_tgt run-to-run - # and makes the serialized graph churn. Sorting fixes the last-write outcome. - for edge in sorted( - extraction.get("edges", []), - key=lambda e: ( - str(e.get("source", e.get("from", ""))), - str(e.get("target", e.get("to", ""))), - str(e.get("relation", "")), - ), - ): + # and makes the serialized graph churn. Sorting also stabilizes multigraph + # key-collision grouping before keyed emission. + def _edge_sort_key(edge: object) -> tuple[str, str, str, str]: + if not isinstance(edge, dict): + return ("", "", "", repr(edge)) + return ( + str(edge.get("source", edge.get("from", ""))), + str(edge.get("target", edge.get("to", ""))), + str(edge.get("relation", "")), + json.dumps(edge, sort_keys=True, ensure_ascii=False, default=str), + ) + + for edge in sorted(edges, key=_edge_sort_key): + if not isinstance(edge, dict): + continue if "source" not in edge and "from" in edge: edge["source"] = edge["from"] if "target" not in edge and "to" in edge: @@ -179,29 +329,36 @@ def build_from_json(extraction: dict, *, directed: bool = False, root: str | Pat if "source" not in edge or "target" not in edge: continue src, tgt = edge["source"], edge["target"] + srcis_hashable = is_hashable(src) + tgtis_hashable = is_hashable(tgt) + if not srcis_hashable or not tgtis_hashable: + endpoint = "source" if not srcis_hashable else "target" + endpoint_value = src if not srcis_hashable else tgt + print( + "[graphify] WARNING: skipped edge with unhashable " + f"{endpoint} endpoint ({type(endpoint_value).__name__})", + file=sys.stderr, + ) + continue # Remap mismatched IDs via normalization before dropping the edge. - if src not in node_set: + if isinstance(src, str) and src not in node_set: src = norm_to_id.get(_normalize_id(src), src) - if tgt not in node_set: + if isinstance(tgt, str) and tgt not in node_set: tgt = norm_to_id.get(_normalize_id(tgt), tgt) if src not in node_set or tgt not in node_set: continue # skip edges to external/stdlib nodes - expected, not an error - attrs = {k: v for k, v in edge.items() if k not in ("source", "target")} + # Exclude legacy from/to alongside source/target so they don't survive + # as ordinary edge attrs after legacy-shape remap above. + base_attrs = {k: v for k, v in edge.items() if k not in ("source", "target", "from", "to")} + raw_key, attrs = strip_schema_key(base_attrs) if "source_file" in attrs: - attrs["source_file"] = _norm_source_file(attrs["source_file"], _root) + attrs["source_file"] = _norm_source_file( + _stable_identity_component(attrs["source_file"]), _root + ) # Drop cross-language INFERRED `calls` edges — same short names (render, # parse, etc.) appear across language boundaries in multi-language chunks, # producing phantom edges that don't represent real call relationships. if attrs.get("relation") == "calls" and attrs.get("confidence") == "INFERRED": - _LANG_FAMILY: dict[str, str] = { - ".py": "py", ".pyi": "py", - ".js": "js", ".mjs": "js", ".cjs": "js", ".jsx": "js", - ".ts": "js", ".tsx": "js", - ".go": "go", ".rs": "rs", - ".java": "jvm", ".kt": "jvm", ".scala": "jvm", ".groovy": "jvm", - ".c": "c", ".h": "c", ".cc": "cpp", ".cpp": "cpp", ".hpp": "cpp", - ".rb": "rb", ".php": "php", ".cs": "cs", ".swift": "swift", ".lua": "lua", - } src_ext = Path(G.nodes[src].get("source_file") or "").suffix.lower() tgt_ext = Path(G.nodes[tgt].get("source_file") or "").suffix.lower() if src_ext and tgt_ext and _LANG_FAMILY.get(src_ext) != _LANG_FAMILY.get(tgt_ext): @@ -210,23 +367,131 @@ def build_from_json(extraction: dict, *, directed: bool = False, root: str | Pat # causing display functions to show edges backwards. attrs["_src"] = src attrs["_tgt"] = tgt - # When the graph is undirected and the same node pair appears twice with - # the same relation but opposite directions (e.g. a `calls` b and b `calls` a), - # nx.Graph collapses them into one edge. The deterministic sort above means - # the lexicographically-later direction would systematically overwrite the - # earlier one's _src/_tgt, silently flipping the surviving edge's caller - # and callee. First-seen direction wins instead — drop the redundant - # reverse-direction duplicate so the original direction is preserved (#1061). - if not G.is_directed() and G.has_edge(src, tgt): - existing = edge_data(G, src, tgt) - if existing.get("relation") == attrs.get("relation") and ( - existing.get("_src") == tgt and existing.get("_tgt") == src - ): - continue - G.add_edge(src, tgt, **attrs) + # Refuse to store any edge whose attrs cannot round-trip through JSON. + # Mutating attrs in place would silently change the user's stored value; + # skipping with a warning matches the rest of the build's defensive policy + # and prevents later json.dump crashes during export, identically in + # simple-graph and multigraph modes. + try: + json.dumps(attrs, ensure_ascii=False) + except (TypeError, ValueError) as exc: + print( + f"[graphify] WARNING: edge ({src}->{tgt}) has non-JSON-serializable " + f"attrs ({exc}); skipping.", + file=sys.stderr, + ) + continue + if multigraph: + if raw_key is not None and not isinstance(raw_key, str): + raise TypeError( + f"multigraph edge 'key' must be a string, got " + f"{type(raw_key).__name__} ({raw_key!r})" + ) + base_key = ( + raw_key + if raw_key is not None + else make_stable_key( + _stable_identity_component(attrs.get("relation")), + _stable_identity_component(attrs.get("source_file")), + _stable_identity_component(attrs.get("source_location")), + ) + ) + if raw_key is not None: + multigraph_explicit_keys.add((src, tgt, base_key)) + multigraph_groups.setdefault((src, tgt, base_key), []).append(dict(attrs)) + else: + # When the graph is undirected and the same node pair appears twice with + # the same relation but opposite directions (e.g. a `calls` b and b `calls` a), + # nx.Graph collapses them into one edge. The deterministic sort above means + # the lexicographically-later direction would systematically overwrite the + # earlier one's _src/_tgt, silently flipping the surviving edge's caller + # and callee. First-seen direction wins instead — drop the redundant + # reverse-direction duplicate so the original direction is preserved (#1061). + if not G.is_directed() and G.has_edge(src, tgt): + existing = edge_data(G, src, tgt) + if existing.get("relation") == attrs.get("relation") and ( + existing.get("_src") == tgt and existing.get("_tgt") == src + ): + continue + G.add_edge(src, tgt, **attrs) + if multigraph: + singleton_groups: list[tuple[Hashable, Hashable, str, dict]] = [] + multi_groups: list[tuple[Hashable, Hashable, str, list[dict]]] = [] + used_keys_by_pair: dict[tuple[Hashable, Hashable], set[str]] = {} + for (src, tgt, base_key), group_attrs in multigraph_groups.items(): + unique_attrs: list[dict] = [] + seen_attr_fingerprints: set[str] = set() + for attrs in group_attrs: + attr_fingerprint = json.dumps( + attrs, sort_keys=True, ensure_ascii=False, default=str + ) + if attr_fingerprint in seen_attr_fingerprints: + multigraph_diagnostics["exact_duplicate_edges"] += 1 + else: + seen_attr_fingerprints.add(attr_fingerprint) + unique_attrs.append(attrs) + if len(unique_attrs) > 1: + multigraph_diagnostics["key_collision_edges"] += 1 + unique_attrs.sort( + key=lambda attrs: json.dumps( + attrs, sort_keys=True, ensure_ascii=False, default=str + ) + ) + multi_groups.append((src, tgt, base_key, unique_attrs)) + elif unique_attrs: + # Reserve the singleton's base_key so any later multi-attr + # collision-repair on the same (src, tgt) avoids it. + used_keys_by_pair.setdefault((src, tgt), set()).add(base_key) + singleton_groups.append((src, tgt, base_key, unique_attrs[0])) + # Sort both lists deterministically. + singleton_groups.sort( + key=lambda item: ( + repr(item[0]), + repr(item[1]), + item[2], + json.dumps(item[3], sort_keys=True, ensure_ascii=False, default=str), + ) + ) + multi_groups.sort( + key=lambda item: ( + repr(item[0]), + repr(item[1]), + item[2], + json.dumps(item[3], sort_keys=True, ensure_ascii=False, default=str), + ) + ) + # Emit singletons first: they use base_key directly and were reserved + # in the pre-loop above, so collision-repair from multi groups will + # see those reservations and salt around them. + for src, tgt, base_key, attrs in singleton_groups: + G.add_edge(src, tgt, key=base_key, **attrs) + # Then emit multi-attr groups with collision-repair salting against + # both reserved singleton base_keys and earlier multi-group repair + # keys on the same (src, tgt) pair. + for src, tgt, base_key, unique_attrs in multi_groups: + used_keys = used_keys_by_pair.setdefault((src, tgt), set()) + preserve_explicit = (src, tgt, base_key) in multigraph_explicit_keys + for index, attrs in enumerate(unique_attrs): + # When the user passed an explicit `key` shared across multiple + # distinct edges, preserve it on the first emit so at least one + # edge per group keeps the canonical user-supplied key. + # Derived base_keys (from make_stable_key) always go through + # collision-repair so emission stays order-independent. + if preserve_explicit and index == 0 and base_key not in used_keys: + key = base_key + else: + key = _make_collision_key(base_key, attrs) + salt = 0 + while key in used_keys: + salt += 1 + key = _make_collision_key(base_key, attrs, salt=salt) + used_keys.add(key) + G.add_edge(src, tgt, key=key, **attrs) hyperedges = extraction.get("hyperedges", []) if hyperedges: G.graph["hyperedges"] = hyperedges + if multigraph: + G.graph["graphify_multigraph_diagnostics"] = multigraph_diagnostics return G @@ -237,7 +502,8 @@ def build( dedup: bool = True, dedup_llm_backend: str | None = None, root: str | Path | None = None, -) -> nx.Graph: + multigraph: bool = False, +) -> nx.Graph | nx.DiGraph | nx.MultiDiGraph: """Merge multiple extraction results into one graph. directed=True produces a DiGraph that preserves edge direction (source→target). @@ -253,19 +519,40 @@ def build( reverse the order if you prefer AST source_location precision to win. """ from graphify.dedup import deduplicate_entities - combined: dict = {"nodes": [], "edges": [], "hyperedges": [], "input_tokens": 0, "output_tokens": 0} + + combined = _combine_extractions(extractions) + dedup_diagnostics: dict = {} + if dedup and combined["nodes"]: + combined["nodes"], combined["edges"] = deduplicate_entities( + combined["nodes"], + combined["edges"], + communities={}, + dedup_llm_backend=dedup_llm_backend, + diagnostics=dedup_diagnostics, + ) + G = build_from_json(combined, directed=directed, root=root, multigraph=multigraph) + if multigraph and dedup_diagnostics: + existing = G.graph.get("graphify_multigraph_diagnostics", {}) + existing.update(dedup_diagnostics) + G.graph["graphify_multigraph_diagnostics"] = existing + return G + + +def _combine_extractions(extractions: list[dict]) -> dict: + combined: dict = { + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } for ext in extractions: combined["nodes"].extend(ext.get("nodes", [])) combined["edges"].extend(ext.get("edges", [])) combined["hyperedges"].extend(ext.get("hyperedges", [])) combined["input_tokens"] += ext.get("input_tokens", 0) combined["output_tokens"] += ext.get("output_tokens", 0) - if dedup and combined["nodes"]: - combined["nodes"], combined["edges"] = deduplicate_entities( - combined["nodes"], combined["edges"], communities={}, - dedup_llm_backend=dedup_llm_backend, - ) - return build_from_json(combined, directed=directed, root=root) + return combined def _norm_label(label: str) -> str: @@ -282,7 +569,7 @@ def deduplicate_by_label(nodes: list[dict], edges: list[dict]) -> tuple[list[dic """ _CHUNK_SUFFIX = re.compile(r"_c\d+$") canonical: dict[str, dict] = {} # norm_label -> surviving node - remap: dict[str, str] = {} # old_id -> surviving_id + remap: dict[str, str] = {} # old_id -> surviving_id for node in nodes: key = _norm_label(node.get("label", node.get("id", ""))) @@ -320,21 +607,43 @@ def deduplicate_by_label(nodes: list[dict], edges: list[dict]) -> tuple[list[dic return deduped_nodes, deduped_edges +def _chunk_has_graph_records(chunk: dict) -> bool: + return bool( + chunk.get("nodes") or chunk.get("edges") or chunk.get("links") or chunk.get("hyperedges") + ) + + def build_merge( new_chunks: list[dict], graph_path: str | Path = "graphify-out/graph.json", prune_sources: list[str] | None = None, *, - directed: bool = False, + directed: bool | None = None, dedup: bool = True, dedup_llm_backend: str | None = None, root: str | Path | None = None, -) -> nx.Graph: - """Load existing graph.json, merge new chunks into it, and save back. + multigraph: bool | None = None, +) -> nx.Graph | nx.DiGraph | nx.MultiDiGraph: + """Load existing graph.json, merge new chunks into it, and return the merged graph. + + Persistence is the caller's responsibility (e.g., via ``export.to_json``); + this function does not write back to disk. Never replaces - only grows (or prunes deleted-file nodes via prune_sources). Safe to call repeatedly: existing nodes and edges are preserved. root: if given, absolute source_file paths in new_chunks are made relative (#932). + + ``directed`` defaults to inheriting the saved graph's flag when an + existing graph.json is present, so updating a directed simple graph with + default args no longer silently downgrades it to undirected. + + ``multigraph`` likewise defaults to inheriting the saved graph's flag. When + the saved graph.json has ``multigraph: true`` the merge produces a + MultiDiGraph that preserves keyed parallel edges end-to-end — existing edges + keep their stored ``key`` (so distinct parallel edges between the same pair + survive the re-feed), new chunks are merged without collapsing parallels, and + the result round-trips back out as multigraph. There is no silent fallback to + simple-graph behavior. """ graph_path = Path(graph_path) if graph_path.exists(): @@ -346,18 +655,97 @@ def build_merge( # attrs are popped before saving in export.py, so going through the # NetworkX round-trip loses direction permanently (#760). from graphify.security import check_graph_file_size_cap + check_graph_file_size_cap(graph_path) data = json.loads(graph_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise TypeError( + f"saved graph.json at {graph_path} must be a JSON object, got {type(data).__name__}" + ) + # Honor the saved graph's `multigraph` flag so a stateful update of a + # multigraph graph.json preserves keyed parallel edges instead of + # collapsing to a simple graph. Existing edges keep their stored `key` + # when re-fed through build(multigraph=True), so distinct parallel edges + # between the same node pair survive the merge round-trip. + saved_multigraph = data.get("multigraph", False) + if saved_multigraph is not True and saved_multigraph is not False: + raise TypeError( + f"'multigraph' in {graph_path} must be a boolean, " + f"got {type(saved_multigraph).__name__} ({saved_multigraph!r})" + ) + if multigraph is None: + multigraph = saved_multigraph + elif multigraph != saved_multigraph: + print( + f"[graphify] WARNING: build_merge multigraph={multigraph} overrides " + f"saved graph.json multigraph={saved_multigraph}", + file=sys.stderr, + ) + # Honor the saved graph's `directed` flag unless the caller explicitly + # overrides. Without this, an update with default args on a directed + # graph silently downgrades it and loses edge direction on next export. + saved_directed_raw = data.get("directed", False) + if saved_directed_raw is not True and saved_directed_raw is not False: + raise TypeError( + f"'directed' in {graph_path} must be a boolean, " + f"got {type(saved_directed_raw).__name__} ({saved_directed_raw!r})" + ) + saved_directed = saved_directed_raw + if directed is None: + directed = saved_directed + elif directed != saved_directed: + print( + f"[graphify] WARNING: build_merge directed={directed} overrides " + f"saved graph.json directed={saved_directed}", + file=sys.stderr, + ) links_key = "links" if "links" in data else "edges" existing_nodes = list(data.get("nodes", [])) existing_edges = list(data.get(links_key, [])) base = [{"nodes": existing_nodes, "edges": existing_edges}] else: + if directed is None: + directed = False + if multigraph is None: + multigraph = False existing_nodes = [] base = [] - all_chunks = base + list(new_chunks) - G = build(all_chunks, directed=directed, dedup=dedup, dedup_llm_backend=dedup_llm_backend, root=root) + incoming_chunks = list(new_chunks) + incoming_has_records = any(_chunk_has_graph_records(chunk) for chunk in incoming_chunks) + dedup_diagnostics: dict = {} + if graph_path.exists() and dedup: + effective_dedup = False + if incoming_has_records: + from graphify.dedup import deduplicate_entities + + incoming = _combine_extractions(incoming_chunks) + if incoming["nodes"]: + incoming["nodes"], incoming["edges"] = deduplicate_entities( + incoming["nodes"], + incoming["edges"], + communities={}, + dedup_llm_backend=dedup_llm_backend, + diagnostics=dedup_diagnostics, + ) + all_chunks = base + [incoming] + else: + all_chunks = base + incoming_chunks + else: + effective_dedup = dedup + all_chunks = base + incoming_chunks + G = build( + all_chunks, + directed=directed, + dedup=effective_dedup, + dedup_llm_backend=dedup_llm_backend, + root=root, + multigraph=multigraph, + ) + if multigraph and dedup_diagnostics: + existing = G.graph.get("graphify_multigraph_diagnostics", {}) + existing.update(dedup_diagnostics) + G.graph["graphify_multigraph_diagnostics"] = existing # Prune nodes and edges from deleted source files if prune_sources: @@ -376,10 +764,7 @@ def build_merge( norm = _norm_source_file(p, _root_str) if norm: prune_set.add(norm) - to_remove = [ - n for n, d in G.nodes(data=True) - if d.get("source_file") in prune_set - ] + to_remove = [n for n, d in G.nodes(data=True) if d.get("source_file") in prune_set] G.remove_nodes_from(to_remove) n_files = len(prune_sources) n_nodes = len(to_remove) @@ -389,27 +774,48 @@ def build_merge( file=sys.stderr, ) - edges_to_remove = [ - (u, v) for u, v, d in G.edges(data=True) - if d.get("source_file") in prune_set - ] - if edges_to_remove: - G.remove_edges_from(edges_to_remove) + # Prune edges belonging to changed/deleted source files. On a + # MultiDiGraph a single (u, v) pair can carry MULTIPLE parallel edges + # from DIFFERENT source files, so removal MUST be keyed: drop only the + # parallel edges whose source_file is in prune_set and leave parallel + # edges from other files between the same pair intact. The two-tuple + # remove_edges_from used by simple graphs would drop only one edge per + # pair on a multigraph (first key) and could evict the wrong file's edge. + # remove_all_parallel_edges is deliberately NOT used here — it is too + # broad and would delete other-file parallels between the same pair. + if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)): + keyed_to_remove = [ + (u, v, k) + for u, v, k, d in G.edges(keys=True, data=True) + if d.get("source_file") in prune_set + ] + for u, v, k in keyed_to_remove: + G.remove_edge(u, v, key=k) + n_edges_removed = len(keyed_to_remove) + else: + edges_to_remove = [ + (u, v) for u, v, d in G.edges(data=True) if d.get("source_file") in prune_set + ] + if edges_to_remove: + G.remove_edges_from(edges_to_remove) + n_edges_removed = len(edges_to_remove) + if n_edges_removed: print( - f"[graphify] Pruned {len(edges_to_remove)} edge(s) from deleted source file(s).", + f"[graphify] Pruned {n_edges_removed} edge(s) from deleted source file(s).", file=sys.stderr, ) - if not n_nodes and not edges_to_remove: + if not n_nodes and not n_edges_removed: print( f"[graphify] {n_files} source file(s) deleted since last run — " f"no matching nodes or edges in graph, already clean.", file=sys.stderr, ) - # Safety check: refuse to shrink the graph silently (#479) - # Skip when dedup or prune_sources is active — shrinkage is intentional there. - if graph_path.exists() and not dedup and not prune_sources: + # Safety check: refuse to shrink the graph silently (#479). + # Stateful dedup applies only to incoming chunks, so only explicit pruning + # may reduce the saved graph's node count. + if graph_path.exists() and not prune_sources: existing_n = len(existing_nodes) new_n = G.number_of_nodes() if new_n < existing_n: @@ -418,6 +824,7 @@ def build_merge( f"Pass prune_sources explicitly if you intend to remove nodes." ) + # No write to graph_path here; persistence is the caller's responsibility. return G diff --git a/graphify/cache.py b/graphify/cache.py index 2052cf7aa..18c1a6d72 100644 --- a/graphify/cache.py +++ b/graphify/cache.py @@ -13,6 +13,32 @@ # absolute path ("/shared/graphify-out"). _GRAPHIFY_OUT = os.environ.get("GRAPHIFY_OUT", "graphify-out") +# Cache schema version — bump this whenever the PRODUCER (AST/semantic +# extraction) output format or content changes in a way that makes existing +# cache entries invalid. Entries are stamped with this version on write and +# revalidated on read; any entry whose recorded version != the current value +# (including legacy entries written before versioning, which have no version +# field) is treated as a cache MISS and rebuilt. +# +# Why this matters for graph profiles (PR 7): raw extraction output is +# PROFILE-INDEPENDENT. The simple-graph vs MultiDiGraph distinction is a +# build-time assembly choice (`build_from_json(multigraph=...)`), not an +# extraction-time choice — the same nodes/edges are extracted regardless of how +# they are later assembled. So the raw cache is intentionally NOT keyed by graph +# profile, and reusing it across profiles is correct and safe (it protects cache +# hit rate). This version constant is the escape hatch: if a future producer +# change ever makes cached output differ by profile (or otherwise incompatible), +# bumping CACHE_SCHEMA_VERSION forces a clean rebuild for everyone, fulfilling +# the design-doc clause "add profile/version invalidation where graph outputs +# can differ". +CACHE_SCHEMA_VERSION = 1 + +# Reserved metadata key stamped into each cache entry's JSON. Chosen with +# dunder bracketing so it cannot collide with extraction payload keys (which are +# plain identifiers like "nodes", "edges", "hyperedges", "source_file"). It is +# stripped back out on read so callers see only their original result dict. +_SCHEMA_VERSION_KEY = "__cache_schema_version__" + def _body_content(content: bytes) -> bytes: """Strip YAML frontmatter from Markdown content, returning only the body.""" @@ -20,7 +46,7 @@ def _body_content(content: bytes) -> bytes: if text.startswith("---"): end = text.find("\n---", 3) if end != -1: - return text[end + 4:].encode() + return text[end + 4 :].encode() return content @@ -86,6 +112,7 @@ def _flush_stat_index() -> None: def _normalize_path(path: Path) -> Path: """Normalize path for consistent cache keys across Windows path spellings.""" import sys + if sys.platform != "win32": return path s = str(path) @@ -120,9 +147,7 @@ def file_hash(path: Path, root: Path = Path(".")) -> str: try: st = p.stat() entry = _stat_index.get(abs_key) - if (entry - and entry.get("size") == st.st_size - and entry.get("mtime_ns") == st.st_mtime_ns): + if entry and entry.get("size") == st.st_size and entry.get("mtime_ns") == st.st_mtime_ns: return entry["hash"] except OSError: pass @@ -146,6 +171,34 @@ def file_hash(path: Path, root: Path = Path(".")) -> str: return digest +def _stamp_schema_version(result: dict) -> dict: + """Return a shallow copy of result with the current schema version stamped in. + + Used on write so every cache entry records the schema version it was + produced under. A shallow copy avoids mutating the caller's dict. + """ + stamped = dict(result) + stamped[_SCHEMA_VERSION_KEY] = CACHE_SCHEMA_VERSION + return stamped + + +def _validate_schema_version(data: dict) -> dict | None: + """Validate a loaded cache entry's schema version, returning the payload. + + Returns the result dict with the reserved version key stripped out (so + callers see only their original payload) if the recorded version matches the + current CACHE_SCHEMA_VERSION. Returns None — a cache MISS — when the version + is missing (legacy entries written before versioning) or mismatched (a stale + entry from an older/newer producer). Treating both as a miss triggers a safe + rebuild rather than silently reusing potentially-incompatible cached output. + """ + if data.get(_SCHEMA_VERSION_KEY) != CACHE_SCHEMA_VERSION: + return None + payload = dict(data) + payload.pop(_SCHEMA_VERSION_KEY, None) + return payload + + def cache_dir(root: Path = Path("."), kind: str = "ast") -> Path: """Returns graphify-out/cache/{kind}/ - creates it if needed. @@ -167,7 +220,10 @@ def load_cached(path: Path, root: Path = Path("."), kind: str = "ast") -> dict | For kind="ast", also checks the legacy flat cache/ directory so users upgrading from pre-0.5.3 don't lose their existing AST cache entries. - Returns None if no cache entry or file has changed. + Returns None if no cache entry, file has changed, or the entry's recorded + schema version does not match the current CACHE_SCHEMA_VERSION (including + pre-versioning entries that lack the field — these are treated as a miss and + rebuilt, the backward-compatible choice). """ try: h = file_hash(path, root) @@ -176,7 +232,7 @@ def load_cached(path: Path, root: Path = Path("."), kind: str = "ast") -> dict | entry = cache_dir(root, kind) / f"{h}.json" if entry.exists(): try: - return json.loads(entry.read_text(encoding="utf-8")) + return _validate_schema_version(json.loads(entry.read_text(encoding="utf-8"))) except (json.JSONDecodeError, OSError): return None # Migration fallback: check legacy flat cache/ dir for AST entries @@ -184,7 +240,7 @@ def load_cached(path: Path, root: Path = Path("."), kind: str = "ast") -> dict | legacy = Path(root).resolve() / _GRAPHIFY_OUT / "cache" / f"{h}.json" if legacy.exists(): try: - return json.loads(legacy.read_text(encoding="utf-8")) + return _validate_schema_version(json.loads(legacy.read_text(encoding="utf-8"))) except (json.JSONDecodeError, OSError): return None return None @@ -194,7 +250,10 @@ def save_cached(path: Path, result: dict, root: Path = Path("."), kind: str = "a """Save extraction result for this file. Stores as graphify-out/cache/{kind}/{hash}.json where hash = SHA256 of current file contents. - result should be a dict with 'nodes' and 'edges' lists. + result should be a dict with 'nodes' and 'edges' lists. The current + CACHE_SCHEMA_VERSION is stamped into the stored JSON (under a reserved key) + so load_cached can invalidate stale entries after a producer change; the + caller's `result` dict is not mutated. No-ops if `path` is not a regular file. Subagent-produced semantic fragments occasionally carry a directory path in `source_file`; skipping them prevents @@ -208,7 +267,7 @@ def save_cached(path: Path, result: dict, root: Path = Path("."), kind: str = "a entry = target_dir / f"{h}.json" fd, tmp_path = tempfile.mkstemp(dir=target_dir, prefix=f"{h}.", suffix=".tmp") try: - os.write(fd, json.dumps(result).encode()) + os.write(fd, json.dumps(_stamp_schema_version(result)).encode()) os.close(fd) try: os.replace(tmp_path, entry) @@ -216,6 +275,7 @@ def save_cached(path: Path, result: dict, root: Path = Path("."), kind: str = "a # Windows: os.replace can fail with WinError 5 if the target is # briefly locked. Fall back to copy-then-delete. import shutil + shutil.copy2(tmp_path, entry) os.unlink(tmp_path) except Exception: @@ -313,7 +373,7 @@ def save_semantic_cache( src = e.get("source_file", "") if src: by_file[src]["edges"].append(e) - for h in (hyperedges or []): + for h in hyperedges or []: src = h.get("source_file", "") if src: by_file[src]["hyperedges"].append(h) diff --git a/graphify/callflow_html.py b/graphify/callflow_html.py index 6195adb98..9f67a9c07 100644 --- a/graphify/callflow_html.py +++ b/graphify/callflow_html.py @@ -28,6 +28,7 @@ from pathlib import Path from collections import Counter, defaultdict from datetime import datetime, timezone +from typing import Any, cast from html import escape @@ -89,7 +90,8 @@ # 2. Data loading and normalization helpers # ────────────────────────────────────────────── -def read_json(path: str | Path, default=None): + +def read_json(path: str | Path | None, default: Any = None) -> Any: """Read JSON with a useful error message.""" if not path: return default @@ -204,7 +206,9 @@ def normalize_edge(raw: dict, index: int) -> dict | None: if not source or not target: return None - relation = first_present(edge, "relation", "type", "kind", "label", "predicate", default="relates") + relation = first_present( + edge, "relation", "type", "kind", "label", "predicate", default="relates" + ) confidence = first_present(edge, "confidence", "evidence", "provenance", default="EXTRACTED") score = first_present(edge, "confidence_score", "score", "weight", "probability", default=1.0) @@ -254,6 +258,7 @@ def load_graph(path: str | Path) -> tuple: """Load graph.json. Returns normalized (nodes, edges, hyperedges, metadata).""" if path: from graphify.security import check_graph_file_size_cap + try: check_graph_file_size_cap(Path(path)) except ValueError as exc: @@ -262,16 +267,32 @@ def load_graph(path: str | Path) -> tuple: if not isinstance(data, dict): raise SystemExit(f"ERROR: graph file must contain a JSON object: {path}") - graph_block = data.get("graph") if isinstance(data.get("graph"), dict) else {} - meta_block = data.get("metadata") if isinstance(data.get("metadata"), dict) else {} + graph_block: dict[str, Any] = ( + cast(dict[str, Any], data.get("graph")) if isinstance(data.get("graph"), dict) else {} + ) + meta_block: dict[str, Any] = ( + cast(dict[str, Any], data.get("metadata")) if isinstance(data.get("metadata"), dict) else {} + ) node_link = _node_link_payload(data) if node_link: raw_nodes, raw_edges = node_link else: - raw_nodes = first_list(data.get("nodes"), data.get("vertices"), graph_block.get("nodes"), graph_block.get("vertices")) - raw_edges = first_list(data.get("links"), data.get("edges"), graph_block.get("links"), graph_block.get("edges")) - hyperedges = first_list(data.get("hyperedges"), graph_block.get("hyperedges"), data.get("groups"), graph_block.get("groups")) + raw_nodes = first_list( + data.get("nodes"), + data.get("vertices"), + graph_block.get("nodes"), + graph_block.get("vertices"), + ) + raw_edges = first_list( + data.get("links"), data.get("edges"), graph_block.get("links"), graph_block.get("edges") + ) + hyperedges = first_list( + data.get("hyperedges"), + graph_block.get("hyperedges"), + data.get("groups"), + graph_block.get("groups"), + ) nodes = [normalize_node(n, i) for i, n in enumerate(raw_nodes) if isinstance(n, dict)] edges = [] @@ -282,9 +303,16 @@ def load_graph(path: str | Path) -> tuple: if edge: edges.append(edge) - meta = dict(graph_block) + meta: dict[str, Any] = dict(graph_block) meta.update(meta_block) - for key in ("built_at_commit", "commit", "project_name", "repo", "repository", "language_breakdown"): + for key in ( + "built_at_commit", + "commit", + "project_name", + "repo", + "repository", + "language_breakdown", + ): if data.get(key) and not meta.get(key): meta[key] = data.get(key) if meta.get("commit") and not meta.get("built_at_commit"): @@ -331,6 +359,7 @@ def load_report(path: str | Path | None) -> str: # 3. Mermaid-safe label helpers # ────────────────────────────────────────────── + def safe_mermaid_text(text: str) -> str: """Sanitize text for use inside a Mermaid node label. @@ -344,10 +373,10 @@ def safe_mermaid_text(text: str) -> str: """ text = str(text or "") text = text.replace('"', "'") - text = text.replace('`', '') - text = text.replace('#', '') - text = text.replace('|', ' ') - text = text.replace('{', '').replace('}', '') + text = text.replace("`", "") + text = text.replace("#", "") + text = text.replace("|", " ") + text = text.replace("{", "").replace("}", "") text = text.replace("->>", " to ").replace("-->", " to ").replace("->", " to ") text = " ".join(text.split()) return escape(text, quote=False) @@ -424,7 +453,9 @@ def resolve_graphify_paths(args) -> dict: project_root = graphify_out.parent if graphify_out.name == "graphify-out" else base graph = Path(args.graph).expanduser() if args.graph else graphify_out / "graph.json" report = Path(args.report).expanduser() if args.report else graphify_out / "GRAPH_REPORT.md" - labels = Path(args.labels).expanduser() if args.labels else graphify_out / ".graphify_labels.json" + labels = ( + Path(args.labels).expanduser() if args.labels else graphify_out / ".graphify_labels.json" + ) sections = Path(args.sections).expanduser() if args.sections else None return { "base": project_root, @@ -509,13 +540,39 @@ def node_kind(node: dict) -> str: if any(word in label for word in ("async", "await", "stream", "sse")): return "async" raw_label = str(node.get("label") or "") - hook_like = raw_label.startswith("use") and len(raw_label) > 3 and (raw_label[3].isupper() or raw_label[3] in "_-") - if any(word in label for word in ("component", "props", "hook", "store")) or hook_like or source_file.endswith((".tsx", ".jsx", ".vue", ".svelte")): + hook_like = ( + raw_label.startswith("use") + and len(raw_label) > 3 + and (raw_label[3].isupper() or raw_label[3] in "_-") + ) + if ( + any(word in label for word in ("component", "props", "hook", "store")) + or hook_like + or source_file.endswith((".tsx", ".jsx", ".vue", ".svelte")) + ): return "ui" raw = raw_label if raw[:1].isupper() and not raw.endswith("()"): return "klass" - if raw.endswith((".py", ".ts", ".tsx", ".js", ".jsx", ".go", ".rs", ".java", ".kt", ".rb", ".php", ".cs", ".swift", ".vue", ".svelte")): + if raw.endswith( + ( + ".py", + ".ts", + ".tsx", + ".js", + ".jsx", + ".go", + ".rs", + ".java", + ".kt", + ".rb", + ".php", + ".cs", + ".swift", + ".vue", + ".svelte", + ) + ): return "module" return "function" @@ -632,6 +689,7 @@ def mermaid_class_defs() -> list: # 4. Community and section indexing # ────────────────────────────────────────────── + def build_community_index(nodes: list) -> dict: """Map community_id (str) -> list of nodes.""" idx = defaultdict(list) @@ -652,7 +710,9 @@ def html_anchor_id(raw: str, fallback: str, used: set) -> str: base = base[:48].strip("-") or "section" candidate = base if candidate in used: - candidate = f"{base}-{hashlib.sha1(raw.encode('utf-8'), usedforsecurity=False).hexdigest()[:6]}" + candidate = ( + f"{base}-{hashlib.sha1(raw.encode('utf-8'), usedforsecurity=False).hexdigest()[:6]}" + ) suffix = 2 while candidate in used: candidate = f"{base}-{suffix}" @@ -688,11 +748,13 @@ def normalize_sections(sections: list, lang: str) -> list: continue sid = html_anchor_id(raw_id, f"section-{index}", used) - normalized.append({ - "id": sid, - "name": raw_name, - "communities": normalize_communities(raw.get("communities", raw.get("community"))), - }) + normalized.append( + { + "id": sid, + "name": raw_name, + "communities": normalize_communities(raw.get("communities", raw.get("community"))), + } + ) return normalized @@ -712,9 +774,22 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "提取管线", "Extraction Pipeline", { - "extract", "extractor", "tree", "sitter", "parser", "language", - "python", "javascript", "typescript", "rust", "java", "go", - "ast", "calls", "imports", "multilang", + "extract", + "extractor", + "tree", + "sitter", + "parser", + "language", + "python", + "javascript", + "typescript", + "rust", + "java", + "go", + "ast", + "calls", + "imports", + "multilang", }, ), ( @@ -722,8 +797,17 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "图谱构建", "Graph Build", { - "build", "graph", "merge", "dedup", "node", "edge", "hyperedge", - "json", "schema", "normalize", "confidence", + "build", + "graph", + "merge", + "dedup", + "node", + "edge", + "hyperedge", + "json", + "schema", + "normalize", + "confidence", }, ), ( @@ -731,8 +815,18 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "分析聚类", "Analysis & Clustering", { - "cluster", "community", "leiden", "cohesion", "analyze", "god", - "surprise", "question", "query", "path", "explain", "benchmark", + "cluster", + "community", + "leiden", + "cohesion", + "analyze", + "god", + "surprise", + "question", + "query", + "path", + "explain", + "benchmark", }, ), ( @@ -740,8 +834,18 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "输出文档", "Outputs & Docs", { - "export", "html", "wiki", "obsidian", "canvas", "svg", "graphml", - "report", "callflow", "mermaid", "tree", "documentation", + "export", + "html", + "wiki", + "obsidian", + "canvas", + "svg", + "graphml", + "report", + "callflow", + "mermaid", + "tree", + "documentation", }, ), ( @@ -749,9 +853,20 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "CLI 与技能安装", "CLI & Skill Installers", { - "main", "install", "uninstall", "skill", "agent", "claude", - "codex", "opencode", "aider", "copilot", "kiro", "vscode", - "hook", "command", + "main", + "install", + "uninstall", + "skill", + "agent", + "claude", + "codex", + "opencode", + "aider", + "copilot", + "kiro", + "vscode", + "hook", + "command", }, ), ( @@ -759,9 +874,21 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "摄取与增量更新", "Ingestion & Updates", { - "ingest", "fetch", "download", "url", "html", "markdown", - "cache", "manifest", "watch", "update", "incremental", - "transcribe", "video", "audio", "google", + "ingest", + "fetch", + "download", + "url", + "html", + "markdown", + "cache", + "manifest", + "watch", + "update", + "incremental", + "transcribe", + "video", + "audio", + "google", }, ), ( @@ -769,8 +896,17 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "服务 API", "Serving API", { - "serve", "api", "request", "response", "endpoint", "router", - "handle", "upload", "search", "delete", "enrich", + "serve", + "api", + "request", + "response", + "endpoint", + "router", + "handle", + "upload", + "search", + "delete", + "enrich", }, ), ( @@ -778,8 +914,17 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "安全与全局图", "Security & Global Graph", { - "security", "safe", "ssrf", "xss", "path", "traversal", - "global", "prefix", "prune", "repo", "clone", + "security", + "safe", + "ssrf", + "xss", + "path", + "traversal", + "global", + "prefix", + "prune", + "repo", + "clone", }, ), ( @@ -787,8 +932,14 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "测试与样例", "Tests & Fixtures", { - "test", "tests", "fixture", "fixtures", "sample", "assert", - "pytest", "mock", + "test", + "tests", + "fixture", + "fixtures", + "sample", + "assert", + "pytest", + "mock", }, ), ] @@ -826,14 +977,24 @@ def _rank_grouped_sections(grouped: dict, max_sections: int) -> tuple[list, list return selected, overflow_communities -def derive_sections_from_communities(nodes: list, labels: dict, lang: str, max_sections: int) -> list: +def derive_sections_from_communities( + nodes: list, labels: dict, lang: str, max_sections: int +) -> list: """Derive architecture-oriented sections when no sections JSON is supplied.""" comm_idx = build_community_index(nodes) - sections = [{"id": "overview", "name": pick_text(lang, "架构总览", "Architecture Overview"), "communities": []}] + sections = [ + { + "id": "overview", + "name": pick_text(lang, "架构总览", "Architecture Overview"), + "communities": [], + } + ] grouped = {} unassigned = [] - for cid, community_nodes in sorted(comm_idx.items(), key=lambda item: (-len(item[1]), str(item[0]))): + for cid, community_nodes in sorted( + comm_idx.items(), key=lambda item: (-len(item[1]), str(item[0])) + ): label = label_for_community(cid, labels, community_nodes, lang) text = _community_text(community_nodes, label) best = None @@ -861,7 +1022,9 @@ def derive_sections_from_communities(nodes: list, labels: dict, lang: str, max_s else: unassigned.append((cid, community_nodes, label)) - selected, overflow_communities = _rank_grouped_sections(grouped, max(1, int(max_sections or 15)) - 1) + selected, overflow_communities = _rank_grouped_sections( + grouped, max(1, int(max_sections or 15)) - 1 + ) sections.extend( {"id": sec["id"], "name": sec["name"], "communities": sec["communities"]} for sec in selected @@ -869,15 +1032,19 @@ def derive_sections_from_communities(nodes: list, labels: dict, lang: str, max_s remaining_slots = max(0, int(max_sections or 15) - (len(sections) - 1) - 1) for cid, community_nodes, label in unassigned[:remaining_slots]: - sections.append({"id": str(label or f"community-{cid}"), "name": label, "communities": [cid]}) + sections.append( + {"id": str(label or f"community-{cid}"), "name": label, "communities": [cid]} + ) other_communities = overflow_communities + [cid for cid, _, _ in unassigned[remaining_slots:]] if other_communities: - sections.append({ - "id": "other", - "name": pick_text(lang, "其他", "Other"), - "communities": other_communities, - }) + sections.append( + { + "id": "other", + "name": pick_text(lang, "其他", "Other"), + "communities": other_communities, + } + ) return sections @@ -905,6 +1072,7 @@ def node_in_section(node_id: str, section_node_ids: set) -> bool: # 5. Edge analysis # ────────────────────────────────────────────── + def classify_edges(edges: list, section_nodes_map: dict) -> dict: """Classify edges as intra-section or inter-section. @@ -958,14 +1126,15 @@ def should_include_edge(edge: dict) -> bool: # 6. Mermaid diagram generators # ────────────────────────────────────────────── -def node_degree_scores(edges: list) -> Counter: + +def node_degree_scores(edges: list) -> dict[str, float]: """Score nodes by useful edge participation.""" - scores = Counter() + scores: defaultdict[str, float] = defaultdict(float) for edge in edges: score = edge_score(edge) - scores[edge.get("source", "")] += score - scores[edge.get("target", "")] += score - return scores + scores[str(edge.get("source", ""))] += score + scores[str(edge.get("target", ""))] += score + return dict(scores) def node_importance(node: dict) -> float: @@ -1037,7 +1206,10 @@ def fallback_key(node: dict) -> tuple: def node_label(node: dict) -> str: """Build a readable Mermaid node label.""" - label = humanize_label(node.get("label") or node.get("id"), node.get("source_file", "")) + label = humanize_label( + str(node.get("label") or node.get("id") or ""), + node.get("source_file", ""), + ) source_file = safe_file_path(node.get("source_file", "")) if source_file and not label.endswith(Path(source_file).name): return f"{safe_mermaid_text(label)}
{safe_mermaid_text(source_file)}" @@ -1056,7 +1228,9 @@ def group_nodes_by_file(nodes: list) -> dict: def section_edge_summary(classified_edges: dict) -> dict: """Aggregate inter-section edge counts and relation names.""" node_section = classified_edges.get("node_section", {}) - summary = defaultdict(lambda: {"count": 0, "relations": Counter()}) + summary: defaultdict[tuple[Any, Any], dict[str, Any]] = defaultdict( + lambda: {"count": 0, "relations": Counter()} + ) for edge in classified_edges.get("inter", []): if not should_include_edge(edge): continue @@ -1070,9 +1244,14 @@ def section_edge_summary(classified_edges: dict) -> dict: return summary -def generate_overview_graph(sections: list, section_nodes_map: dict, - classified_edges: dict, labels: dict, lang: str, - diagram_scale: float) -> str: +def generate_overview_graph( + sections: list, + section_nodes_map: dict, + classified_edges: dict, + labels: dict, + lang: str, + diagram_scale: float, +) -> str: """Generate a readable section-level architecture overview.""" lines = [mermaid_init(diagram_scale, "LR")] section_defs = [sec for sec in sections if sec["id"] != "overview"] @@ -1088,7 +1267,9 @@ def generate_overview_graph(sections: list, section_nodes_map: dict, lines.append(f" class {sid} module;") aggregated = section_edge_summary(classified_edges) - for (src, tgt), data in sorted(aggregated.items(), key=lambda item: item[1]["count"], reverse=True)[:12]: + for (src, tgt), data in sorted( + aggregated.items(), key=lambda item: item[1]["count"], reverse=True + )[:12]: src_id = mermaid_section_id(src) tgt_id = mermaid_section_id(tgt) relation, _ = data["relations"].most_common(1)[0] @@ -1099,19 +1280,29 @@ def generate_overview_graph(sections: list, section_nodes_map: dict, if not aggregated and len(section_defs) > 1: for prev, cur in zip(section_defs, section_defs[1:]): - lines.append(f" {mermaid_section_id(prev['id'])} -.-> {mermaid_section_id(cur['id'])}") + lines.append( + f" {mermaid_section_id(prev['id'])} -.-> {mermaid_section_id(cur['id'])}" + ) lines.extend(mermaid_class_defs()) return "\n".join(lines) -def generate_section_flowchart(section_id: str, section_name: str, - nodes: list, edges: list, lang: str, - diagram_scale: float, max_nodes: int, - max_edges: int) -> str: +def generate_section_flowchart( + section_id: str, + section_name: str, + nodes: list, + edges: list, + lang: str, + diagram_scale: float, + max_nodes: int, + max_edges: int, +) -> str: """Generate a compact, human-readable call-flow chart for a section.""" lines = [mermaid_init(diagram_scale, "LR")] - lines.append(f" %% Section: {safe_mermaid_text(section_name)} ({len(nodes)} nodes, {len(edges)} edges)") + lines.append( + f" %% Section: {safe_mermaid_text(section_name)} ({len(nodes)} nodes, {len(edges)} edges)" + ) if not nodes: empty_label = pick_text(lang, f"{section_name} - 无节点", f"{section_name} - no nodes") @@ -1122,12 +1313,14 @@ def generate_section_flowchart(section_id: str, section_name: str, selected_nodes = select_diagram_nodes(nodes, edges, max_nodes) selected_ids = {node.get("id") for node in selected_nodes} visible_edges = [ - edge for edge in preferred_edges(edges, allow_structure=False) + edge + for edge in preferred_edges(edges, allow_structure=False) if edge.get("source") in selected_ids and edge.get("target") in selected_ids ] if not visible_edges: visible_edges = [ - edge for edge in preferred_edges(edges, allow_structure=True) + edge + for edge in preferred_edges(edges, allow_structure=True) if edge.get("source") in selected_ids and edge.get("target") in selected_ids ] @@ -1160,7 +1353,9 @@ def generate_section_flowchart(section_id: str, section_name: str, omitted_nodes = max(0, len(nodes) - len(selected_nodes)) omitted_edges = max(0, len(visible_edges) - included) if omitted_nodes or omitted_edges: - lines.append(f" %% Omitted for readability: {omitted_nodes} nodes, {omitted_edges} edges") + lines.append( + f" %% Omitted for readability: {omitted_nodes} nodes, {omitted_edges} edges" + ) lines.extend(class_lines) lines.extend(mermaid_class_defs()) return "\n".join(lines) @@ -1170,6 +1365,7 @@ def generate_section_flowchart(section_id: str, section_name: str, # 7. HTML generators # ────────────────────────────────────────────── + def generate_nav(sections: list) -> str: """Generate the sticky navigation bar.""" links = [] @@ -1186,21 +1382,33 @@ def node_display_name(node: dict | None, fallback: str = "") -> str: return humanize_label(label, node.get("source_file", "")) -def format_node_refs(node_ids: set, node_by_id: dict, lang: str, empty_text: str, limit: int = 3) -> str: +def format_node_refs( + node_ids: set, node_by_id: dict, lang: str, empty_text: str, limit: int = 3 +) -> str: """Render node references as readable labels instead of internal IDs.""" if not node_ids: return escape(empty_text) parts = [] - for nid in sorted(node_ids, key=lambda item: node_display_name(node_by_id.get(item), item).lower())[:limit]: + for nid in sorted( + node_ids, key=lambda item: node_display_name(node_by_id.get(item), item).lower() + )[:limit]: node = node_by_id.get(nid) label = node_display_name(node, nid) source = safe_file_path((node or {}).get("source_file", "")) if source: - parts.append(f"{escape(label)}
{escape(source)}") + parts.append( + f'{escape(label)}
{escape(source)}' + ) else: parts.append(f"{escape(label)}") if len(node_ids) > limit: - parts.append(escape(pick_text(lang, f"+{len(node_ids) - limit} 个更多", f"+{len(node_ids) - limit} more"))) + parts.append( + escape( + pick_text( + lang, f"+{len(node_ids) - limit} 个更多", f"+{len(node_ids) - limit} more" + ) + ) + ) return "
".join(parts) @@ -1297,31 +1505,71 @@ def _describe_node(label: str, source_file: str, file_type: str, lang: str) -> s if file_type == "rationale": return pick_text(lang, f"设计说明:{label}", f"Design note for {label}.") if file_type == "document": - return pick_text(lang, f"文档入口,描述 {label} 相关能力。", f"Documentation node describing {label}.") + return pick_text( + lang, f"文档入口,描述 {label} 相关能力。", f"Documentation node describing {label}." + ) if label.endswith(".py") or label.endswith(".tsx") or label.endswith(".ts"): - return pick_text(lang, f"{source} 中的模块文件,承载该层主要实现。", f"Module file in {source}.") + return pick_text( + lang, f"{source} 中的模块文件,承载该层主要实现。", f"Module file in {source}." + ) if "config" in lower: - return pick_text(lang, "读取、解析或持久化项目配置。", "Reads, resolves, or persists project configuration.") + return pick_text( + lang, + "读取、解析或持久化项目配置。", + "Reads, resolves, or persists project configuration.", + ) if "scan" in lower: - return pick_text(lang, "触发项目扫描或处理扫描状态。", "Starts scanning or handles scan status.") + return pick_text( + lang, "触发项目扫描或处理扫描状态。", "Starts scanning or handles scan status." + ) if "ingest" in lower or "clone" in lower or "git" in lower: - return pick_text(lang, "把本地目录或远程仓库转换为分析上下文。", "Turns a local path or remote repository into analysis context.") + return pick_text( + lang, + "把本地目录或远程仓库转换为分析上下文。", + "Turns a local path or remote repository into analysis context.", + ) if "prompt" in lower: - return pick_text(lang, "构造发送给 LLM 的结构化提示。", "Builds structured prompts for model calls.") + return pick_text( + lang, "构造发送给 LLM 的结构化提示。", "Builds structured prompts for model calls." + ) if "analy" in lower: - return pick_text(lang, "编排分析流程并产出结构化文档数据。", "Orchestrates analysis and returns structured documentation data.") + return pick_text( + lang, + "编排分析流程并产出结构化文档数据。", + "Orchestrates analysis and returns structured documentation data.", + ) if "graph" in lower or "dependency" in lower: - return pick_text(lang, "构建依赖关系并提供排序或图形化数据。", "Builds dependency relationships and graph data.") + return pick_text( + lang, + "构建依赖关系并提供排序或图形化数据。", + "Builds dependency relationships and graph data.", + ) if "export" in lower or "markdown" in lower or "html" in lower: - return pick_text(lang, "将文档数据导出为目标格式。", "Exports documentation data to a target format.") + return pick_text( + lang, "将文档数据导出为目标格式。", "Exports documentation data to a target format." + ) if "chat" in lower or "rag" in lower or "retrieve" in lower: - return pick_text(lang, "支撑检索增强问答或流式聊天。", "Supports retrieval-augmented Q&A or streaming chat.") + return pick_text( + lang, + "支撑检索增强问答或流式聊天。", + "Supports retrieval-augmented Q&A or streaming chat.", + ) if "wiki" in lower or "page" in lower or "sidebar" in lower: - return pick_text(lang, "组织文档页面、侧边栏或内容读取。", "Organizes documentation pages, navigation, or content lookup.") + return pick_text( + lang, + "组织文档页面、侧边栏或内容读取。", + "Organizes documentation pages, navigation, or content lookup.", + ) if "cache" in lower or "hash" in lower: - return pick_text(lang, "缓存分析结果或生成缓存键。", "Caches analysis results or computes cache keys.") + return pick_text( + lang, "缓存分析结果或生成缓存键。", "Caches analysis results or computes cache keys." + ) if "test" in lower: - return pick_text(lang, "验证导入、入口点或版本等基础行为。", "Verifies imports, entry points, or version behavior.") + return pick_text( + lang, + "验证导入、入口点或版本等基础行为。", + "Verifies imports, entry points, or version behavior.", + ) return pick_text(lang, f"{source} 中的 {label} 节点。", f"{label} node in {source}.") @@ -1370,7 +1618,9 @@ def derive_flow_chain(sections: list, classified_edges: dict) -> str: seen = {start} current = start while len(chain) < min(7, len(order)): - candidates = [(count, tgt) for tgt, count in outgoing.get(current, {}).items() if tgt not in seen] + candidates = [ + (count, tgt) for tgt, count in outgoing.get(current, {}).items() if tgt not in seen + ] if candidates: _, nxt = max(candidates) else: @@ -1384,9 +1634,14 @@ def derive_flow_chain(sections: list, classified_edges: dict) -> str: return " -> ".join(section_names.get(sid, sid) for sid in chain) -def generate_overview_cards(meta: dict, report_text: str, sections: list, - section_nodes_map: dict, classified_edges: dict, - lang: str) -> str: +def generate_overview_cards( + meta: dict, + report_text: str, + sections: list, + section_nodes_map: dict, + classified_edges: dict, + lang: str, +) -> str: """Generate generic overview cards.""" rows = [] for sec in sections: @@ -1400,14 +1655,18 @@ def generate_overview_cards(meta: dict, report_text: str, sections: list, flow = derive_flow_chain(sections, classified_edges) layer_title = pick_text(lang, "架构层次", "Architecture Layers") - layer_cols = pick_text(lang, "层节点社区", "LayerNodesCommunities") + layer_cols = pick_text( + lang, + "层节点社区", + "LayerNodesCommunities", + ) flow_title = pick_text(lang, "核心数据流", "Core Flow") return f"""

{layer_title}

{layer_cols} - {''.join(rows)} + {"".join(rows)}
@@ -1421,12 +1680,40 @@ def section_keywords(nodes: list, limit: int = 5) -> list: """Pick representative words from labels and file names.""" counts = Counter() stopwords = { - "the", "and", "for", "with", "from", "this", "that", "class", "function", - "method", "file", "src", "lib", "core", "index", "main", "init", "py", - "ts", "tsx", "js", "jsx", "go", "rs", "java", "html", "css", + "the", + "and", + "for", + "with", + "from", + "this", + "that", + "class", + "function", + "method", + "file", + "src", + "lib", + "core", + "index", + "main", + "init", + "py", + "ts", + "tsx", + "js", + "jsx", + "go", + "rs", + "java", + "html", + "css", } for node in nodes: - text = f"{node.get('label', '')} {node.get('source_file', '')}".replace("/", " ").replace("_", " ").replace("-", " ") + text = ( + f"{node.get('label', '')} {node.get('source_file', '')}".replace("/", " ") + .replace("_", " ") + .replace("-", " ") + ) for raw in text.split(): word = "".join(ch for ch in raw.lower() if ch.isalnum()) if len(word) < 3 or word in stopwords: @@ -1475,10 +1762,16 @@ def generate_section_cards(sec: dict, nodes: list, section_edges: list, lang: st else: file_rows = f'{escape(pick_text(lang, "无源文件映射", "No source file mapping"))}' - relation_counts = Counter(edge.get("relation", "relates") for edge in section_edges if should_include_edge(edge)) - relation_text = ", ".join(f"{relation_label(rel, lang)} x{count}" for rel, count in relation_counts.most_common(4)) + relation_counts = Counter( + edge.get("relation", "relates") for edge in section_edges if should_include_edge(edge) + ) + relation_text = ", ".join( + f"{relation_label(rel, lang)} x{count}" for rel, count in relation_counts.most_common(4) + ) if not relation_text: - relation_text = pick_text(lang, "未检测到高置信调用边", "No high-confidence call edges detected") + relation_text = pick_text( + lang, "未检测到高置信调用边", "No high-confidence call edges detected" + ) note = pick_text( lang, f"本节由 graphify 社区聚类生成。关系概况:{relation_text}。图表优先展示高置信、跨节点调用或使用关系,完整节点清单位于表格中。", @@ -1506,6 +1799,7 @@ def generate_section_cards(sec: dict, nodes: list, section_edges: list, lang: st # 8. Main entry point # ────────────────────────────────────────────── + class CallflowOptions: """Options for call-flow architecture HTML generation.""" @@ -1615,27 +1909,39 @@ def write_callflow_html( # Load data nodes, edges, hyperedges, meta = load_graph(paths["graph"]) - labels = load_labels(paths["labels"]) - lang = detect_lang(args.lang, nodes, labels) + loaded_labels = load_labels(paths["labels"]) + lang = detect_lang(args.lang, nodes, loaded_labels) if paths["sections"]: - sections = load_sections(paths["sections"]) + flow_sections = load_sections(paths["sections"]) else: - sections = derive_sections_from_communities(nodes, labels, lang, args.max_sections) - sections = normalize_sections(sections, lang) + flow_sections = derive_sections_from_communities( + nodes, loaded_labels, lang, args.max_sections + ) + flow_sections = normalize_sections(flow_sections, lang) report_text = load_report(paths["report"]) if not nodes: raise ValueError("graph.json contains 0 nodes") - if len(sections) <= 1: + if len(flow_sections) <= 1: raise ValueError("no sections defined") if verbose and len(nodes) >= 5000: - print("WARNING: Large graph -- Mermaid rendering may be slow. Consider --max-sections 5.", file=sys.stderr) + print( + "WARNING: Large graph -- Mermaid rendering may be slow. Consider --max-sections 5.", + file=sys.stderr, + ) node_ids = {node.get("id") for node in nodes} - missing_endpoint_edges = [edge for edge in edges if edge.get("source") not in node_ids or edge.get("target") not in node_ids] + missing_endpoint_edges = [ + edge + for edge in edges + if edge.get("source") not in node_ids or edge.get("target") not in node_ids + ] if verbose and missing_endpoint_edges: - print(f"WARNING: {len(missing_endpoint_edges)} edges reference nodes not present in graph.json.", file=sys.stderr) + print( + f"WARNING: {len(missing_endpoint_edges)} edges reference nodes not present in graph.json.", + file=sys.stderr, + ) meta["project_name"] = infer_project_name(str(paths["graph"]), meta) meta["node_count"] = len(nodes) @@ -1650,13 +1956,13 @@ def write_callflow_html( output_path = paths["graphify_out"] / f"{safe_filename(meta['project_name'])}-callflow.html" if verbose: - print(f"Loaded: {len(nodes)} nodes, {len(edges)} edges, {len(sections)} sections") + print(f"Loaded: {len(nodes)} nodes, {len(edges)} edges, {len(flow_sections)} sections") print(f"Graph: {paths['graph']}") # Build index comm_idx = build_community_index(nodes) meta["community_count"] = len(comm_idx) - section_nodes_map = build_section_node_map(sections, comm_idx) + section_nodes_map = build_section_node_map(flow_sections, comm_idx) classified = classify_edges(edges, section_nodes_map) # Build HTML @@ -1684,19 +1990,31 @@ def write_callflow_html( """) # Header + nav - html.append(generate_header(sections, meta, lang)) + html.append(generate_header(flow_sections, meta, lang)) # ── Architecture Overview (Section "overview") ── - overview_name = sections[0].get("name", "Architecture Overview") if sections else "Architecture Overview" + overview_name = ( + flow_sections[0].get("name", "Architecture Overview") + if flow_sections + else "Architecture Overview" + ) html.append(f"""

1. {escape(str(overview_name))}

""") - html.append(generate_overview_graph(sections, section_nodes_map, classified, labels, lang, args.diagram_scale)) + html.append( + generate_overview_graph( + flow_sections, section_nodes_map, classified, loaded_labels, lang, args.diagram_scale + ) + ) html.append("""
""") - html.append(generate_overview_cards(meta, report_text, sections, section_nodes_map, classified, lang)) + html.append( + generate_overview_cards( + meta, report_text, flow_sections, section_nodes_map, classified, lang + ) + ) report_card = _report_highlights(report_text, lang) if report_card: html.append(f'
\n {report_card}\n
') @@ -1704,7 +2022,7 @@ def write_callflow_html( # ── Per-section content ── section_num = 1 # overview was #1 - for sec in sections: + for sec in flow_sections: if sec["id"] == "overview": continue section_num += 1 @@ -1769,7 +2087,7 @@ def write_callflow_html( html.append("
\n
") # ── Section: Statistics ── - total_sections = sum(1 for s in sections if s["id"] != "overview") + total_sections = sum(1 for s in flow_sections if s["id"] != "overview") html.append(f"""

Project Statistics

@@ -1786,9 +2104,9 @@ def write_callflow_html(

Edge Confidence

- - - + + +
EXTRACTED{sum(1 for e in edges if e.get('confidence') == 'EXTRACTED')}
INFERRED{sum(1 for e in edges if e.get('confidence') == 'INFERRED')}
AMBIGUOUS{sum(1 for e in edges if e.get('confidence') == 'AMBIGUOUS')}
EXTRACTED{sum(1 for e in edges if e.get("confidence") == "EXTRACTED")}
INFERRED{sum(1 for e in edges if e.get("confidence") == "INFERRED")}
AMBIGUOUS{sum(1 for e in edges if e.get("confidence") == "AMBIGUOUS")}
@@ -1796,8 +2114,8 @@ def write_callflow_html( # ── Footer ── html.append(f"""
-

{escape(str(meta.get('project_name', 'Project')))} — Architecture Documentation

-

Generated: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')} · graphify callflow-html

+

{escape(str(meta.get("project_name", "Project")))} — Architecture Documentation

+

Generated: {datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")} · graphify callflow-html

""") @@ -1967,11 +2285,13 @@ def write_callflow_html( # Summary mermaid_count = output.count('
') table_count = output.count('') - section_count = output.count('

dict[str, int]: """Run community detection. Returns {node_id: community_id}. @@ -45,7 +68,9 @@ def _partition(G: nx.Graph, resolution: float = 1.0) -> dict[str, int]: stable.add_edge(src, tgt, **attrs) try: - from graspologic.partition import leiden + with _suppress_graspologic_dependency_warnings(): + from graspologic.partition import leiden + lsig = inspect.signature(leiden).parameters kwargs: dict = {} if "random_seed" in lsig: @@ -59,12 +84,12 @@ def _partition(G: nx.Graph, resolution: float = 1.0) -> dict[str, int]: old_stderr = sys.stderr try: sys.stderr = io.StringIO() - with _suppress_output(): + with _suppress_graspologic_dependency_warnings(), _suppress_output(): result = leiden(stable, **kwargs) finally: sys.stderr = old_stderr return result - except ImportError: + except (ImportError, SyntaxError, Warning): pass # Fallback: networkx louvain (available since networkx 2.7). @@ -77,10 +102,10 @@ def _partition(G: nx.Graph, resolution: float = 1.0) -> dict[str, int]: return {node: cid for cid, nodes in enumerate(communities) for node in nodes} -_MAX_COMMUNITY_FRACTION = 0.25 # communities larger than 25% of graph get split -_MIN_SPLIT_SIZE = 10 # only split if community has at least this many nodes -_COHESION_SPLIT_THRESHOLD = 0.05 # re-split communities with cohesion below this -_COHESION_SPLIT_MIN_SIZE = 50 # only cohesion-split if community has at least this many nodes +_MAX_COMMUNITY_FRACTION = 0.25 # communities larger than 25% of graph get split +_MIN_SPLIT_SIZE = 10 # only split if community has at least this many nodes +_COHESION_SPLIT_THRESHOLD = 0.05 # re-split communities with cohesion below this +_COHESION_SPLIT_MIN_SIZE = 50 # only cohesion-split if community has at least this many nodes def cluster( @@ -106,7 +131,11 @@ def cluster( """ if G.number_of_nodes() == 0: return {} - if G.is_directed(): + # Project multigraphs to simple undirected graph so parallel edges + # don't inflate Louvain/Leiden community detection. + if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)): + G = project_for_community(G) + elif G.is_directed(): G = G.to_undirected() if G.number_of_edges() == 0: return {i: [n] for i, n in enumerate(sorted(G.nodes))} @@ -171,7 +200,10 @@ def cluster( # that bridge otherwise-unrelated subsystems (e.g. CLAUDE.md connected to everything). second_pass: list[list[str]] = [] for nodes in final_communities: - if len(nodes) >= _COHESION_SPLIT_MIN_SIZE and cohesion_score(G, nodes) < _COHESION_SPLIT_THRESHOLD: + if ( + len(nodes) >= _COHESION_SPLIT_MIN_SIZE + and cohesion_score(G, nodes) < _COHESION_SPLIT_THRESHOLD + ): splits = _split_community(G, nodes) second_pass.extend(splits if len(splits) > 1 else [nodes]) else: @@ -207,6 +239,9 @@ def cohesion_score(G: nx.Graph, community_nodes: list[str]) -> float: if n <= 1: return 1.0 subgraph = G.subgraph(community_nodes) + # Project multigraphs to simple graph so parallel edges don't inflate cohesion + if isinstance(subgraph, (nx.MultiGraph, nx.MultiDiGraph)): + subgraph = project_for_community(subgraph) actual = subgraph.number_of_edges() possible = n * (n - 1) / 2 return actual / possible if possible > 0 else 0.0 diff --git a/graphify/dedup.py b/graphify/dedup.py index d5b82766c..34ea20423 100644 --- a/graphify/dedup.py +++ b/graphify/dedup.py @@ -3,11 +3,14 @@ Pipeline: exact normalization → entropy gate → MinHash/LSH blocking → Jaro-Winkler verification → same-community boost → union-find merge. """ + from __future__ import annotations +import json import math import re import unicodedata from collections import defaultdict +from typing import Any from datasketch import MinHash, MinHashLSH from rapidfuzz.distance import JaroWinkler @@ -15,6 +18,7 @@ # ── helpers ─────────────────────────────────────────────────────────────────── + def _norm(label: str) -> str: """Lowercase + collapse non-alphanumeric runs to space (Unicode-aware).""" label = unicodedata.normalize("NFKC", label) @@ -80,6 +84,7 @@ def _short_label_blocked(a: str, b: str, jw_score: float) -> bool: if max(len(a), len(b)) >= 12: return False from rapidfuzz.distance import DamerauLevenshtein + # Allow only same-length single-char substitutions (true typos like "Extractor"/"Extractar"). # Block length-differing pairs regardless of score. if jw_score >= 97.0 and len(a) == len(b) and DamerauLevenshtein.distance(a, b) <= 1: @@ -89,6 +94,7 @@ def _short_label_blocked(a: str, b: str, jw_score: float) -> bool: # ── union-find ──────────────────────────────────────────────────────────────── + class _UF: def __init__(self) -> None: self._parent: dict[str, str] = {} @@ -118,20 +124,22 @@ def components(self) -> dict[str, list[str]]: _ENTROPY_THRESHOLD = 2.5 _LSH_THRESHOLD = 0.7 -_MERGE_THRESHOLD = 92.0 # rapidfuzz normalized_similarity * 100 -_COMMUNITY_BOOST = 5.0 # score bonus when both nodes share community +_MERGE_THRESHOLD = 92.0 # rapidfuzz normalized_similarity * 100 +_COMMUNITY_BOOST = 5.0 # score bonus when both nodes share community _NUM_PERM = 128 _CHUNK_SUFFIX = re.compile(r"_c\d+$") # ── main entry point ────────────────────────────────────────────────────────── + def deduplicate_entities( nodes: list[dict], edges: list[dict], *, communities: dict[str, int], dedup_llm_backend: str | None = None, + diagnostics: dict | None = None, ) -> tuple[list[dict], list[dict]]: """Deduplicate near-identical entities in a knowledge graph. @@ -147,7 +155,7 @@ def deduplicate_entities( # Guard: cross-project dedup is not supported — nodes from different repos # share label names by coincidence and must never be merged by string similarity. # If you need to dedup a global graph, run deduplicate_entities per-repo first. - repos_seen = {n.get("repo") for n in nodes if n.get("repo")} + repos_seen = {str(repo) for n in nodes if (repo := n.get("repo"))} if len(repos_seen) > 1: raise ValueError( f"deduplicate_entities: nodes span multiple repos {sorted(repos_seen)!r}. " @@ -160,7 +168,7 @@ def deduplicate_entities( # Pre-deduplicate: keep first occurrence of each id seen_ids: dict[str, dict] = {} for node in nodes: - nid = node.get("id", "") + nid = str(node.get("id") or "") if nid and nid not in seen_ids: seen_ids[nid] = node unique_nodes = list(seen_ids.values()) @@ -190,7 +198,7 @@ def deduplicate_entities( if len(file_group) > 1: winner = _pick_winner(file_group) for node in file_group: - uf.union(winner["id"], node["id"]) + uf.union(str(winner["id"]), str(node["id"])) exact_merges += len(file_group) - 1 # ── pass 2: MinHash/LSH + Jaro-Winkler (high-entropy nodes only) ───────── @@ -211,18 +219,20 @@ def deduplicate_entities( for node in candidates: norm_label = _norm(node.get("label", node.get("id", ""))) m = _make_minhash(norm_label) - minhashes[node["id"]] = m + node_id = str(node["id"]) + minhashes[node_id] = m try: - lsh.insert(node["id"], m) + lsh.insert(node_id, m) except ValueError: pass # duplicate key in LSH — already inserted for node in candidates: - node_id = node["id"] + node_id = str(node["id"]) norm_label = _norm(node.get("label", node.get("id", ""))) - neighbors = lsh.query(minhashes[node_id]) + neighbors: list[Any] = lsh.query(minhashes[node_id]) for neighbor_id in neighbors: + neighbor_id = str(neighbor_id) if neighbor_id == node_id: continue if uf.find(node_id) == uf.find(neighbor_id): @@ -242,8 +252,12 @@ def deduplicate_entities( c1 = communities.get(node_id) c2 = communities.get(neighbor_id) - if (c1 is not None and c2 is not None and c1 == c2 - and min(len(norm_label), len(neighbor_norm)) >= 12): + if ( + c1 is not None + and c2 is not None + and c1 == c2 + and min(len(norm_label), len(neighbor_norm)) >= 12 + ): score += _COMMUNITY_BOOST if score >= _MERGE_THRESHOLD: @@ -256,11 +270,12 @@ def deduplicate_entities( sf_b = neighbor.get("source_file") or "" if sf_a != sf_b: continue - all_group = norm_to_nodes.get(norm_label, [node]) + \ - norm_to_nodes.get(neighbor_norm, [neighbor]) + all_group = norm_to_nodes.get(norm_label, [node]) + norm_to_nodes.get( + neighbor_norm, [neighbor] + ) winner = _pick_winner(all_group) - uf.union(winner["id"], node_id) - uf.union(winner["id"], neighbor_id) + uf.union(str(winner["id"]), node_id) + uf.union(str(winner["id"]), neighbor_id) fuzzy_merges += 1 # ── pass 3: LLM tiebreaker for ambiguous pairs (opt-in) ────────────────── @@ -283,6 +298,12 @@ def deduplicate_entities( # ── apply remap ─────────────────────────────────────────────────────────── if not remap: + if diagnostics is not None: + diagnostics["remap_self_loop_drops"] = 0 + diagnostics["remap_self_loop_drops_by_relation"] = {} + diagnostics["remap_self_loop_drops_by_source"] = {} + diagnostics["remap_exact_duplicate_collapses"] = 0 + diagnostics["remap_exact_duplicate_collapses_by_relation"] = {} return unique_nodes, edges total = len(remap) @@ -295,25 +316,57 @@ def deduplicate_entities( print(msg + ".", flush=True) deduped_nodes = [n for n in unique_nodes if n["id"] not in remap] - deduped_edges = [] + deduped_edges: list[dict] = [] + seen_fingerprints: set[str] = set() + self_loop_drops = 0 + self_loop_by_relation: dict[str, int] = defaultdict(int) + self_loop_by_source: dict[str, int] = defaultdict(int) + exact_dup_collapses = 0 + exact_dup_by_relation: dict[str, int] = defaultdict(int) + for edge in edges: e = dict(edge) - # Tolerate "from"/"to" keys from LLM backends that don't follow the - # schema exactly — build_from_json normalises later but dedup runs - # first so bracket access would KeyError here (#803). - # Use explicit key presence check (not `or`) so empty-string src/tgt - # aren't silently replaced by the fallback key. src = e["source"] if "source" in e else e.get("from") tgt = e["target"] if "target" in e else e.get("to") if src is None or tgt is None: continue e["source"] = remap.get(src, src) e["target"] = remap.get(tgt, tgt) - # Remove legacy keys so they don't leak into edge attrs in graph.json. e.pop("from", None) e.pop("to", None) - if e["source"] != e["target"]: - deduped_edges.append(e) + + relation = e.get("relation", "") + source_file = e.get("source_file", "") + + if e["source"] == e["target"] and src != tgt: + self_loop_drops += 1 + self_loop_by_relation[relation] += 1 + self_loop_by_source[source_file] += 1 + continue + + fingerprint = json.dumps(e, sort_keys=True, ensure_ascii=False, default=str) + if fingerprint in seen_fingerprints: + exact_dup_collapses += 1 + exact_dup_by_relation[relation] += 1 + continue + seen_fingerprints.add(fingerprint) + + deduped_edges.append(e) + + if diagnostics is not None: + diagnostics["remap_self_loop_drops"] = self_loop_drops + diagnostics["remap_self_loop_drops_by_relation"] = dict(self_loop_by_relation) + diagnostics["remap_self_loop_drops_by_source"] = dict(self_loop_by_source) + diagnostics["remap_exact_duplicate_collapses"] = exact_dup_collapses + diagnostics["remap_exact_duplicate_collapses_by_relation"] = dict(exact_dup_by_relation) + + if self_loop_drops or exact_dup_collapses: + parts = [] + if self_loop_drops: + parts.append(f"dropped {self_loop_drops} self-loop edge(s)") + if exact_dup_collapses: + parts.append(f"collapsed {exact_dup_collapses} exact-duplicate edge(s)") + print(f"[graphify] Remap: {'; '.join(parts)}.", flush=True) return deduped_nodes, deduped_edges @@ -343,12 +396,18 @@ def _llm_tiebreak( """Batch-resolve ambiguous pairs (score in [low, high)) via LLM.""" try: from graphify.llm import BACKENDS, _format_backend_env_keys, _get_backend_api_key + if backend not in BACKENDS: - print(f"[graphify] --dedup-llm: unknown backend {backend!r}, skipping LLM tiebreaker.", flush=True) + print( + f"[graphify] --dedup-llm: unknown backend {backend!r}, skipping LLM tiebreaker.", + flush=True, + ) return if not _get_backend_api_key(backend): env_keys = _format_backend_env_keys(backend) - print(f"[graphify] --dedup-llm: {env_keys} not set, skipping LLM tiebreaker.", flush=True) + print( + f"[graphify] --dedup-llm: {env_keys} not set, skipping LLM tiebreaker.", flush=True + ) return except ImportError: return @@ -368,8 +427,12 @@ def _llm_tiebreak( continue c1 = communities.get(node["id"]) c2 = communities.get(neighbor["id"]) - if (c1 is not None and c2 is not None and c1 == c2 - and min(len(norm_i), len(norm_j)) >= 12): + if ( + c1 is not None + and c2 is not None + and c1 == c2 + and min(len(norm_i), len(norm_j)) >= 12 + ): score += _COMMUNITY_BOOST if low <= score < high: ambiguous.append((node, neighbor, score)) @@ -392,8 +455,7 @@ def _llm_tiebreak( for batch_start in range(0, len(ambiguous), batch_size): batch = ambiguous[batch_start : batch_start + batch_size] pairs_text = "\n".join( - f"{i+1}. \"{a['label']}\" vs \"{b['label']}\"" - for i, (a, b, _) in enumerate(batch) + f'{i + 1}. "{a["label"]}" vs "{b["label"]}"' for i, (a, b, _) in enumerate(batch) ) prompt = ( "For each pair below, answer only 'yes' or 'no': are they the same real-world concept?\n\n" diff --git a/graphify/detect.py b/graphify/detect.py index 36a0a184f..3b9ed81a8 100644 --- a/graphify/detect.py +++ b/graphify/detect.py @@ -5,6 +5,7 @@ import os import re import shlex +import sys from enum import Enum from pathlib import Path @@ -25,23 +26,110 @@ class FileType(str, Enum): _MANIFEST_PATH = "graphify-out/manifest.json" -CODE_EXTENSIONS = {'.py', '.ts', '.tsx', '.js', '.jsx', '.mjs', '.ejs', '.ets', '.go', '.rs', '.java', '.groovy', '.gradle', '.cpp', '.cc', '.cxx', '.c', '.h', '.hpp', '.rb', '.swift', '.kt', '.kts', '.cs', '.scala', '.php', '.lua', '.luau', '.toc', '.zig', '.ps1', '.ex', '.exs', '.m', '.mm', '.jl', '.vue', '.svelte', '.astro', '.dart', '.v', '.sv', '.svh', '.sql', '.r', '.f', '.F', '.f90', '.F90', '.f95', '.F95', '.f03', '.F03', '.f08', '.F08', '.pas', '.pp', '.dpr', '.dpk', '.lpr', '.inc', '.dfm', '.lfm', '.lpk', '.sh', '.bash', '.json', '.dm', '.dme', '.dmi', '.dmm', '.dmf', '.sln', '.csproj', '.fsproj', '.vbproj', '.razor', '.cshtml'} -DOC_EXTENSIONS = {'.md', '.mdx', '.qmd', '.txt', '.rst', '.html', '.yaml', '.yml'} -PAPER_EXTENSIONS = {'.pdf'} -IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.webp', '.svg'} -OFFICE_EXTENSIONS = {'.docx', '.xlsx'} -VIDEO_EXTENSIONS = {'.mp4', '.mov', '.webm', '.mkv', '.avi', '.m4v', '.mp3', '.wav', '.m4a', '.ogg'} +CODE_EXTENSIONS = { + ".py", + ".ts", + ".tsx", + ".js", + ".jsx", + ".mjs", + ".ejs", + ".ets", + ".go", + ".rs", + ".java", + ".groovy", + ".gradle", + ".cpp", + ".cc", + ".cxx", + ".c", + ".h", + ".hpp", + ".rb", + ".swift", + ".kt", + ".kts", + ".cs", + ".scala", + ".php", + ".lua", + ".luau", + ".toc", + ".zig", + ".ps1", + ".ex", + ".exs", + ".m", + ".mm", + ".jl", + ".vue", + ".svelte", + ".astro", + ".dart", + ".v", + ".sv", + ".svh", + ".sql", + ".r", + ".f", + ".F", + ".f90", + ".F90", + ".f95", + ".F95", + ".f03", + ".F03", + ".f08", + ".F08", + ".pas", + ".pp", + ".dpr", + ".dpk", + ".lpr", + ".inc", + ".dfm", + ".lfm", + ".lpk", + ".sh", + ".bash", + ".json", + ".dm", + ".dme", + ".dmi", + ".dmm", + ".dmf", + ".sln", + ".csproj", + ".fsproj", + ".vbproj", + ".razor", + ".cshtml", +} +DOC_EXTENSIONS = {".md", ".mdx", ".qmd", ".txt", ".rst", ".html", ".yaml", ".yml"} +PAPER_EXTENSIONS = {".pdf"} +IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".svg"} +OFFICE_EXTENSIONS = {".docx", ".xlsx"} +VIDEO_EXTENSIONS = {".mp4", ".mov", ".webm", ".mkv", ".avi", ".m4v", ".mp3", ".wav", ".m4a", ".ogg"} -CORPUS_WARN_THRESHOLD = 50_000 # words - below this, warn "you may not need a graph" +CORPUS_WARN_THRESHOLD = 50_000 # words - below this, warn "you may not need a graph" CORPUS_UPPER_THRESHOLD = 500_000 # words - above this, warn about token cost -FILE_COUNT_UPPER = 500 # files - above this, warn about token cost +FILE_COUNT_UPPER = 500 # files - above this, warn about token cost # Parent directories whose contents are always sensitive. # Checked against path.parts[:-1] (parents only) so a root-level file named # "credentials" or "secrets" is not falsely flagged by this stage. -_SENSITIVE_DIRS = frozenset({ - ".ssh", ".gnupg", ".aws", ".gcloud", "secrets", ".secrets", "credentials", -}) +_SENSITIVE_DIRS = frozenset( + { + ".ssh", + ".gnupg", + ".aws", + ".gcloud", + "secrets", + ".secrets", + "credentials", + } +) # Files that may contain secrets - skip silently. # Uses lookarounds instead of \b so underscore-prefixed names like api_token.txt @@ -51,30 +139,33 @@ class FileType(str, Enum): # `token` is kept separate because its longer suffix "izer"/"ize" is the only # common false-positive; other keywords have no such well-known derivatives. _SENSITIVE_PATTERNS = [ - re.compile(r'(^|[\\/])\.(env|envrc)(\.|$)', re.IGNORECASE), - re.compile(r'\.(pem|key|p12|pfx|cert|crt|der|p8)$', re.IGNORECASE), - re.compile(r'(? bool: _SHEBANG_CODE_INTERPRETERS = { - "python", "python3", "python2", - "ruby", "perl", "node", "nodejs", - "bash", "sh", "dash", "zsh", "fish", "ksh", "tcsh", - "lua", "php", "julia", "Rscript", + "python", + "python3", + "python2", + "ruby", + "perl", + "node", + "nodejs", + "bash", + "sh", + "dash", + "zsh", + "fish", + "ksh", + "tcsh", + "lua", + "php", + "julia", + "Rscript", } @@ -152,7 +257,7 @@ def _env_command_args(args: list[str], *, allow_split: bool = True) -> list[str] arg = args[i] if arg == "--": - return args[i + 1:] + return args[i + 1 :] # Split-string forms: tokenize the packed payload, then re-parse it # as env args (so leading assignments/flags inside the payload are @@ -162,36 +267,36 @@ def _env_command_args(args: list[str], *, allow_split: bool = True) -> list[str] if i + 1 >= len(args): return [] return _env_command_args( - _split_env_s(" ".join(args[i + 1:]), []), + _split_env_s(" ".join(args[i + 1 :]), []), allow_split=False, ) if arg.startswith("-S") and len(arg) > 2: return _env_command_args( - _split_env_s(arg[2:], args[i + 1:]), + _split_env_s(arg[2:], args[i + 1 :]), allow_split=False, ) if arg == "-vS": if i + 1 >= len(args): return [] return _env_command_args( - _split_env_s(" ".join(args[i + 1:]), []), + _split_env_s(" ".join(args[i + 1 :]), []), allow_split=False, ) if arg.startswith("-vS") and len(arg) > 3: return _env_command_args( - _split_env_s(arg[3:], args[i + 1:]), + _split_env_s(arg[3:], args[i + 1 :]), allow_split=False, ) if arg.startswith("--split-string="): return _env_command_args( - _split_env_s(arg.split("=", 1)[1], args[i + 1:]), + _split_env_s(arg.split("=", 1)[1], args[i + 1 :]), allow_split=False, ) if arg == "--split-string": if i + 1 >= len(args): return [] return _env_command_args( - _split_env_s(args[i + 1], args[i + 2:]), + _split_env_s(args[i + 1], args[i + 2 :]), allow_split=False, ) @@ -203,11 +308,7 @@ def _env_command_args(args: list[str], *, allow_split: bool = True) -> list[str] continue # Clumped short option + operand - if ( - arg.startswith(("-u", "-C", "-P", "-a")) - and len(arg) > 2 - and not arg.startswith("--") - ): + if arg.startswith(("-u", "-C", "-P", "-a")) and len(arg) > 2 and not arg.startswith("--"): i += 1 continue @@ -217,8 +318,16 @@ def _env_command_args(args: list[str], *, allow_split: bool = True) -> list[str] continue # No-operand flags - if arg in {"-", "-i", "-0", "-v", "--ignore-environment", "--null", - "--debug", "--list-signal-handling"}: + if arg in { + "-", + "-i", + "-0", + "-v", + "--ignore-environment", + "--null", + "--debug", + "--list-signal-handling", + }: i += 1 continue @@ -320,6 +429,7 @@ def extract_pdf_text(path: Path) -> str: """Extract plain text from a PDF file using pypdf.""" try: from pypdf import PdfReader + reader = PdfReader(str(path)) pages = [] for page in reader.pages: @@ -335,11 +445,11 @@ def docx_to_markdown(path: Path) -> str: """Convert a .docx file to markdown text using python-docx.""" try: from docx import Document - from docx.oxml.ns import qn + doc = Document(str(path)) lines = [] for para in doc.paragraphs: - style = para.style.name if para.style else "" + style = str(para.style.name or "") if para.style else "" text = para.text.strip() if not text: lines.append("") @@ -375,6 +485,7 @@ def xlsx_to_markdown(path: Path) -> str: """Convert an .xlsx file to markdown text using openpyxl.""" try: import openpyxl + wb = openpyxl.load_workbook(str(path), read_only=True, data_only=True) sections = [] for sheet_name in wb.sheetnames: @@ -407,6 +518,7 @@ def xlsx_extract_structure(path: Path) -> dict: Returns a nodes/edges dict compatible with the graphify extract pipeline. Used in addition to xlsx_to_markdown so Claude sees both structure and content. """ + def _nid(*parts: str) -> str: return re.sub(r"[^a-z0-9_]", "_", "_".join(p.lower() for p in parts).strip("_")) @@ -428,21 +540,43 @@ def _nid(*parts: str) -> str: stem = re.sub(r"[^a-z0-9]", "_", path.stem.lower()) str_path = str(path) file_nid = _nid(str_path) - nodes: list[dict] = [{"id": file_nid, "label": path.name, "file_type": "document", - "source_file": str_path, "source_location": None}] + nodes: list[dict] = [ + { + "id": file_nid, + "label": path.name, + "file_type": "document", + "source_file": str_path, + "source_location": None, + } + ] edges: list[dict] = [] seen: set[str] = {file_nid} def _add(nid: str, label: str) -> None: if nid not in seen: seen.add(nid) - nodes.append({"id": nid, "label": label, "file_type": "document", - "source_file": str_path, "source_location": None}) + nodes.append( + { + "id": nid, + "label": label, + "file_type": "document", + "source_file": str_path, + "source_location": None, + } + ) def _edge(src: str, tgt: str, relation: str) -> None: - edges.append({"source": src, "target": tgt, "relation": relation, - "confidence": "EXTRACTED", "source_file": str_path, - "source_location": None, "weight": 1.0}) + edges.append( + { + "source": src, + "target": tgt, + "relation": relation, + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": None, + "weight": 1.0, + } + ) for sheet_name in wb.sheetnames: ws = wb[sheet_name] @@ -461,18 +595,28 @@ def _edge(src: str, tgt: str, relation: str) -> None: if ref: try: from openpyxl.utils import range_boundaries + min_col, min_row, max_col, _ = range_boundaries(ref) - header_row = list(ws.iter_rows(min_row=min_row, max_row=min_row, - min_col=min_col, max_col=max_col, - values_only=True)) + header_row = list( + ws.iter_rows( + min_row=min_row, + max_row=min_row, + min_col=min_col, + max_col=max_col, + values_only=True, + ) + ) if header_row: for col_name in header_row[0]: if col_name: col_nid = _nid(stem, tbl.name, str(col_name)) _add(col_nid, str(col_name)) _edge(tbl_nid, col_nid, "contains") - except Exception: - pass + except Exception as exc: + print( + f"[graphify] warning: could not read spreadsheet table columns: {exc}", + file=sys.stderr, + ) else: # Fallback: first non-empty row as column headers for row in ws.iter_rows(max_row=1, values_only=True): @@ -485,8 +629,8 @@ def _edge(src: str, tgt: str, relation: str) -> None: try: wb.close() - except Exception: - pass + except Exception as exc: + print(f"[graphify] warning: could not close workbook {path}: {exc}", file=sys.stderr) return {"nodes": nodes, "edges": edges} @@ -511,6 +655,7 @@ def convert_office_file(path: Path, out_dir: Path) -> Path | None: out_dir.mkdir(parents=True, exist_ok=True) # Use a stable name derived from the original path to avoid collisions import hashlib + name_hash = hashlib.sha256(str(path.resolve()).encode()).hexdigest()[:8] out_path = out_dir / f"{path.stem}_{name_hash}.md" out_path.write_text( @@ -536,33 +681,64 @@ def count_words(path: Path) -> int: # Directory names to always skip - venvs, caches, build artifacts, deps _SKIP_DIRS = { - "venv", ".venv", "env", ".env", - "node_modules", "__pycache__", ".git", - "dist", "build", "target", "out", - "site-packages", "lib64", - ".pytest_cache", ".mypy_cache", ".ruff_cache", - ".tox", ".eggs", "*.egg-info", + "venv", + ".venv", + "env", + ".env", + "node_modules", + "__pycache__", + ".git", + "dist", + "build", + "target", + "out", + "site-packages", + "lib64", + ".pytest_cache", + ".mypy_cache", + ".ruff_cache", + ".tox", + ".eggs", + "*.egg-info", "graphify-out", # never treat own output as source input (#524) # Coverage/test-artefact dirs — generated, never architecturally meaningful - "coverage", "lcov-report", # Vitest/Istanbul/nyc HTML reports (#870) - "visual-tests", "visual-test", # Playwright/visual-regression bundles (#869) - "__snapshots__", "snapshots", # Jest/Vitest snapshot dirs - "storybook-static", # Storybook production build output - "dist-protected", # Protected dist variants (same noise as dist) + "coverage", + "lcov-report", # Vitest/Istanbul/nyc HTML reports (#870) + "visual-tests", + "visual-test", # Playwright/visual-regression bundles (#869) + "__snapshots__", + "snapshots", # Jest/Vitest snapshot dirs + "storybook-static", # Storybook production build output + "dist-protected", # Protected dist variants (same noise as dist) # Framework cache/build dirs — generated, never architecturally meaningful (#873) - ".next", ".nuxt", ".turbo", ".angular", - ".idea", ".cache", ".parcel-cache", ".svelte-kit", ".terraform", ".serverless", + ".next", + ".nuxt", + ".turbo", + ".angular", + ".idea", + ".cache", + ".parcel-cache", + ".svelte-kit", + ".terraform", + ".serverless", ".graphify", # graphify's own extraction cache — never index self-generated data ".worktrees", # git worktree convention (#947) — sibling checkouts, always redundant } # Large generated files that are never useful to extract _SKIP_FILES = { - "package-lock.json", "yarn.lock", "pnpm-lock.yaml", - "Cargo.lock", "poetry.lock", "Gemfile.lock", - "composer.lock", "go.sum", "go.work.sum", + "package-lock.json", + "yarn.lock", + "pnpm-lock.yaml", + "Cargo.lock", + "poetry.lock", + "Gemfile.lock", + "composer.lock", + "go.sum", + "go.work.sum", } + def _is_noise_dir(part: str, parent: "Path | None" = None) -> bool: """Return True if this directory name looks like a venv, cache, or dep dir.""" if part in _SKIP_DIRS: @@ -671,6 +847,7 @@ def _is_ignored(path: Path, root: Path, patterns: list[tuple[Path, str]]) -> boo def _eval(target: Path) -> bool: """Apply last-match-wins to a single target path.""" + def _matches(rel: str, p: str, anchored: bool) -> bool: if anchored: return fnmatch.fnmatch(rel, p) @@ -682,7 +859,7 @@ def _matches(rel: str, p: str, anchored: bool) -> bool: for i, part in enumerate(parts): if fnmatch.fnmatch(part, p): return True - if fnmatch.fnmatch("/".join(parts[:i + 1]), p): + if fnmatch.fnmatch("/".join(parts[: i + 1]), p): return True return False @@ -782,7 +959,7 @@ def _matches(rel: str, p: str, anchored: bool) -> bool: for i, part in enumerate(parts): if fnmatch.fnmatch(part, p): return True - if fnmatch.fnmatch("/".join(parts[:i + 1]), p): + if fnmatch.fnmatch("/".join(parts[: i + 1]), p): return True return False @@ -866,7 +1043,13 @@ def _auto_follow_symlinks(root: Path) -> bool: return False -def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: bool | None = None, extra_excludes: list[str] | None = None) -> dict: +def detect( + root: Path, + *, + follow_symlinks: bool | None = None, + google_workspace: bool | None = None, + extra_excludes: list[str] | None = None, +) -> dict: root = root.resolve() if follow_symlinks is None: follow_symlinks = _auto_follow_symlinks(root) @@ -889,8 +1072,6 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: line = _parse_gitignore_line(pat) if line: ignore_patterns.append((root, line)) - include_patterns = _load_graphifyinclude(root) - # Always include graphify-out/memory/ - query results filed back into the graph memory_dir = root / "graphify-out" / "memory" scan_paths = [root] @@ -918,7 +1099,8 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: # pruning so negated files inside can still be reached. has_negation = any(p.startswith("!") for _, p in ignore_patterns) dirnames[:] = [ - d for d in dirnames + d + for d in dirnames if not _is_noise_dir(d, dp) and (has_negation or not _is_ignored(dp / d, root, ignore_patterns)) ] @@ -949,13 +1131,14 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: if p.suffix.lower() in GOOGLE_WORKSPACE_EXTENSIONS: if not google_workspace: skipped_sensitive.append( - str(p) - + " [Google Workspace shortcut skipped - pass --google-workspace " + str(p) + " [Google Workspace shortcut skipped - pass --google-workspace " "or set GRAPHIFY_GOOGLE_WORKSPACE=1]" ) continue try: - md_path = convert_google_workspace_file(p, converted_dir, xlsx_to_markdown=xlsx_to_markdown) + md_path = convert_google_workspace_file( + p, converted_dir, xlsx_to_markdown=xlsx_to_markdown + ) except Exception as exc: skipped_sensitive.append(str(p) + f" [Google Workspace export failed: {exc}]") continue @@ -965,7 +1148,9 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: files[ftype].append(str(md_path)) total_words += count_words(md_path) else: - skipped_sensitive.append(str(p) + " [Google Workspace export produced no readable text]") + skipped_sensitive.append( + str(p) + " [Google Workspace export produced no readable text]" + ) continue # Office files: convert to markdown sidecar so subagents can read them if p.suffix.lower() in OFFICE_EXTENSIONS: @@ -977,7 +1162,9 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: total_words += count_words(md_path) else: # Conversion failed (library not installed) - skip with note - skipped_sensitive.append(str(p) + " [office conversion failed - pip install graphifyy[office]]") + skipped_sensitive.append( + str(p) + " [office conversion failed - pip install graphifyy[office]]" + ) continue files[ftype].append(str(p)) if ftype != FileType.VIDEO: @@ -1015,6 +1202,7 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: def _md5_file(path: Path) -> str: """MD5 of file contents streamed in 64KB chunks — for change detection only.""" import hashlib as _hl + h = _hl.md5(usedforsecurity=False) try: with path.open("rb") as f: @@ -1092,7 +1280,9 @@ def _normalise_entry(entry): entry["semantic_hash"] = h else: # Preserve semantic_hash only when content is unchanged - entry["semantic_hash"] = prev.get("semantic_hash", "") if h == prev.get("ast_hash", "") else "" + entry["semantic_hash"] = ( + prev.get("semantic_hash", "") if h == prev.get("ast_hash", "") else "" + ) manifest[f] = entry Path(manifest_path).parent.mkdir(parents=True, exist_ok=True) Path(manifest_path).write_text(json.dumps(manifest, indent=2), encoding="utf-8") @@ -1130,7 +1320,12 @@ def detect_incremental( incremental runs. ``None`` (default) means auto-detect: ``True`` when ``root`` contains at least one direct symlinked child, ``False`` otherwise. """ - full = detect(root, follow_symlinks=follow_symlinks, google_workspace=google_workspace, extra_excludes=extra_excludes) + full = detect( + root, + follow_symlinks=follow_symlinks, + google_workspace=google_workspace, + extra_excludes=extra_excludes, + ) manifest = load_manifest(manifest_path) if not manifest: @@ -1158,7 +1353,11 @@ def detect_incremental( elif isinstance(stored, dict): # Normalise legacy {mtime, hash} to new schema if "hash" in stored and "ast_hash" not in stored: - stored = {"mtime": stored.get("mtime", 0), "ast_hash": stored["hash"], "semantic_hash": ""} + stored = { + "mtime": stored.get("mtime", 0), + "ast_hash": stored["hash"], + "semantic_hash": "", + } hash_key = "semantic_hash" if kind == "semantic" else "ast_hash" stored_hash = stored.get(hash_key, "") # Missing semantic_hash means update ran but extract hasn't — always re-extract diff --git a/graphify/edge_identity.py b/graphify/edge_identity.py new file mode 100644 index 000000000..cfdece534 --- /dev/null +++ b/graphify/edge_identity.py @@ -0,0 +1,91 @@ +"""Stable edge identity helpers and schema constants for MultiDiGraph support. + +The node-link ``"key"`` field is reserved schema — it identifies a parallel edge +and must never be stored as an ordinary edge attribute. All callers that build or +load graphs should use :func:`strip_schema_key` before passing attrs to +``G.add_edge`` so the ``key`` kwarg is never duplicated. +""" + +from __future__ import annotations + +import hashlib +import json as _json +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import networkx as nx + +SCHEMA_KEY_FIELD = "key" + + +def make_stable_key( + relation: str | None, + source_file: str | None, + source_location: str | None, +) -> str: + """Return a collision-safe deterministic edge key from semantic identity fields. + + Uses SHA-256 over a canonical JSON payload with explicit field names so that + delimiter characters in field values cannot produce false collisions. The key + format is ``"edge:v1:"``. + + Two edges with the same relation, file, and location always produce the same + key; any difference in those three fields produces a different key. + """ + payload = _json.dumps( + { + "relation": relation, + "source_file": source_file, + "source_location": source_location, + }, + sort_keys=True, + ) + digest = hashlib.sha256(payload.encode()).hexdigest() + return f"edge:v1:{digest}" + + +def strip_schema_key(attrs: dict) -> tuple[object | None, dict]: + """Separate the ``"key"`` schema field from edge attribute kwargs. + + Returns ``(key_value, cleaned_attrs)`` where ``cleaned_attrs`` is a new dict + with ``SCHEMA_KEY_FIELD`` removed. The original *attrs* dict is not mutated. + + The return type is ``object | None`` rather than ``str | None`` because the + field may carry any JSON-decodable value at this layer; callers narrow to + ``str`` after explicit validation (see the multigraph loader/build paths). + + Use before ``G.add_edge(u, v, key=key_value, **cleaned_attrs)`` to avoid + passing ``key`` twice (once as the positional schema arg and once inside attrs). + """ + key_val = attrs.get(SCHEMA_KEY_FIELD) + cleaned = {k: v for k, v in attrs.items() if k != SCHEMA_KEY_FIELD} + return key_val, cleaned + + +def remove_all_parallel_edges( + G: "nx.Graph", + u: object, + v: object, +) -> int: + """Remove ALL edges between u and v, regardless of key count. + + On MultiDiGraph, ``G.remove_edge(u, v)`` removes only one edge (first key). + This helper explicitly iterates keys to remove all parallel edges. + + Returns the number of edges removed. Does not raise if no edges exist + between u and v (returns 0). + """ + import networkx as nx + + if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)): + if not G.has_node(u) or not G.has_node(v): + return 0 + keys = list(G[u][v].keys()) if G.has_edge(u, v) else [] + for key in keys: + G.remove_edge(u, v, key=key) + return len(keys) + else: + if G.has_edge(u, v): + G.remove_edge(u, v) + return 1 + return 0 diff --git a/graphify/export.py b/graphify/export.py index ff127c0b2..e8344c698 100644 --- a/graphify/export.py +++ b/graphify/export.py @@ -7,14 +7,23 @@ import os import re import shutil +import sys from collections import Counter from datetime import date from pathlib import Path +from typing import Any, cast import networkx as nx from networkx.readwrite import json_graph from graphify.security import sanitize_label from graphify.analyze import _node_community_map -from graphify.build import edge_data +from graphify.build import edge_data, edge_datas +from graphify.edge_identity import make_stable_key +from graphify.graph_loader import GRAPHIFY_PROFILE_KEY +from graphify.projections import ( + DEFAULT_RELATIONSHIP_CAP, + format_relationship_envelope, + relationship_envelope, +) # Artifacts worth preserving across rebuilds (non-regenerable without LLM or curation). @@ -54,13 +63,18 @@ def backup_if_protected(out_dir: Path) -> "Path | None": try: labels = json.loads(labels_file.read_text(encoding="utf-8")) is_curated = any(v != f"Community {k}" for k, v in labels.items()) - except Exception: - pass + except Exception as exc: + print( + f"[graphify] warning: could not read community labels for backup check: {exc}", + file=sys.stderr, + ) if not is_semantic and not is_curated: return None - reason = "+".join(filter(None, ["semantic" if is_semantic else "", "curated" if is_curated else ""])) + reason = "+".join( + filter(None, ["semantic" if is_semantic else "", "curated" if is_curated else ""]) + ) today = date.today().isoformat() backup_dir = out / today graph_src = out / "graph.json" @@ -83,16 +97,19 @@ def backup_if_protected(out_dir: Path) -> "Path | None": try: shutil.copy2(src, backup_dir / name) copied += 1 - except Exception: - pass + except Exception as exc: + print(f"[graphify] warning: could not back up {src}: {exc}", file=sys.stderr) if copied: print(f"[graphify] backed up {reason} graph ({copied} files) -> {backup_dir.name}/") return backup_dir except Exception as exc: - import sys - print(f"[graphify] warning: backup failed ({exc}) - continuing with overwrite", file=sys.stderr) + print( + f"[graphify] warning: backup failed ({exc}) - continuing with overwrite", + file=sys.stderr, + ) return None + def _obsidian_tag(name: str) -> str: """Sanitize a community name for use as an Obsidian tag. @@ -104,6 +121,7 @@ def _obsidian_tag(name: str) -> str: def _strip_diacritics(text: str) -> str: import unicodedata + nfkd = unicodedata.normalize("NFKD", text) return "".join(c for c in nfkd if not unicodedata.combining(c)) @@ -147,8 +165,16 @@ def _yaml_str(s: str) -> str: COMMUNITY_COLORS = [ - "#4E79A7", "#F28E2B", "#E15759", "#76B7B2", "#59A14F", - "#EDC948", "#B07AA1", "#FF9DA7", "#9C755F", "#BAB0AC", + "#4E79A7", + "#F28E2B", + "#E15759", + "#76B7B2", + "#59A14F", + "#EDC948", + "#B07AA1", + "#FF9DA7", + "#9C755F", + "#BAB0AC", ] MAX_NODES_FOR_VIZ = 5_000 @@ -161,6 +187,7 @@ def _viz_node_limit() -> int: Set to 0 to disable HTML viz unconditionally (useful for CI runners). """ import os + raw = os.environ.get("GRAPHIFY_VIZ_NODE_LIMIT") if raw is None or not raw.strip(): return MAX_NODES_FOR_VIZ @@ -472,41 +499,133 @@ def attach_hyperedges(G: nx.Graph, hyperedges: list) -> None: def _git_head() -> str | None: """Return the current git HEAD commit hash, or None if not in a git repo.""" import subprocess as _sp + try: - r = _sp.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=3) + r = _sp.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=3) # nosec B603 B607 return r.stdout.strip() if r.returncode == 0 else None except Exception: return None -def to_json(G: nx.Graph, communities: dict[int, list[str]], output_path: str, *, force: bool = False, built_at_commit: str | None = None) -> bool: +def _graph_type_for_instance(G: nx.Graph) -> str: + """Return the graphify ``graph_type`` token for a live NetworkX instance. + + The instance is authoritative: we classify from ``is_multigraph()`` / + ``is_directed()`` rather than from any stored profile, mirroring the + ``multigraph``/``directed`` flag logic in :func:`graphify.graph_loader.load_graph`. + The vocabulary is kept byte-identical to the loader's + :func:`~graphify.graph_loader._set_graph_profile` (``"simple"`` / + ``"digraph"`` / ``"multidigraph"``) so a save/load round-trip is stable. + + graphify only ever produces directed multigraphs (``MultiDiGraph``), and the + loader normalizes any ``multigraph: true`` payload to ``MultiDiGraph``, so an + undirected ``MultiGraph`` instance is still labelled ``"multidigraph"`` for + consistency with what a reload would reconstruct. + """ + if G.is_multigraph(): + return "multidigraph" + if G.is_directed(): + return "digraph" + return "simple" + + +def _ensure_graph_profile(G: nx.Graph) -> None: + """Stamp ``G.graph[GRAPHIFY_PROFILE_KEY]`` so the profile persists in graph.json. + + A freshly *built* graph (from :func:`graphify.build.build_from_json`) has no + ``graphify_profile`` — that key is only set on *load*. Without it the saved + JSON would not carry the simple-vs-multidigraph profile that downstream PR 7 + cache-invalidation / watch profile-mismatch detection relies on. + + Existing profile fields (e.g. from a loaded graph) are preserved, but + ``graph_type`` is always overwritten to match the actual instance — the + instance is authoritative, so a stale serialized ``graph_type`` can never + mislabel the graph we are about to write. This mirrors the overwrite in + :func:`graphify.graph_loader._set_graph_profile`. + """ + existing = G.graph.get(GRAPHIFY_PROFILE_KEY) + profile = dict(existing) if isinstance(existing, dict) else {} + profile["graph_type"] = _graph_type_for_instance(G) + G.graph[GRAPHIFY_PROFILE_KEY] = profile + + +def to_json( + G: nx.Graph, + communities: dict[int, list[str]], + output_path: str, + *, + force: bool = False, + built_at_commit: str | None = None, +) -> bool: # Safety check: refuse to silently shrink an existing graph (#479) existing_path = Path(output_path) + # Empty-merge floor (RISK 4): refuse to overwrite a populated graph.json + # (>0 nodes) with an EMPTY (0-node) graph. A 0-node write over a populated + # graph is a failed/aborted extraction, never a real result, so this floor + # engages REGARDLESS of force — no legitimate caller writes 0 nodes over a + # populated graph, and force=True is exactly the bug enabler. Read the + # existing node count defensively: any error (missing/corrupt file) is + # treated as 0 nodes so a corrupt existing file cannot crash the write + # (the floor then stays inert, which is acceptable — there is no verified + # populated graph to protect). + if existing_path.exists() and G.number_of_nodes() == 0: + try: + existing_data = json.loads(existing_path.read_text(encoding="utf-8")) + existing_n = len(existing_data.get("nodes", [])) + except Exception: + existing_n = 0 + if existing_n > 0: + print( + f"[graphify] ERROR: refusing to overwrite a populated graph.json " + f"({existing_n} nodes) with an EMPTY (0-node) graph - this is a " + f"failed/aborted extraction, not a real result. The previous " + f"graph is preserved.", + file=sys.stderr, + ) + return False if not force and existing_path.exists(): try: from graphify.security import check_graph_file_size_cap + check_graph_file_size_cap(existing_path) existing_data = json.loads(existing_path.read_text(encoding="utf-8")) existing_n = len(existing_data.get("nodes", [])) new_n = G.number_of_nodes() if new_n < existing_n: - import sys as _sys print( f"[graphify] WARNING: new graph has {new_n} nodes but existing " f"graph.json has {existing_n}. Refusing to overwrite — you may be " f"missing chunk files from a previous session. " f"Pass force=True to override.", - file=_sys.stderr, + file=sys.stderr, ) return False - except Exception: - pass # unreadable existing file — proceed with write + except Exception as exc: + print( + f"[graphify] warning: could not inspect existing graph before write: {exc}", + file=sys.stderr, + ) + + # Persist the graph profile so a later load can detect a simple-vs- + # multidigraph mismatch (PR 7 cache invalidation / watch). The profile is + # derived from the live instance and written onto G.graph, which + # node_link_data surfaces under the top-level "graph" key. + _ensure_graph_profile(G) node_community = _node_community_map(communities) try: data = json_graph.node_link_data(G, edges="links") except TypeError: data = json_graph.node_link_data(G) + # Defensively guarantee the profile is present under data["graph"] even if a + # NetworkX build did not surface G.graph (it normally does). The NetworkX + # "multigraph"/"directed" boolean flags are emitted by node_link_data itself. + graph_meta = data.get("graph") + if not isinstance(graph_meta, dict): + graph_meta = {} + data["graph"] = graph_meta + if GRAPHIFY_PROFILE_KEY not in graph_meta: + graph_meta[GRAPHIFY_PROFILE_KEY] = dict(G.graph[GRAPHIFY_PROFILE_KEY]) for node in data["nodes"]: node["community"] = node_community.get(node["id"]) node["norm_label"] = _strip_diacritics(node.get("label", "")).lower() @@ -541,8 +660,7 @@ def prune_dangling_edges(graph_data: dict) -> tuple[dict, int]: links_key = "links" if "links" in graph_data else "edges" before = len(graph_data[links_key]) graph_data[links_key] = [ - e for e in graph_data[links_key] - if e["source"] in node_ids and e["target"] in node_ids + e for e in graph_data[links_key] if e["source"] in node_ids and e["target"] in node_ids ] return graph_data, before - len(graph_data[links_key]) @@ -566,12 +684,7 @@ def _cypher_escape(s: str) -> str: """ # First normalise: drop NUL and other C0 control chars except tab. s = "".join(ch for ch in s if ch >= " " or ch == "\t") - return ( - s.replace("\\", "\\\\") - .replace("'", "\\'") - .replace("\n", "\\n") - .replace("\r", "\\r") - ) + return s.replace("\\", "\\\\").replace("'", "\\'").replace("\n", "\\n").replace("\r", "\\r") # Restrict identifier-position values (labels and relationship types are NOT @@ -592,6 +705,69 @@ def _cypher_label(raw: str, fallback: str) -> str: return cleaned +def _edge_distinguishing_key(data: dict, explicit_key: object | None = None) -> str: + """Return a stable per-edge key that distinguishes parallel edges. + + MultiDiGraph keyed edges carry their key as the positional ``key`` of + ``G.edges(keys=True, data=True)`` rather than inside the attribute dict, so + callers that already hold the positional key pass it as ``explicit_key``. + NetworkX guarantees that positional key is UNIQUE within a ``(u, v)`` pair — + which is exactly the scope Neo4j MERGE deduplicates over — and it may be an + INTEGER (0, 1, 2…) when no explicit string key was set. We therefore accept + any non-None positional key and stringify it; narrowing to ``str`` would + silently drop integer keys and let two parallel edges with identical + (relation, source_file, source_location) collapse to the same edge_key. + + When no positional key is available (simple graphs — one edge per pair, or a + stray ``key`` left in attrs), derive a deterministic ``edge:v1:`` key + from the edge's semantic identity fields via :func:`make_stable_key`. + """ + if explicit_key is not None: + # int or str positional key — unique per (u, v), which is the MERGE scope. + return str(explicit_key) + in_attrs = data.get("key") + if isinstance(in_attrs, str) and in_attrs: + return in_attrs + return make_stable_key( + data.get("relation"), + data.get("source_file"), + data.get("source_location"), + ) + + +def _canvas_edge_id( + source: object, + target: object, + suffix: object, + used_ids: set[str], +) -> str: + """Return a deterministic, globally unique Canvas edge id. + + The readable legacy shape, ``e_{source}_{target}_{suffix}``, can collide when + node ids themselves contain underscores (``a_b -> c`` vs ``a -> b_c``). Keep + that readable id when it is unique, but fall back to a short digest of the + structured tuple when a collision is detected. + """ + readable = f"e_{source}_{target}_{suffix}" + if readable not in used_ids: + used_ids.add(readable) + return readable + + payload = json.dumps( + [str(source), str(target), str(suffix)], + ensure_ascii=True, + separators=(",", ":"), + ) + digest = hashlib.sha256(payload.encode("utf-8")).hexdigest()[:12] + candidate = f"{readable}_{digest}" + counter = 1 + while candidate in used_ids: + counter += 1 + candidate = f"{readable}_{digest}_{counter}" + used_ids.add(candidate) + return candidate + + def to_cypher(G: nx.Graph, output_path: str) -> None: lines = ["// Neo4j Cypher import - generated by /graphify", ""] for node_id, data in G.nodes(data=True): @@ -603,17 +779,33 @@ def to_cypher(G: nx.Graph, output_path: str) -> None: ) lines.append(f"MERGE (n:{ftype} {{id: '{node_id_esc}', label: '{label}'}});") lines.append("") - for u, v, data in G.edges(data=True): + # Preserve EVERY parallel edge (PR 6 go/no-go gate). Neo4j MERGE deduplicates + # on the relationship pattern, so two parallel edges between the same (a, b) + # with the same relation type would collapse to one unless we give each a + # distinguishing property inside the MERGE pattern. We emit a stable + # `edge_key` (the MultiDiGraph positional key when present, else a derived + # make_stable_key) so distinct keys -> distinct relationships. For simple + # graphs this adds one `edge_key` property to the existing single MERGE per + # edge — required for correctness, harmless for re-runs (MERGE is idempotent + # on the now-richer pattern). All values flow through `_cypher_escape`. + is_multi = isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)) + edge_iter = ( + G.edges(keys=True, data=True) + if is_multi + else ((u, v, None, data) for u, v, data in G.edges(data=True)) + ) + for u, v, ekey, data in edge_iter: rel = _cypher_label( (data.get("relation", "RELATES_TO") or "RELATES_TO").upper(), "RELATES_TO", ) conf = _cypher_escape(data.get("confidence", "EXTRACTED")) + edge_key = _cypher_escape(_edge_distinguishing_key(data, ekey)) u_esc = _cypher_escape(u) v_esc = _cypher_escape(v) lines.append( f"MATCH (a {{id: '{u_esc}'}}), (b {{id: '{v_esc}'}}) " - f"MERGE (a)-[:{rel} {{confidence: '{conf}'}}]->(b);" + f"MERGE (a)-[:{rel} {{edge_key: '{edge_key}', confidence: '{conf}'}}]->(b);" ) with open(output_path, "w", encoding="utf-8") as f: # nosec f.write("\n".join(lines)) @@ -645,8 +837,13 @@ def to_html( # Build aggregated community meta-graph from collections import Counter as _Counter import networkx as _nx - print(f"Graph has {G.number_of_nodes()} nodes (above {limit} limit). Building aggregated community view...") - node_to_community = {nid: cid for cid, members in communities.items() for nid in members} + + print( + f"Graph has {G.number_of_nodes()} nodes (above {limit} limit). Building aggregated community view..." + ) + node_to_community = { + nid: cid for cid, members in communities.items() for nid in members + } meta = _nx.Graph() for cid, members in communities.items(): meta.add_node(str(cid), label=(community_labels or {}).get(cid, f"Community {cid}")) @@ -656,8 +853,13 @@ def to_html( if cu is not None and cv is not None and cu != cv: edge_counts[(min(cu, cv), max(cu, cv))] += 1 for (cu, cv), w in edge_counts.items(): - meta.add_edge(str(cu), str(cv), weight=w, - relation=f"{w} cross-community edges", confidence="AGGREGATED") + meta.add_edge( + str(cu), + str(cv), + weight=w, + relation=f"{w} cross-community edges", + confidence="AGGREGATED", + ) if meta.number_of_nodes() <= 1: print("Single community - aggregated view not useful. Skipping graph.html.") return @@ -666,10 +868,11 @@ def to_html( # Remap hyperedges from semantic node IDs to community IDs raw_hyperedges = G.graph.get("hyperedges", []) if raw_hyperedges: - remapped = [] + remapped: list[dict[str, Any]] = [] for he in raw_hyperedges: he_members = he.get("nodes") or he.get("members") or [] - comm_ids, seen = [], set() + comm_ids: list[str] = [] + seen: set[str] = set() for nid in he_members: c = node_to_community.get(nid) if c is None: @@ -681,15 +884,24 @@ def to_html( comm_ids.append(s) if len(comm_ids) < 2: continue - remapped.append({ - "id": he.get("id", ""), - "label": he.get("label") or he.get("relation", "").replace("_", " "), - "nodes": comm_ids, - }) + remapped.append( + { + "id": he.get("id", ""), + "label": he.get("label") or he.get("relation", "").replace("_", " "), + "nodes": comm_ids, + } + ) meta.graph["hyperedges"] = remapped - to_html(meta, meta_communities, output_path, - community_labels=community_labels, member_counts=mc) - print(f"graph.html written (aggregated: {meta.number_of_nodes()} community nodes, {meta.number_of_edges()} cross-community edges)") + to_html( + meta, + meta_communities, + output_path, + community_labels=community_labels, + member_counts=mc, + ) + print( + f"graph.html written (aggregated: {meta.number_of_nodes()} community nodes, {meta.number_of_edges()} cross-community edges)" + ) print("Tip: run with --obsidian for full node-level detail.") return raise ValueError( @@ -718,47 +930,93 @@ def to_html( size = 10 + 30 * (deg / max_deg) # Only show label for high-degree nodes by default; others show on hover font_size = 12 if deg >= max_deg * 0.15 else 0 - vis_nodes.append({ - "id": node_id, - "label": label, - "color": {"background": color, "border": color, "highlight": {"background": "#ffffff", "border": color}}, - "size": round(size, 1), - "font": {"size": font_size, "color": "#ffffff"}, - "title": _html.escape(label), - "community": cid, - "community_name": sanitize_label((community_labels or {}).get(cid, f"Community {cid}")), - "source_file": sanitize_label(str(data.get("source_file") or "")), - "file_type": data.get("file_type", ""), - "degree": deg, - }) + vis_nodes.append( + { + "id": node_id, + "label": label, + "color": { + "background": color, + "border": color, + "highlight": {"background": "#ffffff", "border": color}, + }, + "size": round(size, 1), + "font": {"size": font_size, "color": "#ffffff"}, + "title": _html.escape(label), + "community": cid, + "community_name": sanitize_label( + (community_labels or {}).get(cid, f"Community {cid}") + ), + "source_file": sanitize_label(str(data.get("source_file") or "")), + "file_type": data.get("file_type", ""), + "degree": deg, + } + ) # Build edges list. Restore original edge direction from _src/_tgt # (stashed by build.py for exactly this reason): undirected NetworkX # canonicalizes endpoint order, which would otherwise flip the arrow # for `calls` and `rationale_for` in the rendered graph (#563). + # + # Visual-noise cap (PR 6): at most DEFAULT_RELATIONSHIP_CAP parallel edges + # are drawn per (u, v) pair; any overflow is collapsed into ONE summary edge + # labelled "(+K more, N total)" from the relationship envelope. This is an + # intentional, documented summarization — every parallel edge is still + # preserved losslessly by to_json / to_graphml. Simple graphs (one edge per + # pair) are unaffected: shown == the single edge, no summary edge added. vis_edges = [] - for u, v, data in G.edges(data=True): - confidence = data.get("confidence", "EXTRACTED") - relation = data.get("relation", "") - true_src = data.get("_src", u) - true_tgt = data.get("_tgt", v) - vis_edges.append({ - "from": true_src, - "to": true_tgt, - "label": relation, - "title": _html.escape(f"{relation} [{confidence}]"), - "dashes": confidence != "EXTRACTED", - "width": 2 if confidence == "EXTRACTED" else 1, - "color": {"opacity": 0.7 if confidence == "EXTRACTED" else 0.35}, - "confidence": confidence, - }) + cap = DEFAULT_RELATIONSHIP_CAP + seen_pairs: set[tuple[Any, Any]] = set() + for u, v in G.edges(): + if (u, v) in seen_pairs: + continue # edge_datas returns all parallels for the pair at once + seen_pairs.add((u, v)) + records = edge_datas(G, u, v) + shown = records[:cap] + for data in shown: + confidence = data.get("confidence", "EXTRACTED") + relation = data.get("relation", "") + true_src = data.get("_src", u) + true_tgt = data.get("_tgt", v) + vis_edges.append( + { + "from": true_src, + "to": true_tgt, + "label": relation, + "title": _html.escape(f"{relation} [{confidence}]"), + "dashes": confidence != "EXTRACTED", + "width": 2 if confidence == "EXTRACTED" else 1, + "color": {"opacity": 0.7 if confidence == "EXTRACTED" else 0.35}, + "confidence": confidence, + } + ) + if len(records) > cap: + summary = format_relationship_envelope(G, u, v, cap=cap, directed_only=True) + rep = shown[0] if shown else (records[0] if records else {}) + true_src = rep.get("_src", u) + true_tgt = rep.get("_tgt", v) + vis_edges.append( + { + "from": true_src, + "to": true_tgt, + "label": summary, + "title": _html.escape(summary), + "dashes": True, + "width": 1, + "color": {"opacity": 0.35}, + "confidence": "SUMMARY", + } + ) # Build community legend data legend_data = [] for cid in sorted((community_labels or {}).keys()): color = COMMUNITY_COLORS[cid % len(COMMUNITY_COLORS)] lbl = _html.escape(sanitize_label((community_labels or {}).get(cid, f"Community {cid}"))) - n = member_counts.get(cid, len(communities.get(cid, []))) if member_counts else len(communities.get(cid, [])) + n = ( + member_counts.get(cid, len(communities.get(cid, []))) + if member_counts + else len(communities.get(cid, [])) + ) legend_data.append({"cid": cid, "color": color, "label": lbl, "count": n}) # Escape sequences so embedded JSON cannot break out of the script tag @@ -837,7 +1095,11 @@ def to_obsidian( # Map node_id → safe filename so wikilinks stay consistent. # Deduplicate: if two nodes produce the same filename, append a numeric suffix. def safe_name(label: str) -> str: - cleaned = re.sub(r'[\\/*?:"<>|#^[\]]', "", label.replace("\r\n", " ").replace("\r", " ").replace("\n", " ")).strip() + cleaned = re.sub( + r'[\\/*?:"<>|#^[\]]', + "", + label.replace("\r\n", " ").replace("\r", " ").replace("\n", " "), + ).strip() # Strip trailing .md/.mdx/.markdown so "CLAUDE.md" doesn't become "CLAUDE.md.md" cleaned = re.sub(r"\.(md|mdx|qmd|markdown)$", "", cleaned, flags=re.IGNORECASE) return cleaned or "unnamed" @@ -907,16 +1169,27 @@ def _dominant_confidence(node_id: str) -> str: lines.append(f" - {tag}") lines += ["---", "", f"# {label}", ""] - # Outgoing edges as wikilinks + # Outgoing edges as wikilinks. Render the FULL bundled relation summary + # per neighbor (PR 6 gate + PR 5 read-surface consistency) instead of + # only the first parallel edge. Gate on unique-relation count exactly + # like PR 5: a single relation keeps the historical byte-stable + # `` `{relation}` [{confidence}] `` form (so simple-graph vaults are + # unchanged), while multiple relations render the capped envelope + # bundle (e.g. "calls, imports, contains" or "... (+K more, N total)"). neighbors = list(G.neighbors(node_id)) if neighbors: lines.append("## Connections") for neighbor in sorted(neighbors, key=lambda n: G.nodes[n].get("label", n)): - edata = edge_data(G, node_id, neighbor) neighbor_label = node_filename[neighbor] - relation = edata.get("relation", "") - confidence = edata.get("confidence", "EXTRACTED") - lines.append(f"- [[{neighbor_label}]] - `{relation}` [{confidence}]") + envelope = relationship_envelope(G, node_id, neighbor, directed_only=True) + if len(envelope["relations"]) <= 1: + edata = edge_data(G, node_id, neighbor) + relation = edata.get("relation", "") + confidence = edata.get("confidence", "EXTRACTED") + lines.append(f"- [[{neighbor_label}]] - `{relation}` [{confidence}]") + else: + summary = format_relationship_envelope(G, node_id, neighbor, directed_only=True) + lines.append(f"- [[{neighbor_label}]] - {summary}") lines.append("") # Inline tags at bottom of note body (for Obsidian tag panel) @@ -975,8 +1248,10 @@ def _community_reach(node_id: str) -> int: # Cohesion + member count summary if coh_value is not None: cohesion_desc = ( - "tightly connected" if coh_value >= 0.7 - else "moderately connected" if coh_value >= 0.4 + "tightly connected" + if coh_value >= 0.7 + else "moderately connected" + if coh_value >= 0.4 else "loosely connected" ) lines.append(f"**Cohesion:** {coh_value:.2f} - {cohesion_desc}") @@ -1019,7 +1294,9 @@ def _community_reach(node_id: str) -> int: else f"Community {other_cid}" ) other_safe = safe_name(other_name) - lines.append(f"- {edge_count} edge{'s' if edge_count != 1 else ''} to [[_COMMUNITY_{other_safe}]]") + lines.append( + f"- {edge_count} edge{'s' if edge_count != 1 else ''} to [[_COMMUNITY_{other_safe}]]" + ) lines.append("") # Top bridge nodes - highest degree nodes that connect to other communities @@ -1051,7 +1328,10 @@ def _community_reach(node_id: str) -> int: "colorGroups": [ { "query": f"tag:#community/{label.replace(' ', '_')}", - "color": {"a": 1, "rgb": int(COMMUNITY_COLORS[cid % len(COMMUNITY_COLORS)].lstrip('#'), 16)} + "color": { + "a": 1, + "rgb": int(COMMUNITY_COLORS[cid % len(COMMUNITY_COLORS)].lstrip("#"), 16), + }, } for cid, label in sorted((community_labels or {}).items()) ] @@ -1078,7 +1358,11 @@ def to_canvas( CANVAS_COLORS = ["1", "2", "3", "4", "5", "6"] # red, orange, yellow, green, cyan, purple def safe_name(label: str) -> str: - cleaned = re.sub(r'[\\/*?:"<>|#^[\]]', "", label.replace("\r\n", " ").replace("\r", " ").replace("\n", " ")).strip() + cleaned = re.sub( + r'[\\/*?:"<>|#^[\]]', + "", + label.replace("\r\n", " ").replace("\r", " ").replace("\n", " "), + ).strip() cleaned = re.sub(r"\.(md|mdx|qmd|markdown)$", "", cleaned, flags=re.IGNORECASE) return cleaned or "unnamed" @@ -1104,8 +1388,6 @@ def safe_name(label: str) -> str: # Lay out communities in a grid gap = 80 - group_x_offsets: list[int] = [] - group_y_offsets: list[int] = [] # Precompute group sizes so we can calculate offsets sorted_cids = sorted(communities.keys()) @@ -1168,16 +1450,18 @@ def safe_name(label: str) -> str: canvas_color = CANVAS_COLORS[idx % len(CANVAS_COLORS)] # Group node - canvas_nodes.append({ - "id": f"g{cid}", - "type": "group", - "label": community_name, - "x": gx, - "y": gy, - "width": gw, - "height": gh, - "color": canvas_color, - }) + canvas_nodes.append( + { + "id": f"g{cid}", + "type": "group", + "label": community_name, + "x": gx, + "y": gy, + "width": gw, + "height": gh, + "color": canvas_color, + } + ) # Node cards inside the group - rows of 3 sorted_members = sorted(members, key=lambda n: G.nodes[n].get("label", n)) @@ -1187,34 +1471,90 @@ def safe_name(label: str) -> str: nx_x = gx + 20 + col * (180 + 20) nx_y = gy + 80 + row * (60 + 20) fname = node_filenames.get(node_id, safe_name(G.nodes[node_id].get("label", node_id))) - canvas_nodes.append({ - "id": f"n_{node_id}", - "type": "file", - "file": f"{fname}.md", - "x": nx_x, - "y": nx_y, - "width": 180, - "height": 60, - }) - - # Generate edges - only between nodes both in canvas, cap at 200 highest-weight - all_edges_weighted: list[tuple[float, str, str, str]] = [] - for u, v, edata in G.edges(data=True): - if u in all_canvas_nodes and v in all_canvas_nodes: + canvas_nodes.append( + { + "id": f"n_{node_id}", + "type": "file", + "file": f"{fname}.md", + "x": nx_x, + "y": nx_y, + "width": 180, + "height": 60, + } + ) + + # Generate edges - only between nodes both in canvas, cap at 200 highest-weight. + # + # Obsidian Canvas requires GLOBALLY UNIQUE edge ids; the previous endpoint-only + # `e_{u}_{v}` id silently collapsed parallel edges to one. We now emit a unique + # `e_{u}_{v}_{idx}` per drawn parallel edge. To bound visual noise (PR 6 + # requirement) we draw at most DEFAULT_RELATIONSHIP_CAP parallel edges per + # (u, v) pair; when more exist we draw the capped set PLUS one summary edge + # labelled "(+K more, N total)" via the relationship envelope. This is an + # intentional, documented summarization — the full edge set still survives + # losslessly in to_json / to_graphml. + pair_records: dict[tuple[str, str], list[dict]] = {} + for u, v in G.edges(): + if u not in all_canvas_nodes or v not in all_canvas_nodes: + continue + if (u, v) in pair_records: + continue # edge_datas returns all parallels for the pair at once + pair_records[(u, v)] = edge_datas(G, u, v) + + cap = DEFAULT_RELATIONSHIP_CAP + # Two-phase selection so synthetic summary edges are strictly ADDITIVE and + # never displace real edges from the 200-edge global cap: + # 1. Build the REAL drawn edges (at most `cap` parallels per pair), sort by + # weight desc, and truncate to the top 200. This preserves the original + # "200 highest-weight real edges" contract exactly. + # 2. AFTER truncation, append one overflow summary edge for each (u, v) pair + # that (a) had > cap parallels AND (b) still has at least one real edge in + # the surviving top-200 set. Summaries describe already-counted overflow, + # so they must not consume a real-edge slot; a previously-displaced real + # edge could otherwise be evicted by a `float("inf")` summary (the bug + # this replaces). Summaries are not weight-ranked and are not subject to + # the 200-cap themselves. + real_weighted: list[tuple[float, str, str, int, str]] = [] + overflow_pairs: dict[tuple[str, str], int] = {} + for (u, v), records in sorted( + pair_records.items(), key=lambda kv: (str(kv[0][0]), str(kv[0][1])) + ): + for idx, edata in enumerate(records[:cap]): weight = edata.get("weight", 1.0) relation = edata.get("relation", "") conf = edata.get("confidence", "EXTRACTED") label = f"{relation} [{conf}]" if relation else f"[{conf}]" - all_edges_weighted.append((weight, u, v, label)) + real_weighted.append((weight, u, v, idx, label)) + if len(records) > cap: + overflow_pairs[(u, v)] = len(records) + + real_weighted.sort(key=lambda x: (-x[0], x[1], x[2], x[3])) + surviving_real = real_weighted[:200] + used_edge_ids: set[str] = set() + for weight, u, v, idx, label in surviving_real: + canvas_edges.append( + { + "id": _canvas_edge_id(u, v, idx, used_edge_ids), + "fromNode": f"n_{u}", + "toNode": f"n_{v}", + "label": label, + } + ) - all_edges_weighted.sort(key=lambda x: -x[0]) - for weight, u, v, label in all_edges_weighted[:200]: - canvas_edges.append({ - "id": f"e_{u}_{v}", - "fromNode": f"n_{u}", - "toNode": f"n_{v}", - "label": label, - }) + # Append summary edges only for overflow pairs that survived the 200-cap. + surviving_pairs = {(u, v) for _w, u, v, _idx, _lbl in surviving_real} + for u, v in sorted(overflow_pairs, key=lambda p: (str(p[0]), str(p[1]))): + if (u, v) not in surviving_pairs: + continue # pair fully displaced by the 200-cap — no summary needed + summary_label = format_relationship_envelope(G, u, v, cap=cap, directed_only=True) + canvas_edges.append( + { + "id": _canvas_edge_id(u, v, "summary", used_edge_ids), + "fromNode": f"n_{u}", + "toNode": f"n_{v}", + "label": summary_label, + } + ) canvas_data = {"nodes": canvas_nodes, "edges": canvas_edges} Path(output_path).write_text(json.dumps(canvas_data, indent=2), encoding="utf-8") # nosec @@ -1237,14 +1577,15 @@ def push_to_neo4j( try: from neo4j import GraphDatabase except ImportError as e: - raise ImportError( - "neo4j driver not installed. Run: pip install neo4j" - ) from e + raise ImportError("neo4j driver not installed. Run: pip install neo4j") from e node_community = _node_community_map(communities) if communities else {} def _safe_rel(relation: str) -> str: - return re.sub(r"[^A-Z0-9_]", "_", relation.upper().replace(" ", "_").replace("-", "_")) or "RELATED_TO" + return ( + re.sub(r"[^A-Z0-9_]", "_", relation.upper().replace(" ", "_").replace("-", "_")) + or "RELATED_TO" + ) def _safe_label(label: str) -> str: """Sanitize a Neo4j node label to prevent Cypher injection.""" @@ -1256,6 +1597,7 @@ def _safe_label(label: str) -> str: edges_pushed = 0 with driver.session() as session: + session_any = cast(Any, session) for node_id, data in G.nodes(data=True): props = {k: v for k, v in data.items() if isinstance(v, (str, int, float, bool))} props["id"] = node_id @@ -1263,7 +1605,7 @@ def _safe_label(label: str) -> str: if cid is not None: props["community"] = cid ftype = _safe_label(data.get("file_type", "Entity").capitalize()) - session.run( + session_any.run( f"MERGE (n:{ftype} {{id: $id}}) SET n += $props", id=node_id, props=props, @@ -1273,7 +1615,7 @@ def _safe_label(label: str) -> str: for u, v, data in G.edges(data=True): rel = _safe_rel(data.get("relation", "RELATED_TO")) props = {k: v for k, v in data.items() if isinstance(v, (str, int, float, bool))} - session.run( + session_any.run( f"MATCH (a {{id: $src}}), (b {{id: $tgt}}) " f"MERGE (a)-[r:{rel}]->(b) SET r += $props", src=u, @@ -1300,6 +1642,16 @@ def to_graphml( node_community = _node_community_map(communities) for node_id in H.nodes(): H.nodes[node_id]["community"] = node_community.get(node_id, -1) + # GraphML only serializes scalar (str/int/float/bool) data values. The + # multigraph build path stashes a `graphify_multigraph_diagnostics` dict on + # G.graph, which would raise "GraphML does not support type " + # and abort the write (losing ALL edges, parallel ones included). Drop any + # non-scalar graph-level attrs so multigraph exports succeed losslessly; + # simple graphs carry no such attrs and are unaffected (byte-stable). + for attr_name in [ + name for name, value in H.graph.items() if not isinstance(value, (str, int, float, bool)) + ]: + del H.graph[attr_name] nx.write_graphml(H, output_path) @@ -1319,6 +1671,7 @@ def to_svg( """ try: import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches @@ -1336,24 +1689,66 @@ def to_svg( degree = dict(G.degree()) max_deg = max(degree.values(), default=1) or 1 - node_colors = [COMMUNITY_COLORS[node_community.get(n, 0) % len(COMMUNITY_COLORS)] for n in G.nodes()] + node_colors = [ + COMMUNITY_COLORS[node_community.get(n, 0) % len(COMMUNITY_COLORS)] for n in G.nodes() + ] node_sizes = [300 + 1200 * (degree.get(n, 1) / max_deg) for n in G.nodes()] - # Draw edges - dashed for non-EXTRACTED - for u, v, data in G.edges(data=True): - conf = data.get("confidence", "EXTRACTED") - style = "solid" if conf == "EXTRACTED" else "dashed" - alpha = 0.6 if conf == "EXTRACTED" else 0.3 - x0, y0 = pos[u] - x1, y1 = pos[v] - ax.plot([x0, x1], [y0, y1], color="#aaaaaa", linewidth=0.8, - linestyle=style, alpha=alpha, zorder=1) - - nx.draw_networkx_nodes(G, pos, ax=ax, node_color=node_colors, - node_size=node_sizes, alpha=0.9) - nx.draw_networkx_labels(G, pos, ax=ax, - labels={n: G.nodes[n].get("label", n) for n in G.nodes()}, - font_size=7, font_color="white") + # Draw edges - dashed for non-EXTRACTED. + # + # Visual-noise cap (PR 6): parallel edges between the same pair overlap + # exactly on the spring layout, so drawing all of them is pure clutter. We + # draw at most DEFAULT_RELATIONSHIP_CAP per (u, v) pair and, when more exist, + # add ONE summary text label "(+K more, N total)" at the edge midpoint from + # the relationship envelope. Intentional, documented summarization — the full + # edge set still survives losslessly in to_json / to_graphml. Simple graphs + # (one edge per pair) draw exactly as before with no summary label. + cap = DEFAULT_RELATIONSHIP_CAP + seen_pairs: set[tuple[Any, Any]] = set() + for u, v in G.edges(): + if (u, v) in seen_pairs: + continue # edge_datas returns all parallels for the pair at once + seen_pairs.add((u, v)) + records = edge_datas(G, u, v) + for data in records[:cap]: + conf = data.get("confidence", "EXTRACTED") + style = "solid" if conf == "EXTRACTED" else "dashed" + alpha = 0.6 if conf == "EXTRACTED" else 0.3 + x0, y0 = pos[u] + x1, y1 = pos[v] + ax.plot( + [x0, x1], + [y0, y1], + color="#aaaaaa", + linewidth=0.8, + linestyle=style, + alpha=alpha, + zorder=1, + ) + if len(records) > cap: + x0, y0 = pos[u] + x1, y1 = pos[v] + summary = format_relationship_envelope(G, u, v, cap=cap, directed_only=True) + ax.text( + (x0 + x1) / 2, + (y0 + y1) / 2, + summary, + color="#cccccc", + fontsize=6, + ha="center", + va="center", + zorder=2, + ) + + nx.draw_networkx_nodes(G, pos, ax=ax, node_color=node_colors, node_size=node_sizes, alpha=0.9) + nx.draw_networkx_labels( + G, + pos, + ax=ax, + labels={n: G.nodes[n].get("label", n) for n in G.nodes()}, + font_size=7, + font_color="white", + ) # Legend if community_labels: @@ -1364,10 +1759,15 @@ def to_svg( ) for cid, label in sorted(community_labels.items()) ] - ax.legend(handles=patches, loc="upper left", framealpha=0.7, - facecolor="#2a2a4e", labelcolor="white", fontsize=8) + ax.legend( + handles=patches, + loc="upper left", + framealpha=0.7, + facecolor="#2a2a4e", + labelcolor="white", + fontsize=8, + ) plt.tight_layout() - plt.savefig(output_path, format="svg", bbox_inches="tight", - facecolor=fig.get_facecolor()) + plt.savefig(output_path, format="svg", bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close(fig) diff --git a/graphify/extract.py b/graphify/extract.py index 1d396a697..0f4ad60f0 100644 --- a/graphify/extract.py +++ b/graphify/extract.py @@ -1,44 +1,123 @@ """Deterministic structural extraction from source code using tree-sitter. Outputs nodes+edges dicts.""" + from __future__ import annotations import importlib import json +import logging import os import re import sys import unicodedata from dataclasses import dataclass, field from pathlib import Path -from typing import Callable, Any +from collections.abc import Iterator +from typing import Any, Callable from .cache import load_cached, save_cached from .mcp_ingest import extract_mcp_config, is_mcp_config_path _RECURSION_LIMIT = 10_000 +_LOG = logging.getLogger(__name__) # Language built-in globals that AST may classify as call targets when used as # constructors or coercion functions (e.g. String(x), Number(x), Boolean(x)). # Without this filter they become god-nodes accumulating spurious edges from # every call site. Filter applied at same-file and cross-file resolution. # See issue #726. -_LANGUAGE_BUILTIN_GLOBALS: frozenset[str] = frozenset({ - # JavaScript / TypeScript ECMAScript built-ins - "String", "Number", "Boolean", "Object", "Array", "Symbol", "BigInt", - "Date", "RegExp", "Error", "TypeError", "RangeError", "SyntaxError", - "ReferenceError", "EvalError", "URIError", - "Promise", "Map", "Set", "WeakMap", "WeakSet", "JSON", "Math", - "Reflect", "Proxy", "Intl", - "parseInt", "parseFloat", "isNaN", "isFinite", - "encodeURIComponent", "decodeURIComponent", "encodeURI", "decodeURI", - # Browser / Node common globals - "URL", "URLSearchParams", "FormData", "Blob", "File", - "Headers", "Request", "Response", "AbortController", "AbortSignal", - "TextEncoder", "TextDecoder", "console", - # Python built-in callables - "str", "int", "float", "bool", "list", "dict", "set", "tuple", "bytes", - "len", "range", "enumerate", "zip", "map", "filter", "sum", "min", "max", - "print", "open", "isinstance", "type", "super", "sorted", "reversed", - "any", "all", "abs", "round", "next", "iter", "hash", "id", "repr", - "callable", "getattr", "setattr", "hasattr", "delattr", "vars", "dir", -}) +_LANGUAGE_BUILTIN_GLOBALS: frozenset[str] = frozenset( + { + # JavaScript / TypeScript ECMAScript built-ins + "String", + "Number", + "Boolean", + "Object", + "Array", + "Symbol", + "BigInt", + "Date", + "RegExp", + "Error", + "TypeError", + "RangeError", + "SyntaxError", + "ReferenceError", + "EvalError", + "URIError", + "Promise", + "Map", + "Set", + "WeakMap", + "WeakSet", + "JSON", + "Math", + "Reflect", + "Proxy", + "Intl", + "parseInt", + "parseFloat", + "isNaN", + "isFinite", + "encodeURIComponent", + "decodeURIComponent", + "encodeURI", + "decodeURI", + # Browser / Node common globals + "URL", + "URLSearchParams", + "FormData", + "Blob", + "File", + "Headers", + "Request", + "Response", + "AbortController", + "AbortSignal", + "TextEncoder", + "TextDecoder", + "console", + # Python built-in callables + "str", + "int", + "float", + "bool", + "list", + "dict", + "set", + "tuple", + "bytes", + "len", + "range", + "enumerate", + "zip", + "map", + "filter", + "sum", + "min", + "max", + "print", + "open", + "isinstance", + "type", + "super", + "sorted", + "reversed", + "any", + "all", + "abs", + "round", + "next", + "iter", + "hash", + "id", + "repr", + "callable", + "getattr", + "setattr", + "hasattr", + "delattr", + "vars", + "dir", + } +) def _raise_recursion_limit() -> None: @@ -55,6 +134,7 @@ def _safe_extract(extractor: Callable, path: Path) -> dict: except Exception as e: if os.environ.get("GRAPHIFY_DEBUG"): import traceback + traceback.print_exc(file=sys.stderr) print(f" warning: skipped {path} ({type(e).__name__}: {e})", file=sys.stderr, flush=True) return {"nodes": [], "edges": [], "error": f"{type(e).__name__}: {e}"} @@ -92,14 +172,33 @@ def _file_stem(path: Path) -> str: _JS_INDEX_FILES = ("index.ts", "index.tsx", "index.svelte", "index.js", "index.jsx", "index.mjs") -SEMANTIC_RELATIONS = frozenset({ - "inherits", "implements", "mixes_in", "embeds", "references", - "calls", "imports", "imports_from", "re_exports", "contains", "method", -}) +SEMANTIC_RELATIONS = frozenset( + { + "inherits", + "implements", + "mixes_in", + "embeds", + "references", + "calls", + "imports", + "imports_from", + "re_exports", + "contains", + "method", + } +) -REFERENCE_CONTEXTS = frozenset({ - "field", "parameter_type", "return_type", "generic_arg", "attribute", "value", "type", -}) +REFERENCE_CONTEXTS = frozenset( + { + "field", + "parameter_type", + "return_type", + "generic_arg", + "attribute", + "value", + "type", + } +) def _source_location(line: int | str | None) -> str | None: @@ -173,9 +272,9 @@ def _strip_jsonc(text: str) -> str: """ # Remove block and line comments while leaving string literals untouched. pattern = re.compile( - r'"(?:\\.|[^"\\])*"' # double-quoted string (with escapes) - r"|/\*.*?\*/" # /* block comment */ - r"|//[^\n]*", # // line comment + r'"(?:\\.|[^"\\])*"' # double-quoted string (with escapes) + r"|/\*.*?\*/" # /* block comment */ + r"|//[^\n]*", # // line comment re.DOTALL, ) @@ -205,7 +304,11 @@ def _read_tsconfig_aliases(tsconfig: Path, base_dir: Path, seen: set) -> dict[st try: raw = tsconfig.read_text(encoding="utf-8") except Exception as e: - print(f" warning: could not read {tsconfig} ({type(e).__name__}: {e})", file=sys.stderr, flush=True) + print( + f" warning: could not read {tsconfig} ({type(e).__name__}: {e})", + file=sys.stderr, + flush=True, + ) return {} try: data = json.loads(raw) @@ -213,10 +316,18 @@ def _read_tsconfig_aliases(tsconfig: Path, base_dir: Path, seen: set) -> dict[st try: data = json.loads(_strip_jsonc(raw)) except json.JSONDecodeError as e: - print(f" warning: failed to parse {tsconfig} as JSON/JSONC ({e.msg} at line {e.lineno} col {e.colno})", file=sys.stderr, flush=True) + print( + f" warning: failed to parse {tsconfig} as JSON/JSONC ({e.msg} at line {e.lineno} col {e.colno})", + file=sys.stderr, + flush=True, + ) return {} except Exception as e: - print(f" warning: failed to parse {tsconfig} ({type(e).__name__}: {e})", file=sys.stderr, flush=True) + print( + f" warning: failed to parse {tsconfig} ({type(e).__name__}: {e})", + file=sys.stderr, + flush=True, + ) return {} aliases: dict[str, str] = {} @@ -317,7 +428,8 @@ def _load_workspace_packages(start_dir: Path) -> dict[str, Path]: continue try: data = json.loads(manifest.read_text(encoding="utf-8")) - except Exception: + except Exception as exc: + _LOG.debug("could not read package manifest %s: %s", manifest, exc) continue name = data.get("name") if isinstance(name, str) and name: @@ -331,8 +443,8 @@ def _package_entry_candidates(package_dir: Path, subpath: str) -> list[Path]: manifest_data: dict[str, Any] = {} try: manifest_data = json.loads(manifest.read_text(encoding="utf-8")) - except Exception: - pass + except Exception as exc: + _LOG.debug("could not read package manifest %s: %s", manifest, exc) if subpath: return [package_dir / subpath] @@ -366,7 +478,7 @@ def _resolve_workspace_import(raw: str, start_dir: Path) -> Path | None: if raw == package_name: subpath = "" elif raw.startswith(package_name + "/"): - subpath = raw[len(package_name) + 1:] + subpath = raw[len(package_name) + 1 :] else: continue for candidate in _package_entry_candidates(package_dir, subpath): @@ -394,7 +506,7 @@ def _resolve_js_module_path(raw: str | Path, start_dir: Path | None = None) -> P aliases = _load_tsconfig_aliases(start_dir) for alias_prefix, alias_base in aliases.items(): if raw == alias_prefix or raw.startswith(alias_prefix + "/"): - rest = raw[len(alias_prefix):].lstrip("/") + rest = raw[len(alias_prefix) :].lstrip("/") return _resolve_js_import_path(Path(os.path.normpath(Path(alias_base) / rest))) return _resolve_workspace_import(raw, start_dir) @@ -402,10 +514,11 @@ def _resolve_js_module_path(raw: str | Path, start_dir: Path | None = None) -> P # ── LanguageConfig dataclass ───────────────────────────────────────────────── + @dataclass class LanguageConfig: - ts_module: str # e.g. "tree_sitter_python" - ts_language_fn: str = "language" # attr to call: e.g. tslang.language() + ts_module: str # e.g. "tree_sitter_python" + ts_language_fn: str = "language" # attr to call: e.g. tslang.language() class_types: frozenset = frozenset() function_types: frozenset = frozenset() @@ -422,12 +535,12 @@ class LanguageConfig: # Body detection body_field: str = "body" - body_fallback_child_types: tuple = () # e.g. ("declaration_list", "compound_statement") + body_fallback_child_types: tuple = () # e.g. ("declaration_list", "compound_statement") # Call name extraction - call_function_field: str = "function" # field on call node for callee + call_function_field: str = "function" # field on call node for callee call_accessor_node_types: frozenset = frozenset() # member/attribute nodes - call_accessor_field: str = "attribute" # field on accessor for method name + call_accessor_field: str = "attribute" # field on accessor for method name # Stop recursion at these types in walk_calls function_boundary_types: frozenset = frozenset() @@ -447,22 +560,57 @@ class LanguageConfig: # ── Generic helpers ─────────────────────────────────────────────────────────── -def _read_text(node, source: bytes) -> str: - return source[node.start_byte:node.end_byte].decode("utf-8", errors="replace") - -_PYTHON_TYPE_CONTAINERS = frozenset({ - "list", "dict", "set", "tuple", "frozenset", "type", - "List", "Dict", "Set", "Tuple", "FrozenSet", "Type", - "Optional", "Union", "Sequence", "Iterable", "Mapping", "MutableMapping", - "Iterator", "Callable", "Awaitable", "AsyncIterable", "AsyncIterator", "Coroutine", - "Generator", "AsyncGenerator", "ContextManager", "AsyncContextManager", - "Annotated", "ClassVar", "Final", "Literal", "Concatenate", "ParamSpec", "TypeVar", - "None", "Ellipsis", -}) +def _read_text(node, source: bytes) -> str: + return source[node.start_byte : node.end_byte].decode("utf-8", errors="replace") + + +_PYTHON_TYPE_CONTAINERS = frozenset( + { + "list", + "dict", + "set", + "tuple", + "frozenset", + "type", + "List", + "Dict", + "Set", + "Tuple", + "FrozenSet", + "Type", + "Optional", + "Union", + "Sequence", + "Iterable", + "Mapping", + "MutableMapping", + "Iterator", + "Callable", + "Awaitable", + "AsyncIterable", + "AsyncIterator", + "Coroutine", + "Generator", + "AsyncGenerator", + "ContextManager", + "AsyncContextManager", + "Annotated", + "ClassVar", + "Final", + "Literal", + "Concatenate", + "ParamSpec", + "TypeVar", + "None", + "Ellipsis", + } +) -def _python_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str, str]]) -> None: +def _python_collect_type_refs( + node, source: bytes, generic: bool, out: list[tuple[str, str]] +) -> None: """Walk a Python type annotation; append (name, role) where role is 'type' or 'generic_arg'. Builtin/typing containers (list, dict, Optional, Union, …) are not emitted as refs themselves, @@ -537,7 +685,9 @@ def _csharp_classify_base(name: str, interface_names: set[str]) -> str: return "inherits" -def _csharp_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str, str]]) -> None: +def _csharp_collect_type_refs( + node, source: bytes, generic: bool, out: list[tuple[str, str]] +) -> None: """Walk a C# type expression; append (name, role) tuples (role is 'type' or 'generic_arg').""" if node is None: return @@ -671,11 +821,32 @@ def _java_method_annotation_names(method_node, source: bytes) -> list[str]: return names -_GO_PREDECLARED_TYPES = frozenset({ - "bool", "byte", "complex64", "complex128", "error", "float32", "float64", - "int", "int8", "int16", "int32", "int64", "rune", "string", - "uint", "uint8", "uint16", "uint32", "uint64", "uintptr", "any", "comparable", -}) +_GO_PREDECLARED_TYPES = frozenset( + { + "bool", + "byte", + "complex64", + "complex128", + "error", + "float32", + "float64", + "int", + "int8", + "int16", + "int32", + "int64", + "rune", + "string", + "uint", + "uint8", + "uint16", + "uint32", + "uint64", + "uintptr", + "any", + "comparable", + } +) def _go_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str, str]]) -> None: @@ -705,8 +876,14 @@ def _go_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[st if arg.is_named: _go_collect_type_refs(arg, source, True, out) return - if t in ("pointer_type", "slice_type", "array_type", "map_type", - "channel_type", "parenthesized_type"): + if t in ( + "pointer_type", + "slice_type", + "array_type", + "map_type", + "channel_type", + "parenthesized_type", + ): for c in node.children: if c.is_named: _go_collect_type_refs(c, source, generic, out) @@ -808,8 +985,14 @@ def _php_method_return_type_node(method_node): saw_params = True continue if saw_params and c.is_named and c.type not in ("compound_statement",): - if c.type in ("named_type", "primitive_type", "nullable_type", - "union_type", "intersection_type", "optional_type"): + if c.type in ( + "named_type", + "primitive_type", + "nullable_type", + "union_type", + "intersection_type", + "optional_type", + ): return c return None @@ -833,7 +1016,9 @@ def _kotlin_user_type_name(user_type_node, source: bytes) -> str | None: return None -def _kotlin_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str, str]]) -> None: +def _kotlin_collect_type_refs( + node, source: bytes, generic: bool, out: list[tuple[str, str]] +) -> None: """Walk a Kotlin type expression; append (name, role) tuples.""" if node is None: return @@ -948,8 +1133,9 @@ def _swift_pre_scan(root_node, source: bytes) -> tuple[set[str], set[str]]: return protocols, classes -def _swift_classify_base(name: str, kind: str | None, is_first: bool, - protocols: set[str], classes: set[str]) -> str: +def _swift_classify_base( + name: str, kind: str | None, is_first: bool, protocols: set[str], classes: set[str] +) -> str: """Classify a Swift inheritance_specifier entry as `inherits` or `implements`.""" if name in protocols: return "implements" @@ -973,7 +1159,9 @@ def _swift_user_type_name(user_type_node, source: bytes) -> str | None: return None -def _swift_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str, str]]) -> None: +def _swift_collect_type_refs( + node, source: bytes, generic: bool, out: list[tuple[str, str]] +) -> None: """Walk a Swift type expression; append (name, role) tuples (role 'type' or 'generic_arg').""" if node is None: return @@ -1001,8 +1189,13 @@ def _swift_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple if text: out.append((text, "generic_arg" if generic else "type")) return - if t in ("optional_type", "implicitly_unwrapped_optional_type", "array_type", - "dictionary_type", "tuple_type"): + if t in ( + "optional_type", + "implicitly_unwrapped_optional_type", + "array_type", + "dictionary_type", + "tuple_type", + ): for c in node.children: if c.is_named: _swift_collect_type_refs(c, source, generic, out) @@ -1023,9 +1216,14 @@ def _swift_property_type_node(property_node): # ── C / C++ type-ref helpers ───────────────────────────────────────────────── -_C_PRIMITIVE_TYPE_NODES = frozenset({ - "primitive_type", "sized_type_specifier", "auto", "placeholder_type_specifier", -}) +_C_PRIMITIVE_TYPE_NODES = frozenset( + { + "primitive_type", + "sized_type_specifier", + "auto", + "placeholder_type_specifier", + } +) def _c_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str, str]]) -> None: @@ -1039,9 +1237,16 @@ def _c_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str if text: out.append((text, "generic_arg" if generic else "type")) return - if t in ("pointer_declarator", "reference_declarator", "array_declarator", - "type_qualifier", "type_descriptor", "abstract_pointer_declarator", - "abstract_reference_declarator", "abstract_array_declarator"): + if t in ( + "pointer_declarator", + "reference_declarator", + "array_declarator", + "type_qualifier", + "type_descriptor", + "abstract_pointer_declarator", + "abstract_reference_declarator", + "abstract_array_declarator", + ): for c in node.children: if c.is_named: _c_collect_type_refs(c, source, generic, out) @@ -1076,9 +1281,16 @@ def _cpp_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[s if c.is_named: _cpp_collect_type_refs(c, source, True, out) return - if t in ("type_descriptor", "pointer_declarator", "reference_declarator", - "array_declarator", "type_qualifier", "abstract_pointer_declarator", - "abstract_reference_declarator", "abstract_array_declarator"): + if t in ( + "type_descriptor", + "pointer_declarator", + "reference_declarator", + "array_declarator", + "type_qualifier", + "abstract_pointer_declarator", + "abstract_reference_declarator", + "abstract_array_declarator", + ): for c in node.children: if c.is_named: _cpp_collect_type_refs(c, source, generic, out) @@ -1086,7 +1298,10 @@ def _cpp_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[s # ── Scala type-ref helpers ─────────────────────────────────────────────────── -def _scala_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str, str]]) -> None: + +def _scala_collect_type_refs( + node, source: bytes, generic: bool, out: list[tuple[str, str]] +) -> None: """Walk a Scala type expression; append (name, role) tuples. Handles type_identifier, generic_type (List[T]), and common type wrappers.""" if node is None: @@ -1114,8 +1329,14 @@ def _scala_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple if arg.is_named: _scala_collect_type_refs(arg, source, True, out) return - if t in ("compound_type", "infix_type", "function_type", "tuple_type", - "annotated_type", "projected_type"): + if t in ( + "compound_type", + "infix_type", + "function_type", + "tuple_type", + "annotated_type", + "projected_type", + ): for c in node.children: if c.is_named: _scala_collect_type_refs(c, source, generic, out) @@ -1160,7 +1381,10 @@ def _find_body(node, config: LanguageConfig): # ── Import handlers ─────────────────────────────────────────────────────────── -def _import_python(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + +def _import_python( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: t = node.type if t == "import_statement": for child in node.children: @@ -1168,16 +1392,18 @@ def _import_python(node, source: bytes, file_nid: str, stem: str, edges: list, s raw = _read_text(child, source) module_name = raw.split(" as ")[0].strip().lstrip(".") tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) elif t == "import_from_statement": module_node = node.child_by_field_name("module_name") if module_node: @@ -1193,16 +1419,18 @@ def _import_python(node, source: bytes, file_nid: str, stem: str, edges: list, s tgt_nid = _make_id(str(base / rel)) else: tgt_nid = _make_id(raw) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports_from", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports_from", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) def _resolve_js_import_target(raw: str, str_path: str) -> "tuple[str, Path | None] | None": @@ -1228,7 +1456,11 @@ def _import_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_p # Only handle export_statement if it has a `from` clause (re-export). # Pure exports like `export const x = 1` or `export { localVar }` have no source module. if is_reexport: - has_from = any(child.type == "from" or (_read_text(child, source) == "from") for child in node.children if child.type in ("from", "identifier")) + has_from = any( + child.type == "from" or (_read_text(child, source) == "from") + for child in node.children + if child.type in ("from", "identifier") + ) if not has_from: # Check for string child (source path) as a more reliable indicator has_from = any(child.type == "string" for child in node.children) @@ -1243,16 +1475,18 @@ def _import_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_p if resolved is None: break tgt_nid, resolved_path = resolved - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports_from", - "context": "re-export" if is_reexport else "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports_from", + "context": "re-export" if is_reexport else "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break # Emit symbol-level edges for named imports/re-exports from local/aliased files. @@ -1277,16 +1511,18 @@ def _import_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_p sym = _read_text(name_node, source) if sym == "default": continue # skip default re-exports for ID matching - edges.append({ - "source": file_nid, - "target": _make_id(target_stem, sym), - "relation": "re_exports", - "context": "re-export", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": _make_id(target_stem, sym), + "relation": "re_exports", + "context": "re-export", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) else: # Handle: import { Foo, type Bar } from './bar' for child in node.children: @@ -1298,20 +1534,23 @@ def _import_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_p name_node = spec.child_by_field_name("name") if name_node: sym = _read_text(name_node, source) - edges.append({ - "source": file_nid, - "target": _make_id(target_stem, sym), - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) - - -def _dynamic_import_js(node, source: bytes, caller_nid: str, str_path: str, edges: list, - seen_dyn_pairs: set) -> bool: + edges.append( + { + "source": file_nid, + "target": _make_id(target_stem, sym), + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) + + +def _dynamic_import_js( + node, source: bytes, caller_nid: str, str_path: str, edges: list, seen_dyn_pairs: set +) -> bool: """Detect dynamic import() calls in JS/TS and emit imports_from edges. Handles patterns like: @@ -1358,16 +1597,18 @@ def _dynamic_import_js(node, source: bytes, caller_nid: str, str_path: str, edge pair = (caller_nid, tgt_nid) if pair not in seen_dyn_pairs: seen_dyn_pairs.add(pair) - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "imports_from", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": "imports_from", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break return True @@ -1398,16 +1639,18 @@ def _walk_scoped(n) -> str: ) if module_name: tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break @@ -1435,7 +1678,24 @@ def _import_c(node, source: bytes, file_nid: str, stem: str, edges: list, str_pa resolved = _resolve_c_include_path(raw, str_path) if resolved is not None: tgt_nid = _make_id(str(resolved)) - edges.append({ + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) + break + module_name = raw.split("/")[-1].split(".")[0] + if module_name: + tgt_nid = _make_id(module_name) + edges.append( + { "source": file_nid, "target": tgt_nid, "relation": "imports", @@ -1444,97 +1704,98 @@ def _import_c(node, source: bytes, file_nid: str, stem: str, edges: list, str_pa "source_file": str_path, "source_location": f"L{node.start_point[0] + 1}", "weight": 1.0, - }) - break - module_name = raw.split("/")[-1].split(".")[0] - if module_name: - tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + } + ) break -def _import_csharp(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: +def _import_csharp( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: for child in node.children: if child.type in ("qualified_name", "identifier", "name_equals"): raw = _read_text(child, source) module_name = raw.split(".")[-1].strip() if module_name: tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break -def _import_kotlin(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: +def _import_kotlin( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: path_node = node.child_by_field_name("path") if path_node: raw = _read_text(path_node, source) module_name = raw.split(".")[-1].strip() if module_name: tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) return # Fallback: find identifier child for child in node.children: if child.type == "identifier": raw = _read_text(child, source) tgt_nid = _make_id(raw) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break -def _import_scala(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: +def _import_scala( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: for child in node.children: if child.type in ("stable_id", "identifier"): raw = _read_text(child, source) module_name = raw.split(".")[-1].strip("{} ") if module_name and module_name != "_": tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break @@ -1545,21 +1806,24 @@ def _import_php(node, source: bytes, file_nid: str, stem: str, edges: list, str_ module_name = raw.split("\\")[-1].strip() if module_name: tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break # ── C/C++ function name helpers ─────────────────────────────────────────────── + def _get_c_func_name(node, source: bytes) -> str | None: """Recursively unwrap declarator to find the innermost identifier (C).""" if node.type == "identifier": @@ -1594,6 +1858,7 @@ def _get_cpp_func_name(node, source: bytes) -> str | None: # ── JS/TS extra walk for arrow functions ────────────────────────────────────── + def _find_require_call(value_node): """Return the call_expression node if `value_node` is a `require(...)` call or `require(...).x` member access. Otherwise None.""" @@ -1609,7 +1874,9 @@ def _find_require_call(value_node): return None -def _require_imports_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> bool: +def _require_imports_js( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> bool: """Detect CommonJS require imports inside lexical_declaration / variable_declaration. Handles three patterns: @@ -1647,16 +1914,18 @@ def _require_imports_js(node, source: bytes, file_nid: str, stem: str, edges: li continue tgt_nid, resolved_path = resolved line = node.start_point[0] + 1 - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports_from", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports_from", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) found = True # Symbol-level edges for destructured / accessor binders. @@ -1679,22 +1948,35 @@ def _require_imports_js(node, source: bytes, file_nid: str, stem: str, edges: li sym_names.append(_read_text(prop, source)) if target_stem is not None: for sym in sym_names: - edges.append({ - "source": file_nid, - "target": _make_id(target_stem, sym), - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": _make_id(target_stem, sym), + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) return found -def _js_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, - nodes: list, edges: list, seen_ids: set, function_bodies: list, - parent_class_nid: str | None, add_node_fn, add_edge_fn) -> bool: +def _js_extra_walk( + node, + source: bytes, + file_nid: str, + stem: str, + str_path: str, + nodes: list, + edges: list, + seen_ids: set, + function_bodies: list, + parent_class_nid: str | None, + add_node_fn, + add_edge_fn, +) -> bool: """Handle lexical_declaration (arrow functions, CJS requires, module-level const literals) for JS/TS. Returns True if handled.""" if node.type in ("lexical_declaration", "variable_declaration"): # CJS require imports — emit edges, do not block other lexical_declaration handling @@ -1709,9 +1991,11 @@ def _js_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, parent = node.parent is_module_level = parent is not None and ( parent.type == "program" - or (parent.type == "export_statement" + or ( + parent.type == "export_statement" and parent.parent is not None - and parent.parent.type == "program") + and parent.parent.type == "program" + ) ) # Arrow function declarations and module-level const literals (lexical_declaration only) @@ -1734,7 +2018,11 @@ def _js_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, function_bodies.append((func_nid, body)) arrow_found = True elif value and value.type in ( - "object", "array", "as_expression", "call_expression", "new_expression", + "object", + "array", + "as_expression", + "call_expression", + "new_expression", ): # Module-level const with literal/object/array/factory value name_node = child.child_by_field_name("name") @@ -1756,10 +2044,22 @@ def _js_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, # ── C# extra walk for namespace declarations ────────────────────────────────── -def _csharp_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, - nodes: list, edges: list, seen_ids: set, function_bodies: list, - parent_class_nid: str | None, add_node_fn, add_edge_fn, - walk_fn) -> bool: + +def _csharp_extra_walk( + node, + source: bytes, + file_nid: str, + stem: str, + str_path: str, + nodes: list, + edges: list, + seen_ids: set, + function_bodies: list, + parent_class_nid: str | None, + add_node_fn, + add_edge_fn, + walk_fn, +) -> bool: """Handle namespace_declaration for C#. Returns True if handled.""" if node.type == "namespace_declaration": name_node = node.child_by_field_name("name") @@ -1779,9 +2079,21 @@ def _csharp_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: # ── Swift extra walk for enum cases ────────────────────────────────────────── -def _swift_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, - nodes: list, edges: list, seen_ids: set, function_bodies: list, - parent_class_nid: str | None, add_node_fn, add_edge_fn) -> bool: + +def _swift_extra_walk( + node, + source: bytes, + file_nid: str, + stem: str, + str_path: str, + nodes: list, + edges: list, + seen_ids: set, + function_bodies: list, + parent_class_nid: str | None, + add_node_fn, + add_edge_fn, +) -> bool: """Handle enum_entry for Swift. Returns True if handled.""" if node.type == "enum_entry" and parent_class_nid: for child in node.children: @@ -1819,27 +2131,33 @@ def _swift_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: s call_function_field="function", call_accessor_node_types=frozenset({"member_expression"}), call_accessor_field="property", - function_boundary_types=frozenset({"function_declaration", "arrow_function", "method_definition"}), + function_boundary_types=frozenset( + {"function_declaration", "arrow_function", "method_definition"} + ), import_handler=_import_js, ) _TS_CONFIG = LanguageConfig( ts_module="tree_sitter_typescript", ts_language_fn="language_typescript", - class_types=frozenset({ - "class_declaration", - "abstract_class_declaration", # TS abstract class - "interface_declaration", # parity with Java/C# - "enum_declaration", # named enums - "type_alias_declaration", # named type aliases - }), + class_types=frozenset( + { + "class_declaration", + "abstract_class_declaration", # TS abstract class + "interface_declaration", # parity with Java/C# + "enum_declaration", # named enums + "type_alias_declaration", # named type aliases + } + ), function_types=frozenset({"function_declaration", "method_definition"}), import_types=frozenset({"import_statement", "export_statement"}), call_types=frozenset({"call_expression", "new_expression"}), call_function_field="function", call_accessor_node_types=frozenset({"member_expression"}), call_accessor_field="property", - function_boundary_types=frozenset({"function_declaration", "arrow_function", "method_definition"}), + function_boundary_types=frozenset( + {"function_declaration", "arrow_function", "method_definition"} + ), import_handler=_import_js, ) @@ -1980,7 +2298,14 @@ def _swift_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: s class_types=frozenset({"class_declaration"}), function_types=frozenset({"function_definition", "method_declaration"}), import_types=frozenset({"namespace_use_clause"}), - call_types=frozenset({"function_call_expression", "member_call_expression", "scoped_call_expression", "class_constant_access_expression"}), + call_types=frozenset( + { + "function_call_expression", + "member_call_expression", + "scoped_call_expression", + "class_constant_access_expression", + } + ), static_prop_types=frozenset({"scoped_property_access_expression"}), helper_fn_names=frozenset({"config"}), container_bind_methods=frozenset({"bind", "singleton", "scoped", "instance"}), @@ -2037,23 +2362,26 @@ def _import_lua(node, source: bytes, file_nid: str, stem: str, edges: list, str_ """Extract require('module') from Lua variable_declaration nodes.""" text = _read_text(node, source) import re + m = re.search(r"""require\s*[\('"]\s*['"]?([^'")\s]+)""", text) if m: raw_module = m.group(1) if raw_module: tgt_nid = _resolve_lua_import_target(raw_module, str_path) if tgt_nid: - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": str(node.start_point[0] + 1), - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": str(node.start_point[0] + 1), + "weight": 1.0, + } + ) _LUA_CONFIG = LanguageConfig( @@ -2073,21 +2401,25 @@ def _import_lua(node, source: bytes, file_nid: str, stem: str, edges: list, str_ ) -def _import_swift(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: +def _import_swift( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: for child in node.children: if child.type == "identifier": raw = _read_text(child, source) tgt_nid = _make_id(raw) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break @@ -2115,7 +2447,9 @@ def _read_csharp_type_name(node, source: bytes) -> str | None: _SWIFT_CONFIG = LanguageConfig( ts_module="tree_sitter_swift", class_types=frozenset({"class_declaration", "protocol_declaration"}), - function_types=frozenset({"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"}), + function_types=frozenset( + {"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"} + ), import_types=frozenset({"import_declaration"}), call_types=frozenset({"call_expression"}), call_function_field="", @@ -2123,23 +2457,31 @@ def _read_csharp_type_name(node, source: bytes) -> str | None: call_accessor_field="", name_fallback_child_types=("simple_identifier", "type_identifier", "user_type"), body_fallback_child_types=("class_body", "protocol_body", "function_body", "enum_class_body"), - function_boundary_types=frozenset({"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"}), + function_boundary_types=frozenset( + {"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"} + ), import_handler=_import_swift, ) # ── Generic extractor ───────────────────────────────────────────────────────── + def _extract_generic(path: Path, config: LanguageConfig) -> dict: """Generic AST extractor driven by LanguageConfig.""" try: mod = importlib.import_module(config.ts_module) from tree_sitter import Language, Parser + lang_fn = getattr(mod, config.ts_language_fn, None) if lang_fn is None: # Fallback for PHP: try "language_php" then "language" lang_fn = getattr(mod, "language", None) if lang_fn is None: - return {"nodes": [], "edges": [], "error": f"No language function in {config.ts_module}"} + return { + "nodes": [], + "edges": [], + "error": f"No language function in {config.ts_module}", + } language = Language(lang_fn()) except ImportError: return {"nodes": [], "edges": [], "error": f"{config.ts_module} not installed"} @@ -2188,17 +2530,25 @@ def _extract_generic(path: Path, config: LanguageConfig) -> dict: def add_node(nid: str, label: str, line: int) -> None: if nid not in seen_ids: seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) + nodes.append( + { + "id": nid, + "label": label, + "file_type": "code", + "source_file": str_path, + "source_location": f"L{line}", + } + ) - def add_edge(src: str, tgt: str, relation: str, line: int, - confidence: str = "EXTRACTED", weight: float = 1.0, - context: str | None = None) -> None: + def add_edge( + src: str, + tgt: str, + relation: str, + line: int, + confidence: str = "EXTRACTED", + weight: float = 1.0, + context: str | None = None, + ) -> None: edge = { "source": src, "target": tgt, @@ -2274,19 +2624,23 @@ def walk(node, parent_class_nid: str | None = None) -> None: if base_nid not in seen_ids: base_nid = _make_id(base) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) add_edge(class_nid, base_nid, "inherits", line) # Swift-specific: conformance / inheritance if config.ts_module == "tree_sitter_swift": - swift_kind = _swift_declaration_keyword(node) if t == "class_declaration" else "protocol" + swift_kind = ( + _swift_declaration_keyword(node) if t == "class_declaration" else "protocol" + ) seen_swift_base = False for child in node.children: if child.type != "inheritance_specifier": @@ -2307,20 +2661,25 @@ def walk(node, parent_class_nid: str | None = None) -> None: if base_nid not in seen_ids: base_nid = _make_id(base_name) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base_name, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base_name, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) if t == "protocol_declaration": relation = "inherits" else: relation = _swift_classify_base( - base_name, swift_kind, not seen_swift_base, - swift_protocol_names, swift_class_names, + base_name, + swift_kind, + not seen_swift_base, + swift_protocol_names, + swift_class_names, ) seen_swift_base = True add_edge(class_nid, base_nid, relation, line) @@ -2335,11 +2694,13 @@ def walk(node, parent_class_nid: str | None = None) -> None: _swift_collect_type_refs(arg, source, True, refs) for ref_name, _role in refs: target = ensure_named_node(ref_name, line) - add_edge(class_nid, target, "references", line, - context="generic_arg") + add_edge( + class_nid, target, "references", line, context="generic_arg" + ) # PHP-specific: extends → inherits, implements → implements, use → mixes_in if config.ts_module == "tree_sitter_php": + def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: if not base_name: return @@ -2347,13 +2708,15 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: if base_nid not in seen_ids: base_nid = _make_id(base_name) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base_name, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base_name, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) add_edge(class_nid, base_nid, rel, at_line) @@ -2361,13 +2724,19 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: if child.type == "base_clause": for sub in child.children: if sub.type in ("name", "qualified_name"): - _php_emit_base(_php_name_text(sub, source) or "", - "inherits", child.start_point[0] + 1) + _php_emit_base( + _php_name_text(sub, source) or "", + "inherits", + child.start_point[0] + 1, + ) elif child.type == "class_interface_clause": for sub in child.children: if sub.type in ("name", "qualified_name"): - _php_emit_base(_php_name_text(sub, source) or "", - "implements", child.start_point[0] + 1) + _php_emit_base( + _php_name_text(sub, source) or "", + "implements", + child.start_point[0] + 1, + ) body = node.child_by_field_name("body") if body is None: for c in node.children: @@ -2380,8 +2749,11 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: continue for sub in member.children: if sub.type in ("name", "qualified_name"): - _php_emit_base(_php_name_text(sub, source) or "", - "mixes_in", member.start_point[0] + 1) + _php_emit_base( + _php_name_text(sub, source) or "", + "mixes_in", + member.start_point[0] + 1, + ) # Kotlin-specific: delegation_specifiers → inherits (constructor_invocation) / implements (user_type) if config.ts_module == "tree_sitter_kotlin": @@ -2413,13 +2785,15 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: if base_nid not in seen_ids: base_nid = _make_id(base) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) add_edge(class_nid, base_nid, relation, line) for arg_child in user_type_node.children: @@ -2434,8 +2808,13 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: _kotlin_collect_type_refs(inner, source, True, refs) for ref_name, _role in refs: target = ensure_named_node(ref_name, line) - add_edge(class_nid, target, "references", line, - context="generic_arg") + add_edge( + class_nid, + target, + "references", + line, + context="generic_arg", + ) # C#-specific: inheritance / interface implementation via base_list if config.ts_module == "tree_sitter_c_sharp": @@ -2448,7 +2827,8 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: if sub.type == "generic_name": name_child = sub.child_by_field_name("name") base = ( - _read_text(name_child, source) if name_child + _read_text(name_child, source) + if name_child else _read_text(sub.children[0], source) ) elif sub.type == "qualified_name": @@ -2461,13 +2841,15 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: if base_nid not in seen_ids: base_nid = _make_id(base) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) relation = _csharp_classify_base(base, csharp_interface_names) add_edge(class_nid, base_nid, relation, line) @@ -2482,11 +2864,17 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: _csharp_collect_type_refs(arg, source, True, refs) for ref_name, _role in refs: target = ensure_named_node(ref_name, line) - add_edge(class_nid, target, "references", line, - context="generic_arg") + add_edge( + class_nid, + target, + "references", + line, + context="generic_arg", + ) # Java-specific: extends (superclass) / implements (interfaces) / interface-extends if config.ts_module == "tree_sitter_java": + def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if not base_name: return @@ -2494,13 +2882,15 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if base_nid not in seen_ids: base_nid = _make_id(base_name) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base_name, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base_name, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) add_edge(class_nid, base_nid, rel, at_line) @@ -2526,7 +2916,9 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if sub.type == "type_list": for tid in sub.children: if tid.type == "type_identifier": - _emit_java_parent(_read_text(tid, source), "inherits", line) + _emit_java_parent( + _read_text(tid, source), "inherits", line + ) # Scala: extends_clause carries `extends Base with Trait1 with Trait2`. # The first base after `extends` is `inherits`; each subsequent @@ -2574,8 +2966,7 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: ctx = "generic_arg" if role == "generic_arg" else "field" target_nid = ensure_named_node(ref_name, cp_line) if target_nid != class_nid: - add_edge(class_nid, target_nid, "references", - cp_line, context=ctx) + add_edge(class_nid, target_nid, "references", cp_line, context=ctx) # C++-specific: inheritance via base_class_clause (class and struct). # tree-sitter-cpp shape: @@ -2612,13 +3003,15 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if base_nid not in seen_ids: base_nid = _make_id(base) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) add_edge(class_nid, base_nid, "inherits", line) @@ -2630,9 +3023,7 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: return # Event listener property arrays: $listen = [Event::class => [Listener::class]] - if (t == "property_declaration" - and parent_class_nid - and config.event_listener_properties): + if t == "property_declaration" and parent_class_nid and config.event_listener_properties: handled_event_listener = False for element in node.children: if element.type != "property_element": @@ -2647,9 +3038,11 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: break elif c.type == "array_creation_expression": array_node = c - if (prop_name is None - or prop_name not in config.event_listener_properties - or array_node is None): + if ( + prop_name is None + or prop_name not in config.event_listener_properties + or array_node is None + ): continue handled_event_listener = True for entry in array_node.children: @@ -2683,9 +3076,11 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if handled_event_listener: return - if (config.ts_module == "tree_sitter_c_sharp" - and t == "field_declaration" - and parent_class_nid): + if ( + config.ts_module == "tree_sitter_c_sharp" + and t == "field_declaration" + and parent_class_nid + ): type_node = node.child_by_field_name("type") if type_node is None: for child in node.children: @@ -2696,16 +3091,29 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: type_name = _read_csharp_type_name(type_node, source) if type_name: line = node.start_point[0] + 1 - add_edge(parent_class_nid, ensure_named_node(type_name, line), - "references", line, context="field") + add_edge( + parent_class_nid, + ensure_named_node(type_name, line), + "references", + line, + context="field", + ) return - if (config.ts_module == "tree_sitter_php" - and t == "property_declaration" - and parent_class_nid): + if ( + config.ts_module == "tree_sitter_php" + and t == "property_declaration" + and parent_class_nid + ): for c in node.children: - if c.type not in ("named_type", "primitive_type", "nullable_type", - "union_type", "intersection_type", "optional_type"): + if c.type not in ( + "named_type", + "primitive_type", + "nullable_type", + "union_type", + "intersection_type", + "optional_type", + ): continue line = node.start_point[0] + 1 refs: list[tuple[str, str]] = [] @@ -2718,9 +3126,11 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: break return - if (config.ts_module == "tree_sitter_kotlin" - and t == "property_declaration" - and parent_class_nid): + if ( + config.ts_module == "tree_sitter_kotlin" + and t == "property_declaration" + and parent_class_nid + ): type_node = _kotlin_property_type_node(node) if type_node is not None: line = node.start_point[0] + 1 @@ -2733,9 +3143,11 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: add_edge(parent_class_nid, target_nid, "references", line, context=ctx) return - if (config.ts_module == "tree_sitter_swift" - and t == "property_declaration" - and parent_class_nid): + if ( + config.ts_module == "tree_sitter_swift" + and t == "property_declaration" + and parent_class_nid + ): type_anno = _swift_property_type_node(node) if type_anno is not None: line = node.start_point[0] + 1 @@ -2748,9 +3160,7 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: add_edge(parent_class_nid, target_nid, "references", line, context=ctx) return - if (config.ts_module == "tree_sitter_scala" - and t == "val_definition" - and parent_class_nid): + if config.ts_module == "tree_sitter_scala" and t == "val_definition" and parent_class_nid: type_node = node.child_by_field_name("type") if type_node is not None: line = node.start_point[0] + 1 @@ -2760,20 +3170,19 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: ctx = "generic_arg" if role == "generic_arg" else "field" target_nid = ensure_named_node(ref_name, line) if target_nid != parent_class_nid: - add_edge(parent_class_nid, target_nid, "references", - line, context=ctx) + add_edge(parent_class_nid, target_nid, "references", line, context=ctx) # fall through so any call expressions in the initializer get walked - if (config.ts_module == "tree_sitter_cpp" - and t == "field_declaration" - and parent_class_nid): + if config.ts_module == "tree_sitter_cpp" and t == "field_declaration" and parent_class_nid: # Skip method prototypes (field_declaration with a function_declarator # is a member-function declaration, not a data member). decls = list(node.children_by_field_name("declarator")) is_method = any( d.type == "function_declarator" - or (d.type in ("pointer_declarator", "reference_declarator") - and any(c.type == "function_declarator" for c in d.children)) + or ( + d.type in ("pointer_declarator", "reference_declarator") + and any(c.type == "function_declarator" for c in d.children) + ) for d in decls ) if not is_method: @@ -2786,8 +3195,7 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: ctx = "generic_arg" if role == "generic_arg" else "field" target_nid = ensure_named_node(ref_name, line) if target_nid != parent_class_nid: - add_edge(parent_class_nid, target_nid, "references", - line, context=ctx) + add_edge(parent_class_nid, target_nid, "references", line, context=ctx) # Emit a node for each data member. Use children_by_field_name so we # only visit declarator children, not the type node (which would give # us the type name, not the field name). Handles int x, y; via @@ -2926,8 +3334,14 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: continue type_node = None for sub in p.children: - if sub.type in ("named_type", "primitive_type", "nullable_type", - "union_type", "intersection_type", "optional_type"): + if sub.type in ( + "named_type", + "primitive_type", + "nullable_type", + "union_type", + "intersection_type", + "optional_type", + ): type_node = sub break refs: list[tuple[str, str]] = [] @@ -3002,8 +3416,11 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: add_edge(func_nid, target_nid, "references", line, context=ctx) if config.ts_module in ("tree_sitter_c", "tree_sitter_cpp"): - collect = (_cpp_collect_type_refs if config.ts_module == "tree_sitter_cpp" - else _c_collect_type_refs) + collect = ( + _cpp_collect_type_refs + if config.ts_module == "tree_sitter_cpp" + else _c_collect_type_refs + ) return_node = node.child_by_field_name("type") if return_node is not None: refs: list[tuple[str, str]] = [] @@ -3016,7 +3433,9 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: # function_declarator may be wrapped in pointer/reference declarators decl = node.child_by_field_name("declarator") while decl is not None and decl.type in ( - "pointer_declarator", "reference_declarator"): + "pointer_declarator", + "reference_declarator", + ): decl = decl.child_by_field_name("declarator") if decl is not None and decl.type == "function_declarator": params_node = decl.child_by_field_name("parameters") @@ -3033,8 +3452,7 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: ctx = "generic_arg" if role == "generic_arg" else "parameter_type" target_nid = ensure_named_node(ref_name, line) if target_nid != func_nid: - add_edge(func_nid, target_nid, "references", - line, context=ctx) + add_edge(func_nid, target_nid, "references", line, context=ctx) if config.ts_module == "tree_sitter_scala": params_node = None @@ -3055,8 +3473,7 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: ctx = "generic_arg" if role == "generic_arg" else "parameter_type" target_nid = ensure_named_node(ref_name, line) if target_nid != func_nid: - add_edge(func_nid, target_nid, "references", - line, context=ctx) + add_edge(func_nid, target_nid, "references", line, context=ctx) return_node = node.child_by_field_name("return_type") if return_node is not None: refs = [] @@ -3065,8 +3482,7 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: ctx = "generic_arg" if role == "generic_arg" else "return_type" target_nid = ensure_named_node(ref_name, line) if target_nid != func_nid: - add_edge(func_nid, target_nid, "references", - line, context=ctx) + add_edge(func_nid, target_nid, "references", line, context=ctx) body = _find_body(node, config) if body: @@ -3075,21 +3491,55 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: # JS/TS arrow functions and C# namespaces — language-specific extra handling if config.ts_module in ("tree_sitter_javascript", "tree_sitter_typescript"): - if _js_extra_walk(node, source, file_nid, stem, str_path, - nodes, edges, seen_ids, function_bodies, - parent_class_nid, add_node, add_edge): + if _js_extra_walk( + node, + source, + file_nid, + stem, + str_path, + nodes, + edges, + seen_ids, + function_bodies, + parent_class_nid, + add_node, + add_edge, + ): return if config.ts_module == "tree_sitter_c_sharp": - if _csharp_extra_walk(node, source, file_nid, stem, str_path, - nodes, edges, seen_ids, function_bodies, - parent_class_nid, add_node, add_edge, walk): + if _csharp_extra_walk( + node, + source, + file_nid, + stem, + str_path, + nodes, + edges, + seen_ids, + function_bodies, + parent_class_nid, + add_node, + add_edge, + walk, + ): return if config.ts_module == "tree_sitter_swift": - if _swift_extra_walk(node, source, file_nid, stem, str_path, - nodes, edges, seen_ids, function_bodies, - parent_class_nid, add_node, add_edge): + if _swift_extra_walk( + node, + source, + file_nid, + stem, + str_path, + nodes, + edges, + seen_ids, + function_bodies, + parent_class_nid, + add_node, + add_edge, + ): return # Python's `@property` / `@staticmethod` / `@classmethod` wrap the @@ -3113,7 +3563,7 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: walk(root) # ── Call-graph pass ─────────────────────────────────────────────────────── - label_to_nid: dict[str, str] = {} # case-sensitive (Ruby, C#, Java, Kotlin, etc.) + label_to_nid: dict[str, str] = {} # case-sensitive (Ruby, C#, Java, Kotlin, etc.) label_to_nid_ci: dict[str, str] = {} # case-insensitive (PHP functions/classes) for n in nodes: raw = n["label"] @@ -3146,8 +3596,9 @@ def walk_calls(node, caller_nid: str) -> None: if node.type in config.call_types: # JS/TS dynamic imports: await import('./foo.js') if config.ts_module in ("tree_sitter_javascript", "tree_sitter_typescript"): - if _dynamic_import_js(node, source, caller_nid, str_path, - edges, seen_dyn_import_pairs): + if _dynamic_import_js( + node, source, caller_nid, str_path, edges, seen_dyn_import_pairs + ): # Still recurse into children (import().then(...) may have calls) for child in node.children: walk_calls(child, caller_nid) @@ -3236,18 +3687,28 @@ def walk_calls(node, caller_nid: str) -> None: callee_name = _read_text(name_node, source) elif config.ts_module == "tree_sitter_cpp": # C++: function field, then field_expression/qualified_identifier - func_node = node.child_by_field_name(config.call_function_field) if config.call_function_field else None + func_node = ( + node.child_by_field_name(config.call_function_field) + if config.call_function_field + else None + ) if func_node: if func_node.type == "identifier": callee_name = _read_text(func_node, source) elif func_node.type in ("field_expression", "qualified_identifier"): is_member_call = True - name = func_node.child_by_field_name("field") or func_node.child_by_field_name("name") + name = func_node.child_by_field_name( + "field" + ) or func_node.child_by_field_name("name") if name: callee_name = _read_text(name, source) else: # Generic: get callee from call_function_field - func_node = node.child_by_field_name(config.call_function_field) if config.call_function_field else None + func_node = ( + node.child_by_field_name(config.call_function_field) + if config.call_function_field + else None + ) if func_node: if func_node.type == "identifier": callee_name = _read_text(func_node, source) @@ -3268,28 +3729,32 @@ def walk_calls(node, caller_nid: str) -> None: if pair not in seen_call_pairs: seen_call_pairs.add(pair) line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "context": "call", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": "calls", + "context": "call", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) elif callee_name and not tgt_nid: # Callee not in this file — save for cross-file resolution in extract() - raw_calls.append({ - "caller_nid": caller_nid, - "callee": callee_name, - "is_member_call": is_member_call, - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - }) + raw_calls.append( + { + "caller_nid": caller_nid, + "callee": callee_name, + "is_member_call": is_member_call, + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + } + ) # Helper function calls: config('foo.bar') → uses_config edge to "foo" - if (callee_name and callee_name in config.helper_fn_names): + if callee_name and callee_name in config.helper_fn_names: args_node = node.child_by_field_name("arguments") first_key: str | None = None if args_node: @@ -3307,29 +3772,34 @@ def walk_calls(node, caller_nid: str) -> None: break if first_key: segment = first_key.split(".")[0] - tgt_nid = (label_to_nid_ci.get(segment.lower()) - or label_to_nid_ci.get(f"{segment}.php".lower())) + tgt_nid = label_to_nid_ci.get(segment.lower()) or label_to_nid_ci.get( + f"{segment}.php".lower() + ) if tgt_nid and tgt_nid != caller_nid: relation = f"uses_{callee_name}" pair3 = (caller_nid, tgt_nid, relation) if pair3 not in seen_helper_ref_pairs: seen_helper_ref_pairs.add(pair3) line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": relation, - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": relation, + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) # Service container bindings: $this->app->bind(Foo::class, Bar::class) - if (node.type == "member_call_expression" - and callee_name - and callee_name in config.container_bind_methods): + if ( + node.type == "member_call_expression" + and callee_name + and callee_name in config.container_bind_methods + ): args_node = node.child_by_field_name("arguments") class_args: list[str] = [] if args_node: @@ -3353,16 +3823,18 @@ def walk_calls(node, caller_nid: str) -> None: if pair3 not in seen_bind_pairs: seen_bind_pairs.add(pair3) line = node.start_point[0] + 1 - edges.append({ - "source": contract_nid, - "target": impl_nid, - "relation": "bound_to", - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": contract_nid, + "target": impl_nid, + "relation": "bound_to", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) # Static property access: Foo::$bar → uses_static_prop edge if node.type in config.static_prop_types: @@ -3380,19 +3852,24 @@ def walk_calls(node, caller_nid: str) -> None: if pair3 not in seen_static_ref_pairs: seen_static_ref_pairs.add(pair3) line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "uses_static_prop", - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": "uses_static_prop", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) # PHP class constant access: Foo::BAR → references_constant edge - if config.ts_module == "tree_sitter_php" and node.type == "class_constant_access_expression": + if ( + config.ts_module == "tree_sitter_php" + and node.type == "class_constant_access_expression" + ): class_name = _php_class_const_scope(node) if class_name: tgt_nid = label_to_nid_ci.get(class_name.lower()) @@ -3401,16 +3878,18 @@ def walk_calls(node, caller_nid: str) -> None: if pair3 not in seen_static_ref_pairs: seen_static_ref_pairs.add(pair3) line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "references_constant", - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": "references_constant", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) for child in node.children: walk_calls(child, caller_nid) @@ -3429,23 +3908,27 @@ def walk_calls(node, caller_nid: str) -> None: if pair2 in seen_listen_pairs: continue seen_listen_pairs.add(pair2) - edges.append({ - "source": event_nid, - "target": listener_nid, - "relation": "listened_by", - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": event_nid, + "target": listener_nid, + "relation": "listened_by", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) # ── Clean edges ─────────────────────────────────────────────────────────── valid_ids = seen_ids clean_edges = [] for edge in edges: src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from", "re_exports")): + if src in valid_ids and ( + tgt in valid_ids or edge["relation"] in ("imports", "imports_from", "re_exports") + ): clean_edges.append(edge) result = {"nodes": nodes, "edges": clean_edges, "raw_calls": raw_calls} @@ -3456,7 +3939,15 @@ def walk_calls(node, caller_nid: str) -> None: # ── Python rationale extraction ─────────────────────────────────────────────── -_RATIONALE_PREFIXES = ("# NOTE:", "# IMPORTANT:", "# HACK:", "# WHY:", "# RATIONALE:", "# TODO:", "# FIXME:") +_RATIONALE_PREFIXES = ( + "# NOTE:", + "# IMPORTANT:", + "# HACK:", + "# WHY:", + "# RATIONALE:", + "# TODO:", + "# FIXME:", +) def _is_autogenerated_python(source: bytes) -> bool: @@ -3470,9 +3961,11 @@ def _is_autogenerated_python(source: bytes) -> bool: if any(m in head for m in ("DO NOT EDIT", "@generated", "Generated by the protocol buffer")): return True # Alembic / Flask-Migrate revision files - if (re.search(r"^revision\s*[:=]", head, re.MULTILINE) - and "def upgrade(" in head - and "down_revision" in head): + if ( + re.search(r"^revision\s*[:=]", head, re.MULTILINE) + and "def upgrade(" in head + and "down_revision" in head + ): return True # Django migrations if "class Migration(migrations.Migration)" in head and "operations" in head: @@ -3487,6 +3980,7 @@ def _extract_python_rationale(path: Path, result: dict) -> None: try: import tree_sitter_python as tspython from tree_sitter import Language, Parser + language = Language(tspython.language()) parser = Parser(language) source = path.read_bytes() @@ -3509,7 +4003,9 @@ def _get_docstring(body_node) -> tuple[str, int] | None: if child.type == "expression_statement": for sub in child.children: if sub.type in ("string", "concatenated_string"): - text = source[sub.start_byte:sub.end_byte].decode("utf-8", errors="replace") + text = source[sub.start_byte : sub.end_byte].decode( + "utf-8", errors="replace" + ) text = text.strip("\"'").strip('"""').strip("'''").strip() if len(text) > 20: return text, child.start_point[0] + 1 @@ -3521,22 +4017,26 @@ def _add_rationale(text: str, line: int, parent_nid: str) -> None: rid = _make_id(stem, "rationale", str(line)) if rid not in seen_ids: seen_ids.add(rid) - nodes.append({ - "id": rid, - "label": label, - "file_type": "rationale", + nodes.append( + { + "id": rid, + "label": label, + "file_type": "rationale", + "source_file": str_path, + "source_location": f"L{line}", + } + ) + edges.append( + { + "source": rid, + "target": parent_nid, + "relation": "rationale_for", + "confidence": "EXTRACTED", "source_file": str_path, "source_location": f"L{line}", - }) - edges.append({ - "source": rid, - "target": parent_nid, - "relation": "rationale_for", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + "weight": 1.0, + } + ) # Module-level docstring — skip for auto-generated files (Alembic, Django # migrations, protobuf stubs, etc.) whose module docstrings are revision @@ -3553,7 +4053,9 @@ def walk_docstrings(node, parent_nid: str) -> None: name_node = node.child_by_field_name("name") body = node.child_by_field_name("body") if name_node and body: - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + class_name = source[name_node.start_byte : name_node.end_byte].decode( + "utf-8", errors="replace" + ) nid = _make_id(stem, class_name) ds = _get_docstring(body) if ds: @@ -3565,8 +4067,14 @@ def walk_docstrings(node, parent_nid: str) -> None: name_node = node.child_by_field_name("name") body = node.child_by_field_name("body") if name_node and body: - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - nid = _make_id(parent_nid, func_name) if parent_nid != file_nid else _make_id(stem, func_name) + func_name = source[name_node.start_byte : name_node.end_byte].decode( + "utf-8", errors="replace" + ) + nid = ( + _make_id(parent_nid, func_name) + if parent_nid != file_nid + else _make_id(stem, func_name) + ) ds = _get_docstring(body) if ds: _add_rationale(ds[0], ds[1], nid) @@ -3586,6 +4094,7 @@ def walk_docstrings(node, parent_nid: str) -> None: # ── Public API ──────────────────────────────────────────────────────────────── + def extract_python(path: Path) -> dict: """Extract classes, functions, and imports from a .py file via tree-sitter AST.""" result = _extract_generic(path, _PYTHON_CONFIG) @@ -3615,6 +4124,7 @@ def extract_svelte(path: Path) -> dict: result = _extract_generic(path, _JS_CONFIG) try: import re as _re + src = path.read_text(encoding="utf-8", errors="replace") existing_ids = {n["id"] for n in result.get("nodes", [])} # Source file node ID must match the one _extract_generic creates: @@ -3642,7 +4152,7 @@ def extract_svelte(path: Path) -> dict: resolved_alias = None for alias_prefix, alias_base in aliases.items(): if raw == alias_prefix or raw.startswith(alias_prefix + "/"): - rest = raw[len(alias_prefix):].lstrip("/") + rest = raw[len(alias_prefix) :].lstrip("/") resolved_alias = Path(os.path.normpath(Path(alias_base) / rest)) break if resolved_alias is not None: @@ -3659,34 +4169,42 @@ def extract_svelte(path: Path) -> dict: stub_source_file = raw if node_id in existing_ids: # Edge target already a real node - just add the edge, don't add a node. - result.setdefault("edges", []).append({ - "source": file_node_id, "target": node_id, - "relation": "dynamic_import", "confidence": "EXTRACTED", - "source_file": str(path), - }) + result.setdefault("edges", []).append( + { + "source": file_node_id, + "target": node_id, + "relation": "dynamic_import", + "confidence": "EXTRACTED", + "source_file": str(path), + } + ) continue - result.setdefault("nodes", []).append({ - "id": node_id, "label": raw, - "file_type": "code", "source_file": stub_source_file, - "confidence": "EXTRACTED", - }) - result.setdefault("edges", []).append({ - "source": file_node_id, "target": node_id, - "relation": "dynamic_import", "confidence": "EXTRACTED", - "source_file": str(path), - }) + result.setdefault("nodes", []).append( + { + "id": node_id, + "label": raw, + "file_type": "code", + "source_file": stub_source_file, + "confidence": "EXTRACTED", + } + ) + result.setdefault("edges", []).append( + { + "source": file_node_id, + "target": node_id, + "relation": "dynamic_import", + "confidence": "EXTRACTED", + "source_file": str(path), + } + ) existing_ids.add(node_id) # Static imports inside ", "source_file": "src/evil.py", "file_type": "code", "community": 1}, + { + "id": "api", + "label": "ApiClient", + "source_file": "src/api.py", + "file_type": "code", + "community": 0, + }, + { + "id": "run", + "label": "run()", + "source_file": "src/main.py", + "file_type": "code", + "community": 0, + }, + { + "id": "export", + "label": "write_html()", + "source_file": "src/export.py", + "file_type": "code", + "community": 1, + }, + { + "id": "evil", + "label": "", + "source_file": "src/evil.py", + "file_type": "code", + "community": 1, + }, ], "links": [ - {"source": "run", "target": "api", "relation": "calls", "confidence": "EXTRACTED", "confidence_score": 1.0}, - {"source": "api", "target": "export", "relation": "uses", "confidence": "EXTRACTED", "confidence_score": 1.0}, - {"source": "export", "target": "evil", "relation": "calls", "confidence": "EXTRACTED", "confidence_score": 1.0}, + { + "source": "run", + "target": "api", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, + { + "source": "api", + "target": "export", + "relation": "uses", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, + { + "source": "export", + "target": "evil", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, ], "hyperedges": [], "built_at_commit": "abcdef123456", @@ -103,16 +145,36 @@ def test_export_callflow_html_cli_accepts_positional_graph_path(tmp_path): "multigraph": False, "graph": {}, "nodes": [ - {"id": "external", "label": "ExternalOnly", "source_file": "src/external.py", "file_type": "code", "community": 0}, - {"id": "writer", "label": "write_external()", "source_file": "src/writer.py", "file_type": "code", "community": 1}, + { + "id": "external", + "label": "ExternalOnly", + "source_file": "src/external.py", + "file_type": "code", + "community": 0, + }, + { + "id": "writer", + "label": "write_external()", + "source_file": "src/writer.py", + "file_type": "code", + "community": 1, + }, ], "links": [ - {"source": "external", "target": "writer", "relation": "calls", "confidence": "EXTRACTED", "confidence_score": 1.0}, + { + "source": "external", + "target": "writer", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, ], "hyperedges": [], } (external_out / "graph.json").write_text(json.dumps(graph), encoding="utf-8") - (external_out / ".graphify_labels.json").write_text(json.dumps({"0": "External Runtime", "1": "External Export"}), encoding="utf-8") + (external_out / ".graphify_labels.json").write_text( + json.dumps({"0": "External Runtime", "1": "External Export"}), encoding="utf-8" + ) (external_out / "GRAPH_REPORT.md").write_text( "\n".join( [ @@ -156,10 +218,25 @@ def test_export_callflow_html_cli_accepts_positional_graph_path(tmp_path): def test_derive_sections_groups_by_architecture_keywords(): nodes = [ - {"id": "extract_py", "label": "extract_python", "source_file": "graphify/extract.py", "community": 0}, - {"id": "extract_js", "label": "extract_js", "source_file": "graphify/extract.py", "community": 0}, + { + "id": "extract_py", + "label": "extract_python", + "source_file": "graphify/extract.py", + "community": 0, + }, + { + "id": "extract_js", + "label": "extract_js", + "source_file": "graphify/extract.py", + "community": 0, + }, {"id": "to_html", "label": "to_html", "source_file": "graphify/export.py", "community": 1}, - {"id": "test_html", "label": "test_export_html", "source_file": "tests/test_export.py", "community": 2}, + { + "id": "test_html", + "label": "test_export_html", + "source_file": "tests/test_export.py", + "community": 2, + }, ] sections = derive_sections_from_communities(nodes, {}, "en", 6) diff --git a/tests/test_charmap_encoding.py b/tests/test_charmap_encoding.py index f255dbbb6..f9d24c1a8 100644 --- a/tests/test_charmap_encoding.py +++ b/tests/test_charmap_encoding.py @@ -11,15 +11,12 @@ b) Assert that extract_corpus_parallel reports loud failure (non-zero exit or summary block) when ≥1 chunk fails. """ + from __future__ import annotations import json -import sys -from io import StringIO -from pathlib import Path -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, patch -import pytest from graphify import llm @@ -31,14 +28,15 @@ "type": "result", "subtype": "success", "is_error": False, - "result": json.dumps({ - "nodes": [{"id": "n1", "label": "N1", "file_type": "document", - "source_file": "u.md"}], - "edges": [], - "hyperedges": [], - "input_tokens": 0, - "output_tokens": 0, - }), + "result": json.dumps( + { + "nodes": [{"id": "n1", "label": "N1", "file_type": "document", "source_file": "u.md"}], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } + ), "stop_reason": "end_turn", "usage": { "input_tokens": 1, @@ -54,6 +52,7 @@ # ── Test A: subprocess encoding ─────────────────────────────────────────────── + class TestSubprocessEncoding: """_call_claude_cli must pass encoding="utf-8" to subprocess.run so that non-ASCII content in chunk messages does not raise UnicodeEncodeError on @@ -68,8 +67,10 @@ def test_subprocess_called_with_utf8_encoding(self, monkeypatch): """subprocess.run must be invoked with encoding='utf-8'.""" completed = self._make_completed() monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as mock_run: + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as mock_run, + ): llm._call_claude_cli(_UNICODE_CONTENT, max_tokens=8192) _args, kwargs = mock_run.call_args assert kwargs.get("encoding") == "utf-8", ( @@ -85,8 +86,10 @@ def test_subprocess_does_not_use_text_true_without_encoding(self, monkeypatch): """ completed = self._make_completed() monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as mock_run: + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as mock_run, + ): llm._call_claude_cli(_UNICODE_CONTENT, max_tokens=8192) _args, kwargs = mock_run.call_args # If text=True is present, encoding must also be set to 'utf-8'. @@ -111,12 +114,12 @@ def test_unicode_chars_survive_subprocess_roundtrip(self, monkeypatch, tmp_path) completed = self._make_completed() monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed): + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed), + ): # Should not raise - result = llm.extract_files_direct( - files=[f], backend="claude-cli", root=tmp_path - ) + result = llm.extract_files_direct(files=[f], backend="claude-cli", root=tmp_path) assert len(result["nodes"]) >= 1 def test_call_llm_claude_cli_subprocess_encoding(self, monkeypatch): @@ -126,8 +129,10 @@ def test_call_llm_claude_cli_subprocess_encoding(self, monkeypatch): stdout=json.dumps({"result": "ok", "stop_reason": "end_turn"}), stderr="", ) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as mock_run: + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as mock_run, + ): llm._call_llm(_UNICODE_CONTENT, backend="claude-cli", max_tokens=200) _args, kwargs = mock_run.call_args assert kwargs.get("encoding") == "utf-8", ( @@ -138,6 +143,7 @@ def test_call_llm_claude_cli_subprocess_encoding(self, monkeypatch): # ── Test B: loud failure on chunk error ──────────────────────────────────────── + class TestLoudChunkFailure: """extract_corpus_parallel must surface chunk failures loudly — either via non-zero exit (exception raised from the function) or a printed summary @@ -196,10 +202,11 @@ def test_no_false_alarm_when_all_chunks_succeed(self, monkeypatch, tmp_path, cap f.write_text("z = 1\n", encoding="utf-8") good_result = { - "nodes": [{"id": "n1", "label": "N1", "file_type": "code", - "source_file": str(f)}], - "edges": [], "hyperedges": [], - "input_tokens": 1, "output_tokens": 1, + "nodes": [{"id": "n1", "label": "N1", "file_type": "code", "source_file": str(f)}], + "edges": [], + "hyperedges": [], + "input_tokens": 1, + "output_tokens": 1, "elapsed_seconds": 0.1, } monkeypatch.setattr( @@ -217,6 +224,7 @@ def test_no_false_alarm_when_all_chunks_succeed(self, monkeypatch, tmp_path, cap # ── Substitution validation (rsl-siege-manager path via Python) ──────────────── + class TestSubstitutionValidation: """Exercises the same code path as the rsl-siege-manager reproduction without requiring the `claude` CLI or its auth. @@ -262,9 +270,7 @@ def test_cp1252_would_fail_but_utf8_succeeds(self, tmp_path): try: prompt.encode("utf-8") except UnicodeEncodeError as e: - raise AssertionError( - f"UTF-8 encode must succeed but failed: {e}" - ) from e + raise AssertionError(f"UTF-8 encode must succeed but failed: {e}") from e # cp1252 must fail (confirming these chars are the failing surface) try: @@ -279,9 +285,7 @@ def test_cp1252_would_fail_but_utf8_succeeds(self, tmp_path): except UnicodeEncodeError: pass # Expected — confirms these chars hit the pre-fix failure surface - def test_subprocess_encoding_kwarg_in_extract_files_direct( - self, monkeypatch, tmp_path - ): + def test_subprocess_encoding_kwarg_in_extract_files_direct(self, monkeypatch, tmp_path): """End-to-end path: write unicode file → extract_files_direct → subprocess. Subprocess must receive encoding='utf-8', not the locale default. @@ -290,39 +294,49 @@ def test_subprocess_encoding_kwarg_in_extract_files_direct( f.write_text(self._UNICODE_CHARS, encoding="utf-8") _ENVELOPE_SIMPLE = { - "type": "result", "subtype": "success", "is_error": False, - "result": json.dumps({ - "nodes": [{"id": "u_chunk", "label": "Unicode Chunk", - "file_type": "document", - "source_file": "unicode_chunk.md"}], - "edges": [], "hyperedges": [], - "input_tokens": 1, "output_tokens": 1, - }), + "type": "result", + "subtype": "success", + "is_error": False, + "result": json.dumps( + { + "nodes": [ + { + "id": "u_chunk", + "label": "Unicode Chunk", + "file_type": "document", + "source_file": "unicode_chunk.md", + } + ], + "edges": [], + "hyperedges": [], + "input_tokens": 1, + "output_tokens": 1, + } + ), "stop_reason": "end_turn", "usage": { - "input_tokens": 1, "output_tokens": 1, - "cache_read_input_tokens": 0, "cache_creation_input_tokens": 0, + "input_tokens": 1, + "output_tokens": 1, + "cache_read_input_tokens": 0, + "cache_creation_input_tokens": 0, }, "modelUsage": { "claude-opus-4-7": {"inputTokens": 1, "outputTokens": 1}, }, } - completed = MagicMock( - returncode=0, stdout=json.dumps(_ENVELOPE_SIMPLE), stderr="" - ) + completed = MagicMock(returncode=0, stdout=json.dumps(_ENVELOPE_SIMPLE), stderr="") monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as mock_run: - result = llm.extract_files_direct( - files=[f], backend="claude-cli", root=tmp_path - ) + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as mock_run, + ): + result = llm.extract_files_direct(files=[f], backend="claude-cli", root=tmp_path) assert mock_run.called _args, kwargs = mock_run.call_args assert kwargs.get("encoding") == "utf-8", ( - "subprocess.run must be called with encoding='utf-8'; " - f"got {kwargs.get('encoding')!r}" + f"subprocess.run must be called with encoding='utf-8'; got {kwargs.get('encoding')!r}" ) # Confirm the unicode content was in the input (not truncated/replaced) inp = kwargs.get("input", "") diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 087464ab8..f037349cf 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -1,6 +1,6 @@ """Tests for token-aware chunking and parallel chunk execution in graphify.llm.""" + import time -from pathlib import Path from unittest.mock import patch import pytest @@ -13,12 +13,14 @@ def no_tokenizer(): compresses repeated/synthetic content heavily, which would make pack-size assertions tied to specific input sizes flaky.""" from graphify import llm + with patch.object(llm, "_TOKENIZER", None): yield # ---- Token-aware packing ----------------------------------------------------- + def test_pack_chunks_packs_small_files_together(tmp_path): """Many small files should land in a single chunk, not one chunk per file.""" from graphify.llm import _pack_chunks_by_tokens @@ -64,10 +66,14 @@ def test_pack_chunks_groups_by_directory(tmp_path): dir_a.mkdir() dir_b.mkdir() - a1 = dir_a / "x.py"; a1.write_text("a") - a2 = dir_a / "y.py"; a2.write_text("a") - b1 = dir_b / "x.py"; b1.write_text("b") - b2 = dir_b / "y.py"; b2.write_text("b") + a1 = dir_a / "x.py" + a1.write_text("a") + a2 = dir_a / "y.py" + a2.write_text("a") + b1 = dir_b / "x.py" + b1.write_text("b") + b2 = dir_b / "y.py" + b2.write_text("b") # Big budget — everything fits in one chunk in principle, but the order # within the chunk should keep dir_a's files contiguous and dir_b's @@ -87,8 +93,10 @@ def test_pack_chunks_oversized_file_gets_its_own_chunk(tmp_path, no_tokenizer): """A file larger than the budget can't be split — it goes alone in a chunk.""" from graphify.llm import _pack_chunks_by_tokens - big = tmp_path / "big.py"; big.write_text("x" * 200_000) # ~50k tokens (cap-bound) - small = tmp_path / "small.py"; small.write_text("x") + big = tmp_path / "big.py" + big.write_text("x" * 200_000) # ~50k tokens (cap-bound) + small = tmp_path / "small.py" + small.write_text("x") chunks = _pack_chunks_by_tokens([big, small], token_budget=1_000) sizes = [len(c) for c in chunks] @@ -100,13 +108,15 @@ def test_pack_chunks_oversized_file_gets_its_own_chunk(tmp_path, no_tokenizer): def test_pack_chunks_rejects_non_positive_budget(tmp_path): from graphify.llm import _pack_chunks_by_tokens - f = tmp_path / "x.py"; f.write_text("a") + f = tmp_path / "x.py" + f.write_text("a") with pytest.raises(ValueError): _pack_chunks_by_tokens([f], token_budget=0) # ---- Tokenizer fallback ------------------------------------------------------ + def test_estimate_file_tokens_uses_tiktoken_when_available(tmp_path): """When tiktoken is installed, the estimator should call into it for accurate counts rather than the chars/4 heuristic.""" @@ -139,6 +149,7 @@ def test_estimate_file_tokens_falls_back_to_chars_when_no_tokenizer(tmp_path): # ---- Parallel execution ------------------------------------------------------ + def _stub_chunk_result(file_count: int, idx: int) -> dict: """Build a deterministic fake extraction result for a chunk.""" return { @@ -157,7 +168,8 @@ def test_corpus_parallel_runs_chunks_concurrently(tmp_path): files = [] for i in range(8): - f = tmp_path / f"f{i}.py"; f.write_text("x") + f = tmp_path / f"f{i}.py" + f.write_text("x") files.append(f) def slow_extract(chunk, **kwargs): @@ -183,7 +195,8 @@ def test_corpus_parallel_sequential_when_max_concurrency_is_one(tmp_path): files = [] for i in range(3): - f = tmp_path / f"f{i}.py"; f.write_text("x") + f = tmp_path / f"f{i}.py" + f.write_text("x") files.append(f) call_order = [] @@ -208,7 +221,8 @@ def test_corpus_parallel_continues_after_chunk_failure(tmp_path, capsys): files = [] for i in range(4): - f = tmp_path / f"f{i}.py"; f.write_text("x") + f = tmp_path / f"f{i}.py" + f.write_text("x") files.append(f) call_count = {"n": 0} @@ -236,7 +250,8 @@ def test_corpus_parallel_legacy_mode_when_token_budget_is_none(tmp_path): files = [] for i in range(45): - f = tmp_path / f"f{i}.py"; f.write_text("x") + f = tmp_path / f"f{i}.py" + f.write_text("x") files.append(f) chunks_seen = [] @@ -260,7 +275,8 @@ def test_corpus_parallel_token_budget_default_packs_files(tmp_path): files = [] for i in range(50): - f = tmp_path / f"f{i}.py"; f.write_text("x = 1\n") + f = tmp_path / f"f{i}.py" + f.write_text("x = 1\n") files.append(f) chunks_seen = [] @@ -279,6 +295,7 @@ def record(chunk, **kwargs): # ---- Adaptive retry on truncation ------------------------------------------- + def _stub_with_finish(file_count: int, finish_reason: str = "stop") -> dict: """Build a stub extraction result with a controllable finish_reason.""" return { @@ -398,7 +415,8 @@ def test_adaptive_retry_single_file_truncation_does_not_recurse(tmp_path, capsys warning and return what we got. No infinite loop.""" from graphify.llm import _extract_with_adaptive_retry - f = tmp_path / "huge.py"; f.write_text("x") + f = tmp_path / "huge.py" + f.write_text("x") calls = [] diff --git a/tests/test_claude_cli_backend.py b/tests/test_claude_cli_backend.py index eeb6fd27b..e5c38fab9 100644 --- a/tests/test_claude_cli_backend.py +++ b/tests/test_claude_cli_backend.py @@ -3,6 +3,7 @@ Mocks subprocess.run + shutil.which so the suite runs on CI without the `claude` binary or a live network call. """ + from __future__ import annotations import json @@ -16,19 +17,31 @@ "type": "result", "subtype": "success", "is_error": False, - "result": json.dumps({ - "nodes": [ - {"id": "foo_module", "label": "Foo", "file_type": "document", "source_file": "foo.md"}, - {"id": "foo_greet", "label": "greet", "file_type": "code", "source_file": "foo.md"}, - ], - "edges": [ - {"source": "foo_module", "target": "foo_greet", - "relation": "references", "confidence": "EXTRACTED", "confidence_score": 1.0}, - ], - "hyperedges": [], - "input_tokens": 0, - "output_tokens": 0, - }), + "result": json.dumps( + { + "nodes": [ + { + "id": "foo_module", + "label": "Foo", + "file_type": "document", + "source_file": "foo.md", + }, + {"id": "foo_greet", "label": "greet", "file_type": "code", "source_file": "foo.md"}, + ], + "edges": [ + { + "source": "foo_module", + "target": "foo_greet", + "relation": "references", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, + ], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } + ), "stop_reason": "end_turn", "usage": { "input_tokens": 6, @@ -44,8 +57,10 @@ def fake_claude(monkeypatch): completed = MagicMock(returncode=0, stdout=json.dumps(_ENVELOPE), stderr="") monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as run: + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as run, + ): yield run @@ -67,8 +82,10 @@ def test_finish_reason_length_on_max_tokens(monkeypatch): envelope = dict(_ENVELOPE, stop_reason="max_tokens") completed = MagicMock(returncode=0, stdout=json.dumps(envelope), stderr="") monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed): + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed), + ): result = llm._call_claude_cli("dummy", max_tokens=8192) assert result["finish_reason"] == "length" @@ -81,16 +98,20 @@ def test_raises_when_cli_missing(): def test_raises_on_nonzero_exit(): completed = MagicMock(returncode=2, stdout="", stderr="auth failed") - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed): + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed), + ): with pytest.raises(RuntimeError, match="exited 2"): llm._call_claude_cli("dummy", max_tokens=8192) def test_raises_on_garbage_envelope(): completed = MagicMock(returncode=0, stdout="not json", stderr="") - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed): + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed), + ): with pytest.raises(RuntimeError, match="unparseable JSON envelope"): llm._call_claude_cli("dummy", max_tokens=8192) @@ -137,9 +158,11 @@ def fake_which(name): "claude.cmd": r"C:\Users\u\AppData\Roaming\npm\claude.cmd", }.get(name) - with patch("platform.system", return_value="Windows"), \ - patch("shutil.which", side_effect=fake_which), \ - patch("subprocess.run", return_value=completed) as run: + with ( + patch("platform.system", return_value="Windows"), + patch("shutil.which", side_effect=fake_which), + patch("subprocess.run", return_value=completed) as run, + ): llm._call_claude_cli("dummy", max_tokens=8192) argv = run.call_args.args[0] @@ -162,9 +185,11 @@ def fake_which(name): return "/usr/local/bin/claude" return None - with patch("platform.system", return_value="Windows"), \ - patch("shutil.which", side_effect=fake_which), \ - patch("subprocess.run", return_value=completed) as run: + with ( + patch("platform.system", return_value="Windows"), + patch("shutil.which", side_effect=fake_which), + patch("subprocess.run", return_value=completed) as run, + ): llm._call_claude_cli("dummy", max_tokens=8192) argv = run.call_args.args[0] @@ -174,8 +199,7 @@ def fake_which(name): def test_windows_raises_when_neither_cmd_nor_bare_claude_present(): """If neither `claude.cmd` nor `claude` are on PATH on Windows, raise the standard not-found error.""" - with patch("platform.system", return_value="Windows"), \ - patch("shutil.which", return_value=None): + with patch("platform.system", return_value="Windows"), patch("shutil.which", return_value=None): with pytest.raises(RuntimeError, match="Claude Code CLI not found"): llm._call_claude_cli("dummy", max_tokens=8192) @@ -186,9 +210,11 @@ def test_non_windows_uses_bare_claude(monkeypatch): completed = MagicMock(returncode=0, stdout=json.dumps(_ENVELOPE), stderr="") monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("platform.system", return_value="Linux"), \ - patch("shutil.which", return_value="/usr/local/bin/claude"), \ - patch("subprocess.run", return_value=completed) as run: + with ( + patch("platform.system", return_value="Linux"), + patch("shutil.which", return_value="/usr/local/bin/claude"), + patch("subprocess.run", return_value=completed) as run, + ): llm._call_claude_cli("dummy", max_tokens=8192) argv = run.call_args.args[0] diff --git a/tests/test_claude_md.py b/tests/test_claude_md.py index f81f10dd3..3198b797e 100644 --- a/tests/test_claude_md.py +++ b/tests/test_claude_md.py @@ -1,13 +1,13 @@ """Tests for graphify claude install / uninstall commands.""" -from pathlib import Path -import pytest -from graphify.__main__ import claude_install, claude_uninstall, _CLAUDE_MD_MARKER, _CLAUDE_MD_SECTION + +from graphify.__main__ import claude_install, claude_uninstall, _CLAUDE_MD_MARKER # --------------------------------------------------------------------------- # install # --------------------------------------------------------------------------- + def test_install_creates_claude_md(tmp_path): """Creates CLAUDE.md when none exists.""" claude_install(tmp_path) @@ -58,6 +58,7 @@ def test_install_idempotent_message(tmp_path, capsys): # uninstall # --------------------------------------------------------------------------- + def test_uninstall_removes_section(tmp_path): """Removes the graphify section after it was installed.""" claude_install(tmp_path) @@ -101,9 +102,11 @@ def test_uninstall_no_op_when_no_file(tmp_path, capsys): # settings.json PreToolUse hook # --------------------------------------------------------------------------- + def test_install_creates_settings_json(tmp_path): """claude_install also writes .claude/settings.json with PreToolUse hook.""" import json + claude_install(tmp_path) settings_path = tmp_path / ".claude" / "settings.json" assert settings_path.exists() @@ -115,6 +118,7 @@ def test_install_creates_settings_json(tmp_path): def test_install_settings_json_idempotent(tmp_path): """Running claude_install twice does not duplicate the PreToolUse hook.""" import json + claude_install(tmp_path) claude_install(tmp_path) settings_path = tmp_path / ".claude" / "settings.json" @@ -127,6 +131,7 @@ def test_install_settings_json_idempotent(tmp_path): def test_uninstall_removes_settings_hook(tmp_path): """claude_uninstall removes the PreToolUse hook from settings.json.""" import json + claude_install(tmp_path) claude_uninstall(tmp_path) settings_path = tmp_path / ".claude" / "settings.json" diff --git a/tests/test_cli_export.py b/tests/test_cli_export.py index 942dbcf26..8909b8ef0 100644 --- a/tests/test_cli_export.py +++ b/tests/test_cli_export.py @@ -3,6 +3,7 @@ Each test builds a minimal graph in a temp dir, runs the CLI command as a subprocess, and asserts the expected output file exists and is non-empty / valid. """ + from __future__ import annotations import json import os @@ -10,13 +11,14 @@ import sys from pathlib import Path -import pytest PYTHON = sys.executable FIXTURES = Path(__file__).parent / "fixtures" -def _run(args: list[str], cwd: Path, env: dict[str, str] | None = None) -> subprocess.CompletedProcess: +def _run( + args: list[str], cwd: Path, env: dict[str, str] | None = None +) -> subprocess.CompletedProcess: return subprocess.run( [PYTHON, "-m", "graphify"] + args, cwd=cwd, @@ -53,14 +55,13 @@ def _make_graph(tmp_path: Path) -> Path: "surprises": surprises, } (out / ".graphify_analysis.json").write_text(json.dumps(analysis)) - (out / ".graphify_labels.json").write_text( - json.dumps({str(k): v for k, v in labels.items()}) - ) + (out / ".graphify_labels.json").write_text(json.dumps({str(k): v for k, v in labels.items()})) return out # ── graphify export html ───────────────────────────────────────────────────── + def test_export_html_creates_file(tmp_path): _make_graph(tmp_path) r = _run(["export", "html"], tmp_path) @@ -83,8 +84,24 @@ def test_export_html_error_without_graph(tmp_path): assert r.returncode != 0 +def test_update_accepts_no_viz_and_removes_stale_html(tmp_path): + (tmp_path / "app.py").write_text("def alpha():\n return 1\n", encoding="utf-8") + out = tmp_path / "graphify-out" + out.mkdir() + stale_html = out / "graph.html" + stale_html.write_text("", encoding="utf-8") + + env = os.environ | {"GRAPHIFY_NO_TIPS": "1"} + r = _run(["update", ".", "--force", "--no-viz"], tmp_path, env=env) + + assert r.returncode == 0, r.stderr + assert not stale_html.exists() + assert "Skipped graph.html" not in r.stdout + + # ── graphify export obsidian ───────────────────────────────────────────────── + def test_export_obsidian_creates_vault(tmp_path): _make_graph(tmp_path) r = _run(["export", "obsidian"], tmp_path) @@ -106,6 +123,7 @@ def test_export_obsidian_custom_dir(tmp_path): # ── graphify export wiki ───────────────────────────────────────────────────── + def test_export_wiki_creates_articles(tmp_path): _make_graph(tmp_path) r = _run(["export", "wiki"], tmp_path) @@ -130,6 +148,7 @@ def test_export_wiki_accepts_edges_only_graph_json(tmp_path): # ── graphify export graphml ────────────────────────────────────────────────── + def test_export_graphml_creates_file(tmp_path): _make_graph(tmp_path) r = _run(["export", "graphml"], tmp_path) @@ -143,6 +162,7 @@ def test_export_graphml_creates_file(tmp_path): # ── graphify export neo4j (cypher) ─────────────────────────────────────────── + def test_export_neo4j_creates_cypher(tmp_path): _make_graph(tmp_path) r = _run(["export", "neo4j"], tmp_path) @@ -156,6 +176,7 @@ def test_export_neo4j_creates_cypher(tmp_path): # ── graphify query ─────────────────────────────────────────────────────────── + def test_query_returns_output(tmp_path): _make_graph(tmp_path) r = _run(["query", "test"], tmp_path) @@ -195,6 +216,7 @@ def test_query_uses_graphify_out_env(tmp_path): # ── graphify path ──────────────────────────────────────────────────────────── + def test_path_runs_without_error(tmp_path): _make_graph(tmp_path) r = _run(["path", "Transformer", "LayerNorm"], tmp_path) @@ -221,6 +243,7 @@ def test_path_uses_graphify_out_env(tmp_path): # ── graphify explain ───────────────────────────────────────────────────────── + def test_explain_runs_without_error(tmp_path): _make_graph(tmp_path) r = _run(["explain", "test"], tmp_path) @@ -246,6 +269,7 @@ def test_explain_uses_graphify_out_env(tmp_path): # ── graphify export unknown format ─────────────────────────────────────────── + def test_export_unknown_format_fails(tmp_path): r = _run(["export", "pdf"], tmp_path) assert r.returncode != 0 @@ -267,6 +291,7 @@ def test_update_no_cluster_writes_raw_graph(tmp_path): # Regression test for #934 - cluster-only crashes when graphify-out/ doesn't exist + def test_cluster_only_creates_output_dir_when_missing(tmp_path): """cluster-only must not crash with FileNotFoundError when graphify-out/ is absent (#934).""" # Build graph.json somewhere other than the default graphify-out/ location @@ -278,6 +303,7 @@ def test_cluster_only_creates_output_dir_when_missing(tmp_path): graph_json = out_dir / "graph.json" # Simulate user archiving the output dir before re-clustering import shutil + shutil.copy(graph_json, graph_src) shutil.rmtree(out_dir) @@ -290,6 +316,7 @@ def test_cluster_only_creates_output_dir_when_missing(tmp_path): # Regression test for #1027 - cluster-only must remap labels via node overlap + def test_cluster_only_remaps_labels_to_previous_cids(tmp_path): """cluster-only must invoke remap_communities_to_previous so the existing .graphify_labels.json keeps tracking the same conceptual communities after @@ -350,6 +377,7 @@ def test_cluster_only_remaps_labels_to_previous_cids(tmp_path): # silently bails or generates a degraded artifact whenever the sidecar is # missing, even though the data is right there. + def test_export_html_falls_back_to_node_community_attribute(tmp_path): """When .graphify_analysis.json is absent, export html should reconstruct communities from the per-node attribute in graph.json rather than bailing @@ -385,8 +413,7 @@ def test_export_html_fallback_recovers_multiple_communities(tmp_path): # And the count we'd reconstruct from graph.json's node attributes graph = json.loads((out / "graph.json").read_text(encoding="utf-8")) reconstructed_cids = { - n["community"] for n in graph.get("nodes", []) - if n.get("community") is not None + n["community"] for n in graph.get("nodes", []) if n.get("community") is not None } assert len(reconstructed_cids) == expected_count, ( f"reconstruction would lose communities: sidecar={expected_count} vs " diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 21fd2ca3a..305eb9345 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,5 +1,4 @@ import json -import sys import networkx as nx from pathlib import Path from graphify.build import build_from_json @@ -7,38 +6,45 @@ FIXTURES = Path(__file__).parent / "fixtures" + def make_graph(): return build_from_json(json.loads((FIXTURES / "extraction.json").read_text())) + def test_cluster_returns_dict(): G = make_graph() communities = cluster(G) assert isinstance(communities, dict) + def test_cluster_covers_all_nodes(): G = make_graph() communities = cluster(G) all_nodes = {n for nodes in communities.values() for n in nodes} assert all_nodes == set(G.nodes) + def test_cohesion_score_complete_graph(): G = nx.complete_graph(4) G = nx.relabel_nodes(G, {i: str(i) for i in G.nodes}) score = cohesion_score(G, list(G.nodes)) assert score == 1.0 + def test_cohesion_score_single_node(): G = nx.Graph() G.add_node("a") score = cohesion_score(G, ["a"]) assert score == 1.0 + def test_cohesion_score_disconnected(): G = nx.Graph() G.add_nodes_from(["a", "b", "c"]) score = cohesion_score(G, ["a", "b", "c"]) assert score == 0.0 + def test_cohesion_score_range(): G = make_graph() communities = cluster(G) @@ -46,6 +52,7 @@ def test_cohesion_score_range(): score = cohesion_score(G, nodes) assert 0.0 <= score <= 1.0 + def test_score_all_keys_match_communities(): G = make_graph() communities = cluster(G) @@ -98,3 +105,84 @@ def test_remap_communities_to_previous_assigns_deterministic_new_ids(): assert list(remapped.keys()) == [0, 1] assert remapped[0] == ["x", "y", "z"] assert remapped[1] == ["m"] + + +# --- MultiDiGraph safety tests (PR 4B) --- + + +def _make_multigraph_triangle(): + """MultiDiGraph with nodes {a, b, c}: 5 parallel edges a->b, 3 parallel edges b->c.""" + G = nx.MultiDiGraph() + G.add_nodes_from(["a", "b", "c"]) + for i in range(5): + G.add_edge("a", "b", key=f"ab-{i}", relation=f"rel-{i}") + for i in range(3): + G.add_edge("b", "c", key=f"bc-{i}", relation=f"rel-{i}") + return G + + +def test_cohesion_multigraph_stays_bounded(): + """Cohesion must be <= 1.0 even when parallel edges outnumber unique pairs.""" + G = _make_multigraph_triangle() + # 3 nodes, 8 total edge records, but only 2 unique pairs -> must not exceed 1.0 + score = cohesion_score(G, ["a", "b", "c"]) + assert score <= 1.0, f"cohesion {score} exceeds 1.0 on multigraph" + assert score >= 0.0 + + +def test_cohesion_multigraph_equals_simple_graph_cohesion(): + """Cohesion on a multigraph should equal cohesion on the equivalent simple graph.""" + # Build a MultiDiGraph: a-b, b-c, a-c each with 3 parallel edges + MG = nx.MultiDiGraph() + MG.add_nodes_from(["a", "b", "c"]) + for pair in [("a", "b"), ("b", "c"), ("a", "c")]: + for i in range(3): + MG.add_edge(pair[0], pair[1], key=f"{pair[0]}{pair[1]}-{i}") + + # Build equivalent simple graph: a-b, b-c, a-c (1 edge each) + SG = nx.Graph() + SG.add_nodes_from(["a", "b", "c"]) + SG.add_edge("a", "b") + SG.add_edge("b", "c") + SG.add_edge("a", "c") + + multi_score = cohesion_score(MG, ["a", "b", "c"]) + simple_score = cohesion_score(SG, ["a", "b", "c"]) + assert multi_score == simple_score, f"multi={multi_score} != simple={simple_score}" + + +def test_cluster_multigraph_produces_valid_communities(): + """cluster() on a MultiDiGraph with clear community structure should detect communities.""" + G = nx.MultiDiGraph() + # Two triangles connected by a weak bridge, with parallel edges and + # confidence data so projected weights are non-zero (avoids graspologic + # zero-weight panic in some versions). + for pair in [("a", "b"), ("b", "c"), ("a", "c")]: + for k in range(3): + G.add_edge(pair[0], pair[1], key=f"{pair[0]}{pair[1]}-{k}", confidence="EXTRACTED") + for pair in [("d", "e"), ("e", "f"), ("d", "f")]: + for k in range(3): + G.add_edge(pair[0], pair[1], key=f"{pair[0]}{pair[1]}-{k}", confidence="EXTRACTED") + G.add_edge("c", "d", key="bridge", confidence="AMBIGUOUS") + + communities = cluster(G) + assert isinstance(communities, dict) + assert len(communities) > 0 + all_nodes = {n for nodes in communities.values() for n in nodes} + assert all_nodes == set(G.nodes), "Not all nodes assigned to communities" + + +def test_cluster_multigraph_does_not_crash(): + """Smoke test: cluster() on a MultiDiGraph with parallel edges must not raise.""" + G = nx.MultiDiGraph() + nodes = ["a", "b", "c", "d", "e"] + G.add_nodes_from(nodes) + for i in range(len(nodes)): + for j in range(i + 1, min(i + 3, len(nodes))): + for k in range(4): + G.add_edge( + nodes[i], nodes[j], key=f"{nodes[i]}-{nodes[j]}-{k}", confidence="EXTRACTED" + ) + # Must not raise + communities = cluster(G) + assert isinstance(communities, dict) diff --git a/tests/test_confidence.py b/tests/test_confidence.py index 299548aca..4e6af31f1 100644 --- a/tests/test_confidence.py +++ b/tests/test_confidence.py @@ -1,9 +1,9 @@ """Tests for confidence_score on edges.""" + import json import tempfile from pathlib import Path -import networkx as nx from graphify.build import build_from_json from graphify.cluster import cluster, score_all @@ -24,12 +24,33 @@ def _make_extraction(**edge_overrides): {"id": "n_d", "label": "D", "file_type": "document", "source_file": "d.md"}, ], "edges": [ - {"source": "n_a", "target": "n_b", "relation": "calls", "confidence": "EXTRACTED", - "confidence_score": 1.0, "source_file": "a.py", "weight": 1.0}, - {"source": "n_b", "target": "n_c", "relation": "implements", "confidence": "INFERRED", - "confidence_score": 0.75, "source_file": "b.py", "weight": 0.8}, - {"source": "n_c", "target": "n_d", "relation": "references", "confidence": "AMBIGUOUS", - "confidence_score": 0.2, "source_file": "c.md", "weight": 0.5}, + { + "source": "n_a", + "target": "n_b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "a.py", + "weight": 1.0, + }, + { + "source": "n_b", + "target": "n_c", + "relation": "implements", + "confidence": "INFERRED", + "confidence_score": 0.75, + "source_file": "b.py", + "weight": 0.8, + }, + { + "source": "n_c", + "target": "n_d", + "relation": "references", + "confidence": "AMBIGUOUS", + "confidence_score": 0.2, + "source_file": "c.md", + "weight": 0.5, + }, ], "input_tokens": 100, "output_tokens": 50, @@ -108,10 +129,22 @@ def test_to_json_defaults_missing_confidence_score(): ], "edges": [ # No confidence_score field on any of these - {"source": "n_x", "target": "n_y", "relation": "calls", - "confidence": "EXTRACTED", "source_file": "x.py", "weight": 1.0}, - {"source": "n_y", "target": "n_z", "relation": "depends_on", - "confidence": "INFERRED", "source_file": "y.py", "weight": 1.0}, + { + "source": "n_x", + "target": "n_y", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "x.py", + "weight": 1.0, + }, + { + "source": "n_y", + "target": "n_z", + "relation": "depends_on", + "confidence": "INFERRED", + "source_file": "y.py", + "weight": 1.0, + }, ], "input_tokens": 0, "output_tokens": 0, @@ -148,7 +181,7 @@ def test_report_shows_avg_confidence_for_inferred(): report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, ".") assert "avg confidence" in report, "Report should show avg confidence for INFERRED edges" # The fixture has one INFERRED edge with score 0.75, so avg should be 0.75 - assert "0.75" in report, f"Expected avg confidence 0.75 in report" + assert "0.75" in report, "Expected avg confidence 0.75 in report" def test_report_inferred_tag_with_score(): @@ -160,9 +193,15 @@ def test_report_inferred_tag_with_score(): {"id": "n_q", "label": "Renderer", "file_type": "code", "source_file": "renderer.py"}, ], "edges": [ - {"source": "n_p", "target": "n_q", "relation": "feeds", - "confidence": "INFERRED", "confidence_score": 0.82, - "source_file": "parser.py", "weight": 1.0}, + { + "source": "n_p", + "target": "n_q", + "relation": "feeds", + "confidence": "INFERRED", + "confidence_score": 0.82, + "source_file": "parser.py", + "weight": 1.0, + }, ], "input_tokens": 0, "output_tokens": 0, diff --git a/tests/test_dedup.py b/tests/test_dedup.py index 293d2a8fe..b5f90b307 100644 --- a/tests/test_dedup.py +++ b/tests/test_dedup.py @@ -1,29 +1,34 @@ """Tests for graphify/dedup.py entity deduplication pipeline.""" + from __future__ import annotations -import pytest from graphify.dedup import deduplicate_entities, _entropy, _shingles # ── entropy gate ───────────────────────────────────────────────────────────── + def test_entropy_short_label_low(): assert _entropy("AI") < 2.5 + def test_entropy_normal_label_high(): assert _entropy("AuthenticationManager") >= 2.5 + def test_entropy_empty_string(): assert _entropy("") == 0.0 # ── shingles ───────────────────────────────────────────────────────────────── + def test_shingles_produces_trigrams(): s = _shingles("hello") assert "hel" in s assert "ell" in s assert "llo" in s + def test_shingles_short_string(): # strings shorter than 3 chars return single shingle of the string itself assert _shingles("ab") == {"ab"} @@ -31,8 +36,13 @@ def test_shingles_short_string(): # ── full pipeline ───────────────────────────────────────────────────────────── + def _make_nodes(*labels): - return [{"id": label.lower().replace(" ", "_"), "label": label, "source_file": "test.md"} for label in labels] + return [ + {"id": label.lower().replace(" ", "_"), "label": label, "source_file": "test.md"} + for label in labels + ] + def _make_edges(src, tgt, relation="relates_to"): return [{"source": src, "target": tgt, "relation": relation}] @@ -122,9 +132,11 @@ def test_dedup_llm_flag_accepted(): # ── build integration ───────────────────────────────────────────────────────── + def test_build_calls_dedup(): """build() should deduplicate near-identical nodes across extractions.""" from graphify.build import build + chunk1 = { "nodes": [{"id": "graphextractor", "label": "GraphExtractor", "source_file": "a.py"}], "edges": [], @@ -139,6 +151,7 @@ def test_build_calls_dedup(): # --- #878: fuzzy dedup false merges on short/variant labels --- + def test_dedup_does_not_merge_numeric_variants(tmp_path): """Chip SKU variants (ASR1603 vs ASR1605) must not be merged (#878).""" nodes = _make_nodes("ASR1603", "ASR1605") @@ -164,6 +177,7 @@ def test_dedup_still_merges_real_typos(): """Genuine same-length single-char typos should still merge (#878 non-regression).""" from graphify.dedup import _is_variant_pair, _short_label_blocked from rapidfuzz.distance import JaroWinkler + a, b = "graphextractor", "graphextractar" score = JaroWinkler.normalized_similarity(a, b) * 100 assert not _is_variant_pair(a, b), "not a variant pair" @@ -173,6 +187,7 @@ def test_dedup_still_merges_real_typos(): def test_variant_pair_helper(): """_is_variant_pair correctly identifies chip-model variant pairs (#878).""" from graphify.dedup import _is_variant_pair + assert _is_variant_pair("asr1603", "asr1605") assert _is_variant_pair("cortex a55", "cortex a55x") assert not _is_variant_pair("graphextractor", "graphextracter") diff --git a/tests/test_dedup_remap.py b/tests/test_dedup_remap.py new file mode 100644 index 000000000..467f6c0a8 --- /dev/null +++ b/tests/test_dedup_remap.py @@ -0,0 +1,542 @@ +"""Tests for PR 4A: dedup remap contract — parallel edge preservation, +self-loop counting, exact duplicate collapse, build integration, and the +remove_all_parallel_edges helper. + +Groups A–C and E will fail until the production dedup/build changes land +(diagnostics parameter, remap counters). Group D tests the helper +implemented in edge_identity.py and should pass immediately. +""" + +from __future__ import annotations + +import networkx as nx + +from graphify.dedup import deduplicate_entities +from graphify.build import build +from graphify.edge_identity import remove_all_parallel_edges + + +# ── helpers (mirrors test_dedup.py patterns) ───────────────────────────────── + + +def _make_nodes(*labels, source_file="test.md"): + return [ + {"id": label.lower().replace(" ", "_"), "label": label, "source_file": source_file} + for label in labels + ] + + +def _make_edge(src, tgt, relation="relates_to", source_file="test.py", **extra): + edge = {"source": src, "target": tgt, "relation": relation, "source_file": source_file} + edge.update(extra) + return edge + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Group A: Post-remap parallel edge preservation +# ═══════════════════════════════════════════════════════════════════════════════ + + +def test_remap_preserves_parallel_edges_different_relation(): + """Two edges A->C (calls) and A->C (imports) survive when B is merged into C. + + Setup: nodes A, B, C where B and C are exact duplicates (same normalized label). + Edges: A->B (calls), A->C (imports). After dedup, B merges into C (winner). + Expected: two edges A->C with different relations survive. + """ + # B and C share the normalized label "dataloader" so they are exact duplicates. + nodes = [ + {"id": "a", "label": "Caller", "source_file": "a.py"}, + {"id": "b", "label": "DataLoader", "source_file": "a.py"}, + {"id": "c", "label": "dataloader", "source_file": "a.py"}, + ] + edges = [ + _make_edge("a", "b", relation="calls", source_file="a.py"), + _make_edge("a", "c", relation="imports", source_file="a.py"), + ] + diagnostics: dict = {} + result_nodes, result_edges = deduplicate_entities( + nodes, + edges, + communities={}, + diagnostics=diagnostics, + ) + # B merged into winner; both edges now point to the winner + assert len(result_nodes) == 2 + # Two distinct relations -> both edges survive + relations = {e["relation"] for e in result_edges} + assert "calls" in relations + assert "imports" in relations + assert len(result_edges) == 2 + + +def test_remap_preserves_parallel_edges_incoming_and_outgoing(): + """Edges B->X and Y->B survive as C->X and Y->C when B merges into C.""" + nodes = [ + {"id": "x", "label": "NodeX", "source_file": "a.py"}, + {"id": "y", "label": "NodeY", "source_file": "a.py"}, + {"id": "b", "label": "DataLoader", "source_file": "a.py"}, + {"id": "c", "label": "dataloader", "source_file": "a.py"}, + ] + edges = [ + _make_edge("b", "x", relation="calls", source_file="a.py"), + _make_edge("y", "b", relation="imports", source_file="a.py"), + ] + diagnostics: dict = {} + result_nodes, result_edges = deduplicate_entities( + nodes, + edges, + communities={}, + diagnostics=diagnostics, + ) + assert len(result_nodes) == 3 # x, y, winner(b/c) + # Both edges survive: one outgoing, one incoming + assert len(result_edges) == 2 + # The loser ID should be remapped to the winner + winner_id = next(n["id"] for n in result_nodes if n["label"] in ("DataLoader", "dataloader")) + sources = {e["source"] for e in result_edges} + targets = {e["target"] for e in result_edges} + assert winner_id in sources or winner_id in targets + + +def test_remap_preserves_edges_with_different_source_location(): + """Two edges A->B with same relation but different source_location survive remap.""" + nodes = [ + {"id": "a", "label": "Caller", "source_file": "a.py"}, + {"id": "b", "label": "DataLoader", "source_file": "a.py"}, + {"id": "c", "label": "dataloader", "source_file": "a.py"}, + ] + edges = [ + _make_edge("a", "b", relation="calls", source_file="a.py", source_location="L10"), + _make_edge("a", "c", relation="calls", source_file="a.py", source_location="L20"), + ] + diagnostics: dict = {} + result_nodes, result_edges = deduplicate_entities( + nodes, + edges, + communities={}, + diagnostics=diagnostics, + ) + assert len(result_nodes) == 2 + # Same relation but different source_location -> both survive (not exact duplicates) + assert len(result_edges) == 2 + locations = {e.get("source_location") for e in result_edges} + assert locations == {"L10", "L20"} + + +def test_remap_preserves_key_field_through_dict_copy(): + """If edge dicts carry a pre-existing 'key' field, remap preserves it verbatim.""" + nodes = [ + {"id": "a", "label": "Caller", "source_file": "a.py"}, + {"id": "b", "label": "DataLoader", "source_file": "a.py"}, + {"id": "c", "label": "dataloader", "source_file": "a.py"}, + ] + edges = [ + _make_edge("a", "b", relation="calls", source_file="a.py", key="user-key-1"), + ] + diagnostics: dict = {} + result_nodes, result_edges = deduplicate_entities( + nodes, + edges, + communities={}, + diagnostics=diagnostics, + ) + assert len(result_edges) == 1 + assert result_edges[0].get("key") == "user-key-1" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Group B: Self-loop counting + exact duplicate collapse +# ═══════════════════════════════════════════════════════════════════════════════ + + +def test_remap_counts_self_loop_drops(): + """Self-loop drops counted in diagnostics dict, broken down by relation and source_file. + + Setup: nodes A, B that are exact duplicates. Edge A->B (calls, from a.py). + After remap: B merges into A, edge becomes A->A = self-loop, dropped. + Assert diagnostics: remap_self_loop_drops=1, by_relation={'calls':1}, by_source={'a.py':1} + """ + nodes = [ + {"id": "a", "label": "DataLoader", "source_file": "a.py"}, + {"id": "b", "label": "dataloader", "source_file": "a.py"}, + ] + edges = [ + _make_edge("a", "b", relation="calls", source_file="a.py"), + ] + diagnostics: dict = {} + result_nodes, result_edges = deduplicate_entities( + nodes, + edges, + communities={}, + diagnostics=diagnostics, + ) + assert len(result_nodes) == 1 + assert result_edges == [] # self-loop dropped + assert diagnostics.get("remap_self_loop_drops") == 1 + assert diagnostics.get("remap_self_loop_drops_by_relation", {}).get("calls") == 1 + assert diagnostics.get("remap_self_loop_drops_by_source", {}).get("a.py") == 1 + + +def test_remap_preserves_preexisting_self_loop_on_remapped_node(): + """A real self-loop survives when its node is remapped to the canonical winner.""" + nodes = [ + {"id": "winner", "label": "DataLoader", "source_file": "a.py"}, + {"id": "loser_long", "label": "dataloader", "source_file": "a.py"}, + ] + edges = [ + _make_edge("loser_long", "loser_long", relation="calls", source_file="a.py"), + ] + diagnostics: dict = {} + + result_nodes, result_edges = deduplicate_entities( + nodes, + edges, + communities={}, + diagnostics=diagnostics, + ) + + assert len(result_nodes) == 1 + assert result_edges == [ + {"source": "winner", "target": "winner", "relation": "calls", "source_file": "a.py"} + ] + assert diagnostics.get("remap_self_loop_drops") == 0 + + +def test_remap_collapses_exact_duplicates_after_remap(): + """Two edges that become identical after remap collapse to one. + + Setup: nodes A, B, C where B merges into C. + Two edges: A->B (calls, from x.py, line 10) and A->C (calls, from x.py, line 10). + After remap both become A->C with identical attrs -> collapse to one. + Assert diagnostics: remap_exact_duplicate_collapses=1, by_relation={'calls':1} + """ + nodes = [ + {"id": "a", "label": "Caller", "source_file": "x.py"}, + {"id": "b", "label": "DataLoader", "source_file": "x.py"}, + {"id": "c", "label": "dataloader", "source_file": "x.py"}, + ] + edges = [ + _make_edge("a", "b", relation="calls", source_file="x.py", source_location="L10"), + _make_edge("a", "c", relation="calls", source_file="x.py", source_location="L10"), + ] + diagnostics: dict = {} + result_nodes, result_edges = deduplicate_entities( + nodes, + edges, + communities={}, + diagnostics=diagnostics, + ) + assert len(result_nodes) == 2 + # After remap, both edges are identical (A->winner, calls, x.py, L10) -> collapse to 1 + assert len(result_edges) == 1 + assert diagnostics.get("remap_exact_duplicate_collapses") == 1 + assert diagnostics.get("remap_exact_duplicate_collapses_by_relation", {}).get("calls") == 1 + + +def test_remap_does_not_collapse_non_exact_duplicates(): + """Two edges with same source/target after remap but different attrs both survive. + + Setup: nodes A, B, C where B merges into C. + Edges: A->B (calls, line 10), A->C (calls, line 20). + After remap: A->C (calls, line 10) and A->C (calls, line 20) — both survive. + """ + nodes = [ + {"id": "a", "label": "Caller", "source_file": "x.py"}, + {"id": "b", "label": "DataLoader", "source_file": "x.py"}, + {"id": "c", "label": "dataloader", "source_file": "x.py"}, + ] + edges = [ + _make_edge("a", "b", relation="calls", source_file="x.py", source_location="L10"), + _make_edge("a", "c", relation="calls", source_file="x.py", source_location="L20"), + ] + diagnostics: dict = {} + result_nodes, result_edges = deduplicate_entities( + nodes, + edges, + communities={}, + diagnostics=diagnostics, + ) + assert len(result_nodes) == 2 + assert len(result_edges) == 2 + locations = {e.get("source_location") for e in result_edges} + assert locations == {"L10", "L20"} + + +def test_remap_returns_diagnostics_when_dict_provided(): + """diagnostics dict is populated with all counter keys even when counts are zero.""" + nodes = [ + {"id": "a", "label": "Caller", "source_file": "a.py"}, + {"id": "b", "label": "Target", "source_file": "a.py"}, + ] + edges = [_make_edge("a", "b", relation="calls")] + diagnostics: dict = {} + deduplicate_entities(nodes, edges, communities={}, diagnostics=diagnostics) + # When no merges happen, diagnostics should still have the counter keys at 0 + assert "remap_self_loop_drops" in diagnostics + assert "remap_exact_duplicate_collapses" in diagnostics + assert diagnostics["remap_self_loop_drops"] == 0 + assert diagnostics["remap_exact_duplicate_collapses"] == 0 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Group C: Build integration +# ═══════════════════════════════════════════════════════════════════════════════ + + +def test_build_with_dedup_and_multigraph_preserves_parallel_edges(): + """build(extractions, dedup=True, multigraph=True) preserves non-duplicate parallel edges. + + Create 2 extraction chunks with overlapping nodes but different edges. + Assert the built MultiDiGraph has the expected parallel edges. + """ + ext1 = { + "nodes": [ + {"id": "caller", "label": "Caller", "file_type": "code", "source_file": "a.py"}, + {"id": "dataloader", "label": "DataLoader", "file_type": "code", "source_file": "b.py"}, + ], + "edges": [ + { + "source": "caller", + "target": "dataloader", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + "source_location": "L10", + }, + ], + } + ext2 = { + "nodes": [ + {"id": "caller", "label": "Caller", "file_type": "code", "source_file": "a.py"}, + {"id": "dataloader", "label": "DataLoader", "file_type": "code", "source_file": "b.py"}, + ], + "edges": [ + { + "source": "caller", + "target": "dataloader", + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": "a.py", + "source_location": "L2", + }, + ], + } + G = build([ext1, ext2], dedup=True, multigraph=True, directed=True) + assert isinstance(G, nx.MultiDiGraph) + # Two edges with different relations should both survive + assert G.number_of_edges("caller", "dataloader") == 2 + relations = {data["relation"] for data in G["caller"]["dataloader"].values()} + assert relations == {"calls", "imports"} + + +def test_build_with_dedup_and_multigraph_reports_diagnostics(): + """G.graph['graphify_multigraph_diagnostics'] contains remap_ prefixed counters after build.""" + ext1 = { + "nodes": [ + {"id": "a", "label": "Caller", "file_type": "code", "source_file": "a.py"}, + {"id": "b", "label": "DataLoader", "file_type": "code", "source_file": "a.py"}, + {"id": "c", "label": "dataloader", "file_type": "code", "source_file": "a.py"}, + ], + "edges": [ + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + "source_location": "L10", + }, + ], + } + G = build([ext1], dedup=True, multigraph=True, directed=True) + diag = G.graph.get("graphify_multigraph_diagnostics", {}) + # Should contain remap_ prefixed counters from the dedup pass + remap_keys = [k for k in diag if k.startswith("remap_")] + assert len(remap_keys) > 0, f"Expected remap_ keys in diagnostics, got: {diag}" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Group D: Safe remove-all-parallel helper +# ═══════════════════════════════════════════════════════════════════════════════ + + +def test_remove_all_parallel_edges_removes_all_keys(): + """On MultiDiGraph with 3 edges between u,v (different keys), removes all 3.""" + G = nx.MultiDiGraph() + G.add_edge("a", "b", key="k1", relation="calls") + G.add_edge("a", "b", key="k2", relation="imports") + G.add_edge("a", "b", key="k3", relation="references") + assert G.number_of_edges("a", "b") == 3 + + removed = remove_all_parallel_edges(G, "a", "b") + + assert removed == 3 + assert G.number_of_edges("a", "b") == 0 + + +def test_remove_all_parallel_edges_no_edges_noop(): + """No edges between u,v -> returns 0, no raise.""" + G = nx.MultiDiGraph() + G.add_node("a") + G.add_node("b") + + removed = remove_all_parallel_edges(G, "a", "b") + + assert removed == 0 + + +def test_remove_all_parallel_edges_simple_digraph(): + """On simple DiGraph, removes the single edge, returns 1.""" + G = nx.DiGraph() + G.add_edge("a", "b", relation="calls") + + removed = remove_all_parallel_edges(G, "a", "b") + + assert removed == 1 + assert not G.has_edge("a", "b") + + +def test_remove_all_parallel_edges_does_not_use_two_tuple_semantics(): + """Verify the helper works correctly even if NetworkX's remove_edges_from + would only remove one. + + Create MultiDiGraph with 3 keyed edges between (a,b). Call helper. + Assert all 3 removed. + """ + G = nx.MultiDiGraph() + G.add_edge("a", "b", key="k1", relation="calls") + G.add_edge("a", "b", key="k2", relation="imports") + G.add_edge("a", "b", key="k3", relation="references") + + # NetworkX's remove_edges_from with 2-tuple only removes first key: + # G.remove_edges_from([("a", "b")]) would leave 2 edges. + # Our helper must remove all 3. + removed = remove_all_parallel_edges(G, "a", "b") + + assert removed == 3 + assert G.number_of_edges() == 0 + assert G.has_node("a") # nodes preserved + assert G.has_node("b") + + +def test_remove_all_parallel_edges_missing_node(): + """If either node doesn't exist in the graph, returns 0 without raising.""" + G = nx.MultiDiGraph() + G.add_node("a") + + assert remove_all_parallel_edges(G, "a", "nonexistent") == 0 + assert remove_all_parallel_edges(G, "nonexistent", "a") == 0 + + +def test_remove_all_parallel_edges_simple_graph_no_edge(): + """On simple Graph with no edge between u,v, returns 0.""" + G = nx.Graph() + G.add_node("a") + G.add_node("b") + + assert remove_all_parallel_edges(G, "a", "b") == 0 + + +def test_remove_all_parallel_edges_multigraph_undirected(): + """On undirected MultiGraph, removes all parallel edges.""" + G = nx.MultiGraph() + G.add_edge("a", "b", key="k1", relation="calls") + G.add_edge("a", "b", key="k2", relation="imports") + + removed = remove_all_parallel_edges(G, "a", "b") + + assert removed == 2 + assert G.number_of_edges() == 0 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Group E: Simple-graph regression +# ═══════════════════════════════════════════════════════════════════════════════ + + +def test_simple_graph_dedup_output_unchanged(): + """Default simple-graph build+dedup on a fixed fixture produces identical output. + + This is the go/no-go regression: if this test fails, PR 4A broke the default path. + """ + extraction = { + "nodes": [ + { + "id": "graphextractor", + "label": "GraphExtractor", + "file_type": "code", + "source_file": "a.py", + }, + { + "id": "graph_extractor", + "label": "graph_extractor", + "file_type": "code", + "source_file": "a.py", + }, + {"id": "dataloader", "label": "DataLoader", "file_type": "code", "source_file": "b.py"}, + { + "id": "networkanalyzer", + "label": "NetworkAnalyzer", + "file_type": "code", + "source_file": "c.py", + }, + ], + "edges": [ + { + "source": "graphextractor", + "target": "dataloader", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + "source_location": "L5", + }, + { + "source": "graph_extractor", + "target": "dataloader", + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": "a.py", + "source_location": "L1", + }, + { + "source": "dataloader", + "target": "networkanalyzer", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "b.py", + "source_location": "L10", + }, + ], + } + + # Default simple-graph build with dedup + G = build([extraction], dedup=True, directed=True) + assert "graphify_multigraph_diagnostics" not in G.graph + + # "GraphExtractor" and "graph_extractor" are near-duplicates — dedup merges them. + # _pick_winner prefers shorter ID, no chunk suffix -> "graphextractor" wins + # (both are same length=14; tiebreak is by sort, so the first is picked). + # Alternatively graph_extractor (15 chars) vs graphextractor (14 chars) -> graphextractor wins. + winner_candidates = {"graphextractor", "graph_extractor"} + surviving_nodes = set(G.nodes()) + + # After dedup: 3 nodes survive (one of the two graph-extractor variants + dataloader + networkanalyzer) + assert G.number_of_nodes() == 3, ( + f"Expected 3 nodes after dedup, got {G.number_of_nodes()}: {sorted(surviving_nodes)}" + ) + + # The winner is the one that survived + winner = winner_candidates & surviving_nodes + assert len(winner) == 1, f"Expected exactly one winner from {winner_candidates}, got {winner}" + + assert "dataloader" in surviving_nodes + assert "networkanalyzer" in surviving_nodes + + # Edges: after remap, both edges pointing to dataloader should survive + # (different relations: calls vs imports), plus the dataloader->networkanalyzer edge. + # The self-loop case doesn't apply here since edges go from graph_extractor -> dataloader. + assert G.number_of_edges() >= 2, f"Expected at least 2 edges, got {G.number_of_edges()}" + + # The dataloader->networkanalyzer edge must survive unchanged + assert G.has_edge("dataloader", "networkanalyzer") diff --git a/tests/test_detect.py b/tests/test_detect.py index 900851802..cdc0708f9 100644 --- a/tests/test_detect.py +++ b/tests/test_detect.py @@ -1,52 +1,74 @@ from pathlib import Path -from graphify.detect import classify_file, count_words, detect, detect_incremental, save_manifest, FileType, _looks_like_paper, _is_ignored, _load_graphifyignore, _is_sensitive +from graphify.detect import ( + classify_file, + count_words, + detect, + detect_incremental, + save_manifest, + FileType, + _is_ignored, + _load_graphifyignore, + _is_sensitive, +) FIXTURES = Path(__file__).parent / "fixtures" + def test_classify_python(): assert classify_file(Path("foo.py")) == FileType.CODE + def test_classify_typescript(): assert classify_file(Path("bar.ts")) == FileType.CODE + def test_classify_markdown(): assert classify_file(Path("README.md")) == FileType.DOCUMENT + def test_classify_pdf(): assert classify_file(Path("paper.pdf")) == FileType.PAPER + def test_classify_pdf_in_xcassets_skipped(): # PDFs inside Xcode asset catalogs are vector icons, not papers asset_pdf = Path("MyApp/Images.xcassets/icon.imageset/icon.pdf") assert classify_file(asset_pdf) is None + def test_classify_pdf_in_xcassets_root_skipped(): asset_pdf = Path("Pods/HXPHPicker/Assets.xcassets/photo.pdf") assert classify_file(asset_pdf) is None + def test_classify_unknown_returns_none(): assert classify_file(Path("archive.zip")) is None + def test_classify_image(): assert classify_file(Path("screenshot.png")) == FileType.IMAGE assert classify_file(Path("design.jpg")) == FileType.IMAGE assert classify_file(Path("diagram.webp")) == FileType.IMAGE + def test_count_words_sample_md(): words = count_words(FIXTURES / "sample.md") assert words > 5 + def test_detect_finds_fixtures(): result = detect(FIXTURES) assert result["total_files"] >= 2 assert "code" in result["files"] assert "document" in result["files"] + def test_detect_warns_small_corpus(): result = detect(FIXTURES) assert result["needs_graph"] is False assert result["warning"] is not None + def test_detect_skips_noise_dot_dirs(): """Noise dot dirs (.next, .nuxt, .graphify cache, …) are skipped (#873). Non-noise dot dirs (.github, .claude, …) are now allowed through.""" @@ -301,6 +323,7 @@ def test_detect_incremental_propagates_follow_symlinks(tmp_path, monkeypatch): def test_classify_video_extensions(): """Video and audio file extensions should classify as VIDEO.""" from graphify.detect import FileType + assert classify_file(Path("lecture.mp4")) == FileType.VIDEO assert classify_file(Path("podcast.mp3")) == FileType.VIDEO assert classify_file(Path("talk.mov")) == FileType.VIDEO @@ -399,7 +422,9 @@ def test_detect_skips_visual_tests_dir(tmp_path): def test_detect_skips_snapshots_dir(tmp_path): """__snapshots__/ and snapshots/ are jest/vitest artefacts — must be excluded.""" (tmp_path / "__snapshots__").mkdir() - (tmp_path / "__snapshots__" / "app.test.ts.snap").write_text("// Jest Snapshot\nexports[`test 1`] = `
`") + (tmp_path / "__snapshots__" / "app.test.ts.snap").write_text( + "// Jest Snapshot\nexports[`test 1`] = `
`" + ) (tmp_path / "app.ts").write_text("export function greet() { return 'hi'; }") result = detect(tmp_path) all_files = [f for files in result["files"].values() for f in files] @@ -422,6 +447,7 @@ def test_detect_skips_storybook_static_dir(tmp_path): # --- #873: dot dirs allowed, framework caches blocked --- + def test_detect_allows_github_dir(tmp_path): """Files inside .github/ (workflows etc.) are now indexed (#873).""" gh = tmp_path / ".github" / "workflows" @@ -430,7 +456,9 @@ def test_detect_allows_github_dir(tmp_path): (tmp_path / "main.py").write_text("def run(): pass") result = detect(tmp_path) all_files = [f for files in result["files"].values() for f in files] - assert any(".github" in f for f in all_files), "expected .github/workflows/ci.yml to be detected" + assert any(".github" in f for f in all_files), ( + "expected .github/workflows/ci.yml to be detected" + ) def test_detect_skips_next_cache(tmp_path): @@ -461,9 +489,9 @@ def test_detect_skips_graphify_own_cache(tmp_path): # --- #882: gitignore parent-exclusion rule for ! re-includes --- + def test_negation_cannot_rescue_file_under_excluded_dir(tmp_path): """A ! re-include cannot un-ignore a file whose parent dir is excluded (#882).""" - from graphify.detect import _is_ignored, _load_graphifyignore android = tmp_path / "android" / "app" / "src" android.mkdir(parents=True) victim = android / "Main.kt" @@ -478,7 +506,6 @@ def test_negation_cannot_rescue_file_under_excluded_dir(tmp_path): def test_negation_works_when_no_ancestor_excluded(tmp_path): """A ! re-include must still un-ignore a file when no ancestor is excluded (#882).""" - from graphify.detect import _is_ignored, _load_graphifyignore src = tmp_path / "src" src.mkdir() keep = src / "keep.py" @@ -492,7 +519,6 @@ def test_negation_works_when_no_ancestor_excluded(tmp_path): def test_negation_ancestor_itself_reincluded(tmp_path): """If the ancestor dir itself is re-included, its children should not be blocked (#882).""" - from graphify.detect import _is_ignored, _load_graphifyignore vendor = tmp_path / "vendor" / "lib" vendor.mkdir(parents=True) f = vendor / "utils.py" @@ -505,9 +531,11 @@ def test_negation_ancestor_itself_reincluded(tmp_path): # Regression tests for #1087 - anchored patterns must not match basename deep in tree + def test_anchored_dir_not_matched_at_depth(tmp_path): """/inbox/ must not match src/inbox/ — only inbox/ at the anchor root.""" from graphify.detect import _is_ignored, _load_graphifyignore + src_inbox = tmp_path / "src" / "inbox" src_inbox.mkdir(parents=True) f = src_inbox / "main.rs" @@ -525,35 +553,32 @@ def test_anchored_dir_not_matched_at_depth(tmp_path): def test_anchored_dir_matches_at_root(tmp_path): """/inbox/ must still match inbox/ at the anchor root (positive case).""" from graphify.detect import _is_ignored, _load_graphifyignore + inbox = tmp_path / "inbox" inbox.mkdir() f = inbox / "data.json" f.write_text("{}") (tmp_path / ".graphifyignore").write_text("/inbox/\n") patterns = _load_graphifyignore(tmp_path) - assert _is_ignored(f, tmp_path, patterns), ( - "inbox/data.json must be ignored by /inbox/" - ) - assert _is_ignored(inbox, tmp_path, patterns), ( - "inbox/ must be ignored by /inbox/" - ) + assert _is_ignored(f, tmp_path, patterns), "inbox/data.json must be ignored by /inbox/" + assert _is_ignored(inbox, tmp_path, patterns), "inbox/ must be ignored by /inbox/" def test_anchored_file_not_matched_at_depth(tmp_path): """/build must not match src/build.""" from graphify.detect import _is_ignored, _load_graphifyignore + src_build = tmp_path / "src" / "build" src_build.mkdir(parents=True) (tmp_path / ".graphifyignore").write_text("/build\n") patterns = _load_graphifyignore(tmp_path) - assert not _is_ignored(src_build, tmp_path, patterns), ( - "src/build must NOT be ignored by /build" - ) + assert not _is_ignored(src_build, tmp_path, patterns), "src/build must NOT be ignored by /build" def test_unanchored_dir_still_matches_at_depth(tmp_path): """inbox/ (no leading /) must still match src/inbox/ anywhere in the tree.""" from graphify.detect import _is_ignored, _load_graphifyignore + src_inbox = tmp_path / "src" / "inbox" src_inbox.mkdir(parents=True) f = src_inbox / "main.rs" @@ -568,6 +593,7 @@ def test_unanchored_dir_still_matches_at_depth(tmp_path): def test_anchored_multi_segment_pattern(tmp_path): """/src/inbox/ must match src/inbox/ but not x/src/inbox/.""" from graphify.detect import _is_ignored, _load_graphifyignore + (tmp_path / "src" / "inbox").mkdir(parents=True) (tmp_path / "x" / "src" / "inbox").mkdir(parents=True) target_ok = tmp_path / "src" / "inbox" / "a.py" @@ -588,34 +614,44 @@ def test_anchored_multi_segment_pattern(tmp_path): def test_sensitive_flags_api_token_txt(): assert _is_sensitive(Path("api_token.txt")) + def test_sensitive_flags_oauth_token_json(): assert _is_sensitive(Path("oauth_token.json")) + def test_sensitive_flags_underscore_secret(): assert _is_sensitive(Path("app_secret.yaml")) + def test_sensitive_does_not_flag_tokenizer_py(): assert not _is_sensitive(Path("tokenizer.py")) + def test_sensitive_does_not_flag_tokenize_py(): assert not _is_sensitive(Path("tokenize.py")) + def test_sensitive_flags_passwords_py(): # passwords.py is just as likely a secret store as passwords.txt — code ext is no excuse assert _is_sensitive(Path("passwords.py")) + def test_sensitive_flags_ssh_dir(): assert _is_sensitive(Path("/home/user/.ssh/id_rsa")) + def test_sensitive_flags_secrets_dir(): assert _is_sensitive(Path("config/secrets/db.json")) + def test_sensitive_flags_token_txt(): assert _is_sensitive(Path("token.txt")) + def test_sensitive_flags_credentials_json(): assert _is_sensitive(Path("credentials.json")) + def test_sensitive_does_not_flag_root_file_named_credentials(): # A root-level file called "credentials" (no parent dir named credentials) # must NOT be flagged by Stage 1; Stage 2 name-pattern check catches it instead. @@ -628,11 +664,13 @@ def test_sensitive_does_not_flag_root_file_named_credentials(): # Verify the whole function still returns True (via name pattern, not dir check). assert _is_sensitive(p) + def test_sensitive_secret_handler_txt(): # Both patterns now use (?![a-zA-Z]) so underscore after keyword is allowed. # "secret_handler.txt": "secret" followed by "_" (not alpha) → flagged. assert _is_sensitive(Path("secret_handler.txt")) + def test_sensitive_token_config_yaml(): # "token_config.yaml": "token" followed by "_" (not alpha) → flagged. assert _is_sensitive(Path("token_config.yaml")) @@ -640,6 +678,7 @@ def test_sensitive_token_config_yaml(): # ── Issue #933: failed-chunk files must not be frozen in manifest ───────────── + def test_save_manifest_skips_semantic_hash_for_files_without_cache(tmp_path): """Files in failed chunks have no semantic cache entry; save_manifest must leave their semantic_hash empty so detect_incremental re-queues them (#933).""" @@ -653,7 +692,12 @@ def test_save_manifest_skips_semantic_hash_for_files_without_cache(tmp_path): doc2.write_text("# B\n\ncontent b") # Simulate: doc1's chunk succeeded (has a cache entry), doc2's chunk failed (no entry). - save_cached(doc1, {"nodes": [{"id": "a", "source_file": str(doc1)}], "edges": [], "hyperedges": []}, root=tmp_path, kind="semantic") + save_cached( + doc1, + {"nodes": [{"id": "a", "source_file": str(doc1)}], "edges": [], "hyperedges": []}, + root=tmp_path, + kind="semantic", + ) # doc2: no cache entry written files = {"document": [str(doc1), str(doc2)]} @@ -674,7 +718,6 @@ def test_save_manifest_skips_semantic_hash_for_files_without_cache(tmp_path): assert str(doc2) not in manifest, "failed-chunk file must be absent from manifest" - def test_save_manifest_without_filter_unchanged_for_code(tmp_path): """Code files must be stamped in the manifest regardless of semantic cache.""" import json @@ -689,8 +732,11 @@ def test_save_manifest_without_filter_unchanged_for_code(tmp_path): manifest = json.loads(Path(manifest_path).read_text()) assert str(py) in manifest assert manifest[str(py)]["ast_hash"] != "" + + # Regression tests for #945 - .gitignore fallback when no .graphifyignore exists + def test_gitignore_fallback_when_no_graphifyignore(tmp_path): """When no .graphifyignore exists, .gitignore patterns are honored (#945).""" (tmp_path / ".git").mkdir() @@ -719,12 +765,13 @@ def test_graphifyignore_takes_precedence_over_gitignore(tmp_path): result = detect(tmp_path) code = result["files"]["code"] - assert any("main.py" in f for f in code) # gitignore NOT applied + assert any("main.py" in f for f in code) # gitignore NOT applied assert not any("other.py" in f for f in code) # graphifyignore IS applied # Regression tests for #947 - .worktrees/ skipped and --exclude flag + def test_detect_skips_worktrees_dir(tmp_path): """Files inside .worktrees/ are never indexed (#947).""" wt = tmp_path / ".worktrees" / "feature-branch" @@ -770,9 +817,11 @@ def test_detect_extra_excludes_pattern(tmp_path): # Shebang interpreter parsing # --------------------------------------------------------------------------- + def test_shebang_interpreter_plain(tmp_path): """Plain shebang returns the interpreter basename.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "plain" script.write_bytes(b"#!/usr/bin/python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -781,6 +830,7 @@ def test_shebang_interpreter_plain(tmp_path): def test_shebang_interpreter_env_single_arg(tmp_path): """`#!/usr/bin/env python3` returns the interpreter, not 'env'.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_single" script.write_bytes(b"#!/usr/bin/env python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -789,6 +839,7 @@ def test_shebang_interpreter_env_single_arg(tmp_path): def test_shebang_interpreter_env_dash_s(tmp_path): """`#!/usr/bin/env -S python3 -u` (-S split-args form) recovers the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_dashs" script.write_bytes(b"#!/usr/bin/env -S python3 -u\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -797,6 +848,7 @@ def test_shebang_interpreter_env_dash_s(tmp_path): def test_shebang_interpreter_env_with_flags(tmp_path): """`#!/usr/bin/env -i bash` skips env flags and resolves to the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_flags" script.write_bytes(b"#!/usr/bin/env -i bash\necho hi\n") assert _shebang_interpreter(script) == "bash" @@ -805,6 +857,7 @@ def test_shebang_interpreter_env_with_flags(tmp_path): def test_shebang_interpreter_env_with_assignment(tmp_path): """`#!/usr/bin/env DEBUG=1 python3` skips var=value assignments.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_assign" script.write_bytes(b"#!/usr/bin/env DEBUG=1 python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -813,6 +866,7 @@ def test_shebang_interpreter_env_with_assignment(tmp_path): def test_shebang_interpreter_no_shebang(tmp_path): """File without shebang returns None.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "no_shebang" script.write_bytes(b"print('x')\n") assert _shebang_interpreter(script) is None @@ -821,6 +875,7 @@ def test_shebang_interpreter_no_shebang(tmp_path): def test_shebang_interpreter_quoted_path(tmp_path): """Quoted interpreter path with spaces parses correctly via shlex.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "quoted" # Note: actual `#!` on disk wouldn't permit a quoted path on most kernels, # but shlex must not crash and should produce a reasonable answer @@ -839,6 +894,7 @@ def test_shebang_file_type_classifies_via_interpreter(tmp_path): def test_shebang_interpreter_unreadable_returns_none(tmp_path): """Unreadable / nonexistent files return None, never raise.""" from graphify.detect import _shebang_interpreter + missing = tmp_path / "does_not_exist" assert _shebang_interpreter(missing) is None @@ -846,6 +902,7 @@ def test_shebang_interpreter_unreadable_returns_none(tmp_path): def test_shebang_interpreter_env_unset_with_operand(tmp_path): """`env -u VAR python3` skips both -u and its required operand.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_unset" script.write_bytes(b"#!/usr/bin/env -u PYTHONPATH python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -855,6 +912,7 @@ def test_shebang_interpreter_env_unset_with_operand(tmp_path): def test_shebang_interpreter_env_chdir_with_operand(tmp_path): """`env -C /tmp python3` skips both -C and its workdir operand.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_chdir" script.write_bytes(b"#!/usr/bin/env -C /tmp python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -864,6 +922,7 @@ def test_shebang_interpreter_env_chdir_with_operand(tmp_path): def test_shebang_interpreter_env_path_with_operand(tmp_path): """`env -P /bin python3` skips both -P and its utilpath operand.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_path" script.write_bytes(b"#!/usr/bin/env -P /bin python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -873,6 +932,7 @@ def test_shebang_interpreter_env_path_with_operand(tmp_path): def test_shebang_interpreter_env_dash_s_after_flag(tmp_path): """`env -i -S "python3 -u"` handles -S after another env flag.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_flag_dash_s" script.write_bytes(b'#!/usr/bin/env -i -S "python3 -u"\nprint("x")\n') assert _shebang_interpreter(script) == "python3" @@ -882,6 +942,7 @@ def test_shebang_interpreter_env_dash_s_after_flag(tmp_path): def test_shebang_interpreter_env_clumped_u_operand(tmp_path): """Clumped `-uPYTHONPATH` form (no space between flag and operand) is one arg.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_clumped" script.write_bytes(b"#!/usr/bin/env -uPYTHONPATH python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -891,6 +952,7 @@ def test_shebang_interpreter_env_clumped_u_operand(tmp_path): def test_shebang_interpreter_env_missing_operand_returns_none(tmp_path): """`env -u` with no operand → not a valid command, return None.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_missing_op" script.write_bytes(b"#!/usr/bin/env -u\n") assert _shebang_interpreter(script) is None @@ -899,6 +961,7 @@ def test_shebang_interpreter_env_missing_operand_returns_none(tmp_path): def test_shebang_interpreter_env_gnu_split_string_equals(tmp_path): """GNU `--split-string='python3 -u'` (with `=` operand) → python3.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_split_eq" script.write_bytes(b"#!/usr/bin/env --split-string='python3 -u'\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -908,6 +971,7 @@ def test_shebang_interpreter_env_gnu_split_string_equals(tmp_path): def test_shebang_interpreter_env_gnu_split_string_separate(tmp_path): """GNU `--split-string "python3 -u"` (separate operand) → python3.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_split_sep" script.write_bytes(b'#!/usr/bin/env --split-string "python3 -u"\nprint("x")\n') assert _shebang_interpreter(script) == "python3" @@ -917,6 +981,7 @@ def test_shebang_interpreter_env_gnu_split_string_separate(tmp_path): def test_shebang_interpreter_env_gnu_argv0_operand(tmp_path): """GNU `-a alias python3` skips both -a and its argv0 operand.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_argv0" script.write_bytes(b"#!/usr/bin/env -a alias python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -926,6 +991,7 @@ def test_shebang_interpreter_env_gnu_argv0_operand(tmp_path): def test_shebang_interpreter_env_compact_dash_s(tmp_path): """Compact `-Spython3 -u` form (no space between -S and packed string).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_compact_dash_s" script.write_bytes(b"#!/usr/bin/env -Spython3 -u\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -935,6 +1001,7 @@ def test_shebang_interpreter_env_compact_dash_s(tmp_path): def test_shebang_interpreter_env_compact_v_then_s(tmp_path): """Compact `-vSpython3` (-v plus compact -S).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_compact_vs" script.write_bytes(b"#!/usr/bin/env -vSpython3 -u\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -944,6 +1011,7 @@ def test_shebang_interpreter_env_compact_v_then_s(tmp_path): def test_shebang_interpreter_env_long_unset_separate_operand(tmp_path): """GNU `--unset PYTHONPATH python3` (separate operand).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_unset" script.write_bytes(b"#!/usr/bin/env --unset PYTHONPATH python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -953,6 +1021,7 @@ def test_shebang_interpreter_env_long_unset_separate_operand(tmp_path): def test_shebang_interpreter_env_long_unset_equals(tmp_path): """GNU `--unset=PYTHONPATH python3` (`=` operand form).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_unset_eq" script.write_bytes(b"#!/usr/bin/env --unset=PYTHONPATH python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -962,6 +1031,7 @@ def test_shebang_interpreter_env_long_unset_equals(tmp_path): def test_shebang_interpreter_env_long_chdir_separate_operand(tmp_path): """GNU `--chdir /tmp python3` (separate operand).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_chdir" script.write_bytes(b"#!/usr/bin/env --chdir /tmp python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -971,6 +1041,7 @@ def test_shebang_interpreter_env_long_chdir_separate_operand(tmp_path): def test_shebang_interpreter_env_long_chdir_equals(tmp_path): """GNU `--chdir=/tmp python3` (`=` operand form).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_chdir_eq" script.write_bytes(b"#!/usr/bin/env --chdir=/tmp python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -980,6 +1051,7 @@ def test_shebang_interpreter_env_long_chdir_equals(tmp_path): def test_shebang_interpreter_env_signal_flags(tmp_path): """GNU signal-handling flags skip transparently.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_signal" script.write_bytes(b"#!/usr/bin/env --default-signal=TERM --ignore-signal=PIPE python3\n") assert _shebang_interpreter(script) == "python3" @@ -989,6 +1061,7 @@ def test_shebang_interpreter_env_signal_flags(tmp_path): def test_shebang_interpreter_env_unknown_option_returns_none(tmp_path): """Unknown hyphen-prefixed env option → return None rather than guessing.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_unknown" script.write_bytes(b"#!/usr/bin/env --no-such-flag python3\n") # Must refuse to guess: if we can't classify the option, we can't trust @@ -999,10 +1072,10 @@ def test_shebang_interpreter_env_unknown_option_returns_none(tmp_path): def test_shebang_interpreter_env_dash_s_assignment_before_interpreter(tmp_path): """`-S` payload may carry NAME=value assignments before the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_s_assignment" script.write_bytes( - b"#!/usr/bin/env -S PYTHONPATH=/opt/custom:${PYTHONPATH} python3\n" - b"print('x')\n" + b"#!/usr/bin/env -S PYTHONPATH=/opt/custom:${PYTHONPATH} python3\nprint('x')\n" ) assert _shebang_interpreter(script) == "python3" assert classify_file(script) == FileType.CODE @@ -1011,6 +1084,7 @@ def test_shebang_interpreter_env_dash_s_assignment_before_interpreter(tmp_path): def test_shebang_interpreter_env_dash_s_flag_before_interpreter(tmp_path): """`-S` payload may carry env flags (e.g. -i) before the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_s_flag" script.write_bytes(b"#!/usr/bin/env -S -i OLDUSER=${USER} python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -1020,6 +1094,7 @@ def test_shebang_interpreter_env_dash_s_flag_before_interpreter(tmp_path): def test_shebang_interpreter_env_long_split_assignment_before_interpreter(tmp_path): """`--split-string=` payload may carry assignments before the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_split_assignment" script.write_bytes( b"#!/usr/bin/env --split-string='PYTHONPATH=/opt/custom:${PYTHONPATH} python3 -u'\n" @@ -1032,6 +1107,7 @@ def test_shebang_interpreter_env_long_split_assignment_before_interpreter(tmp_pa def test_shebang_interpreter_env_long_split_flag_before_interpreter(tmp_path): """`--split-string=` payload may carry env flags before the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_split_flag" script.write_bytes(b"#!/usr/bin/env --split-string='-i python3 -u'\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -1043,6 +1119,7 @@ def test_shebang_interpreter_env_nested_split_string_rejected(tmp_path): on the recursive call bounds the recursion depth at one). Without this guard, a malicious or strange shebang could spin the parser indefinitely.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_nested_split" # Outer -S splits into ["-S", "python3", "-u"]; inner -S is treated as an # unknown option in the recursed pass, so we get None (refuse to guess). @@ -1053,6 +1130,7 @@ def test_shebang_interpreter_env_nested_split_string_rejected(tmp_path): def test_shebang_interpreter_env_vs_assignment_before_interpreter(tmp_path): """`-vS` packed payload also re-parses for leading assignments.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_vs_assignment" script.write_bytes(b"#!/usr/bin/env -vS DEBUG=1 python3 -u\nprint('x')\n") assert _shebang_interpreter(script) == "python3" diff --git a/tests/test_devin.py b/tests/test_devin.py index a3bca5d05..1f94bb58d 100644 --- a/tests/test_devin.py +++ b/tests/test_devin.py @@ -1,24 +1,28 @@ """Tests for graphify devin install / uninstall commands.""" + from pathlib import Path import sys from unittest.mock import patch -import pytest # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _devin_install_user(tmp_path): from graphify.__main__ import install + old_cwd = Path.cwd() try: import os + os.chdir(tmp_path) with patch("graphify.__main__.Path.home", return_value=tmp_path): install(platform="devin") finally: import os + os.chdir(old_cwd) @@ -38,6 +42,7 @@ def _rules_path(project_dir): # User-scope install (graphify install --platform devin / graphify devin install) # --------------------------------------------------------------------------- + def test_devin_install_user_creates_skill_file(tmp_path): """User-scope install copies skill to ~/.config/devin/skills/graphify/SKILL.md.""" _devin_install_user(tmp_path) @@ -71,9 +76,11 @@ def test_devin_install_user_does_not_write_rules(tmp_path): # Project-scope install (graphify devin install --project) # --------------------------------------------------------------------------- + def test_devin_install_project_creates_skill_file(tmp_path, monkeypatch): """Project-scope install copies skill to .devin/skills/graphify/SKILL.md.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -88,6 +95,7 @@ def test_devin_install_project_creates_skill_file(tmp_path, monkeypatch): def test_devin_install_project_creates_rules_file(tmp_path, monkeypatch): """Project-scope install writes .windsurf/rules/graphify.md.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -104,6 +112,7 @@ def test_devin_install_project_creates_rules_file(tmp_path, monkeypatch): def test_devin_rules_content_recommends_graphify_query(tmp_path): """The rules file installed by devin must use query-first policy.""" from graphify.__main__ import _devin_rules_install + _devin_rules_install(tmp_path) content = _rules_path(tmp_path).read_text() assert "graphify query" in content @@ -112,6 +121,7 @@ def test_devin_rules_content_recommends_graphify_query(tmp_path): def test_devin_rules_install_idempotent(tmp_path, capsys): """Installing rules twice does not change content and prints 'no change'.""" from graphify.__main__ import _devin_rules_install + _devin_rules_install(tmp_path) content_first = _rules_path(tmp_path).read_text() _devin_rules_install(tmp_path) @@ -123,6 +133,7 @@ def test_devin_rules_install_idempotent(tmp_path, capsys): def test_devin_install_project_hints_git_add(tmp_path, monkeypatch, capsys): """Project-scope install prints a git add hint covering .devin/ and .windsurf/.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -138,6 +149,7 @@ def test_devin_install_project_hints_git_add(tmp_path, monkeypatch, capsys): # Uninstall — user scope # --------------------------------------------------------------------------- + def test_devin_uninstall_user_removes_skill_file(tmp_path): """User-scope uninstall removes the skill file.""" _devin_install_user(tmp_path) @@ -145,6 +157,7 @@ def test_devin_uninstall_user_removes_skill_file(tmp_path): assert skill.exists() from graphify.__main__ import _remove_skill_file + with patch("graphify.__main__.Path.home", return_value=tmp_path): _remove_skill_file("devin") assert not skill.exists() @@ -154,6 +167,7 @@ def test_devin_uninstall_user_noop_when_not_installed(tmp_path, capsys): """User-scope uninstall prints an appropriate message when nothing is installed.""" from graphify.__main__ import main import os + old_cwd = Path.cwd() try: os.chdir(tmp_path) @@ -170,9 +184,11 @@ def test_devin_uninstall_user_noop_when_not_installed(tmp_path, capsys): # Uninstall — project scope # --------------------------------------------------------------------------- + def test_devin_uninstall_project_removes_skill_file(tmp_path, monkeypatch): """Project-scope uninstall removes .devin/skills/graphify/SKILL.md.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -188,6 +204,7 @@ def test_devin_uninstall_project_removes_skill_file(tmp_path, monkeypatch): def test_devin_uninstall_project_removes_rules_file(tmp_path, monkeypatch): """Project-scope uninstall removes .windsurf/rules/graphify.md.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -203,6 +220,7 @@ def test_devin_uninstall_project_removes_rules_file(tmp_path, monkeypatch): def test_devin_uninstall_project_does_not_touch_user_scope(tmp_path, monkeypatch): """Project-scope uninstall must not remove the user-scope skill file.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -222,6 +240,7 @@ def test_devin_uninstall_project_does_not_touch_user_scope(tmp_path, monkeypatch def test_devin_rules_uninstall_noop_when_not_installed(tmp_path): """_devin_rules_uninstall does nothing if the rules file was never written.""" from graphify.__main__ import _devin_rules_uninstall + _devin_rules_uninstall(tmp_path) # should not raise @@ -229,9 +248,11 @@ def test_devin_rules_uninstall_noop_when_not_installed(tmp_path): # Skill file content # --------------------------------------------------------------------------- + def test_devin_skill_file_exists_in_package(): """skill-devin.md must be present in the installed package.""" import graphify + skill = Path(graphify.__file__).parent / "skill-devin.md" assert skill.exists(), "skill-devin.md missing from package" @@ -244,6 +265,7 @@ def test_devin_skill_file_uses_python_c_syntax(): ``python -c "..."`` so they work in pipx / venv environments. """ import graphify + skill = (Path(graphify.__file__).parent / "skill-devin.md").read_text() assert '.graphify_python) -c "' in skill, ( "skill-devin.md must use the interpreter-detection pattern " @@ -255,6 +277,7 @@ def test_devin_skill_file_uses_python_c_syntax(): def test_devin_skill_file_frontmatter_has_triggers(): """Devin skill frontmatter must list triggers for model-invocable activation.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-devin.md").read_text() assert "triggers:" in skill assert "model" in skill @@ -264,9 +287,11 @@ def test_devin_skill_file_frontmatter_has_triggers(): # Platform config sanity # --------------------------------------------------------------------------- + def test_devin_in_platform_config(): """devin must be registered in _PLATFORM_CONFIG.""" from graphify.__main__ import _PLATFORM_CONFIG + assert "devin" in _PLATFORM_CONFIG assert _PLATFORM_CONFIG["devin"]["skill_file"] == "skill-devin.md" assert _PLATFORM_CONFIG["devin"]["claude_md"] is False @@ -275,6 +300,7 @@ def test_devin_in_platform_config(): def test_devin_platform_skill_destination_user_scope(tmp_path): """User-scope destination must be ~/.config/devin/skills/graphify/SKILL.md.""" from graphify.__main__ import _platform_skill_destination + with patch("graphify.__main__.Path.home", return_value=tmp_path): dst = _platform_skill_destination("devin", project=False) assert dst == tmp_path / ".config" / "devin" / "skills" / "graphify" / "SKILL.md" @@ -283,6 +309,7 @@ def test_devin_platform_skill_destination_user_scope(tmp_path): def test_devin_in_main_help_text(capsys, monkeypatch): """`graphify --help` must list devin in the platform list and in the per-platform section.""" from graphify.__main__ import main + monkeypatch.setattr(sys, "argv", ["graphify", "--help"]) main() captured = capsys.readouterr().out @@ -305,5 +332,6 @@ def test_devin_in_main_help_text(capsys, monkeypatch): def test_devin_platform_skill_destination_project_scope(tmp_path): """Project-scope destination must be /.devin/skills/graphify/SKILL.md.""" from graphify.__main__ import _platform_skill_destination + dst = _platform_skill_destination("devin", project=True, project_dir=tmp_path) assert dst == tmp_path / ".devin" / "skills" / "graphify" / "SKILL.md" diff --git a/tests/test_dotnet.py b/tests/test_dotnet.py index 17a146073..4bc6faf30 100644 --- a/tests/test_dotnet.py +++ b/tests/test_dotnet.py @@ -1,7 +1,7 @@ """Tests for .NET project file extraction (.sln, .csproj, .razor).""" + from pathlib import Path import tempfile -import pytest from graphify.extract import extract_sln, extract_csproj, extract_razor FIXTURES = Path(__file__).parent / "fixtures" @@ -17,6 +17,7 @@ def _relations(r): # ── .sln ───────────────────────────────────────────────────────────────────── + def test_sln_extracts_projects(): r = extract_sln(FIXTURES / "sample.sln") assert "error" not in r @@ -39,13 +40,14 @@ def test_sln_project_dependency(): # ── .csproj ────────────────────────────────────────────────────────────────── + def test_csproj_packages(): r = extract_csproj(FIXTURES / "sample.csproj") assert "error" not in r labels = _labels(r) - assert any("MediatR" in l for l in labels) - assert any("FluentValidation" in l for l in labels) - assert any("Swashbuckle" in l for l in labels) + assert any("MediatR" in label for label in labels) + assert any("FluentValidation" in label for label in labels) + assert any("Swashbuckle" in label for label in labels) def test_csproj_project_references(): @@ -74,6 +76,7 @@ def test_csproj_invalid_xml(): # ── .razor ─────────────────────────────────────────────────────────────────── + def test_razor_using_and_inject(): r = extract_razor(FIXTURES / "sample.razor") assert "error" not in r @@ -91,7 +94,7 @@ def test_razor_components(): def test_razor_page_route(): r = extract_razor(FIXTURES / "sample.razor") - assert any("/counter" in l for l in _labels(r)) + assert any("/counter" in label for label in _labels(r)) def test_razor_inherits(): @@ -113,13 +116,16 @@ def test_razor_missing_file(): # ── dispatch & detect integration ──────────────────────────────────────────── + def test_dispatch_table(): from graphify.extract import _get_extractor + for ext in (".sln", ".csproj", ".fsproj", ".vbproj", ".razor", ".cshtml"): assert _get_extractor(Path(f"foo{ext}")) is not None, f"{ext} not in dispatch" def test_code_extensions(): from graphify.detect import CODE_EXTENSIONS + for ext in (".sln", ".csproj", ".fsproj", ".vbproj", ".razor", ".cshtml"): assert ext in CODE_EXTENSIONS, f"{ext} missing" diff --git a/tests/test_edge_identity.py b/tests/test_edge_identity.py new file mode 100644 index 000000000..f80efb345 --- /dev/null +++ b/tests/test_edge_identity.py @@ -0,0 +1,85 @@ +"""Tests for graphify.edge_identity — schema constants and stable key helpers.""" + +from __future__ import annotations + +from graphify.edge_identity import SCHEMA_KEY_FIELD, make_stable_key, strip_schema_key + + +def test_schema_key_field_constant(): + assert SCHEMA_KEY_FIELD == "key" + + +def test_make_stable_key_deterministic(): + k1 = make_stable_key("calls", "src/a.py", "L10") + k2 = make_stable_key("calls", "src/a.py", "L10") + assert k1 == k2 + assert isinstance(k1, str) + assert k1 # non-empty + + +def test_make_stable_key_all_none(): + k = make_stable_key(None, None, None) + assert isinstance(k, str) + assert k # non-empty — never crashes or returns empty string + + +def test_make_stable_key_differs_by_source_location(): + k1 = make_stable_key("calls", "src/a.py", "L10") + k2 = make_stable_key("calls", "src/a.py", "L20") + assert k1 != k2 + + +def test_make_stable_key_identical_fields_match(): + k1 = make_stable_key("imports", "graphify/build.py", "L42") + k2 = make_stable_key("imports", "graphify/build.py", "L42") + assert k1 == k2 + + +def test_strip_schema_key_removes_key_field(): + attrs = {"key": "calls:a.py:L1", "relation": "calls", "confidence": "EXTRACTED"} + key_val, cleaned = strip_schema_key(attrs) + assert key_val == "calls:a.py:L1" + assert "key" not in cleaned + assert cleaned["relation"] == "calls" + assert cleaned["confidence"] == "EXTRACTED" + + +def test_strip_schema_key_no_key_present(): + attrs = {"relation": "imports", "confidence_score": 1.0} + key_val, cleaned = strip_schema_key(attrs) + assert key_val is None + assert cleaned == attrs + assert "key" not in cleaned + + +def test_strip_schema_key_does_not_mutate_input(): + attrs = {"key": "k1", "relation": "calls"} + original = dict(attrs) + strip_schema_key(attrs) + assert attrs == original + + +# --------------------------------------------------------------------------- +# Blocker 1: delimiter-collision safety +# --------------------------------------------------------------------------- + + +def test_make_stable_key_no_delimiter_collision(): + # "a:b","c","d" must not hash the same as "a","b:c","d" + k1 = make_stable_key("a:b", "c", "d") + k2 = make_stable_key("a", "b:c", "d") + assert k1 != k2 + + +def test_make_stable_key_format_is_versioned(): + k = make_stable_key("calls", "a.py", "L1") + assert k.startswith("edge:v1:") + + +def test_make_stable_key_none_differs_from_empty_and_unknown(): + # make_stable_key(None, None, None) must not collide with + # make_stable_key("unknown", "", "") — None must serialize as JSON null, + # not be normalised to "unknown"/"" before hashing. + k_none = make_stable_key(None, None, None) + k_defaults = make_stable_key("unknown", "", "") + assert k_none != k_defaults diff --git a/tests/test_explain_cli.py b/tests/test_explain_cli.py index 1d00955f0..e84390fb7 100644 --- a/tests/test_explain_cli.py +++ b/tests/test_explain_cli.py @@ -1,4 +1,5 @@ """Regression tests for `graphify explain` arrow direction (#853).""" + from __future__ import annotations import json import graphify.__main__ as mainmod @@ -6,24 +7,54 @@ def _write_graph(tmp_path): graph_data = { - "directed": False, "multigraph": False, "graph": {}, + "directed": False, + "multigraph": False, + "graph": {}, "nodes": [ - {"id": "validate", "label": "validateSanitySession()", - "source_file": "server/sanity-validate-session.ts", "community": 0}, - {"id": "create_patch", "label": "createPatchHandler()", - "source_file": "server/create-patch-handler.ts", "community": 0}, - {"id": "create_edit", "label": "createEditHandler()", - "source_file": "server/create-edit-handler.ts", "community": 0}, - {"id": "stable_stringify", "label": "stableStringify()", - "source_file": "shared/stringify.ts", "community": 0}, + { + "id": "validate", + "label": "validateSanitySession()", + "source_file": "server/sanity-validate-session.ts", + "community": 0, + }, + { + "id": "create_patch", + "label": "createPatchHandler()", + "source_file": "server/create-patch-handler.ts", + "community": 0, + }, + { + "id": "create_edit", + "label": "createEditHandler()", + "source_file": "server/create-edit-handler.ts", + "community": 0, + }, + { + "id": "stable_stringify", + "label": "stableStringify()", + "source_file": "shared/stringify.ts", + "community": 0, + }, ], "links": [ - {"source": "create_patch", "target": "validate", - "relation": "calls", "confidence": "EXTRACTED"}, - {"source": "create_edit", "target": "validate", - "relation": "calls", "confidence": "EXTRACTED"}, - {"source": "validate", "target": "stable_stringify", - "relation": "calls", "confidence": "EXTRACTED"}, + { + "source": "create_patch", + "target": "validate", + "relation": "calls", + "confidence": "EXTRACTED", + }, + { + "source": "create_edit", + "target": "validate", + "relation": "calls", + "confidence": "EXTRACTED", + }, + { + "source": "validate", + "target": "stable_stringify", + "relation": "calls", + "confidence": "EXTRACTED", + }, ], } p = tmp_path / "graph.json" @@ -33,8 +64,9 @@ def _write_graph(tmp_path): def _run(monkeypatch, graph_path, label, capsys): monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) - monkeypatch.setattr(mainmod.sys, "argv", - ["graphify", "explain", label, "--graph", str(graph_path)]) + monkeypatch.setattr( + mainmod.sys, "argv", ["graphify", "explain", label, "--graph", str(graph_path)] + ) mainmod.main() return capsys.readouterr().out @@ -54,3 +86,105 @@ def test_caller_shows_callee_as_outbound(monkeypatch, tmp_path, capsys): out = _run(monkeypatch, p, "createPatchHandler", capsys) assert "--> validateSanitySession() [calls]" in out assert "<-- " not in out + + +def _write_multigraph(tmp_path, relations): + """Node 'a' with `relations` parallel edges to neighbor 'b' (MultiDiGraph).""" + links = [ + {"source": "a", "target": "b", "relation": rel, "key": idx} + for idx, rel in enumerate(relations) + ] + graph_data = { + "directed": True, + "multigraph": True, + "graph": {}, + "nodes": [ + {"id": "a", "label": "alpha()", "source_file": "a.py", "community": 0}, + {"id": "b", "label": "beta()", "source_file": "b.py", "community": 0}, + ], + "links": links, + } + p = tmp_path / "graph.json" + p.write_text(json.dumps(graph_data)) + return p + + +def test_explain_multigraph_neighbor_bundles_relations(monkeypatch, tmp_path, capsys): + """PR5 gate: a neighbor reached by 4 parallel edges shows the bundle, not one.""" + p = _write_multigraph(tmp_path, ["calls", "imports", "contains", "reads"]) + out = _run(monkeypatch, p, "alpha()", capsys) + # 4 unique relations exceeds the default cap (3), so a capped bundle renders + # the bundle for that neighbor rather than a single first-edge relation. + assert "--> beta() [calls, contains, imports (+1 more, 4 total)]" in out + # First-edge-only regression guard: a lone "[calls] [...]" block must NOT appear. + assert "--> beta() [calls] [" not in out + + +def test_explain_multigraph_capped_summary(monkeypatch, tmp_path, capsys): + """A neighbor pair with >3 unique relations renders the capped (+K more, N total) form.""" + p = _write_multigraph(tmp_path, ["gamma", "alpha", "epsilon", "beta", "delta"]) + out = _run(monkeypatch, p, "alpha()", capsys) + # sorted unique: alpha, beta, delta, epsilon, gamma -> first 3 + capped suffix. + assert "--> beta() [alpha, beta, delta (+2 more, 5 total)]" in out + + +def test_explain_simple_graph_output_regression(monkeypatch, tmp_path, capsys): + """Simple DiGraph explain output is unchanged: '[rel] [conf]' per neighbor.""" + p = _write_graph(tmp_path) + out = _run(monkeypatch, p, "validateSanitySession", capsys) + # Byte-stable bracketed form, matching test_callee_shows_callers_as_inbound. + assert "<-- createPatchHandler() [calls]" in out + assert "<-- createEditHandler() [calls]" in out + assert "--> stableStringify() [calls]" in out + + +def _write_bidirectional_multigraph(tmp_path): + """A<->B with different relations each way: A->B 'calls', B->A 'imports'.""" + graph_data = { + "directed": True, + "multigraph": True, + "graph": {}, + "nodes": [ + {"id": "a", "label": "alpha()", "source_file": "a.py", "community": 0}, + {"id": "b", "label": "beta()", "source_file": "b.py", "community": 0}, + ], + "links": [ + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "key": 0, + }, + { + "source": "b", + "target": "a", + "relation": "imports", + "confidence": "EXTRACTED", + "key": 0, + }, + ], + } + p = tmp_path / "graph.json" + p.write_text(json.dumps(graph_data)) + return p + + +def test_explain_directional_isolation(monkeypatch, tmp_path, capsys): + """Out and in connections to the same neighbor stay isolated by direction. + + Regression for the directed_only fix: relationship_envelope merges both + directions by default, which would wrongly show 'calls, imports' on both + the out (-->) and in (<--) arrows. directed_only=True isolates each + connection's own stored direction. + """ + p = _write_bidirectional_multigraph(tmp_path) + out = _run(monkeypatch, p, "alpha()", capsys) + # Outgoing A->B shows ONLY 'calls'; incoming B->A shows ONLY 'imports'. + assert "--> beta() [calls] [EXTRACTED]" in out + assert "<-- beta() [imports] [EXTRACTED]" in out + # Neither arrow may merge the opposite direction's relation. + assert "--> beta() [calls, imports" not in out + assert "<-- beta() [calls, imports" not in out + assert "--> beta() [imports" not in out + assert "<-- beta() [calls]" not in out diff --git a/tests/test_export.py b/tests/test_export.py index 65964d24e..7e1269c2e 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -1,15 +1,22 @@ import json import tempfile from pathlib import Path + +import networkx as nx +from networkx.readwrite import json_graph + from graphify.build import build_from_json from graphify.cluster import cluster from graphify.export import to_json, to_cypher, to_graphml, to_html, to_canvas +from graphify.graph_loader import GRAPHIFY_PROFILE_KEY, load_graph FIXTURES = Path(__file__).parent / "fixtures" + def make_graph(): return build_from_json(json.loads((FIXTURES / "extraction.json").read_text())) + def test_to_json_creates_file(): G = make_graph() communities = cluster(G) @@ -18,6 +25,7 @@ def test_to_json_creates_file(): to_json(G, communities, str(out)) assert out.exists() + def test_to_json_valid_json(): G = make_graph() communities = cluster(G) @@ -28,6 +36,7 @@ def test_to_json_valid_json(): assert "nodes" in data assert "links" in data + def test_to_json_nodes_have_community(): G = make_graph() communities = cluster(G) @@ -38,6 +47,7 @@ def test_to_json_nodes_have_community(): for node in data["nodes"]: assert "community" in node + def test_to_cypher_creates_file(): G = make_graph() with tempfile.TemporaryDirectory() as tmp: @@ -45,6 +55,7 @@ def test_to_cypher_creates_file(): to_cypher(G, str(out)) assert out.exists() + def test_to_cypher_contains_merge_statements(): G = make_graph() with tempfile.TemporaryDirectory() as tmp: @@ -53,6 +64,7 @@ def test_to_cypher_contains_merge_statements(): content = out.read_text() assert "MERGE" in content + def test_to_graphml_creates_file(): G = make_graph() communities = cluster(G) @@ -61,6 +73,7 @@ def test_to_graphml_creates_file(): to_graphml(G, communities, str(out)) assert out.exists() + def test_to_graphml_valid_xml(): G = make_graph() communities = cluster(G) @@ -71,6 +84,7 @@ def test_to_graphml_valid_xml(): assert " load_graph reconstructs the same graph_type for every type, + proving the profile survives a save/load cycle.""" + cases = [ + (build_from_json(_build_extraction()), "simple", nx.Graph), + (build_from_json(_build_extraction(), directed=True), "digraph", nx.DiGraph), + (build_from_json(_build_extraction(), multigraph=True), "multidigraph", nx.MultiDiGraph), + ] + for G, expected_type, expected_cls in cases: + communities = {0: list(G.nodes)} + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "graph.json" + to_json(G, communities, str(out), force=True) + data = json.loads(out.read_text()) + reloaded = load_graph(data, require_capabilities=False) + assert isinstance(reloaded, expected_cls) + assert reloaded.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == expected_type + # node_link_graph (the lower-level loader) also sees G.graph metadata. + nlg = json_graph.node_link_graph(data, edges="links") + assert nlg.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == expected_type + + +def test_to_json_simple_graph_regression(): + """Simple-graph output is unchanged except for the added graphify_profile. + + The "graph" metadata object gains exactly one key (graphify_profile); it was + empty ({}) before. Stripping that key leaves the pre-PR7 empty object, and + every other structural key (nodes/links/directed/multigraph/hyperedges) is + unaffected. + """ + G = build_from_json(_build_extraction()) + communities = cluster(G) + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "graph.json" + to_json(G, communities, str(out)) + data = json.loads(out.read_text()) + + # The only added graph-metadata content is the profile. + assert data["graph"] == {GRAPHIFY_PROFILE_KEY: {"graph_type": "simple"}} + # Removing the profile yields the pre-change empty "graph" object — nothing + # else leaked into the graph-level metadata. + data["graph"].pop(GRAPHIFY_PROFILE_KEY) + assert data["graph"] == {} + # Core structural keys remain present and well-formed. + assert isinstance(data["nodes"], list) and data["nodes"] + assert isinstance(data["links"], list) + assert data["directed"] is False + assert data["multigraph"] is False + for node in data["nodes"]: + assert "id" in node and "community" in node + + +# ── RISK 4: empty-merge floor in to_json (Guard 1) ─────────────────────────── +# +# to_json must refuse to overwrite a populated on-disk graph.json (>0 nodes) +# with an EMPTY (0-node) graph — a 0-node write over a populated graph is a +# failed/aborted extraction, never a real result. This floor engages +# REGARDLESS of force=True (force is the bug enabler here), and only when the +# *new* graph has 0 nodes AND the existing file is populated. It must NOT block +# a fresh empty write (no existing file), a non-zero dedup shrink, or a +# 0-over-0 write (nothing populated to protect). + + +def test_to_json_floor_blocks_zero_over_populated_even_with_force(tmp_path): + """Existing populated graph.json + a 0-node graph with force=True must be + refused (return False) and leave the on-disk graph untouched. This is the + RED-before-fix case: without the floor, force=True wipes 4 nodes to 0.""" + out = tmp_path / "graph.json" + + # Seed a populated graph.json (4 nodes) via the real write path. + populated = build_from_json(_build_extraction()) + assert populated.number_of_nodes() == 4 + assert to_json(populated, cluster(populated), str(out), force=True) is True + assert len(json.loads(out.read_text())["nodes"]) == 4 + + # Attempt to overwrite with a 0-node graph, force=True. + empty = nx.Graph() + assert empty.number_of_nodes() == 0 + assert to_json(empty, {}, str(out), force=True) is False + + # The previous populated graph is preserved on disk. + assert len(json.loads(out.read_text())["nodes"]) == 4 + + +def test_to_json_floor_blocks_zero_over_populated_without_force(tmp_path, capsys): + """Guard 1 (not the pre-existing shrink guard) fires for force=False + 0-node + over populated. Pre-fix the shrink guard fired and emitted a WARNING; Guard 1 + emits a distinct ERROR message. Asserting the exact Guard-1 text makes this + test red-before-fix / green-after-fix, eliminating the vacuousness identified + by the bug-hunter.""" + out = tmp_path / "graph.json" + + populated = build_from_json(_build_extraction()) + assert populated.number_of_nodes() == 4 + assert to_json(populated, cluster(populated), str(out), force=True) is True + assert len(json.loads(out.read_text())["nodes"]) == 4 + + empty = nx.Graph() + result = to_json(empty, {}, str(out), force=False) + + # Guard 1 must have fired: return False and preserve the on-disk graph. + assert result is False + assert len(json.loads(out.read_text())["nodes"]) == 4 + + # The exact Guard-1 ERROR message must appear on stderr. Pre-fix the shrink + # guard fires instead and emits a WARNING with different text, making the + # assertion below fail on unfixed code. + captured = capsys.readouterr() + assert ( + "[graphify] ERROR: refusing to overwrite a populated graph.json " + "(4 nodes) with an EMPTY (0-node) graph - this is a " + "failed/aborted extraction, not a real result. The previous " + "graph is preserved." + ) in captured.err + + +def test_to_json_allows_fresh_empty_no_existing_file(tmp_path): + """A7: no existing file + 0-node graph + force=True is allowed — the floor + must NOT engage when existing_path.exists() is False. Writes a valid + 0-node graph.json.""" + out = tmp_path / "graph.json" + assert not out.exists() + + empty = nx.Graph() + assert to_json(empty, {}, str(out), force=True) is True + + data = json.loads(out.read_text()) + assert data["nodes"] == [] + + +def test_to_json_allows_nonzero_dedup_shrink_with_force(tmp_path): + """A10: existing 4 nodes, new 2-node graph, force=True is allowed — only a + new graph with 0 nodes trips the floor. A non-zero dedup/shrink under force + is a legitimate result.""" + out = tmp_path / "graph.json" + + populated = build_from_json(_build_extraction()) + assert populated.number_of_nodes() == 4 + assert to_json(populated, cluster(populated), str(out), force=True) is True + assert len(json.loads(out.read_text())["nodes"]) == 4 + + smaller = nx.Graph() + smaller.add_node("a") + smaller.add_node("b") + assert smaller.number_of_nodes() == 2 + assert to_json(smaller, {}, str(out), force=True) is True + + assert len(json.loads(out.read_text())["nodes"]) == 2 + + +def test_to_json_allows_zero_over_empty_existing(tmp_path): + """An existing file with 0 nodes + a new 0-node graph is allowed — there is + nothing populated to protect, so the floor must NOT engage.""" + out = tmp_path / "graph.json" + + # Seed a 0-node graph.json (no existing file → floor inert on first write). + first_empty = nx.Graph() + assert to_json(first_empty, {}, str(out), force=True) is True + assert json.loads(out.read_text())["nodes"] == [] + + # Overwrite 0-over-0: allowed. + second_empty = nx.Graph() + assert to_json(second_empty, {}, str(out), force=True) is True + assert json.loads(out.read_text())["nodes"] == [] diff --git a/tests/test_export_multigraph.py b/tests/test_export_multigraph.py new file mode 100644 index 000000000..30b9e25cc --- /dev/null +++ b/tests/test_export_multigraph.py @@ -0,0 +1,548 @@ +"""Export round-trip and parallel-edge fidelity tests for MultiDiGraph (PR 6). + +PR 6 go/no-go gate: "Every export either preserves every parallel edge OR +documents and tests an intentional projection/summarization." + +These tests exercise the four fixed exporters (``to_cypher``, ``to_obsidian``, +``to_canvas``, ``to_html``/``to_svg``) plus the natively-lossless ``to_json`` / +``to_graphml`` round-trips, and pin the simple-graph regression strings so the +single-relation path stays byte-stable against the pre-PR6 output. + +Fixture style mirrors ``tests/test_export.py`` (tempfile + ``build_from_json``). +""" + +import json +import re +import tempfile +from pathlib import Path + +import networkx as nx +from networkx.readwrite import json_graph + +from graphify.build import build_from_json +from graphify.edge_identity import make_stable_key +from graphify.export import ( + to_canvas, + to_cypher, + to_graphml, + to_html, + to_json, + to_obsidian, + to_svg, +) +from graphify.projections import DEFAULT_RELATIONSHIP_CAP + +# Relations on the A->B pair (3 parallel edges, distinct source_location). +AB_RELATIONS = ["calls", "imports", "contains"] +# Relations on the C->D pair (5 parallel edges, above DEFAULT_RELATIONSHIP_CAP). +CD_RELATIONS = ["calls", "imports", "contains", "extends", "uses"] + + +def make_multigraph() -> nx.MultiDiGraph: + """Build a MultiDiGraph with three pairs: + + - ``A->B``: 3 parallel edges (calls/imports/contains), distinct locations. + - ``C->D``: 5 parallel edges (> cap), distinct locations. + - ``E->F``: a single-edge simple-graph control inside the multigraph. + """ + nodes = [ + { + "id": n, + "label": n.upper(), + "file_type": "code", + "source_file": f"{n}.py", + "source_location": "L1", + } + for n in ("a", "b", "c", "d", "e", "f") + ] + edges = ( + [ + { + "source": "a", + "target": "b", + "relation": rel, + "confidence": "EXTRACTED", + "source_file": "a.py", + "source_location": f"L{i}", + } + for i, rel in enumerate(AB_RELATIONS) + ] + + [ + { + "source": "c", + "target": "d", + "relation": rel, + "confidence": "EXTRACTED", + "source_file": "c.py", + "source_location": f"L{i}", + } + for i, rel in enumerate(CD_RELATIONS) + ] + + [ + { + "source": "e", + "target": "f", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "e.py", + "source_location": "L1", + } + ] + ) + G = build_from_json({"nodes": nodes, "edges": edges}, multigraph=True) + assert isinstance(G, nx.MultiDiGraph) + # Sanity: 3 + 5 + 1 = 9 parallel edges preserved at build time. + assert G.number_of_edges() == 9 + return G + + +def make_simple_digraph() -> nx.DiGraph: + """Single-relation directed control graph for byte-stability regression.""" + extraction = { + "nodes": [ + { + "id": "A", + "label": "Alpha", + "file_type": "code", + "source_file": "a.py", + "source_location": "L1", + }, + { + "id": "B", + "label": "Beta", + "file_type": "code", + "source_file": "b.py", + "source_location": "L2", + }, + ], + "edges": [ + { + "source": "A", + "target": "B", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + "source_location": "L1", + } + ], + } + G = build_from_json(extraction, directed=True) + assert isinstance(G, nx.DiGraph) + return G + + +COMMUNITIES = {0: ["a", "b", "c", "d", "e", "f"]} + + +# ── Lossless round-trips (preserve every parallel edge) ────────────────────── + + +def test_json_roundtrip_preserves_all_parallel_edges(): + """to_json -> node_link_graph reconstructs every parallel edge.""" + G = make_multigraph() + original = G.number_of_edges() + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "graph.json" + to_json(G, COMMUNITIES, str(out), force=True) + data = json.loads(out.read_text()) + # node_link_data stamps multigraph/directed flags so the loader + # reconstructs a MultiDiGraph automatically. + assert data.get("multigraph") is True + assert data.get("directed") is True + G2 = json_graph.node_link_graph(data, edges="links") + assert isinstance(G2, nx.MultiDiGraph) + assert G2.number_of_edges() == original == 9 + + +def test_graphml_roundtrip_preserves_parallel_edges(): + """write_graphml -> read_graphml preserves the parallel edge count.""" + G = make_multigraph() + original = G.number_of_edges() + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "graph.graphml" + # Must not raise on the multigraph diagnostics graph-attr (dict value). + to_graphml(G, COMMUNITIES, str(out)) + G2 = nx.read_graphml(out) + assert G2.is_multigraph() + assert G2.number_of_edges() == original == 9 + + +# ── Cypher: one distinct relationship per parallel edge ────────────────────── + + +def test_cypher_emits_distinct_edge_per_parallel(): + """Each parallel edge produces its own MERGE with a distinct edge_key.""" + G = make_multigraph() + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "cypher.txt" + to_cypher(G, str(out)) + content = out.read_text() + + merge_lines = [ln for ln in content.splitlines() if ln.startswith("MATCH")] + # One MERGE per parallel edge — no Neo4j-side collapse. + assert len(merge_lines) == G.number_of_edges() == 9 + + edge_keys = re.findall(r"edge_key: '([^']+)'", content) + assert len(edge_keys) == 9 + # Every emitted relationship carries a globally distinct distinguishing key. + assert len(set(edge_keys)) == 9 + + # The three A->B parallel edges all sit between the same endpoints but keep + # distinct keys, so MERGE treats them as three relationships, not one. + ab_lines = [ln for ln in merge_lines if "{id: 'a'}" in ln and "{id: 'b'}" in ln] + assert len(ab_lines) == 3 + ab_keys = set() + for ln in ab_lines: + m = re.search(r"edge_key: '([^']+)'", ln) + assert m is not None + ab_keys.add(m.group(1)) + assert len(ab_keys) == 3 + + +# ── Canvas: globally unique edge ids + visual cap summary ──────────────────── + + +def test_canvas_edge_ids_unique(): + """Every canvas edge id is unique (no parallel-edge id collisions).""" + G = make_multigraph() + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "graph.canvas" + to_canvas(G, COMMUNITIES, str(out)) + data = json.loads(out.read_text()) + + edge_ids = [e["id"] for e in data["edges"]] + assert edge_ids, "canvas should contain edges" + assert len(edge_ids) == len(set(edge_ids)), "canvas edge ids must be unique" + + # Golden / deterministic ordering for the A->B trio (3 <= cap, all drawn). + ab_ids = sorted( + e["id"] for e in data["edges"] if e["fromNode"] == "n_a" and e["toNode"] == "n_b" + ) + assert ab_ids == ["e_a_b_0", "e_a_b_1", "e_a_b_2"] + + +def test_canvas_edge_ids_unique_when_node_ids_contain_underscores(): + """Tuple-concatenated ids must not collide for ambiguous underscore splits.""" + G = nx.MultiDiGraph() + for node_id in ["a_b", "c", "a", "b_c"]: + G.add_node(node_id, label=node_id) + G.add_edge("a_b", "c", relation="r", confidence="EXTRACTED", weight=1.0) + G.add_edge("a", "b_c", relation="s", confidence="EXTRACTED", weight=1.0) + + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "graph.canvas" + to_canvas(G, {0: list(G.nodes)}, str(out)) + data = json.loads(out.read_text()) + + edge_ids = [edge["id"] for edge in data["edges"]] + assert len(edge_ids) == 2 + assert len(edge_ids) == len(set(edge_ids)) + assert all(edge_id.startswith("e_a_b_c_0") for edge_id in edge_ids) + + +def test_canvas_visual_cap_summary(): + """A >cap pair draws at most cap+1 canvas edges with an overflow summary.""" + G = make_multigraph() + cap = DEFAULT_RELATIONSHIP_CAP + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "graph.canvas" + to_canvas(G, COMMUNITIES, str(out)) + data = json.loads(out.read_text()) + + cd_edges = [e for e in data["edges"] if e["fromNode"] == "n_c" and e["toNode"] == "n_d"] + # 5 parallel edges -> cap drawn + 1 summary edge. + assert len(cd_edges) == cap + 1 + cd_ids = sorted(e["id"] for e in cd_edges) + assert cd_ids == ["e_c_d_0", "e_c_d_1", "e_c_d_2", "e_c_d_summary"] + + summary = next(e for e in cd_edges if e["id"] == "e_c_d_summary") + # Envelope overflow text: "(+K more, N total)". + assert "more" in summary["label"] + assert "5 total" in summary["label"] + + +# ── Obsidian: all relations per neighbor (capped when > cap) ───────────────── + + +def test_obsidian_shows_all_relations(): + """to_obsidian lists every relation to a neighbor, capped when above cap.""" + G = make_multigraph() + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) + to_obsidian(G, COMMUNITIES, str(out)) + a_note = (out / "A.md").read_text() + c_note = (out / "C.md").read_text() + + a_conn = [ln for ln in a_note.splitlines() if ln.startswith("- [[")] + assert len(a_conn) == 1 + # All three A->B relations are listed, not just the first edge. Assert on the + # SET of relations present (and the wikilink prefix) rather than a pinned + # joined order, so a future envelope ordering change does not false-positive. + assert a_conn[0].startswith("- [[B]] - ") + for rel in AB_RELATIONS: + assert rel in a_conn[0] + # No overflow marker — 3 relations is within DEFAULT_RELATIONSHIP_CAP. + assert "more" not in a_conn[0] + + # The 5-relation C->D bundle renders the capped envelope form. + c_conn = [ln for ln in c_note.splitlines() if ln.startswith("- [[")] + assert len(c_conn) == 1 + assert "more" in c_conn[0] + assert "5 total" in c_conn[0] + + +# ── HTML / SVG: visual cap + summary label ─────────────────────────────────── + + +def test_html_svg_visual_cap(): + """HTML and SVG cap parallel edges and surface an overflow summary label.""" + G = make_multigraph() + cap = DEFAULT_RELATIONSHIP_CAP + with tempfile.TemporaryDirectory() as tmp: + html_out = Path(tmp) / "graph.html" + to_html(G, COMMUNITIES, str(html_out)) + html = html_out.read_text() + + svg_out = Path(tmp) / "graph.svg" + to_svg(G, COMMUNITIES, str(svg_out), community_labels={0: "Group 0"}) + svg = svg_out.read_text() + + # Summary label for the 5-parallel C->D pair appears in both surfaces. + assert "5 total" in html + assert f"+{len(CD_RELATIONS) - cap} more" in html + assert "5 total" in svg + + # The HTML edge dataset draws at most cap "real" C->D edges plus one summary. + # Parse RAW_EDGES out of the embedded script to count C->D draws precisely. + m = re.search(r"const RAW_EDGES = (\[.*?\]);", html, re.DOTALL) + assert m, "RAW_EDGES array must be embedded in the HTML" + raw_edges = json.loads(m.group(1)) + cd_real = [ + e + for e in raw_edges + if e.get("from") == "c" and e.get("to") == "d" and e.get("confidence") != "SUMMARY" + ] + cd_summary = [ + e + for e in raw_edges + if e.get("from") == "c" and e.get("to") == "d" and e.get("confidence") == "SUMMARY" + ] + assert len(cd_real) == cap + assert len(cd_summary) == 1 + + +# ── Regression: canvas summary edges must not evict real edges (BLOCK 1) ────── + + +def test_canvas_summary_does_not_displace_real_edges_over_cap(): + """With > 200 real edges, the 200-cap keeps the highest-weight REAL edges and + summary edges are strictly additive (never evict a real edge). + + Reproduces the priority-inversion bug: summary edges were pushed into the + weighted top-200 selection with ``float("inf")`` weight, sorting to the FRONT + and displacing the 201st-highest-weight real edge. A graph with 210 ascending- + weight single-edge pairs PLUS one low-weight 5-parallel overflow pair must: + - emit exactly 200 real edges (no summary stealing a real slot), + - retain the highest-weight real edge, + - drop the lowest-weight real edge (legitimately over the 200-cap). + """ + G = nx.MultiDiGraph() + members: list[str] = [] + for i in range(210): + a, b = f"a{i}", f"b{i}" + G.add_node(a, label=a) + G.add_node(b, label=b) + # Ascending weights 1..210 so ordering is unambiguous. + G.add_edge( + a, + b, + relation="calls", + confidence="EXTRACTED", + source_file="f.py", + source_location=f"L{i}", + weight=float(i + 1), + ) + members += [a, b] + # Low-weight overflow pair (5 parallels) — its reals are below the cap line. + G.add_node("X", label="X") + G.add_node("Y", label="Y") + for j in range(5): + G.add_edge( + "X", + "Y", + relation=f"r{j}", + confidence="EXTRACTED", + source_file="x.py", + source_location=f"LX{j}", + weight=0.1, + ) + members += ["X", "Y"] + + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "graph.canvas" + to_canvas(G, {0: members}, str(out)) + edges = json.loads(out.read_text())["edges"] + + real = [e for e in edges if not e["id"].endswith("_summary")] + # Real edges are capped at EXACTLY 200 — a summary never consumed a real slot. + assert len(real) == 200 + # Highest-weight real edge survives; lowest-weight one is legitimately dropped. + assert any(e["fromNode"] == "n_a209" and e["toNode"] == "n_b209" for e in real) + assert not any(e["fromNode"] == "n_a0" and e["toNode"] == "n_b0" for e in real) + # All ids remain globally unique. + ids = [e["id"] for e in edges] + assert len(ids) == len(set(ids)) + + +def test_canvas_summary_additive_when_overflow_pair_survives(): + """When a high-weight overflow pair survives the 200-cap, its summary edge is + ADDED on top of the 200 real edges (total > 200), not in place of one.""" + G = nx.MultiDiGraph() + members: list[str] = [] + for i in range(199): + a, b = f"a{i}", f"b{i}" + G.add_node(a, label=a) + G.add_node(b, label=b) + G.add_edge( + a, + b, + relation="calls", + confidence="EXTRACTED", + source_file="f.py", + source_location=f"L{i}", + weight=1.0, + ) + members += [a, b] + G.add_node("X", label="X") + G.add_node("Y", label="Y") + for j in range(5): + G.add_edge( + "X", + "Y", + relation=f"r{j}", + confidence="EXTRACTED", + source_file="x.py", + source_location=f"LX{j}", + weight=100.0, # high weight -> overflow pair's reals survive the cap + ) + members += ["X", "Y"] + + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "graph.canvas" + to_canvas(G, {0: members}, str(out)) + edges = json.loads(out.read_text())["edges"] + + real = [e for e in edges if not e["id"].endswith("_summary")] + summary = [e for e in edges if e["id"].endswith("_summary")] + assert len(real) == 200 # real edges still capped at 200 + assert len(summary) == 1 # summary additive (201 total) + xy_summary = [e for e in summary if e["fromNode"] == "n_X" and e["toNode"] == "n_Y"] + assert len(xy_summary) == 1 + assert "5 total" in xy_summary[0]["label"] + ids = [e["id"] for e in edges] + assert len(ids) == len(set(ids)) + + +# ── Regression: integer positional keys distinguish parallels (BLOCK 2) ─────── + + +def test_cypher_distinguishes_parallels_with_identical_identity_fields(): + """Parallel edges that share IDENTICAL relation/source_file/source_location + still get DISTINCT edge_keys, so Neo4j MERGE preserves all of them. + + Reproduces the integer-key drop bug: a directly-constructed MultiDiGraph + yields INTEGER positional keys (0, 1, 2…). The old ``isinstance(key, str)`` + guard discarded them and fell back to make_stable_key(relation, file, + location) — identical for every edge here — collapsing them to ONE edge_key + and letting MERGE dedup the parallels. The fix accepts any non-None + positional key (stringified), which NetworkX guarantees unique per (u, v). + """ + G = nx.MultiDiGraph() + G.add_node("A", label="Alpha", file_type="code") + G.add_node("B", label="Beta", file_type="code") + # Three parallel edges, byte-identical semantic identity fields. + for _ in range(3): + G.add_edge( + "A", + "B", + relation="calls", + confidence="EXTRACTED", + source_file="a.py", + source_location="L1", + ) + # Positional keys are integers (NetworkX default). + positional_keys = [k for _u, _v, k in G.edges(keys=True)] + assert positional_keys == [0, 1, 2] + assert all(isinstance(k, int) for k in positional_keys) + + with tempfile.TemporaryDirectory() as tmp: + out = Path(tmp) / "cypher.txt" + to_cypher(G, str(out)) + content = out.read_text() + + merge_lines = [ln for ln in content.splitlines() if ln.startswith("MATCH")] + assert len(merge_lines) == G.number_of_edges() == 3 + edge_keys = re.findall(r"edge_key: '([^']+)'", content) + assert len(edge_keys) == 3 + # The crux: distinct edge_key per parallel edge despite identical identity + # fields — count distinct == parallel count, so MERGE keeps all three. + assert len(set(edge_keys)) == 3 + + +# ── Simple-graph regression: byte-stable single-relation output ────────────── + + +def test_export_simple_graph_regression(): + """Single-relation DiGraph output is pinned exactly (pre-PR6 stability). + + The Cypher line gains a documented `edge_key` property (required so Neo4j + MERGE never collapses parallel edges); the canvas id gains a `_0` parallel + suffix. Obsidian's single-relation Connections line is byte-identical to the + pre-PR6 ``- [[label]] - `relation` [confidence]`` form. + """ + G = make_simple_digraph() + comm = {0: ["A", "B"]} + expected_key = make_stable_key("calls", "a.py", "L1") + + with tempfile.TemporaryDirectory() as tmp: + # Cypher — exact line including the new edge_key property. + cypher_out = Path(tmp) / "cypher.txt" + to_cypher(G, str(cypher_out)) + cypher_lines = [ln for ln in cypher_out.read_text().splitlines() if ln.startswith("MATCH")] + assert cypher_lines == [ + "MATCH (a {id: 'A'}), (b {id: 'B'}) " + f"MERGE (a)-[:CALLS {{edge_key: '{expected_key}', confidence: 'EXTRACTED'}}]->(b);" + ] + + # Canvas — single edge keeps deterministic `_0` parallel suffix. + canvas_out = Path(tmp) / "graph.canvas" + to_canvas(G, comm, str(canvas_out)) + canvas_edges = json.loads(canvas_out.read_text())["edges"] + assert canvas_edges == [ + { + "id": "e_A_B_0", + "fromNode": "n_A", + "toNode": "n_B", + "label": "calls [EXTRACTED]", + } + ] + + # Obsidian — byte-identical to the historical single-relation form. + obs_out = Path(tmp) / "vault" + to_obsidian(G, comm, str(obs_out)) + conn_lines = [ + ln for ln in (obs_out / "Alpha.md").read_text().splitlines() if ln.startswith("- [[") + ] + assert conn_lines == ["- [[Beta]] - `calls` [EXTRACTED]"] + + # HTML — single edge, no summary edge injected. + html_out = Path(tmp) / "graph.html" + to_html(G, comm, str(html_out)) + html = html_out.read_text() + m = re.search(r"const RAW_EDGES = (\[.*?\]);", html, re.DOTALL) + assert m + raw_edges = json.loads(m.group(1)) + assert len(raw_edges) == 1 + assert raw_edges[0]["from"] == "A" + assert raw_edges[0]["to"] == "B" + assert raw_edges[0]["confidence"] == "EXTRACTED" diff --git a/tests/test_extract.py b/tests/test_extract.py index 0d5db2c5a..2712a6ba3 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -1,5 +1,13 @@ from pathlib import Path -from graphify.extract import extract_python, extract, collect_files, _make_id, extract_bash, extract_json, _DISPATCH +from graphify.extract import ( + extract_python, + extract, + collect_files, + _make_id, + extract_bash, + extract_json, + _DISPATCH, +) FIXTURES = Path(__file__).parent / "fixtures" @@ -29,7 +37,7 @@ def test_extract_python_finds_class(): def test_extract_python_finds_methods(): result = extract_python(FIXTURES / "sample.py") labels = [n["label"] for n in result["nodes"]] - assert any("__init__" in l or "forward" in l for l in labels) + assert any("__init__" in label or "forward" in label for label in labels) def test_extract_python_no_dangling_edges(): @@ -66,7 +74,8 @@ def test_extract_disambiguates_duplicate_symbol_ids_by_source_path(tmp_path): result = extract([first, second], cache_root=tmp_path) program_nodes = [ - node for node in result["nodes"] + node + for node in result["nodes"] if node["label"] == "Program" and node.get("source_file", "").endswith("Program.cs") ] @@ -76,14 +85,13 @@ def test_extract_disambiguates_duplicate_symbol_ids_by_source_path(tmp_path): node_ids = {node["id"] for node in result["nodes"]} program_by_source = {node["source_file"]: node["id"] for node in program_nodes} file_nodes_by_source = { - node["source_file"]: node["id"] - for node in result["nodes"] - if node["label"] == "Program.cs" + node["source_file"]: node["id"] for node in result["nodes"] if node["label"] == "Program.cs" } assert set(program_by_source) == set(file_nodes_by_source) contains_edges = [ - edge for edge in result["edges"] + edge + for edge in result["edges"] if edge["relation"] == "contains" and edge["source_file"] in program_by_source ] assert len(contains_edges) == 2 @@ -129,14 +137,16 @@ def test_extract_rewires_unique_inheritance_stub_to_real_definition(tmp_path): inherits_edges = [edge for edge in result["edges"] if edge["relation"] == "inherits"] matching = [ - edge for edge in inherits_edges + edge + for edge in inherits_edges if node_by_id[edge["source"]]["label"] == "SqliteBookStore" and node_by_id[edge["target"]]["label"] == "BookStore" ] assert matching assert matching[0]["target"] == next( - node["id"] for node in result["nodes"] + node["id"] + for node in result["nodes"] if node["label"] == "BookStore" and node.get("source_file") == "interfaces.py" ) assert all( @@ -158,7 +168,8 @@ def test_extract_keeps_stub_when_multiple_real_definitions_match(tmp_path): result = extract([first, second, implementation], cache_root=tmp_path) stubs = [ - node for node in result["nodes"] + node + for node in result["nodes"] if node["label"] == "BookStore" and not node.get("source_file") ] @@ -177,8 +188,7 @@ def test_extract_does_not_rewire_inheritance_stub_to_same_named_function(tmp_pat inherits_edges = [edge for edge in result["edges"] if edge["relation"] == "inherits"] assert any( - node["label"] == "BookStore" and not node.get("source_file") - for node in result["nodes"] + node["label"] == "BookStore" and not node.get("source_file") for node in result["nodes"] ) assert not any( node_by_id[edge["source"]]["label"] == "SqliteBookStore" @@ -190,27 +200,20 @@ def test_extract_does_not_rewire_inheritance_stub_to_same_named_function(tmp_pat def test_extract_does_not_rewire_constructor_method_to_same_named_class(tmp_path): source = tmp_path / "Sample.java" source.write_text( - "class DataProcessor {\n" - " public DataProcessor() {}\n" - "}\n", + "class DataProcessor {\n public DataProcessor() {}\n}\n", encoding="utf-8", ) result = extract([source], cache_root=tmp_path) - constructor_nodes = [ - node for node in result["nodes"] - if node["label"] == ".DataProcessor()" - ] + constructor_nodes = [node for node in result["nodes"] if node["label"] == ".DataProcessor()"] assert constructor_nodes - assert not any( - edge["source"] == edge["target"] - for edge in result["edges"] - ) + assert not any(edge["source"] == edge["target"] for edge in result["edges"]) def test_collect_files_from_dir(): from graphify.extract import _DISPATCH + files = collect_files(FIXTURES) supported = set(_DISPATCH.keys()) assert all(f.suffix in supported for f in files) @@ -339,8 +342,7 @@ def test_cross_file_calls_skip_ambiguous_duplicate_labels(tmp_path): result = extract([caller, helper_a, helper_b], cache_root=tmp_path) nodes = {n["id"]: n for n in result["nodes"]} calls = [ - e for e in result["edges"] - if e["relation"] == "calls" and e["confidence"] == "INFERRED" + e for e in result["edges"] if e["relation"] == "calls" and e["confidence"] == "INFERRED" ] assert not any( @@ -362,15 +364,17 @@ def test_extract_generic_surfaces_tree_sitter_version_mismatch_hint(monkeypatch) # this is exactly what users see when an older tree-sitter is paired # with a newer language binding. fake_ts = types.ModuleType("tree_sitter") + def _raise(*args, **kwargs): raise TypeError("missing 1 required positional argument: 'name'") - fake_ts.Language = _raise - fake_ts.Parser = None + + setattr(fake_ts, "Language", _raise) + setattr(fake_ts, "Parser", None) monkeypatch.setitem(sys.modules, "tree_sitter", fake_ts) # Stub the language module so import_module returns something with .language fake_lang_mod = types.ModuleType("fake_ts_lang") - fake_lang_mod.language = lambda: object() + setattr(fake_lang_mod, "language", lambda: object()) monkeypatch.setitem(sys.modules, "fake_ts_lang", fake_lang_mod) config = LanguageConfig(ts_module="fake_ts_lang", ts_language_fn="language") @@ -384,6 +388,7 @@ def _raise(*args, **kwargs): def test_extract_js_destructured_require_imports_from(): """`const { foo } = require('./mod')` must emit imports_from to the resolved module path.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "cjs_require.js") imports_from = [e for e in result["edges"] if e["relation"] == "imports_from"] targets = [e["target"] for e in imports_from] @@ -398,6 +403,7 @@ def test_extract_js_destructured_require_imports_from(): def test_extract_js_destructured_require_named_symbols(): """Destructured CJS requires must emit symbol-level `imports` edges per binder.""" from graphify.extract import extract_js, _make_id, _file_stem + result = extract_js(FIXTURES / "cjs_require.js") sym_targets = [e["target"] for e in result["edges"] if e["relation"] == "imports"] foundation_stem = _file_stem(FIXTURES / "foundation.js") @@ -408,6 +414,7 @@ def test_extract_js_destructured_require_named_symbols(): def test_extract_js_member_require_emits_property_symbol(): """`const x = require('./m').y` must emit symbol edge for `y`.""" from graphify.extract import extract_js, _make_id, _file_stem + result = extract_js(FIXTURES / "cjs_require.js") sym_targets = [e["target"] for e in result["edges"] if e["relation"] == "imports"] helpers_stem = _file_stem(FIXTURES / "helpers.js") @@ -417,6 +424,7 @@ def test_extract_js_member_require_emits_property_symbol(): def test_extract_js_arrow_function_still_extracted(): """Regression: arrow functions in lexical_declaration must still produce nodes.""" from graphify.extract import extract_js + arrow_fixture = FIXTURES / "_arrow_only.js" arrow_fixture.write_text("const greet = () => console.log('hi');\n") try: @@ -432,18 +440,13 @@ def test_cross_file_call_promoted_to_extracted_with_import_evidence(tmp_path): an `imports` or `imports_from` edge linking it to the callee.""" caller = tmp_path / "caller.js" callee = tmp_path / "lib.js" - caller.write_text( - "const { doWork } = require('./lib');\n" - "function run() { doWork(); }\n" - ) - callee.write_text( - "function doWork() { return 1; }\n" - "module.exports = { doWork };\n" - ) + caller.write_text("const { doWork } = require('./lib');\nfunction run() { doWork(); }\n") + callee.write_text("function doWork() { return 1; }\nmodule.exports = { doWork };\n") result = extract([caller, callee], cache_root=tmp_path) nodes = {n["id"]: n for n in result["nodes"]} call_edges = [ - e for e in result["edges"] + e + for e in result["edges"] if e["relation"] == "calls" and nodes[e["source"]]["label"] == "run()" and nodes[e["target"]]["label"] == "doWork()" @@ -460,14 +463,12 @@ def test_cross_file_call_remains_inferred_without_import_evidence(tmp_path): callee = tmp_path / "lib.js" # Caller does NOT require lib — same-name function happens to exist elsewhere caller.write_text("function run() { doUnique(); }\n") - callee.write_text( - "function doUnique() { return 1; }\n" - "module.exports = { doUnique };\n" - ) + callee.write_text("function doUnique() { return 1; }\nmodule.exports = { doUnique };\n") result = extract([caller, callee], cache_root=tmp_path) nodes = {n["id"]: n for n in result["nodes"]} call_edges = [ - e for e in result["edges"] + e + for e in result["edges"] if e["relation"] == "calls" and nodes[e["source"]]["label"] == "run()" and nodes[e["target"]]["label"] == "doUnique()" @@ -481,14 +482,16 @@ def test_cross_file_call_remains_inferred_without_import_evidence(tmp_path): # `language_typescript` grammar. Parsing JSX with the wrong grammar produces # silent ERROR nodes and drops every function/call inside JSX trees. + def test_extract_tsx_finds_helpers_and_component(): """Functions defined alongside a JSX-returning component must be captured.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "sample.tsx") labels = [n["label"] for n in result["nodes"]] - assert any("fmtDate" in l for l in labels), f"fmtDate missing from {labels}" - assert any("fmtCount" in l for l in labels), f"fmtCount missing from {labels}" - assert any("App" in l for l in labels), f"App missing from {labels}" + assert any("fmtDate" in label for label in labels), f"fmtDate missing from {labels}" + assert any("fmtCount" in label for label in labels), f"fmtCount missing from {labels}" + assert any("App" in label for label in labels), f"App missing from {labels}" def test_extract_tsx_jsx_expression_calls_resolve(): @@ -498,6 +501,7 @@ def test_extract_tsx_jsx_expression_calls_resolve(): JSX is parsed as ERROR nodes and these call_expressions disappear. """ from graphify.extract import extract_js + result = extract_js(FIXTURES / "sample.tsx") nodes_by_id = {n["id"]: n for n in result["nodes"]} call_targets = { @@ -516,6 +520,7 @@ def test_extract_tsx_jsx_expression_calls_resolve(): def test_extract_tsx_uses_tsx_grammar(): """Wiring check: the .tsx config must use tree-sitter's `language_tsx`.""" from graphify.extract import _TSX_CONFIG, _TS_CONFIG + assert _TSX_CONFIG.ts_language_fn == "language_tsx" assert _TS_CONFIG.ts_language_fn == "language_typescript" @@ -526,6 +531,7 @@ def test_extract_tsx_uses_tsx_grammar(): # detect this, warn, and fall back to sequential extraction rather than # propagating a 290-line traceback. + def test_extract_falls_back_to_sequential_when_parallel_returns_false(tmp_path, monkeypatch): """extract() must run sequential when _extract_parallel signals failure (returns False).""" from graphify import extract as extract_mod @@ -561,15 +567,19 @@ def test_extract_parallel_returns_false_on_broken_pool(tmp_path, monkeypatch, ca from graphify import extract as extract_mod class FakePool: - def __init__(self, *a, **kw): pass - def __enter__(self): return self - def __exit__(self, *a): return False + def __init__(self, *a, **kw): + pass + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + def submit(self, *a, **kw): raise BrokenProcessPool("simulated spawn failure") - monkeypatch.setattr( - concurrent.futures, "ProcessPoolExecutor", lambda *a, **kw: FakePool() - ) + monkeypatch.setattr(concurrent.futures, "ProcessPoolExecutor", lambda *a, **kw: FakePool()) uncached = [(0, FIXTURES / "sample.py")] per_file: list = [None] @@ -580,10 +590,68 @@ def submit(self, *a, **kw): assert "__main__" in out, "warning must hint at the Windows __main__ guard idiom" +def test_extract_parallel_returns_false_when_pool_unavailable(tmp_path, monkeypatch, capsys): + """ProcessPoolExecutor setup OSErrors must fall back to sequential extraction.""" + import concurrent.futures + from graphify import extract as extract_mod + + def raise_permission_error(*args, **kwargs): + raise PermissionError("semaphore probe denied") + + monkeypatch.setattr(concurrent.futures, "ProcessPoolExecutor", raise_permission_error) + + uncached = [(0, FIXTURES / "sample.py")] + per_file: list = [None] + + ok = extract_mod._extract_parallel(uncached, per_file, tmp_path, 2, 1) + + assert ok is False + out = capsys.readouterr().out + assert "parallel extraction unavailable" in out + assert "PermissionError" in out + assert "falling back to sequential" in out + + +def test_extract_parallel_worker_warning_handles_sparse_file_indexes(tmp_path, monkeypatch, capsys): + """Worker-failure warnings must not index work_items by original file index.""" + import concurrent.futures + from graphify import extract as extract_mod + + class FakePool: + def __init__(self, *a, **kw): + pass + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + def submit(self, fn, item): + future: concurrent.futures.Future = concurrent.futures.Future() + future.set_exception(RuntimeError("simulated worker failure")) + return future + + monkeypatch.setattr(concurrent.futures, "ProcessPoolExecutor", lambda *a, **kw: FakePool()) + + source = tmp_path / "late.py" + source.write_text("x = 1\n", encoding="utf-8") + uncached = [(3, source)] + per_file: list = [None, None, None, None] + + ok = extract_mod._extract_parallel(uncached, per_file, tmp_path, 2, 4) + + assert ok is True + err = capsys.readouterr().err + assert "late.py" in err + assert "simulated worker failure" in err + + # --------------------------------------------------------------------------- # Bash extractor tests (#866) # --------------------------------------------------------------------------- + def test_dispatch_includes_sh_and_json(): assert ".sh" in _DISPATCH assert ".bash" in _DISPATCH @@ -626,7 +694,7 @@ def test_extract_bash_emits_source_imports_from(tmp_path): helpers = tmp_path / "helpers.sh" helpers.write_text("# helper\n") script = tmp_path / "deploy.sh" - script.write_text(f"#!/bin/bash\nsource ./helpers.sh\nfoo() {{ echo hi; }}\n") + script.write_text("#!/bin/bash\nsource ./helpers.sh\nfoo() { echo hi; }\n") result = extract_bash(script) import_edges = [e for e in result["edges"] if e["relation"] == "imports_from"] assert len(import_edges) >= 1 @@ -661,6 +729,7 @@ def test_extract_bash_missing_grammar_returns_error(): """extract_bash returns error dict when tree-sitter-bash not installed (mocked).""" import unittest.mock as mock import builtins + real_import = builtins.__import__ def patched(name, *args, **kwargs): @@ -677,11 +746,7 @@ def patched(name, *args, **kwargs): def test_extract_bash_rejects_command_substitution_as_call(tmp_path): """`$(build)` must not be recorded as a call edge to build().""" script = tmp_path / "command_substitution.sh" - script.write_text( - "#!/usr/bin/env bash\n" - "build() { echo build; }\n" - "$(build)\n" - ) + script.write_text("#!/usr/bin/env bash\nbuild() { echo build; }\n$(build)\n") result = extract_bash(script) labels = {n["id"]: n["label"] for n in result["nodes"]} call_pairs = [ @@ -695,11 +760,7 @@ def test_extract_bash_rejects_command_substitution_as_call(tmp_path): def test_extract_bash_process_substitution_not_recorded(tmp_path): """`<(helper)` (process substitution) must not be recorded as a call edge.""" script = tmp_path / "process_substitution.sh" - script.write_text( - "#!/usr/bin/env bash\n" - "helper() { echo h; }\n" - "diff <(helper) <(helper)\n" - ) + script.write_text("#!/usr/bin/env bash\nhelper() { echo h; }\ndiff <(helper) <(helper)\n") result = extract_bash(script) labels = {n["id"]: n["label"] for n in result["nodes"]} call_pairs = [ @@ -713,11 +774,7 @@ def test_extract_bash_process_substitution_not_recorded(tmp_path): def test_extract_bash_shadowing_function_is_recorded(tmp_path): """User-defined function shadowing an external command (install/find/etc.) must still produce a call edge.""" script = tmp_path / "shadowing.sh" - script.write_text( - "#!/usr/bin/env bash\n" - "install() { echo install; }\n" - "deploy() { install; }\n" - ) + script.write_text("#!/usr/bin/env bash\ninstall() { echo install; }\ndeploy() { install; }\n") result = extract_bash(script) labels = {n["id"]: n["label"] for n in result["nodes"]} call_pairs = [ @@ -739,10 +796,15 @@ def test_extract_bash_creates_entrypoint_node(tmp_path): assert "bash_entrypoint" in kinds, f"No bash_entrypoint node; kinds={kinds}" assert "file" in kinds, f"No file node; kinds={kinds}" file_node = next(n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "file") - entry_node = next(n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint") + entry_node = next( + n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint" + ) contains_edges = [ - e for e in result["edges"] - if e["relation"] == "contains" and e["source"] == file_node["id"] and e["target"] == entry_node["id"] + e + for e in result["edges"] + if e["relation"] == "contains" + and e["source"] == file_node["id"] + and e["target"] == entry_node["id"] ] assert contains_edges, "Missing contains edge from file → bash_entrypoint" @@ -750,23 +812,19 @@ def test_extract_bash_creates_entrypoint_node(tmp_path): def test_extract_bash_top_level_call_attributes_to_entrypoint(tmp_path): """Top-level function call attaches to the entrypoint node, not orphaned.""" script = tmp_path / "top_level_call.sh" - script.write_text( - "#!/usr/bin/env bash\n" - "build() { echo build; }\n" - "build\n" - ) + script.write_text("#!/usr/bin/env bash\nbuild() { echo build; }\nbuild\n") result = extract_bash(script) entry_node = next( (n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint"), None, ) assert entry_node is not None, "No entrypoint node created" - call_pairs = [ - (e["source"], e["target"]) - for e in result["edges"] - if e["relation"] == "calls" - ] - target_ids = {tgt for _, tgt in call_pairs if any(n["id"] == tgt and n["label"] == "build()" for n in result["nodes"])} + call_pairs = [(e["source"], e["target"]) for e in result["edges"] if e["relation"] == "calls"] + target_ids = { + tgt + for _, tgt in call_pairs + if any(n["id"] == tgt and n["label"] == "build()" for n in result["nodes"]) + } source_ids_to_build = {src for src, tgt in call_pairs if tgt in target_ids} assert entry_node["id"] in source_ids_to_build, ( f"Top-level call to build not attributed to entrypoint; calls={call_pairs}" @@ -788,8 +846,12 @@ def test_extract_bash_entrypoint_no_collision_with_function_named_script(tmp_pat script = tmp_path / "deploy.sh" script.write_text("#!/usr/bin/env bash\nfunction script() { echo hi; }\n") result = extract_bash(script) - entry_nodes = [n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint"] - func_nodes = [n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_function"] + entry_nodes = [ + n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint" + ] + func_nodes = [ + n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_function" + ] assert entry_nodes, "Must have a bash_entrypoint node" assert func_nodes, "Must have a bash_function node for 'script'" entry_id = entry_nodes[0]["id"] @@ -814,8 +876,12 @@ def test_extract_bash_nested_function_calls_recorded(tmp_path): ) result = extract_bash(script) node_id_by_label = {n["label"].rstrip("()"): n["id"] for n in result["nodes"]} - assert "inner" in node_id_by_label, f"inner function must be discovered; labels={list(node_id_by_label)}" - assert "do_work" in node_id_by_label, f"do_work function must be discovered; labels={list(node_id_by_label)}" + assert "inner" in node_id_by_label, ( + f"inner function must be discovered; labels={list(node_id_by_label)}" + ) + assert "do_work" in node_id_by_label, ( + f"do_work function must be discovered; labels={list(node_id_by_label)}" + ) calls = {(e["source"], e["target"]) for e in result["edges"] if e.get("relation") == "calls"} inner_id = node_id_by_label["inner"] do_work_id = node_id_by_label["do_work"] @@ -832,9 +898,7 @@ def test_extract_bash_source_user_defined_emits_calls_not_imports_from(tmp_path) helpers.write_text("#!/bin/bash\n") script = tmp_path / "run.sh" script.write_text( - "#!/usr/bin/env bash\n" - "function source() { echo 'custom source'; }\n" - "source ./helpers.sh\n" + "#!/usr/bin/env bash\nfunction source() { echo 'custom source'; }\nsource ./helpers.sh\n" ) result = extract_bash(script) import_edges = [e for e in result["edges"] if e.get("relation") == "imports_from"] @@ -847,6 +911,7 @@ def test_extract_bash_source_user_defined_emits_calls_not_imports_from(tmp_path) # JSON extractor tests (#866) # --------------------------------------------------------------------------- + def test_extract_json_top_level_keys(): result = extract_json(FIXTURES / "sample.json") assert "error" not in result @@ -907,12 +972,14 @@ def test_extract_json_no_self_loops(): def test_extract_bash_via_dispatch(): from graphify.extract import _get_extractor + assert _get_extractor(Path("foo.sh")) is extract_bash assert _get_extractor(Path("foo.bash")) is extract_bash def test_extract_json_via_dispatch(): from graphify.extract import _get_extractor + assert _get_extractor(Path("foo.json")) is extract_json @@ -938,6 +1005,7 @@ def test_extract_bash_node_metadata_is_sanitized(): def test_barrel_reexport_emits_re_exports_edges(): """export { X } from './mod' must emit re_exports edges for each named specifier.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") reexports = [e for e in result["edges"] if e["relation"] == "re_exports"] targets = [e["target"] for e in reexports] @@ -952,6 +1020,7 @@ def test_barrel_reexport_emits_re_exports_edges(): def test_barrel_reexport_emits_imports_from(): """Barrel file must emit file-level imports_from edges to source modules.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") imports_from = [e for e in result["edges"] if e["relation"] == "imports_from"] targets = [e["target"] for e in imports_from] @@ -963,6 +1032,7 @@ def test_barrel_reexport_emits_imports_from(): def test_barrel_reexport_context_tagged(): """re_exports edges should have context='re-export'.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") reexports = [e for e in result["edges"] if e["relation"] == "re_exports"] for e in reexports: @@ -972,6 +1042,7 @@ def test_barrel_reexport_context_tagged(): def test_barrel_local_exports_still_extracted(): """export function/const in a barrel file must still create nodes.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") labels = [n["label"] for n in result["nodes"]] assert "localHelper()" in labels or "localHelper" in labels @@ -982,6 +1053,7 @@ def test_barrel_local_exports_still_extracted(): def test_barrel_reexport_confidence_extracted(): """All re_exports edges should have confidence=EXTRACTED.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") reexports = [e for e in result["edges"] if e["relation"] == "re_exports"] for e in reexports: @@ -1015,6 +1087,7 @@ def test_pure_export_no_from_not_treated_as_reexport(): """export { localVar } without 'from' should NOT create re_exports edges.""" from graphify.extract import extract_js import tempfile + code = b"const x = 1;\nexport { x };\n" with tempfile.NamedTemporaryFile(suffix=".ts", delete=False) as f: f.write(code) @@ -1035,8 +1108,8 @@ def test_dart_child_node_ids_are_stem_based(tmp_path): result = extract_dart(src_file) stem = _file_stem(src_file) # -> "mydir.sample" - expected_class_nid = _make_id(stem, "MyClass") # -> "mydir_sample_myclass" - expected_func_nid = _make_id(stem, "myFunc") # -> "mydir_sample_myfunc" + expected_class_nid = _make_id(stem, "MyClass") # -> "mydir_sample_myclass" + expected_func_nid = _make_id(stem, "myFunc") # -> "mydir_sample_myfunc" node_ids = {n["id"] for n in result["nodes"]} @@ -1054,9 +1127,9 @@ def test_dart_child_node_ids_are_stem_based(tmp_path): for node in result["nodes"]: if node["id"] == file_nid: continue - assert "_" + stem.replace(".", "_") in node["id"] or node["id"].startswith(stem.replace(".", "_")), ( + assert "_" + stem.replace(".", "_") in node["id"] or node["id"].startswith( + stem.replace(".", "_") + ), ( f"Child node ID '{node['id']}' does not start with the expected stem prefix '{stem}'. " "This suggests an absolute path is still leaking into the ID." ) - - diff --git a/tests/test_extract_cli.py b/tests/test_extract_cli.py index 6998bcce9..4d4be4087 100644 --- a/tests/test_extract_cli.py +++ b/tests/test_extract_cli.py @@ -1,9 +1,369 @@ """Tests for `graphify extract` CLI dispatch path in graphify.__main__.""" + from __future__ import annotations +import json +import os +import subprocess +import sys +from pathlib import Path + +import networkx as nx import pytest import graphify.__main__ as mainmod +from graphify.build import build_from_json +from graphify.export import to_json +from graphify.graph_loader import load_graph +from graphify.llm import BACKENDS, _backend_env_keys + + +PYTHON = sys.executable + + +def _clean_env() -> dict: + """Return os.environ with every backend API key stripped out. + + Mirrors tests/test_incremental._clean_env so subprocess runs do not pick up + a real key from the developer's shell and accidentally hit a live LLM. + """ + env = dict(os.environ) + for backend in BACKENDS: + for env_key in _backend_env_keys(backend): + env.pop(env_key, None) + for extra in ( + "AWS_PROFILE", + "AWS_REGION", + "AWS_DEFAULT_REGION", + "OLLAMA_BASE_URL", + "OLLAMA_API_KEY", + ): + env.pop(extra, None) + return env + + +def _run(args: list[str], cwd: Path, *, env: dict | None = None) -> subprocess.CompletedProcess: + """Run `python -m graphify ` as a sanitized subprocess.""" + return subprocess.run( + [PYTHON, "-m", "graphify"] + args, + cwd=cwd, + capture_output=True, + text=True, + env=env if env is not None else _clean_env(), + ) + + +def _make_code_corpus(tmp_path: Path) -> Path: + """A tiny AST-only code corpus — no docs, so semantic/LLM extraction never runs. + + The functions reference each other so AST extraction produces real edges. + """ + corpus = tmp_path / "corpus" + corpus.mkdir() + (corpus / "app.py").write_text( + "def helper():\n return 1\n\n\n" + "def main():\n return helper()\n\n\n" + "def extra():\n return main()\n", + encoding="utf-8", + ) + return corpus + + +def _write_multidigraph_graph_json(corpus: Path) -> Path: + """Seed corpus/graphify-out/graph.json as a multidigraph with parallel edges. + + Built exactly the way the pipeline persists it (build_from_json multigraph=True + -> export.to_json), so the file carries the top-level ``multigraph: true`` flag + and ``graphify_profile.graph_type == multidigraph``. Two parallel main->helper + edges (different relations) prove parallels survive a sticky re-extract. + """ + nodes = [ + { + "id": n, + "label": f"{n}()", + "file_type": "code", + "source_file": "app.py", + "source_location": "L1", + } + for n in ("main", "helper") + ] + edges = [ + { + "source": "main", + "target": "helper", + "relation": rel, + "confidence": "EXTRACTED", + "source_file": "app.py", + "source_location": f"L{i}", + } + for i, rel in enumerate(["calls", "imports"]) + ] + G = build_from_json({"nodes": nodes, "edges": edges}, multigraph=True) + assert isinstance(G, nx.MultiDiGraph) + assert G.number_of_edges("main", "helper") == 2 + out = corpus / "graphify-out" + out.mkdir(exist_ok=True) + graph_json = out / "graph.json" + to_json(G, {0: ["main", "helper"]}, str(graph_json), force=True) + # Persist the scan root so a later `update` (no path arg) can recover it. + (out / ".graphify_root").write_text(str(corpus), encoding="utf-8") + return graph_json + + +def _graph_type(graph_data: dict) -> str | None: + return graph_data.get("graph", {}).get("graphify_profile", {}).get("graph_type") + + +def _parallel_edges(graph_data: dict, src: str, tgt: str) -> list[dict]: + links = graph_data.get("links", graph_data.get("edges", [])) + return [e for e in links if e.get("source") == src and e.get("target") == tgt] + + +# ───────────────────────────── PR 9: public --multigraph / --simple ───────────── +# +# extract exposes the MultiDiGraph build publicly. Default is STICKY: a default +# re-extract inherits the existing graph.json profile (a multigraph stays a +# multigraph). --multigraph forces a keyed MultiDiGraph; --simple is the explicit, +# warned, lossy downgrade. Capability failures surface as a clean CLI error. + + +def test_extract_simple_default(tmp_path): + """No flag on a fresh corpus → a simple graph (historical behavior). + + A fresh corpus has no existing graph.json to inherit, so the sticky default + collapses to the historical simple build: multigraph:false / graph_type simple. + """ + corpus = _make_code_corpus(tmp_path) + env = _clean_env() + env["ANTHROPIC_API_KEY"] = "sk-test-fake-key" # code-only corpus never calls the LLM + r = _run(["extract", str(corpus), "--backend", "claude"], tmp_path, env=env) + assert r.returncode == 0, f"fresh simple extract should succeed: {r.stderr}" + + graph_json = corpus / "graphify-out" / "graph.json" + assert graph_json.exists(), f"graph.json must be written: {r.stderr}" + data = json.loads(graph_json.read_text(encoding="utf-8")) + assert data.get("multigraph") is False, "default fresh build must be a simple graph" + assert _graph_type(data) == "simple" + + +def test_extract_multigraph_flag(tmp_path): + """`extract --multigraph` → graph.json is a keyed MultiDiGraph. + + Real end-to-end CLI subprocess: multigraph:true + graphify_profile.graph_type + == "multidigraph", and it reloads as an actual nx.MultiDiGraph. + """ + corpus = _make_code_corpus(tmp_path) + env = _clean_env() + env["ANTHROPIC_API_KEY"] = "sk-test-fake-key" + r = _run(["extract", str(corpus), "--backend", "claude", "--multigraph"], tmp_path, env=env) + assert r.returncode == 0, f"extract --multigraph should succeed: {r.stderr}" + + graph_json = corpus / "graphify-out" / "graph.json" + data = json.loads(graph_json.read_text(encoding="utf-8")) + assert data.get("multigraph") is True, "--multigraph must produce a multigraph graph.json" + assert data.get("directed") is True, "a MultiDiGraph is always directed" + assert _graph_type(data) == "multidigraph" + # Reloads as a real MultiDiGraph. + G = load_graph(data) + assert G.is_multigraph(), "graph.json must reload as a MultiDiGraph" + + +def test_extract_multigraph_then_update_sticky(tmp_path): + """`extract --multigraph`, then default re-extract/update STAYS multigraph. + + The second build is run WITHOUT any flag 3 times in a row; the profile must + stay multidigraph each time (idempotence-under-repeat), with the keyed + parallel-edge capability intact — never a silent collapse to simple. + """ + corpus = _make_code_corpus(tmp_path) + env = _clean_env() + env["ANTHROPIC_API_KEY"] = "sk-test-fake-key" + + r0 = _run(["extract", str(corpus), "--backend", "claude", "--multigraph"], tmp_path, env=env) + assert r0.returncode == 0, f"initial --multigraph extract failed: {r0.stderr}" + graph_json = corpus / "graphify-out" / "graph.json" + + # Seed two parallel main->helper edges so we can prove parallels persist. + _write_multidigraph_graph_json(corpus) + seeded = json.loads(graph_json.read_text(encoding="utf-8")) + assert seeded.get("multigraph") is True + assert len(_parallel_edges(seeded, "main", "helper")) == 2 + + # Default re-extract (NO flag) 3×; sticky must keep it multigraph every time. + for attempt in range(1, 4): + r = _run(["extract", str(corpus), "--backend", "claude"], tmp_path, env=env) + assert r.returncode == 0, f"sticky re-extract #{attempt} failed: {r.stderr}" + data = json.loads(graph_json.read_text(encoding="utf-8")) + assert data.get("multigraph") is True, ( + f"re-extract #{attempt} must STAY multigraph (sticky), " + f"got multigraph={data.get('multigraph')!r}" + ) + assert _graph_type(data) == "multidigraph", f"re-extract #{attempt} profile drifted" + # Parallel edges are not collapsed away by the sticky rebuild. + par = _parallel_edges(data, "main", "helper") + assert len(par) == 2, f"re-extract #{attempt} must preserve keyed parallel edges" + assert sorted(e["relation"] for e in par) == ["calls", "imports"] + # Reloads as a MultiDiGraph with the parallels intact. + G = load_graph(data) + assert G.is_multigraph() + assert G.number_of_edges("main", "helper") == 2 + + # A default `update` (the watch entrypoint) also stays multigraph. + ru = _run(["update", str(corpus)], tmp_path, env=env) + assert ru.returncode == 0, f"sticky update failed: {ru.stderr}" + after_update = json.loads(graph_json.read_text(encoding="utf-8")) + assert after_update.get("multigraph") is True, "update must inherit the multigraph profile" + assert _graph_type(after_update) == "multidigraph" + + +def test_extract_multigraph_no_cluster_sticky_idempotent(tmp_path): + """`--no-cluster` still preserves a sticky multigraph across no-op re-runs. + + A no-cluster incremental scan with no changed files produces an empty fresh + extraction. The command must merge that empty delta with the saved graph, + not overwrite graph.json with zero nodes/edges. + """ + corpus = _make_code_corpus(tmp_path) + env = _clean_env() + env["ANTHROPIC_API_KEY"] = "sk-test-fake-key" + + r0 = _run( + ["extract", str(corpus), "--backend", "claude", "--multigraph", "--no-cluster"], + tmp_path, + env=env, + ) + assert r0.returncode == 0, f"initial no-cluster --multigraph failed: {r0.stderr}" + + graph_json = corpus / "graphify-out" / "graph.json" + first = json.loads(graph_json.read_text(encoding="utf-8")) + first_nodes = len(first.get("nodes", [])) + first_edges = len(first.get("links", first.get("edges", []))) + assert first.get("multigraph") is True + assert _graph_type(first) == "multidigraph" + assert first_nodes > 0 + assert first_edges > 0 + + for attempt in range(1, 4): + r = _run( + ["extract", str(corpus), "--backend", "claude", "--no-cluster"], + tmp_path, + env=env, + ) + assert r.returncode == 0, f"sticky no-cluster re-extract #{attempt} failed: {r.stderr}" + data = json.loads(graph_json.read_text(encoding="utf-8")) + assert data.get("multigraph") is True + assert _graph_type(data) == "multidigraph" + assert len(data.get("nodes", [])) == first_nodes + assert len(data.get("links", data.get("edges", []))) == first_edges + + +def test_extract_explicit_simple_downgrade_warns(tmp_path): + """Existing multigraph graph.json + `extract --simple` → builds simple AND warns. + + The downgrade collapses parallel edges, so it requires explicit intent and a + loud lossy-collapse WARNING — never a silent collapse. A manifest is seeded so + the run takes the incremental (preserve+merge) path, where the existing + multigraph's parallel edges are loaded and then collapsed under the simple + target — the real lossy projection we want to prove. + """ + from graphify.detect import save_manifest + + corpus = _make_code_corpus(tmp_path) + graph_json = _write_multidigraph_graph_json(corpus) + out = corpus / "graphify-out" + save_manifest( + {"code": [str(corpus / "app.py")]}, + manifest_path=str(out / "manifest.json"), + kind="both", + ) + before = json.loads(graph_json.read_text(encoding="utf-8")) + assert before.get("multigraph") is True + assert len(_parallel_edges(before, "main", "helper")) == 2 + + env = _clean_env() + env["ANTHROPIC_API_KEY"] = "sk-test-fake-key" + r = _run(["extract", str(corpus), "--backend", "claude", "--simple"], tmp_path, env=env) + assert r.returncode == 0, f"--simple downgrade should succeed: {r.stderr}" + # Lossy-collapse WARNING must be printed (explicit, audible downgrade). + assert "WARNING" in r.stderr and "--simple" in r.stderr, ( + f"explicit --simple downgrade must warn about lossy collapse, got: {r.stderr}" + ) + assert "collaps" in r.stderr.lower() + + after = json.loads(graph_json.read_text(encoding="utf-8")) + assert after.get("multigraph") is False, "--simple must produce a non-multigraph graph" + assert _graph_type(after) != "multidigraph" + # The two parallel edges from the seeded multigraph collapse onto a single + # main->helper edge (the lossy projection — one survivor, not two parallels). + assert len(_parallel_edges(after, "main", "helper")) == 1, ( + "explicit --simple must collapse the existing parallel edges onto one" + ) + + +def test_extract_multigraph_capability_failure_message(monkeypatch, tmp_path, capsys): + """A MultiDiGraph capability failure surfaces as a clean CLI error, exit 1. + + The probe RuntimeError must be caught and printed (no traceback), and no + graph.json may be written. Run in-process so we can monkeypatch the probe. + """ + corpus = _make_code_corpus(tmp_path) + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test-fake-key") + + def _boom(): + raise RuntimeError( + "error: --multigraph requires NetworkX keyed MultiDiGraph node-link " + "round-trip support. Simulated capability failure." + ) + + # Patch where the extract handler imports it from. + monkeypatch.setattr("graphify.multigraph_compat.require_multigraph_capabilities", _boom) + monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) + monkeypatch.setattr( + mainmod.sys, + "argv", + ["graphify", "extract", str(corpus), "--backend", "claude", "--multigraph"], + ) + + with pytest.raises(SystemExit) as exc_info: + mainmod.main() + assert exc_info.value.code == 1, f"capability failure must exit 1, got {exc_info.value.code}" + + err = capsys.readouterr().err + assert "--multigraph requires" in err, f"clean capability message expected, got: {err}" + assert "Traceback" not in err, "capability failure must not leak a traceback" + assert not (corpus / "graphify-out" / "graph.json").exists(), ( + "no graph.json may be written when the capability gate fails" + ) + + +def test_extract_multigraph_query_roundtrip(tmp_path, capsys, monkeypatch): + """End-to-end public workflow: a multigraph corpus with same-endpoint different + relations exposes the parallel relationships through the public query/path path. + + Builds the multigraph graph.json the way `extract --multigraph` persists it, + then runs `graphify path` (a public query surface) and asserts BOTH parallel + relations show — the parallel relationships are visible, not collapsed. + """ + corpus = _make_code_corpus(tmp_path) + graph_json = _write_multidigraph_graph_json(corpus) + + # Sanity: the persisted file is a multidigraph with both parallel relations. + data = json.loads(graph_json.read_text(encoding="utf-8")) + assert data.get("multigraph") is True + G = load_graph(data) + assert G.is_multigraph() and G.number_of_edges("main", "helper") == 2 + + # Public query surface: `graphify path main helper` bundles all relations. + monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) + monkeypatch.setattr( + mainmod.sys, + "argv", + ["graphify", "path", "main", "helper", "--graph", str(graph_json)], + ) + mainmod.main() + out = capsys.readouterr().out + assert "calls" in out, f"parallel 'calls' relation must appear in path output: {out}" + assert "imports" in out, f"parallel 'imports' relation must appear in path output: {out}" def _make_corpus(tmp_path): @@ -17,9 +377,7 @@ def _make_corpus(tmp_path): return tmp_path -def test_extract_exits_nonzero_when_all_semantic_chunks_fail( - monkeypatch, tmp_path, capsys -): +def test_extract_exits_nonzero_when_all_semantic_chunks_fail(monkeypatch, tmp_path, capsys): """When every semantic chunk errors (e.g. backend SDK not installed), the CLI must exit non-zero instead of silently writing an AST-only graph. @@ -48,23 +406,19 @@ def _all_chunks_failed(paths, **kwargs): "output_tokens": 0, } - monkeypatch.setattr( - "graphify.llm.extract_corpus_parallel", _all_chunks_failed - ) + monkeypatch.setattr("graphify.llm.extract_corpus_parallel", _all_chunks_failed) monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) monkeypatch.setattr( mainmod.sys, "argv", - ["graphify", "extract", str(corpus), "--backend", "claude", - "--out", str(out_dir)], + ["graphify", "extract", str(corpus), "--backend", "claude", "--out", str(out_dir)], ) with pytest.raises(SystemExit) as exc_info: mainmod.main() assert exc_info.value.code == 1, ( - f"expected exit code 1 when all semantic chunks fail, " - f"got {exc_info.value.code}" + f"expected exit code 1 when all semantic chunks fail, got {exc_info.value.code}" ) stderr = capsys.readouterr().err @@ -78,9 +432,7 @@ def _all_chunks_failed(paths, **kwargs): ) -def test_extract_succeeds_when_at_least_one_chunk_completes( - monkeypatch, tmp_path -): +def test_extract_succeeds_when_at_least_one_chunk_completes(monkeypatch, tmp_path): """Sanity counter-test: a successful chunk run keeps exit 0. Confirms the new guard only fires on the all-failed path, not on every extract.""" corpus = _make_corpus(tmp_path) @@ -99,15 +451,12 @@ def _one_chunk_succeeded(paths, **kwargs): "output_tokens": 50, } - monkeypatch.setattr( - "graphify.llm.extract_corpus_parallel", _one_chunk_succeeded - ) + monkeypatch.setattr("graphify.llm.extract_corpus_parallel", _one_chunk_succeeded) monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) monkeypatch.setattr( mainmod.sys, "argv", - ["graphify", "extract", str(corpus), "--backend", "claude", - "--out", str(out_dir)], + ["graphify", "extract", str(corpus), "--backend", "claude", "--out", str(out_dir)], ) # extract may still raise SystemExit at the end (clean exit code 0) @@ -121,3 +470,211 @@ def _one_chunk_succeeded(paths, **kwargs): assert (out_dir / "graphify-out" / "graph.json").exists(), ( "graph.json must be written on the happy path" ) + + +def test_extract_no_cluster_refuses_to_zero_populated_graph(monkeypatch, tmp_path, capsys): + """RISK 4 — Guard 3: the non-incremental no-cluster simple path must NOT wipe a + populated graph.json with a 0-node extraction. + + The bug: with an existing populated (simple) graph.json but NO manifest.json + (so the run is non-incremental) the ``--no-cluster`` branch falls to the raw + ``graph_json_path.write_text(json.dumps(merged, ...))`` ``else`` case. That raw + write bypasses both existing empty-merge guards (``export.to_json`` / + ``watch._check_shrink``). When AST extraction aborts (returns 0 nodes) the raw + write overwrites the saved graph with an EMPTY one — a failed extraction + silently destroys real data. The clustered sibling already guards this with + ``if G.number_of_nodes() == 0: ... sys.exit(1)``; the no-cluster simple path + must do the same. The command must instead exit non-zero, print the byte- + identical guard message, and leave the populated graph.json untouched. + """ + corpus = _make_code_corpus(tmp_path) + out = corpus / "graphify-out" + out.mkdir(exist_ok=True) + graph_json = out / "graph.json" + + # Seed a POPULATED *simple* graph.json the way the pipeline persists it + # (build_from_json default-simple -> to_json). Simple (not multigraph) so the + # sticky profile resolves to non-multigraph and the run takes the raw-write + # ``else`` branch — exactly the unguarded site. NO manifest.json is written, + # so the run is non-incremental (the path the incremental build_merge floor + # never protects). + seed_nodes = [ + { + "id": n, + "label": f"{n}()", + "file_type": "code", + "source_file": "app.py", + "source_location": "L1", + } + for n in ("main", "helper", "extra") + ] + seed_edges = [ + { + "source": "main", + "target": "helper", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "app.py", + "source_location": "L5", + } + ] + G_seed = build_from_json({"nodes": seed_nodes, "edges": seed_edges}) + assert not G_seed.is_multigraph(), "seed must be a simple graph (non-multigraph)" + to_json(G_seed, {0: ["main", "helper", "extra"]}, str(graph_json), force=True) + before = json.loads(graph_json.read_text(encoding="utf-8")) + seeded_n = len(before.get("nodes", [])) + assert seeded_n == 3, "seed graph.json must start populated with 3 nodes" + assert before.get("multigraph") is False, "seed graph.json must be simple" + assert not (out / "manifest.json").exists(), "no manifest → non-incremental run" + + # Force the AST extraction to abort so the merged extraction yields 0 nodes. + # This mirrors the real trigger (a parser/extractor blowing up): the extract + # handler's ``except`` resets ast_result to an empty dict, and a code-only + # corpus has no semantic pass, so ``merged`` collapses to 0 nodes. The extract + # handler imports ``extract`` from graphify.extract at call time, so patching + # the source symbol is picked up. + def _ast_boom(paths, **kwargs): + raise RuntimeError("simulated AST extractor failure (parser crash)") + + import graphify.extract as _extract_mod + + monkeypatch.setattr(_extract_mod, "extract", _ast_boom) + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test-fake-key") # code-only: LLM never called + monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) + monkeypatch.setattr( + mainmod.sys, + "argv", + ["graphify", "extract", str(corpus), "--backend", "claude", "--no-cluster"], + ) + + with pytest.raises(SystemExit) as exc_info: + mainmod.main() + assert exc_info.value.code == 1, ( + f"a 0-node no-cluster extraction over a populated graph must exit 1, " + f"got {exc_info.value.code}" + ) + + err = capsys.readouterr().err + # Byte-identical to the Guard 1 / Guard 2 message. + assert ( + f"[graphify] ERROR: refusing to overwrite a populated graph.json " + f"({seeded_n} nodes) with an EMPTY (0-node) graph - this is a " + f"failed/aborted extraction, not a real result. The previous graph " + f"is preserved." in err + ), f"guard message must match Guards 1/2 byte-for-byte, got: {err!r}" + + # The populated graph.json must be PRESERVED — not wiped to an empty graph. + after = json.loads(graph_json.read_text(encoding="utf-8")) + assert len(after.get("nodes", [])) == seeded_n, ( + "the populated graph.json must NOT be overwritten with a 0-node graph" + ) + + +def test_extract_no_cluster_incremental_zero_merge_exits_nonzero_and_preserves_graph( + monkeypatch, tmp_path, capsys +): + """RISK 4 — Guard 1 signaling gap: the INCREMENTAL no-cluster path must SIGNAL + failure (exit non-zero, no false-success line) when the merge yields 0 nodes. + + The incremental no-cluster branch writes through + ``to_json(_nc_graph, {}, ..., force=True)`` (Guard 1). When ``build_merge`` + collapses to a 0-node graph over a populated graph.json, Guard 1's empty-merge + floor correctly *returns False and PRESERVES the data* — but the caller ignored + that return value: it fell through, printed the success line + ``[graphify extract] wrote ... graph.json — 0 nodes, 0 edges (no clustering)`` + and exited 0. The data was safe, but a failed/aborted extraction reported a + misleading false success (wrong exit code + message). + + The fix captures Guard 1's ``False`` return at the no-cluster incremental write + site and, on refusal only, emits an aborted-extraction stderr note and exits 1 + — never the bogus "wrote ... 0 nodes" success line. A populated graph.json plus + a manifest.json makes the run incremental; ``build_merge`` is forced to yield an + empty graph to model the aborted/pruned-to-empty merge. The legitimate sticky + no-cluster case (``test_extract_multigraph_no_cluster_sticky_idempotent``) keeps + exit 0 because ``build_merge`` preserves the existing nodes there (True return). + """ + from graphify.detect import save_manifest + + corpus = _make_code_corpus(tmp_path) + out = corpus / "graphify-out" + out.mkdir(exist_ok=True) + graph_json = out / "graph.json" + + # Seed a POPULATED *simple* graph.json the way the pipeline persists it. + seed_nodes = [ + { + "id": n, + "label": f"{n}()", + "file_type": "code", + "source_file": "app.py", + "source_location": "L1", + } + for n in ("main", "helper", "extra") + ] + seed_edges = [ + { + "source": "main", + "target": "helper", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "app.py", + "source_location": "L5", + } + ] + G_seed = build_from_json({"nodes": seed_nodes, "edges": seed_edges}) + assert not G_seed.is_multigraph(), "seed must be a simple graph (non-multigraph)" + to_json(G_seed, {0: ["main", "helper", "extra"]}, str(graph_json), force=True) + + # A manifest.json alongside the populated graph.json makes the run INCREMENTAL + # (incremental_mode = manifest.exists() and graph.json.exists()), so the write + # routes through the incremental ``to_json(..., force=True)`` site, not the + # raw-write else-branch the Guard 3 sibling covers. + save_manifest( + {"code": [str(corpus / "app.py")]}, + manifest_path=str(out / "manifest.json"), + kind="both", + ) + + before = json.loads(graph_json.read_text(encoding="utf-8")) + seeded_n = len(before.get("nodes", [])) + assert seeded_n == 3, "seed graph.json must start populated with 3 nodes" + assert before.get("multigraph") is False, "seed graph.json must be simple" + assert (out / "manifest.json").exists(), "manifest → incremental run" + + # Force the incremental merge to yield a 0-node graph (aborted / pruned-to-empty + # extraction). The no-cluster incremental branch imports build_merge from + # graphify.build at call time, so patching the source symbol is picked up. + def _empty_merge(*args, **kwargs): + return build_from_json({"nodes": [], "edges": []}) + + import graphify.build as _build_mod + + monkeypatch.setattr(_build_mod, "build_merge", _empty_merge) + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test-fake-key") # code-only: LLM never called + monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) + monkeypatch.setattr( + mainmod.sys, + "argv", + ["graphify", "extract", str(corpus), "--backend", "claude", "--no-cluster"], + ) + + with pytest.raises(SystemExit) as exc_info: + mainmod.main() + assert exc_info.value.code == 1, ( + f"a 0-node incremental no-cluster merge over a populated graph must exit 1, " + f"got {exc_info.value.code}" + ) + + captured = capsys.readouterr() + # The misleading false-success line must NOT be printed. + assert "0 nodes, 0 edges" not in captured.out, ( + f"a 0-node aborted merge must NOT print the 'wrote ... 0 nodes' success " + f"line, got stdout: {captured.out!r}" + ) + + # The populated graph.json must be PRESERVED — not wiped to an empty graph. + after = json.loads(graph_json.read_text(encoding="utf-8")) + assert len(after.get("nodes", [])) == seeded_n, ( + "the populated graph.json must NOT be overwritten with a 0-node graph" + ) + assert after == before, "graph.json must be byte-for-byte unchanged after the refused write" diff --git a/tests/test_global_graph.py b/tests/test_global_graph.py index f40d9c6d5..ee8176946 100644 --- a/tests/test_global_graph.py +++ b/tests/test_global_graph.py @@ -1,23 +1,29 @@ """Tests for the global graph infrastructure (graphify/global_graph.py), prefix/prune helpers in graphify/build.py, and the cross-repo guard in graphify/dedup.py.""" + from __future__ import annotations import json +from contextlib import contextmanager + import pytest import networkx as nx from unittest.mock import patch +import graphify.__main__ as mainmod + # ── helpers ────────────────────────────────────────────────────────────────── + def _make_graph(nodes, edges=None): """Build a simple nx.Graph from node dicts.""" G = nx.Graph() for n in nodes: nid = n["id"] G.add_node(nid, **{k: v for k, v in n.items() if k != "id"}) - for e in (edges or []): + for e in edges or []: G.add_edge( e["source"], e["target"], @@ -28,6 +34,7 @@ def _make_graph(nodes, edges=None): def _graph_to_json(G, path): from networkx.readwrite import json_graph as jg + try: data = jg.node_link_data(G, edges="links") except TypeError: @@ -35,10 +42,51 @@ def _graph_to_json(G, path): path.write_text(json.dumps(data), encoding="utf-8") +def _make_multidigraph(nodes, edges): + """Build an nx.MultiDiGraph from node dicts and keyed edge dicts. + + Each edge dict must carry a ``key`` so parallel edges between the same + (source, target) survive the build and the node_link round-trip. + """ + G = nx.MultiDiGraph() + for n in nodes: + nid = n["id"] + G.add_node(nid, **{k: v for k, v in n.items() if k != "id"}) + for e in edges: + G.add_edge( + e["source"], + e["target"], + key=e["key"], + **{k: v for k, v in e.items() if k not in ("source", "target", "key")}, + ) + return G + + +@contextmanager +def _patch_global(global_dir): + """Single context manager that points global_graph at a temp dir. + + Patches ``_GLOBAL_DIR`` / ``_GLOBAL_GRAPH`` / ``_GLOBAL_MANIFEST`` for the + duration of the ``with`` block, mirroring the inline triple-patch the older + tests use, so the PR 8 tests can ``with _patch_global(tmp / ".graphify"):``. + """ + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch( + "graphify.global_graph._GLOBAL_MANIFEST", + global_dir / "global-manifest.json", + ), + ): + yield + + # ── build.py helpers ────────────────────────────────────────────────────────── + def test_prefix_graph_preserves_label(): from graphify.build import prefix_graph_for_global + G = _make_graph([{"id": "userservice", "label": "UserService", "source_file": "src/user.py"}]) H = prefix_graph_for_global(G, "repoA") assert "repoA::userservice" in H.nodes @@ -48,6 +96,7 @@ def test_prefix_graph_preserves_label(): def test_prefix_graph_sets_repo_and_local_id(): from graphify.build import prefix_graph_for_global + G = _make_graph([{"id": "userservice", "label": "UserService"}]) H = prefix_graph_for_global(G, "repoA") data = H.nodes["repoA::userservice"] @@ -57,6 +106,7 @@ def test_prefix_graph_sets_repo_and_local_id(): def test_prefix_graph_rewrites_edges(): from graphify.build import prefix_graph_for_global + G = _make_graph( [{"id": "a", "label": "A"}, {"id": "b", "label": "B"}], [{"source": "a", "target": "b"}], @@ -68,6 +118,7 @@ def test_prefix_graph_rewrites_edges(): def test_prune_repo_removes_correct_nodes(): from graphify.build import prune_repo_from_graph + G = nx.Graph() G.add_node("repoA::userservice", repo="repoA", label="UserService") G.add_node("repoB::userservice", repo="repoB", label="UserService") @@ -81,6 +132,7 @@ def test_prune_repo_removes_correct_nodes(): def test_prune_repo_returns_zero_if_not_present(): from graphify.build import prune_repo_from_graph + G = nx.Graph() G.add_node("repoA::x", repo="repoA") removed = prune_repo_from_graph(G, "repoB") @@ -90,16 +142,20 @@ def test_prune_repo_returns_zero_if_not_present(): # ── global_graph.py ─────────────────────────────────────────────────────────── + def test_global_add_creates_global_graph(tmp_path): src_graph = tmp_path / "graph.json" G = _make_graph([{"id": "userservice", "label": "UserService", "source_file": "src/user.py"}]) _graph_to_json(G, src_graph) global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add + result = global_add(src_graph, "repoA") assert result["skipped"] is False @@ -116,10 +172,13 @@ def test_global_add_skip_on_unchanged_hash(tmp_path): _graph_to_json(G, src_graph) global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add + global_add(src_graph, "repoA") result2 = global_add(src_graph, "repoA") @@ -137,10 +196,13 @@ def test_global_add_two_repos_no_collision(tmp_path): global_dir = tmp_path / ".graphify" global_graph_path = global_dir / "global-graph.json" global_manifest_path = global_dir / "global-manifest.json" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_graph_path), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_manifest_path): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_graph_path), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_manifest_path), + ): from graphify.global_graph import global_add, _load_global_graph + global_add(g1, "repoA") global_add(g2, "repoB") G = _load_global_graph() @@ -156,30 +218,64 @@ def test_global_remove(tmp_path): _graph_to_json(G, src_graph) global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add, global_remove + global_add(src_graph, "repoA") removed = global_remove("repoA") assert removed > 0 # manifest should no longer list repoA - need to re-patch for list call global_dir2 = global_dir # same dir - with patch("graphify.global_graph._GLOBAL_DIR", global_dir2), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir2 / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir2 / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir2), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir2 / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir2 / "global-manifest.json"), + ): from graphify.global_graph import global_list + repos = global_list() assert "repoA" not in repos +def test_global_remove_backs_up_before_overwrite(tmp_path): + """Removing a repo mutates global-graph.json, so recovery policy requires a backup.""" + src_graph = tmp_path / "graph.json" + G = _make_graph([{"id": "userservice", "label": "UserService", "source_file": "src/user.py"}]) + _graph_to_json(G, src_graph) + + global_dir = tmp_path / ".graphify" + global_graph_path = global_dir / "global-graph.json" + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_graph_path), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): + from graphify.global_graph import global_add, global_remove + + global_add(src_graph, "repoA") + before_remove = global_graph_path.read_bytes() + removed = global_remove("repoA") + + assert removed > 0 + backups = list(global_dir.glob("global-graph.*.bak")) + assert len(backups) == 1 + assert backups[0].read_bytes() == before_remove + + def test_global_remove_unknown_tag_raises(tmp_path): global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_remove + with pytest.raises(KeyError): global_remove("nonexistent") @@ -192,10 +288,13 @@ def test_global_add_collision_warning(tmp_path, capsys): _graph_to_json(G, g2) global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add + global_add(g1, "myrepo") global_add(g2, "myrepo") # different source path, same tag @@ -205,8 +304,10 @@ def test_global_add_collision_warning(tmp_path, capsys): # ── dedup guard ─────────────────────────────────────────────────────────────── + def test_dedup_raises_on_cross_repo_nodes(): from graphify.dedup import deduplicate_entities + nodes = [ {"id": "repoA::userservice", "label": "UserService", "repo": "repoA"}, {"id": "repoB::userservice", "label": "UserService", "repo": "repoB"}, @@ -217,6 +318,7 @@ def test_dedup_raises_on_cross_repo_nodes(): def test_dedup_ok_with_single_repo(): from graphify.dedup import deduplicate_entities + nodes = [ {"id": "repoA::userservice", "label": "UserService", "repo": "repoA"}, {"id": "repoA::auth", "label": "Auth", "repo": "repoA"}, @@ -227,6 +329,7 @@ def test_dedup_ok_with_single_repo(): def test_dedup_ok_with_no_repo_attr(): from graphify.dedup import deduplicate_entities + nodes = [ {"id": "userservice", "label": "UserService"}, {"id": "auth", "label": "Auth"}, @@ -237,6 +340,7 @@ def test_dedup_ok_with_no_repo_attr(): # ── merge-graphs prefix ─────────────────────────────────────────────────────── + def test_merge_graphs_prefixes_ids(tmp_path): """merge-graphs should prefix node IDs with repo name to avoid silent collision.""" from graphify.build import prefix_graph_for_global @@ -290,9 +394,883 @@ def test_global_add_rejects_oversized_source_graph(monkeypatch, tmp_path): global_dir = tmp_path / ".graphify" monkeypatch.setattr("graphify.security._MAX_GRAPH_FILE_BYTES", 8) - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add + with pytest.raises(ValueError, match="exceeds"): global_add(src_graph, "repoA") + + +# ── PR 8: keyed/class-normalized composition, recovery, backup ───────────────── + + +def _PARALLEL_NODES(): + return [ + {"id": "a", "label": "A", "source_file": "src/a.py"}, + {"id": "b", "label": "B", "source_file": "src/b.py"}, + ] + + +def _PARALLEL_EDGES(): + return [ + {"source": "a", "target": "b", "key": "calls:L1", "relation": "calls"}, + {"source": "a", "target": "b", "key": "imports:L2", "relation": "imports"}, + ] + + +def test_global_add_multidigraph_preserves_parallel_edges(tmp_path): + """A MultiDiGraph source with parallel edges keeps every keyed edge in the + global graph, which reloads as a MultiDiGraph (no keyless collapse).""" + src_graph = tmp_path / "graph.json" + M = _make_multidigraph(_PARALLEL_NODES(), _PARALLEL_EDGES()) + _graph_to_json(M, src_graph) + + global_dir = tmp_path / ".graphify" + with _patch_global(global_dir): + from graphify.global_graph import global_add, _load_global_graph + + result = global_add(src_graph, "repoA") + G = _load_global_graph() + + assert result["skipped"] is False + assert isinstance(G, nx.MultiDiGraph) + assert G.number_of_edges("repoA::a", "repoA::b") == 2 + assert sorted(G["repoA::a"]["repoA::b"].keys()) == ["calls:L1", "imports:L2"] + assert G.graph["graphify_profile"]["graph_type"] == "multidigraph" + + +def test_global_add_multidigraph_idempotent_under_repeat(tmp_path): + """THE PR-7-lesson test: running global_add of the SAME multigraph repo 3 + times keeps the parallel-edge count, edge keys, and stored profile identical + after every run — no duplication, no drift, no re-collapse. + + To prove the KEYED COMPOSE itself is idempotent (not merely the hash-skip + short-circuit), the second repo's source is mutated to a fresh hash on every + iteration while the FIRST repo (repoA) survives the prune and is re-composed + through the keyed edge path each time. repoA's parallel edges must stay + rock-stable across all three forced re-composes.""" + repo_a_src = tmp_path / "a.json" + M_a = _make_multidigraph(_PARALLEL_NODES(), _PARALLEL_EDGES()) + _graph_to_json(M_a, repo_a_src) + + global_dir = tmp_path / ".graphify" + with _patch_global(global_dir): + from graphify.global_graph import global_add, _load_global_graph + + global_add(repo_a_src, "repoA") + + observed = [] + for i in range(3): + # Distinct second repo each iteration → forces a real re-compose that + # re-runs the keyed edge loop over the surviving repoA subgraph. + churn_src = tmp_path / f"churn_{i}.json" + M_b = _make_multidigraph( + [ + {"id": f"c{i}", "label": f"C{i}", "source_file": "src/c.py"}, + {"id": f"d{i}", "label": f"D{i}", "source_file": "src/d.py"}, + ], + [ + {"source": f"c{i}", "target": f"d{i}", "key": "j1", "relation": "calls"}, + {"source": f"c{i}", "target": f"d{i}", "key": "j2", "relation": "uses"}, + ], + ) + _graph_to_json(M_b, churn_src) + global_add(churn_src, "repoB") + + G = _load_global_graph() + observed.append( + ( + G.number_of_edges("repoA::a", "repoA::b"), + tuple(sorted(G["repoA::a"]["repoA::b"].keys())), + G.graph["graphify_profile"]["graph_type"], + ) + ) + + # IDEMPOTENCE ASSERTION: parallel-edge count (2), keys, and profile identical + # after each of the three repeated global_add re-composes — no drift. + assert observed == [ + (2, ("calls:L1", "imports:L2"), "multidigraph"), + (2, ("calls:L1", "imports:L2"), "multidigraph"), + (2, ("calls:L1", "imports:L2"), "multidigraph"), + ] + + +def test_global_add_mixed_simple_and_multi_no_crash(tmp_path): + """One simple repo + one multi repo must not crash through a NetworkX class + mismatch; the global target upgrades to multidigraph and both repos' edges + are present (the multi repo keyed).""" + simple_src = tmp_path / "simple.json" + S = _make_graph( + [ + {"id": "x", "label": "X", "source_file": "src/x.py"}, + {"id": "y", "label": "Y", "source_file": "src/y.py"}, + ], + [{"source": "x", "target": "y", "relation": "calls"}], + ) + _graph_to_json(S, simple_src) + + multi_src = tmp_path / "multi.json" + M = _make_multidigraph(_PARALLEL_NODES(), _PARALLEL_EDGES()) + _graph_to_json(M, multi_src) + + global_dir = tmp_path / ".graphify" + with _patch_global(global_dir): + from graphify.global_graph import global_add, _load_global_graph + + global_add(simple_src, "repoSimple") + # Composing a multi repo into the existing simple global graph must not + # raise "All graphs must be directed or undirected." + global_add(multi_src, "repoMulti") + G = _load_global_graph() + + assert isinstance(G, nx.MultiDiGraph) + assert G.graph["graphify_profile"]["graph_type"] == "multidigraph" + # simple repo's single edge survives (folded into the multigraph) + assert G.has_edge("repoSimple::x", "repoSimple::y") + # multi repo's parallel edges survive distinctly, keyed + assert G.number_of_edges("repoMulti::a", "repoMulti::b") == 2 + assert sorted(G["repoMulti::a"]["repoMulti::b"].keys()) == ["calls:L1", "imports:L2"] + + +def test_global_add_simple_only_regression(tmp_path): + """Pure simple inputs produce a simple global graph whose output is unchanged + apart from the new graphify_profile metadata. Repeating twice is identical.""" + g1 = tmp_path / "g1.json" + g2 = tmp_path / "g2.json" + _graph_to_json( + _make_graph( + [ + {"id": "u", "label": "U", "source_file": "src/u.py"}, + {"id": "v", "label": "V", "source_file": "src/v.py"}, + ], + [{"source": "u", "target": "v", "relation": "calls"}], + ), + g1, + ) + _graph_to_json( + _make_graph([{"id": "w", "label": "W", "source_file": "src/w.py"}]), + g2, + ) + + global_dir = tmp_path / ".graphify" + global_graph_path = global_dir / "global-graph.json" + with _patch_global(global_dir): + from graphify.global_graph import global_add, _load_global_graph + + global_add(g1, "repoA") + global_add(g2, "repoB") + G = _load_global_graph() + first_bytes = global_graph_path.read_text(encoding="utf-8") + + # Repeat the same two adds (hash-skip path) → byte-identical output. + global_add(g1, "repoA") + global_add(g2, "repoB") + second_bytes = global_graph_path.read_text(encoding="utf-8") + + # Simple-only stays a simple Graph (not upgraded), profile is "simple". + assert isinstance(G, nx.Graph) + assert not G.is_multigraph() + assert not G.is_directed() + assert G.graph["graphify_profile"]["graph_type"] == "simple" + assert G.has_edge("repoA::u", "repoA::v") + assert "repoB::w" in G.nodes + # Byte-stable across repeated adds (idempotent simple output). + assert first_bytes == second_bytes + + +def test_normalize_graphs_for_global_infers_target(recwarn): + """Mixed inputs infer multidigraph; an explicit simple target on a multi + input warns and projects the multigraph down to simple.""" + from graphify.global_graph import normalize_graphs_for_global + + simple = _make_graph([{"id": "x"}, {"id": "y"}], [{"source": "x", "target": "y"}]) + multi = _make_multidigraph( + [{"id": "a"}, {"id": "b"}], + [ + {"source": "a", "target": "b", "key": "k1"}, + {"source": "a", "target": "b", "key": "k2"}, + ], + ) + + # Inference: any multi input → multidigraph target, no warning, no collapse. + normalized, target = normalize_graphs_for_global([simple, multi]) + assert target == "multidigraph" + assert all(isinstance(g, nx.MultiDiGraph) for g in normalized) + assert normalized[1].number_of_edges("a", "b") == 2 + assert len(recwarn.list) == 0 + + # Explicit simple target with a multi input → WARNING + projection to simple. + with pytest.warns(UserWarning, match="collaps"): + normalized2, target2 = normalize_graphs_for_global([simple, multi], target_type="simple") + assert target2 == "simple" + assert all(type(g) is nx.Graph for g in normalized2) + # Parallel edges collapse to a single (a, b) pair on the simple projection. + assert normalized2[1].number_of_edges() == 1 + + # Unknown target token is rejected. + with pytest.raises(ValueError, match="target_type"): + normalize_graphs_for_global([simple], target_type="bogus") + + +def test_detect_pre_profile_global_graph(): + """A JSON without graphify_profile and without multigraph/directed flags is + detected as pre-profile; any of those markers clears the flag.""" + from graphify.global_graph import detect_pre_profile + + assert detect_pre_profile({"nodes": [{"id": "a"}], "links": []}) is True + # Top-level profile present → not pre-profile. + assert detect_pre_profile({"nodes": [], "links": [], "graphify_profile": {}}) is False + # Nested profile under "graph" → not pre-profile. + assert ( + detect_pre_profile( + {"nodes": [], "links": [], "graph": {"graphify_profile": {"graph_type": "simple"}}} + ) + is False + ) + # Explicit class flags → writer knew the class → not pre-profile. + assert detect_pre_profile({"nodes": [], "links": [], "multigraph": False}) is False + assert detect_pre_profile({"nodes": [], "links": [], "directed": True}) is False + assert detect_pre_profile("not a dict") is False + + +def test_pre_profile_upgrade_refused_with_recovery_message(tmp_path): + """Upgrading a pre-profile global graph to multidigraph refuses with a clear + recovery message and does NOT mutate/destroy the existing global-graph.json.""" + from graphify.global_graph import ( + GlobalGraphRecoveryError, + refuse_pre_profile_upgrade, + ) + + # Direct helper contract: pre-profile + multidigraph target → raises. + pre_profile = {"nodes": [{"id": "a"}], "links": []} + with pytest.raises(GlobalGraphRecoveryError, match="rebuild|remove|backup|pre-profile"): + refuse_pre_profile_upgrade(pre_profile, "multidigraph") + # Non-upgrade targets are allowed (no raise). + refuse_pre_profile_upgrade(pre_profile, "simple") + refuse_pre_profile_upgrade(pre_profile, "digraph") + + # End-to-end through global_add: seed a pre-profile global graph (no profile, + # no flags), then add a multigraph repo → upgrade refused, file untouched. + global_dir = tmp_path / ".graphify" + global_dir.mkdir(parents=True) + global_graph_path = global_dir / "global-graph.json" + pre_profile_disk = { + "nodes": [{"id": "legacy::old", "repo": "legacy", "label": "Old"}], + "links": [], + } + original = json.dumps(pre_profile_disk, indent=2) + global_graph_path.write_text(original, encoding="utf-8") + + multi_src = tmp_path / "multi.json" + _graph_to_json(_make_multidigraph(_PARALLEL_NODES(), _PARALLEL_EDGES()), multi_src) + + with _patch_global(global_dir): + from graphify.global_graph import global_add + + with pytest.raises(GlobalGraphRecoveryError): + global_add(multi_src, "repoMulti") + + # The original pre-profile graph.json must be intact (not overwritten). + assert global_graph_path.read_text(encoding="utf-8") == original + # A recovery backup may have been taken alongside it; that is allowed. + + +def test_global_add_backs_up_before_overwrite(tmp_path): + """A backup snapshot of the prior global-graph.json is created before an + overwrite, and the original content is recoverable from the backup.""" + global_dir = tmp_path / ".graphify" + + g1 = tmp_path / "g1.json" + _graph_to_json(_make_graph([{"id": "u", "label": "U", "source_file": "src/u.py"}]), g1) + g2 = tmp_path / "g2.json" + _graph_to_json(_make_graph([{"id": "w", "label": "W", "source_file": "src/w.py"}]), g2) + + with _patch_global(global_dir): + from graphify.global_graph import global_add + + global_add(g1, "repoA") + first = (global_dir / "global-graph.json").read_text(encoding="utf-8") + # Second add (different repo, different hash) overwrites → backup taken. + global_add(g2, "repoB") + + backups = list(global_dir.glob("global-graph.*.bak")) + assert backups, "expected a dated .bak snapshot before overwrite" + # The backup holds the pre-overwrite (first-add) state, recoverable verbatim. + assert backups[0].read_text(encoding="utf-8") == first + + +def test_backup_global_graph_idempotent(tmp_path): + """Repeated backup_global_graph() calls in the same run do not error and do + not corrupt the snapshot (one dated backup, byte-stable).""" + global_dir = tmp_path / ".graphify" + global_dir.mkdir(parents=True) + global_graph_path = global_dir / "global-graph.json" + content = json.dumps({"nodes": [{"id": "a"}], "links": []}, indent=2) + global_graph_path.write_text(content, encoding="utf-8") + + with _patch_global(global_dir): + from graphify.global_graph import backup_global_graph + + p1 = backup_global_graph() + p2 = backup_global_graph() + p3 = backup_global_graph() + + assert p1 is not None + assert p1 == p2 == p3 # same dated backup path + assert p1.read_text(encoding="utf-8") == content + # Exactly one backup file (no proliferation across repeated calls). + assert len(list(global_dir.glob("global-graph.*.bak"))) == 1 + + +def test_backup_global_graph_none_when_absent(tmp_path): + """backup_global_graph() returns None when there is no global graph to back + up (nothing to snapshot, never raises).""" + global_dir = tmp_path / ".graphify" + with _patch_global(global_dir): + from graphify.global_graph import backup_global_graph + + assert backup_global_graph() is None + + +# ── merge-driver / merge-graphs class normalization (PR 8) ───────────────────── +# +# Both commands run in-process through ``graphify.__main__.main`` with argv +# monkeypatched (env-isolated, mirroring test_extract_cli / test_query_cli). The +# go/no-go gate for PR 8: mixed graph inputs never crash through a NetworkX class +# mismatch AND never silently collapse multigraph input without an explicit +# simple target. Merge is STATEFUL, so every path is also asserted under REPEATED +# application (run 2-3×) to prove idempotence — no duplicated edges, no key drift, +# no profile drift, no re-collapse. + + +def _reload_graph(path): + """Rehydrate a graph.json written by a merge command (handles edges/links).""" + from networkx.readwrite import json_graph as jg + + data = json.loads(path.read_text(encoding="utf-8")) + if "links" not in data and "edges" in data: + data = dict(data, links=data["edges"]) + try: + return jg.node_link_graph(data, edges="links"), data + except TypeError: + return jg.node_link_graph(data), data + + +def _edge_keys(G): + """Stable, comparable edge identity: keyed triples for multigraphs, else pairs.""" + if G.is_multigraph(): + return sorted((u, v, k) for u, v, k in G.edges(keys=True)) + return sorted(G.edges()) + + +def _run_merge_driver(monkeypatch, base_p, current_p, other_p): + """Invoke `graphify merge-driver` in-process; return the exit code (0 on ok).""" + monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) + monkeypatch.setattr( + mainmod.sys, + "argv", + ["graphify", "merge-driver", str(base_p), str(current_p), str(other_p)], + ) + try: + mainmod.main() + return 0 + except SystemExit as exc: + return exc.code if isinstance(exc.code, int) else 1 + + +def _run_merge_graphs(monkeypatch, paths, out_path, *flags): + """Invoke `graphify merge-graphs` in-process; return the exit code (0 on ok).""" + monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) + argv = ["graphify", "merge-graphs", *[str(p) for p in paths], "--out", str(out_path), *flags] + monkeypatch.setattr(mainmod.sys, "argv", argv) + try: + mainmod.main() + return 0 + except SystemExit as exc: + return exc.code if isinstance(exc.code, int) else 1 + + +def _repo_graph(root, repo, G): + """Write *G* to //graphify-out/graph.json (merge-graphs layout).""" + out_dir = root / repo / "graphify-out" + out_dir.mkdir(parents=True) + gp = out_dir / "graph.json" + _graph_to_json(G, gp) + return gp + + +def test_merge_driver_mixed_classes_no_crash(monkeypatch, tmp_path): + """merge-driver: simple `current` + MultiDiGraph `other` must NOT crash through + a NetworkX class mismatch; the result is a keyed multidigraph that preserves + both sides' edges. This is the core go/no-go gate.""" + current = _make_graph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "b", "label": "B", "source_file": "b.py"}, + ], + [{"source": "a", "target": "b", "relation": "calls", "confidence": "EXTRACTED"}], + ) + other = _make_multidigraph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "c", "label": "C", "source_file": "c.py"}, + ], + [ + {"source": "a", "target": "c", "key": 0, "relation": "imports"}, + {"source": "a", "target": "c", "key": 1, "relation": "calls"}, + ], + ) + base_p = tmp_path / "base.json" + current_p = tmp_path / "current.json" + other_p = tmp_path / "other.json" + _graph_to_json(_make_graph([]), base_p) + _graph_to_json(current, current_p) + _graph_to_json(other, other_p) + + code = _run_merge_driver(monkeypatch, base_p, current_p, other_p) + assert code == 0 # no class-mismatch crash → clean exit, not a surfaced conflict + + merged, data = _reload_graph(current_p) + assert merged.is_multigraph() # upgraded to the multi target, not collapsed + assert data["graph"]["graphify_profile"]["graph_type"] == "multidigraph" + # Both the simple edge and BOTH parallel multi edges survive. + assert ("a", "b", 0) in _edge_keys(merged) + assert ("a", "c", 0) in _edge_keys(merged) + assert ("a", "c", 1) in _edge_keys(merged) + assert merged.number_of_edges() == 3 + + +def test_merge_driver_idempotent_under_repeat(monkeypatch, tmp_path): + """STATEFUL idempotence: running merge-driver on the SAME inputs 3× must keep + the edge count, edge KEYS and stored profile identical every time — no + duplicated edges, no key drift, no re-collapse. The merge-driver writes back + to `current`, so each rerun re-loads its own multidigraph output as `current`; + the keyed compose must overwrite the same (u, v, key) slots, not accumulate.""" + current = _make_graph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "b", "label": "B", "source_file": "b.py"}, + ], + [{"source": "a", "target": "b", "relation": "calls", "confidence": "EXTRACTED"}], + ) + other = _make_multidigraph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "c", "label": "C", "source_file": "c.py"}, + ], + [ + {"source": "a", "target": "c", "key": 0, "relation": "imports"}, + {"source": "a", "target": "c", "key": 1, "relation": "calls"}, + ], + ) + base_p = tmp_path / "base.json" + current_p = tmp_path / "current.json" + other_p = tmp_path / "other.json" + _graph_to_json(_make_graph([]), base_p) + _graph_to_json(current, current_p) + _graph_to_json(other, other_p) + + snapshots = [] + for _ in range(3): + assert _run_merge_driver(monkeypatch, base_p, current_p, other_p) == 0 + merged, data = _reload_graph(current_p) + snapshots.append( + ( + merged.number_of_edges(), + _edge_keys(merged), + data["graph"]["graphify_profile"]["graph_type"], + ) + ) + + # The exact stability assertion: edges + keys + profile identical across all 3 runs. + assert snapshots[0] == snapshots[1] == snapshots[2] + assert snapshots[0][0] == 3 # 1 simple + 2 parallel, never duplicated + assert snapshots[0][2] == "multidigraph" + + +def test_merge_graphs_multidigraph_preserves_parallel_edges(monkeypatch, tmp_path): + """merge-graphs over a multigraph + a simple input keeps parallel edges with + distinct keys (resolved target inferred as multidigraph, not collapsed).""" + multi = _make_multidigraph( + [ + {"id": "x", "label": "X", "source_file": "x.py"}, + {"id": "y", "label": "Y", "source_file": "y.py"}, + ], + [ + {"source": "x", "target": "y", "key": 0, "relation": "calls"}, + {"source": "x", "target": "y", "key": 1, "relation": "imports"}, + ], + ) + simple = _make_graph( + [ + {"id": "z", "label": "Z", "source_file": "z.py"}, + {"id": "w", "label": "W", "source_file": "w.py"}, + ], + [{"source": "z", "target": "w", "relation": "uses"}], + ) + g1 = _repo_graph(tmp_path, "repo1", multi) + g2 = _repo_graph(tmp_path, "repo2", simple) + out_p = tmp_path / "merged.json" + + assert _run_merge_graphs(monkeypatch, [g1, g2], out_p) == 0 + merged, data = _reload_graph(out_p) + assert merged.is_multigraph() + assert data["graph"]["graphify_profile"]["graph_type"] == "multidigraph" + # Both parallel edges survive (prefixed by repo tag), distinct keys retained. + keys = _edge_keys(merged) + assert ("repo1::x", "repo1::y", 0) in keys + assert ("repo1::x", "repo1::y", 1) in keys + assert merged.number_of_edges() == 3 # 2 parallel + 1 simple + + +def test_merge_graphs_idempotent_under_repeat(monkeypatch, tmp_path): + """STATEFUL idempotence: the SAME merge-graphs run repeated 3× yields a stable + output — edge count, keys and profile unchanged (no duplicated parallel edges, + no key drift, no re-collapse). Inputs are read fresh each run; only the output + is overwritten, so stability proves the keyed compose is deterministic.""" + multi = _make_multidigraph( + [ + {"id": "x", "label": "X", "source_file": "x.py"}, + {"id": "y", "label": "Y", "source_file": "y.py"}, + ], + [ + {"source": "x", "target": "y", "key": 0, "relation": "calls"}, + {"source": "x", "target": "y", "key": 1, "relation": "imports"}, + ], + ) + simple = _make_graph( + [{"id": "z", "label": "Z", "source_file": "z.py"}], + [], + ) + g1 = _repo_graph(tmp_path, "repo1", multi) + g2 = _repo_graph(tmp_path, "repo2", simple) + out_p = tmp_path / "merged.json" + + snapshots = [] + for _ in range(3): + assert _run_merge_graphs(monkeypatch, [g1, g2], out_p) == 0 + merged, data = _reload_graph(out_p) + snapshots.append( + ( + merged.number_of_edges(), + _edge_keys(merged), + data["graph"]["graphify_profile"]["graph_type"], + ) + ) + assert snapshots[0] == snapshots[1] == snapshots[2] + assert snapshots[0][0] == 2 # 2 parallel edges, never duplicated to 4 + assert snapshots[0][2] == "multidigraph" + + +def test_merge_graphs_explicit_simple_target_warns_on_multi(monkeypatch, tmp_path, capsys): + """An EXPLICIT --simple target over a multigraph input projects DOWN to simple + WITH a warning (intentional, audible collapse) — never a silent collapse.""" + multi = _make_multidigraph( + [ + {"id": "x", "label": "X", "source_file": "x.py"}, + {"id": "y", "label": "Y", "source_file": "y.py"}, + ], + [ + {"source": "x", "target": "y", "key": 0, "relation": "calls"}, + {"source": "x", "target": "y", "key": 1, "relation": "imports"}, + ], + ) + simple = _make_graph( + [{"id": "z", "label": "Z", "source_file": "z.py"}], + [], + ) + g1 = _repo_graph(tmp_path, "repo1", multi) + g2 = _repo_graph(tmp_path, "repo2", simple) + out_p = tmp_path / "merged.json" + + with pytest.warns(UserWarning, match="multigraph"): + assert _run_merge_graphs(monkeypatch, [g1, g2], out_p, "--simple") == 0 + + # Loud collapse: a WARNING is also emitted on stderr, and the result is simple. + err = capsys.readouterr().err + assert "WARNING" in err and "multigraph" in err + merged, data = _reload_graph(out_p) + assert not merged.is_multigraph() and not merged.is_directed() + assert data["graph"]["graphify_profile"]["graph_type"] == "simple" + # Parallel edges folded onto a single (x, y) pair (the explicit, warned choice). + assert merged.number_of_edges() == 1 + + +def test_merge_simple_only_regression(monkeypatch, tmp_path): + """Pure simple inputs → simple output, byte-stable across repeated runs (the + default no-flag path must not upgrade or perturb a simple-only merge).""" + s1 = _make_graph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "b", "label": "B", "source_file": "b.py"}, + ], + [{"source": "a", "target": "b", "relation": "calls", "confidence": "EXTRACTED"}], + ) + s2 = _make_graph( + [ + {"id": "c", "label": "C", "source_file": "c.py"}, + {"id": "d", "label": "D", "source_file": "d.py"}, + ], + [{"source": "c", "target": "d", "relation": "uses", "confidence": "EXTRACTED"}], + ) + g1 = _repo_graph(tmp_path, "repo1", s1) + g2 = _repo_graph(tmp_path, "repo2", s2) + out_p = tmp_path / "merged.json" + + assert _run_merge_graphs(monkeypatch, [g1, g2], out_p) == 0 + first = out_p.read_text(encoding="utf-8") + assert _run_merge_graphs(monkeypatch, [g1, g2], out_p) == 0 + second = out_p.read_text(encoding="utf-8") + + assert first == second # byte-stable under repeat + merged, data = _reload_graph(out_p) + assert not merged.is_multigraph() and not merged.is_directed() # stays simple + assert data["graph"]["graphify_profile"]["graph_type"] == "simple" + assert merged.number_of_edges() == 2 # no silent multi upgrade + + +def test_merge_backs_up_before_overwrite(monkeypatch, tmp_path): + """An overwriting merge writes a dated .bak sibling of the pre-merge target + first, so the previous state is recoverable.""" + s1 = _make_graph([{"id": "a", "label": "A", "source_file": "a.py"}], []) + s2 = _make_graph([{"id": "b", "label": "B", "source_file": "b.py"}], []) + g1 = _repo_graph(tmp_path, "repo1", s1) + g2 = _repo_graph(tmp_path, "repo2", s2) + out_p = tmp_path / "merged.json" + + # Pre-seed an existing output so the merge OVERWRITES it (triggers backup). + sentinel = _make_graph([{"id": "old", "label": "OLD", "source_file": "old.py"}], []) + _graph_to_json(sentinel, out_p) + sentinel_bytes = out_p.read_bytes() + + monkeypatch.delenv("GRAPHIFY_NO_BACKUP", raising=False) + assert _run_merge_graphs(monkeypatch, [g1, g2], out_p) == 0 + + backups = list(tmp_path.glob("merged.*.bak")) + assert len(backups) == 1 # exactly one dated backup sibling, no proliferation + assert backups[0].read_bytes() == sentinel_bytes # holds the PRE-overwrite state + + +def test_merge_pre_profile_refused(monkeypatch, tmp_path): + """Merging that would UPGRADE a pre-profile graph (no graphify_profile / + multigraph / directed markers) to a multidigraph target is refused with a + recovery message, leaving the target file unmutated — its lost parallel edges + cannot be reconstructed by an in-place upgrade.""" + from networkx.readwrite import json_graph as jg + + # A pre-profile `current`: strip the multigraph/directed flags AND any profile + # so detect_pre_profile() classifies it as predating class tracking. + current = _make_graph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "b", "label": "B", "source_file": "b.py"}, + ], + [{"source": "a", "target": "b", "relation": "calls"}], + ) + current_p = tmp_path / "current.json" + raw = jg.node_link_data(current, edges="links") + raw.pop("multigraph", None) + raw.pop("directed", None) + raw.pop("graph", None) # no graphify_profile anywhere → pre-profile + current_p.write_text(json.dumps(raw), encoding="utf-8") + pre_bytes = current_p.read_bytes() + + # `other` is a multigraph → the merge would upgrade `current` to multidigraph. + other = _make_multidigraph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "c", "label": "C", "source_file": "c.py"}, + ], + [ + {"source": "a", "target": "c", "key": 0, "relation": "imports"}, + {"source": "a", "target": "c", "key": 1, "relation": "calls"}, + ], + ) + base_p = tmp_path / "base.json" + other_p = tmp_path / "other.json" + _graph_to_json(_make_graph([]), base_p) + _graph_to_json(other, other_p) + + code = _run_merge_driver(monkeypatch, base_p, current_p, other_p) + assert code == 1 # refused → surfaced as a conflict, not silently upgraded + assert current_p.read_bytes() == pre_bytes # target left unmutated + + +def _write_pre_profile_graph(path, nodes, edges): + """Write a LEGACY pre-profile graph.json: bare nodes + links, NO graphify_profile + and NO multigraph/directed flags, so detect_pre_profile() treats it as predating + class tracking (it may already be a silently-collapsed simple graph).""" + from networkx.readwrite import json_graph as jg + + G = _make_graph(nodes, edges) + raw = jg.node_link_data(G, edges="links") + raw.pop("multigraph", None) + raw.pop("directed", None) + raw.pop("graph", None) + path.write_text(json.dumps(raw), encoding="utf-8") + + +def test_merge_driver_pre_profile_other_does_not_block(monkeypatch, tmp_path): + """REGRESSION for the narrowed pre-profile refusal: when `current` is a real + MultiDiGraph and `other` is a LEGACY pre-profile simple graph, the merge must + SUCCEED — `other` is read-only (merged in, never rewritten), so its pre-profile + status implies no unreconstructable in-place loss. Before the fix the refusal + loop also inspected `other`, false-positive-blocking this valid merge with a + misleading 'global-graph.json / rebuild from source' recovery message.""" + # `current`: a genuine MultiDiGraph (carries multigraph:true + parallel edges). + current = _make_multidigraph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "b", "label": "B", "source_file": "b.py"}, + ], + [ + {"source": "a", "target": "b", "key": 0, "relation": "calls"}, + {"source": "a", "target": "b", "key": 1, "relation": "imports"}, + ], + ) + current_p = tmp_path / "current.json" + _graph_to_json(current, current_p) + + # `other`: a legacy pre-profile simple graph (no profile / multigraph flags). + other_p = tmp_path / "other.json" + _write_pre_profile_graph( + other_p, + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "c", "label": "C", "source_file": "c.py"}, + ], + [{"source": "a", "target": "c", "relation": "uses"}], + ) + + base_p = tmp_path / "base.json" + _graph_to_json(_make_graph([]), base_p) + + code = _run_merge_driver(monkeypatch, base_p, current_p, other_p) + assert code == 0 # NOT refused — a pre-profile `other` must not block the merge + + merged, data = _reload_graph(current_p) + assert merged.is_multigraph() # current stays multidigraph, not collapsed + assert data["graph"]["graphify_profile"]["graph_type"] == "multidigraph" + keys = _edge_keys(merged) + # current's parallel edges preserved... + assert ("a", "b", 0) in keys + assert ("a", "b", 1) in keys + # ...and other's edge is merged in (keyed onto the multi target). + assert ("a", "c", 0) in keys + assert merged.number_of_edges() == 3 + + +def test_merge_driver_pre_profile_current_still_refused(monkeypatch, tmp_path): + """Confirm the guard STILL fires for the legitimate case after the fix narrowed + its scope: `current` is a pre-profile simple graph and `other` is a MultiDiGraph, + so the inferred target is multidigraph and the merge would upgrade the + OVERWRITTEN current in place (its lost parallels unreconstructable). merge-driver + must REFUSE (exit 1, recovery message) and leave current unmutated — proving the + fix did not disable the real protection.""" + current_p = tmp_path / "current.json" + _write_pre_profile_graph( + current_p, + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "b", "label": "B", "source_file": "b.py"}, + ], + [{"source": "a", "target": "b", "relation": "calls"}], + ) + pre_bytes = current_p.read_bytes() + + other = _make_multidigraph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "c", "label": "C", "source_file": "c.py"}, + ], + [ + {"source": "a", "target": "c", "key": 0, "relation": "imports"}, + {"source": "a", "target": "c", "key": 1, "relation": "calls"}, + ], + ) + base_p = tmp_path / "base.json" + other_p = tmp_path / "other.json" + _graph_to_json(_make_graph([]), base_p) + _graph_to_json(other, other_p) + + monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) + monkeypatch.setattr( + mainmod.sys, + "argv", + ["graphify", "merge-driver", str(base_p), str(current_p), str(other_p)], + ) + import pytest as _pytest + + with _pytest.raises(SystemExit) as exc_info: + mainmod.main() + assert exc_info.value.code == 1 # the real protection still fires + assert current_p.read_bytes() == pre_bytes # current left unmutated + + +def test_merge_driver_pre_profile_current_refusal_message(monkeypatch, tmp_path, capsys): + """Companion to the refusal test: the refusal prints the recovery message + (rebuild-from-source guidance), not a silent failure.""" + current_p = tmp_path / "current.json" + _write_pre_profile_graph( + current_p, + [{"id": "a", "label": "A", "source_file": "a.py"}], + [], + ) + other = _make_multidigraph( + [ + {"id": "a", "label": "A", "source_file": "a.py"}, + {"id": "c", "label": "C", "source_file": "c.py"}, + ], + [ + {"source": "a", "target": "c", "key": 0, "relation": "imports"}, + {"source": "a", "target": "c", "key": 1, "relation": "calls"}, + ], + ) + base_p = tmp_path / "base.json" + other_p = tmp_path / "other.json" + _graph_to_json(_make_graph([]), base_p) + _graph_to_json(other, other_p) + + code = _run_merge_driver(monkeypatch, base_p, current_p, other_p) + assert code == 1 + err = capsys.readouterr().err + assert "pre-profile" in err + assert "multidigraph" in err + assert str(current_p) in err + assert "regenerate" in err or "recreate" in err + assert "global-graph.json" not in err + assert "graphify global remove" not in err + + +def test_merge_backup_suppressed_by_env(monkeypatch, tmp_path): + """`_backup_merge_target` honors the GRAPHIFY_NO_BACKUP env var: with it set, an + overwriting merge writes NO .bak; without it, the dated .bak sibling is created. + Confirms the env-suppression path Copilot flagged as only indirectly exercised.""" + s1 = _make_graph([{"id": "a", "label": "A", "source_file": "a.py"}], []) + s2 = _make_graph([{"id": "b", "label": "B", "source_file": "b.py"}], []) + g1 = _repo_graph(tmp_path, "repo1", s1) + g2 = _repo_graph(tmp_path, "repo2", s2) + + # --- with GRAPHIFY_NO_BACKUP=1: overwrite an existing target, expect NO .bak --- + out_suppressed = tmp_path / "merged_suppressed.json" + _graph_to_json( + _make_graph([{"id": "old", "label": "OLD", "source_file": "old.py"}], []), out_suppressed + ) + monkeypatch.setenv("GRAPHIFY_NO_BACKUP", "1") + assert _run_merge_graphs(monkeypatch, [g1, g2], out_suppressed) == 0 + assert list(tmp_path.glob("merged_suppressed.*.bak")) == [] # suppressed → no backup + + # --- without it: overwrite an existing target, expect the .bak to appear --- + out_enabled = tmp_path / "merged_enabled.json" + sentinel = _make_graph([{"id": "old", "label": "OLD", "source_file": "old.py"}], []) + _graph_to_json(sentinel, out_enabled) + sentinel_bytes = out_enabled.read_bytes() + monkeypatch.delenv("GRAPHIFY_NO_BACKUP", raising=False) + assert _run_merge_graphs(monkeypatch, [g1, g2], out_enabled) == 0 + backups = list(tmp_path.glob("merged_enabled.*.bak")) + assert len(backups) == 1 # backup created when env is unset + assert backups[0].read_bytes() == sentinel_bytes # holds the PRE-overwrite state diff --git a/tests/test_google_workspace.py b/tests/test_google_workspace.py index 9d8cbfa4b..c23913304 100644 --- a/tests/test_google_workspace.py +++ b/tests/test_google_workspace.py @@ -1,4 +1,3 @@ -from pathlib import Path import json import graphify.google_workspace as gw diff --git a/tests/test_graph_loader.py b/tests/test_graph_loader.py new file mode 100644 index 000000000..4fcd3ce45 --- /dev/null +++ b/tests/test_graph_loader.py @@ -0,0 +1,557 @@ +"""Tests for graphify.graph_loader — schema-aware graph loading. + +Seven required PR 2 scenarios from the Wave 3 handoff guardrails. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import networkx as nx +import pytest + +from graphify.graph_loader import GRAPHIFY_PROFILE_KEY, load_graph + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +_NODES = [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "b.py"}, +] + +_SIMPLE_EDGE = { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "a.py", + "weight": 1.0, +} + +_KEYED_EDGE = {**_SIMPLE_EDGE, "key": "calls:a.py:L1"} + +_KEYED_EDGE_2 = { + "source": "a", + "target": "b", + "relation": "imports", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "a.py", + "key": "imports:a.py:L5", + "weight": 1.0, +} + + +def _simple_links() -> dict: + """Legacy simple JSON using 'links' key.""" + return {"nodes": _NODES, "links": [_SIMPLE_EDGE]} + + +def _simple_edges() -> dict: + """Modern simple JSON using 'edges' key.""" + return {"nodes": _NODES, "edges": [_SIMPLE_EDGE]} + + +def _multigraph_data() -> dict: + """Valid multigraph node-link JSON with two keyed parallel edges.""" + return { + "multigraph": True, + "nodes": _NODES, + "links": [_KEYED_EDGE, _KEYED_EDGE_2], + } + + +def _multigraph_missing_keys() -> dict: + """Multigraph JSON where edges lack 'key' fields.""" + edge_no_key = {k: v for k, v in _SIMPLE_EDGE.items() if k != "key"} + return {"multigraph": True, "nodes": _NODES, "links": [edge_no_key]} + + +# --------------------------------------------------------------------------- +# Scenario 1: legacy 'links' loads as nx.Graph +# --------------------------------------------------------------------------- + + +def test_load_graph_rejects_non_object_payload(): + with pytest.raises(TypeError, match="serialized graph data must be a JSON object"): + load_graph([]) + + +def test_load_graph_rejects_non_list_nodes(): + data = {**_simple_links(), "nodes": 123} + + with pytest.raises(TypeError, match="'nodes' must be a list, got int"): + load_graph(data) + + +def test_legacy_links_loads_as_simple_graph(): + G = load_graph(_simple_links()) + assert type(G) is nx.Graph + assert not G.is_multigraph() + assert G.number_of_nodes() == 2 + assert G.number_of_edges() == 1 + + +# --------------------------------------------------------------------------- +# Scenario 2: modern 'edges' loads as nx.Graph +# --------------------------------------------------------------------------- + + +def test_modern_edges_loads_as_simple_graph(): + G = load_graph(_simple_edges()) + assert type(G) is nx.Graph + assert not G.is_multigraph() + assert G.number_of_edges() == 1 + + +# --------------------------------------------------------------------------- +# Scenario 3: valid multigraph JSON with keyed parallel edges → nx.MultiDiGraph +# --------------------------------------------------------------------------- + + +def test_valid_multigraph_loads_as_multidigraph(): + G = load_graph(_multigraph_data()) + assert type(G) is nx.MultiDiGraph + assert G.is_multigraph() + assert G.number_of_nodes() == 2 + assert G.number_of_edges() == 2 # both parallel edges preserved + + +# --------------------------------------------------------------------------- +# Scenario 4: malformed multigraph (missing keys) repairs explicitly, not silently +# --------------------------------------------------------------------------- + + +def test_malformed_multigraph_missing_keys_repairs_explicitly(capsys): + G = load_graph(_multigraph_missing_keys()) + # Must produce a MultiDiGraph (not silently fall back to simple) + assert type(G) is nx.MultiDiGraph + assert G.number_of_edges() == 1 + # Must warn to stderr + captured = capsys.readouterr() + assert "missing" in captured.err.lower() or "key" in captured.err.lower() + + +# --------------------------------------------------------------------------- +# Scenario 5: edge 'key' is stripped from attrs — not stored as an edge attribute +# --------------------------------------------------------------------------- + + +def test_schema_key_stripped_from_edge_attrs(): + G = load_graph(_multigraph_data()) + assert isinstance(G, nx.MultiDiGraph) + for u, v, k, data in G.edges(keys=True, data=True): + assert "key" not in data, ( + f"Edge ({u},{v},key={k!r}) must not store 'key' inside its attrs dict" + ) + + +# --------------------------------------------------------------------------- +# Scenario 6: G.graph["graphify_profile"] is present after load +# --------------------------------------------------------------------------- + + +def test_graph_profile_metadata_round_trips(): + G = load_graph(_simple_links()) + assert GRAPHIFY_PROFILE_KEY in G.graph + profile = G.graph[GRAPHIFY_PROFILE_KEY] + assert isinstance(profile, dict) + assert "graph_type" in profile + + +def test_graph_profile_type_for_multidigraph(): + G = load_graph(_multigraph_data()) + assert G.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == "multidigraph" + + +def test_graph_profile_type_for_simple(): + G = load_graph(_simple_links()) + assert G.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == "simple" + + +# --------------------------------------------------------------------------- +# Scenario 7: capability probe failure raises clearly; simple loading unaffected +# --------------------------------------------------------------------------- + + +def test_capability_probe_failure_raises_clear_error(): + with patch( + "graphify.graph_loader.require_multigraph_capabilities", + side_effect=RuntimeError("MultiDiGraph not supported: simulated failure"), + ): + with pytest.raises(RuntimeError, match="MultiDiGraph not supported"): + load_graph(_multigraph_data(), require_capabilities=True) + + +def test_capability_probe_failure_does_not_affect_simple_load(): + with patch( + "graphify.graph_loader.require_multigraph_capabilities", + side_effect=RuntimeError("should not be called"), + ): + # Simple JSON must not trigger the capability probe at all + G = load_graph(_simple_links(), require_capabilities=True) + assert type(G) is nx.Graph + + +# --------------------------------------------------------------------------- +# Blocker 2: missing-key repair must preserve distinct parallel edges +# --------------------------------------------------------------------------- + + +def _two_missing_key_parallel_edges() -> dict: + """Multigraph with two missing-key edges sharing relation/file but different attrs.""" + return { + "multigraph": True, + "nodes": _NODES, + "links": [ + { + "source": "a", + "target": "b", + "relation": "calls", + "source_file": "a.py", + "confidence": "EXTRACTED", + "weight": 1.0, + "context": "one", + }, + { + "source": "a", + "target": "b", + "relation": "calls", + "source_file": "a.py", + "confidence": "EXTRACTED", + "weight": 1.0, + "context": "two", + }, + ], + } + + +def test_missing_key_repair_preserves_distinct_parallel_edges(capsys): + G = load_graph(_two_missing_key_parallel_edges()) + assert type(G) is nx.MultiDiGraph + assert G.number_of_edges() == 2, ( + f"Both missing-key parallel edges must survive repair; got {G.number_of_edges()}" + ) + captured = capsys.readouterr() + assert "missing" in captured.err.lower() or "key" in captured.err.lower() + + +# --------------------------------------------------------------------------- +# Blocker 3: simple loader must respect serialized directedness +# --------------------------------------------------------------------------- + + +def test_directed_true_loads_as_digraph(): + data = { + "directed": True, + "multigraph": False, + "nodes": _NODES, + "edges": [_SIMPLE_EDGE], + } + G = load_graph(data) + assert type(G) is nx.DiGraph + + +def test_directed_false_explicitly_loads_as_graph(): + data = { + "directed": False, + "multigraph": False, + "nodes": _NODES, + "edges": [_SIMPLE_EDGE], + } + G = load_graph(data) + assert type(G) is nx.Graph + + +def test_directed_true_profile_graph_type(): + data = { + "directed": True, + "multigraph": False, + "nodes": _NODES, + "edges": [_SIMPLE_EDGE], + } + G = load_graph(data) + assert G.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == "digraph" + + +# --------------------------------------------------------------------------- +# Blocker 4: malformed JSON must fail cleanly or skip under documented policy +# --------------------------------------------------------------------------- + + +def test_non_dict_edge_entries_are_skipped(): + data = {"nodes": _NODES, "edges": ["not-a-dict", None, 42]} + G = load_graph(data) + assert G.number_of_edges() == 0 + + +def test_edges_value_not_a_list_raises(): + data = {"nodes": _NODES, "edges": "not-a-list"} + with pytest.raises((TypeError, ValueError)): + load_graph(data) + + +def test_non_dict_graphify_profile_is_ignored(): + data = { + "nodes": _NODES, + "edges": [_SIMPLE_EDGE], + GRAPHIFY_PROFILE_KEY: "bad-profile", + } + G = load_graph(data) + assert isinstance(G.graph[GRAPHIFY_PROFILE_KEY], dict) + assert "graph_type" in G.graph[GRAPHIFY_PROFILE_KEY] + + +def test_edge_missing_source_or_target_skipped(): + data = { + "nodes": _NODES, + "edges": [ + {"target": "b", "relation": "calls"}, + {"source": "a", "relation": "calls"}, + ], + } + G = load_graph(data) + assert G.number_of_edges() == 0 + + +# --------------------------------------------------------------------------- +# Non-string multigraph key values must raise before NetworkX sees them +# --------------------------------------------------------------------------- + + +def _multigraph_with_key(key_value: object) -> dict: + return { + "multigraph": True, + "nodes": _NODES, + "links": [{**_SIMPLE_EDGE, "key": key_value}], + } + + +def test_multigraph_list_key_raises(): + with pytest.raises((TypeError, ValueError)): + load_graph(_multigraph_with_key(["bad"])) + + +def test_multigraph_dict_key_raises(): + with pytest.raises((TypeError, ValueError)): + load_graph(_multigraph_with_key({"bad": 1})) + + +def test_multigraph_int_key_raises(): + with pytest.raises((TypeError, ValueError)): + load_graph(_multigraph_with_key(123)) + + +def test_load_simple_edge_with_empty_string_source_not_shadowed_by_from(): + # An edge with source="" AND from="a" must not silently use "from" as the + # source — an explicitly-set empty source means the edge is invalid. + data = { + "nodes": _NODES, + "links": [{"source": "", "from": "a", "target": "b", "relation": "calls"}], + } + G = load_graph(data) + assert G.number_of_edges() == 0 + + +def test_load_simple_edge_with_from_key_loaded(): + # Edges using legacy "from"/"to" keys should load correctly as long as + # the IDs are non-empty and present in the node set. + data = { + "nodes": _NODES, + "links": [{"from": "a", "to": "b", "relation": "calls"}], + } + G = load_graph(data) + assert G.number_of_edges() == 1 + + +def test_load_simple_preserves_falsy_hashable_ids(): + """Falsy-but-hashable node IDs like 0 or False must survive the loader.""" + data = { + "directed": False, + "multigraph": False, + "nodes": [{"id": 0}, {"id": ""}, {"id": "x"}], + "links": [ + {"source": 0, "target": "", "relation": "calls"}, + {"source": 0, "target": "x", "relation": "imports"}, + ], + } + G = load_graph(data) + assert G.number_of_nodes() == 3 + assert G.number_of_edges() == 2 + assert G.has_edge(0, "") + assert G.has_edge(0, "x") + + +def test_load_directed_preserves_falsy_hashable_ids(): + data = { + "directed": True, + "multigraph": False, + "nodes": [{"id": 0}, {"id": "y"}], + "links": [{"source": 0, "target": "y", "relation": "calls"}], + } + G = load_graph(data) + assert G.number_of_edges() == 1 + assert G.has_edge(0, "y") + + +def test_load_multigraph_preserves_falsy_hashable_ids(): + data = { + "directed": True, + "multigraph": True, + "nodes": [{"id": 0}, {"id": 1}], + "links": [ + {"source": 0, "target": 1, "key": "k1", "relation": "calls"}, + {"source": 0, "target": 1, "key": "k2", "relation": "imports"}, + ], + } + G = load_graph(data) + assert G.number_of_edges() == 2 + assert G.has_edge(0, 1) + + +def test_graph_attributes_round_trip_through_node_link_data(): + """G.graph[...] attrs must survive node_link_data → load_graph round-trip. + + NetworkX serializes graph-level metadata under data["graph"]; the loader + must read from there, not only from top-level keys. + """ + import networkx as nx + from networkx.readwrite import json_graph + + G_out = nx.DiGraph() + G_out.add_node("a") + G_out.add_node("b") + G_out.add_edge("a", "b", relation="calls") + G_out.graph["graphify_profile"] = {"graph_type": "digraph", "extra": "value"} + G_out.graph["hyperedges"] = [{"members": ["a", "b"]}] + G_out.graph["graphify_multigraph_diagnostics"] = {"collapsed": 0} + + data = json_graph.node_link_data(G_out, edges="links") + G_in = load_graph(data) + + assert G_in.graph["graphify_profile"]["extra"] == "value" + assert G_in.graph["hyperedges"] == [{"members": ["a", "b"]}] + assert G_in.graph["graphify_multigraph_diagnostics"] == {"collapsed": 0} + + +def test_graph_attributes_round_trip_through_multigraph_node_link_data(): + """Same round-trip guarantee for multigraph exports.""" + import networkx as nx + from networkx.readwrite import json_graph + + G_out = nx.MultiDiGraph() + G_out.add_node("a") + G_out.add_node("b") + G_out.add_edge("a", "b", key="k1", relation="calls") + G_out.add_edge("a", "b", key="k2", relation="imports") + G_out.graph["graphify_profile"] = {"graph_type": "multidigraph"} + G_out.graph["graphify_multigraph_diagnostics"] = {"exact_duplicates": 0} + + data = json_graph.node_link_data(G_out, edges="links") + G_in = load_graph(data, require_capabilities=False) + + assert G_in.graph["graphify_profile"]["graph_type"] == "multidigraph" + assert G_in.graph["graphify_multigraph_diagnostics"] == {"exact_duplicates": 0} + assert G_in.number_of_edges() == 2 + + +def test_load_skips_unhashable_node_ids(capsys): + """Corrupted graph.json with unhashable node ids must not crash; skip + warn.""" + data = { + "directed": True, + "multigraph": False, + "nodes": [{"id": "ok"}, {"id": ["unhashable"]}, {"id": {"also": "unhashable"}}], + "links": [{"source": "ok", "target": "ok", "relation": "self"}], + } + G = load_graph(data) + captured = capsys.readouterr() + assert G.number_of_nodes() == 1 + assert "unhashable" in captured.err.lower() + + +def test_load_skips_edges_with_unhashable_endpoints(): + """Edges with unhashable source/target must be skipped, not raise TypeError.""" + data = { + "directed": True, + "multigraph": False, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [ + {"source": "a", "target": "b", "relation": "calls"}, + {"source": ["unhashable"], "target": "b", "relation": "bogus"}, + {"source": "a", "target": {"also": "unhashable"}, "relation": "bogus"}, + ], + } + G = load_graph(data) + assert G.number_of_edges() == 1 + + +def test_load_multigraph_skips_unhashable_endpoints(): + """Same protection in the multigraph loader.""" + data = { + "directed": True, + "multigraph": True, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [ + {"source": "a", "target": "b", "key": "k1", "relation": "calls"}, + {"source": ["bad"], "target": "b", "key": "k2", "relation": "calls"}, + ], + } + G = load_graph(data, require_capabilities=False) + assert G.number_of_edges() == 1 + + +def test_load_multigraph_duplicate_keys_repaired_not_overwritten(capsys): + """Two parallel edges with same (src, tgt, key) but different attrs must both survive.""" + data = { + "directed": True, + "multigraph": True, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [ + {"source": "a", "target": "b", "key": "same", "relation": "calls", "context": "one"}, + {"source": "a", "target": "b", "key": "same", "relation": "calls", "context": "two"}, + ], + } + G = load_graph(data, require_capabilities=False) + assert G.number_of_edges() == 2, "duplicate-key edges must both be preserved via repair keys" + captured = capsys.readouterr() + assert "duplicate" in captured.err.lower() + + +def test_load_graph_rejects_non_bool_multigraph_field(): + """String 'false' or other non-bool 'multigraph' must be rejected, not coerced.""" + data = {**_simple_links(), "multigraph": "false"} + with pytest.raises(TypeError, match="'multigraph' must be a boolean"): + load_graph(data) + + +def test_load_graph_rejects_non_bool_directed_field(): + data = {**_simple_links(), "directed": "true"} + with pytest.raises(TypeError, match="'directed' must be a boolean"): + load_graph(data) + + +def test_load_multigraph_with_omitted_directed_does_not_warn(capsys): + """Missing 'directed' alongside 'multigraph: true' must not trigger the false warning.""" + data = { + "multigraph": True, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [{"source": "a", "target": "b", "key": "k", "relation": "calls"}], + } + load_graph(data, require_capabilities=False) + captured = capsys.readouterr() + assert "multigraph=true but directed=false" not in captured.err + + +def test_load_graph_overwrites_stale_graph_type_in_profile(): + """Stale graph_type from serialized profile must not survive when loading.""" + data = { + "multigraph": True, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [{"source": "a", "target": "b", "key": "k", "relation": "calls"}], + "graph": {"graphify_profile": {"graph_type": "simple"}}, + } + G = load_graph(data, require_capabilities=False) + assert G.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == "multidigraph" diff --git a/tests/test_hooks.py b/tests/test_hooks.py index 873b2028c..7f43b2ab1 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -1,4 +1,5 @@ """Tests for hooks.py - git hook install/uninstall.""" + import os import subprocess from types import SimpleNamespace @@ -120,7 +121,6 @@ def test_status_shows_both_hooks(tmp_path): assert result.count("installed") >= 2 - def test_hooks_dir_resolves_relative_git_hooks_path(tmp_path, monkeypatch): repo = _make_git_repo(tmp_path) @@ -155,15 +155,18 @@ def fake_run(*args, **kwargs): assert _hooks_dir(repo) == hooks.resolve() + def test_hook_skips_head_on_exe(): """Hook script must skip shebang extraction for .exe binaries (Windows).""" from graphify.hooks import _PYTHON_DETECT - assert "*.exe) _SHEBANG=" in _PYTHON_DETECT or '*.exe)' in _PYTHON_DETECT + + assert "*.exe) _SHEBANG=" in _PYTHON_DETECT or "*.exe)" in _PYTHON_DETECT def test_hook_check_no_additionalContext(tmp_path): """graphify hook-check must not emit additionalContext — Codex Desktop rejects it.""" import sys + out = tmp_path / "graphify-out" out.mkdir() (out / "graph.json").write_text("{}", encoding="utf-8") diff --git a/tests/test_hypergraph.py b/tests/test_hypergraph.py index dda8ac793..f82d36816 100644 --- a/tests/test_hypergraph.py +++ b/tests/test_hypergraph.py @@ -1,11 +1,11 @@ """Tests for hyperedge support in graphify.""" + from __future__ import annotations import json import tempfile from pathlib import Path import networkx as nx -import pytest from graphify.build import build_from_json from graphify.export import attach_hyperedges, to_json @@ -22,10 +22,22 @@ {"id": "DigestAuth", "label": "DigestAuth", "file_type": "code", "source_file": "auth.py"}, {"id": "Request", "label": "Request", "file_type": "code", "source_file": "http.py"}, {"id": "Response", "label": "Response", "file_type": "code", "source_file": "http.py"}, - {"id": "BaseClient", "label": "BaseClient", "file_type": "code", "source_file": "client.py"}, + { + "id": "BaseClient", + "label": "BaseClient", + "file_type": "code", + "source_file": "client.py", + }, ], "edges": [ - {"source": "BasicAuth", "target": "Request", "relation": "uses", "confidence": "EXTRACTED", "confidence_score": 1.0, "source_file": "auth.py"}, + { + "source": "BasicAuth", + "target": "Request", + "relation": "uses", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "auth.py", + }, ], "hyperedges": [ { @@ -55,6 +67,7 @@ # 1. Hyperedges survive build_from_json round-trip # --------------------------------------------------------------------------- + def test_build_from_json_stores_hyperedges(): G = build_from_json(SAMPLE_EXTRACTION) assert "hyperedges" in G.graph @@ -78,6 +91,7 @@ def test_build_from_json_missing_hyperedges_key(): # 2. attach_hyperedges deduplicates by id # --------------------------------------------------------------------------- + def test_attach_hyperedges_adds_new(): G = nx.Graph() attach_hyperedges(G, [{"id": "auth_flow", "label": "Auth Flow", "nodes": ["A", "B", "C"]}]) @@ -94,10 +108,13 @@ def test_attach_hyperedges_deduplicates(): def test_attach_hyperedges_multiple_different_ids(): G = nx.Graph() - attach_hyperedges(G, [ - {"id": "flow_a", "label": "Flow A", "nodes": ["A", "B", "C"]}, - {"id": "flow_b", "label": "Flow B", "nodes": ["D", "E", "F"]}, - ]) + attach_hyperedges( + G, + [ + {"id": "flow_a", "label": "Flow A", "nodes": ["A", "B", "C"]}, + {"id": "flow_b", "label": "Flow B", "nodes": ["D", "E", "F"]}, + ], + ) assert len(G.graph["hyperedges"]) == 2 @@ -111,6 +128,7 @@ def test_attach_hyperedges_skips_entry_without_id(): # 3. to_json includes hyperedges key # --------------------------------------------------------------------------- + def test_to_json_includes_hyperedges(): G = build_from_json(SAMPLE_EXTRACTION) communities = {0: list(G.nodes())} @@ -139,6 +157,7 @@ def test_to_json_hyperedges_empty_when_none(): # 4. Hyperedges loaded from graph.json via build_from_json # --------------------------------------------------------------------------- + def test_hyperedges_roundtrip_via_json_file(): """Write graph.json then reload it - hyperedges must survive.""" G = build_from_json(SAMPLE_EXTRACTION) @@ -149,11 +168,22 @@ def test_hyperedges_roundtrip_via_json_file(): # Reload the JSON as if build_from_json were called on it data = json.loads(Path(path).read_text()) - G2 = build_from_json({ - "nodes": [{"id": n["id"], **{k: v for k, v in n.items() if k != "id"}} for n in data["nodes"]], - "edges": [{"source": e["source"], "target": e["target"], **{k: v for k, v in e.items() if k not in ("source", "target")}} for e in data.get("links", [])], - "hyperedges": data.get("hyperedges", []), - }) + G2 = build_from_json( + { + "nodes": [ + {"id": n["id"], **{k: v for k, v in n.items() if k != "id"}} for n in data["nodes"] + ], + "edges": [ + { + "source": e["source"], + "target": e["target"], + **{k: v for k, v in e.items() if k not in ("source", "target")}, + } + for e in data.get("links", []) + ], + "hyperedges": data.get("hyperedges", []), + } + ) assert G2.graph.get("hyperedges", []) != [] assert G2.graph["hyperedges"][0]["id"] == "auth_flow" @@ -162,13 +192,24 @@ def test_hyperedges_roundtrip_via_json_file(): # 5. Report includes hyperedges section when hyperedges present # --------------------------------------------------------------------------- + def _make_report(G): communities = {0: list(G.nodes())} cohesion = {0: 1.0} labels = {0: "All"} gods = [{"label": "BasicAuth", "degree": 2}] surprises = [] - return generate(G, communities, cohesion, labels, gods, surprises, SAMPLE_DETECTION, {"input": 10, "output": 5}, ".") + return generate( + G, + communities, + cohesion, + labels, + gods, + surprises, + SAMPLE_DETECTION, + {"input": 10, "output": 5}, + ".", + ) def test_report_includes_hyperedges_section(): @@ -191,6 +232,7 @@ def test_report_includes_hyperedge_node_list(): # 6. Report skips hyperedges section when none present # --------------------------------------------------------------------------- + def test_report_skips_hyperedges_section_when_empty(): extraction = {**SAMPLE_EXTRACTION, "hyperedges": []} G = build_from_json(extraction) diff --git a/tests/test_import_extension_resolution.py b/tests/test_import_extension_resolution.py index 0d1222c0a..930c2fcaa 100644 --- a/tests/test_import_extension_resolution.py +++ b/tests/test_import_extension_resolution.py @@ -25,8 +25,11 @@ def _write(path: Path, body: str) -> Path: def _import_targets(result: dict) -> set[str]: - return {str(e.get("target") or "") for e in result["edges"] - if e.get("relation") in ("imports", "imports_from")} + return { + str(e.get("target") or "") + for e in result["edges"] + if e.get("relation") in ("imports", "imports_from") + } # ── _resolve_js_module_path unit tests ────────────────────────────────────── @@ -94,13 +97,10 @@ def test_resolve_svelte_to_svelte_ts_for_rune_files(tmp_path): """Svelte 5: `from './foo.svelte'` may actually point at `foo.svelte.ts` (a rune-only TypeScript file with no .svelte file). The resolver must APPEND .ts to the full filename, not swap suffixes.""" - target = _write(tmp_path / "is-mobile.svelte.ts", - "export const isMobile = () => true") + target = _write(tmp_path / "is-mobile.svelte.ts", "export const isMobile = () => true") written_as = tmp_path / "is-mobile.svelte" resolved = _resolve_js_module_path(written_as) - assert resolved == target, ( - f"Expected resolution to is-mobile.svelte.ts; got {resolved}" - ) + assert resolved == target, f"Expected resolution to is-mobile.svelte.ts; got {resolved}" def test_resolve_svelte_to_svelte_js_for_javascript_rune_files(tmp_path): @@ -111,8 +111,7 @@ def test_resolve_svelte_to_svelte_js_for_javascript_rune_files(tmp_path): Same code path as the .svelte.ts case — the generalized resolver tries every extension in priority order, so JS-only and TS-only projects both work without special-casing.""" - target = _write(tmp_path / "store.svelte.js", - "export const count = $state(0)") + target = _write(tmp_path / "store.svelte.js", "export const count = $state(0)") written_as = tmp_path / "store.svelte" resolved = _resolve_js_module_path(written_as) assert resolved == target @@ -128,10 +127,8 @@ def test_resolve_svelte_prefers_svelte_ts_over_svelte_js(tmp_path): expect tooling to read the `.svelte.ts` source. graphify is a source- code tool, not a runtime resolver, so source-first ordering is correct for our use case.""" - ts_target = _write(tmp_path / "store.svelte.ts", - "export const count = $state(0)") - _write(tmp_path / "store.svelte.js", - "export const count = 0 // build artifact") + ts_target = _write(tmp_path / "store.svelte.ts", "export const count = $state(0)") + _write(tmp_path / "store.svelte.js", "export const count = 0 // build artifact") written_as = tmp_path / "store.svelte" resolved = _resolve_js_module_path(written_as) assert resolved == ts_target @@ -142,8 +139,10 @@ def test_resolve_real_svelte_file_wins_over_svelte_ts_sibling(tmp_path): must resolve to that — not get hijacked to a sibling `foo.svelte.ts` rune file. The existence-check short-circuits before any append.""" real = _write(tmp_path / "Card.svelte", "
card markup
") - _write(tmp_path / "Card.svelte.ts", - "export const helpers = {} // rune sibling, not the import target") + _write( + tmp_path / "Card.svelte.ts", + "export const helpers = {} // rune sibling, not the import target", + ) resolved = _resolve_js_module_path(real) assert resolved == real @@ -180,10 +179,8 @@ def test_resolve_real_js_stays_js_when_ts_does_not_exist(tmp_path): def test_bare_path_import_resolves_in_ts_file(tmp_path): """The #716 reproducer: TS file imports a sibling without an extension.""" - target = _write(tmp_path / "type-helpers.ts", - "export type GetNestedType = T") - importer = _write(tmp_path / "page.ts", - "import type { GetNestedType } from './type-helpers'\n") + target = _write(tmp_path / "type-helpers.ts", "export type GetNestedType = T") + importer = _write(tmp_path / "page.ts", "import type { GetNestedType } from './type-helpers'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -194,10 +191,8 @@ def test_bare_path_import_resolves_in_ts_file(tmp_path): def test_directory_import_resolves_to_index_ts(tmp_path): """`from './queue'` must resolve to `./queue/index.ts`.""" - target = _write(tmp_path / "queue" / "index.ts", - "export const enqueue = () => {}") - importer = _write(tmp_path / "page.ts", - "import { enqueue } from './queue'\n") + target = _write(tmp_path / "queue" / "index.ts", "export const enqueue = () => {}") + importer = _write(tmp_path / "page.ts", "import { enqueue } from './queue'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -211,10 +206,8 @@ def test_directory_import_resolves_to_index_ts(tmp_path): def test_dot_svelte_import_resolves_to_dot_svelte_ts(tmp_path): """Svelte 5 rune file: import written as .svelte, real file is .svelte.ts.""" - target = _write(tmp_path / "is-mobile.svelte.ts", - "export const isMobile = () => true") - importer = _write(tmp_path / "page.ts", - "import { isMobile } from './is-mobile.svelte'\n") + target = _write(tmp_path / "is-mobile.svelte.ts", "export const isMobile = () => true") + importer = _write(tmp_path / "page.ts", "import { isMobile } from './is-mobile.svelte'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -230,8 +223,7 @@ def test_explicit_ts_import_still_works(tmp_path): """The most common case — import with explicit .ts extension — must continue to work after the resolver change.""" target = _write(tmp_path / "foo.ts", "export const x = 1") - importer = _write(tmp_path / "page.ts", - "import { x } from './foo.ts'\n") + importer = _write(tmp_path / "page.ts", "import { x } from './foo.ts'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -244,8 +236,7 @@ def test_explicit_svelte_import_still_works(tmp_path): """Real .svelte file imports must still resolve when the .svelte file exists (i.e. don't accidentally redirect to a non-existent .svelte.ts).""" target = _write(tmp_path / "Card.svelte", "
") - importer = _write(tmp_path / "page.ts", - "import Card from './Card.svelte'\n") + importer = _write(tmp_path / "page.ts", "import Card from './Card.svelte'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -259,14 +250,12 @@ def test_external_module_unchanged(tmp_path): """Bare module specifiers (no leading dot, no alias match) must still fall through to the external/last-segment path — don't accidentally treat 'lodash' as a relative path.""" - importer = _write(tmp_path / "page.ts", - "import _ from 'lodash-es'\n") + importer = _write(tmp_path / "page.ts", "import _ from 'lodash-es'\n") result = extract_js(importer) targets = _import_targets(result) # The target should be the bare module name, not a resolved file path assert "lodash_es" in targets or any("lodash" in t for t in targets), ( - f"External module specifier should still produce an external " - f"reference; got {targets}" + f"External module specifier should still produce an external reference; got {targets}" ) @@ -276,19 +265,17 @@ def test_external_module_unchanged(tmp_path): def test_alias_import_with_bare_path_resolves(tmp_path): """`$lib/foo` (alias + bare path) — both layers must work together.""" src = tmp_path / "src" - target = _write(src / "lib" / "type-helpers.ts", - "export type X = string") - _write(tmp_path / "tsconfig.json", - '{"compilerOptions":{"paths":{"$lib":["./src/lib"],' - '"$lib/*":["./src/lib/*"]}}}') + target = _write(src / "lib" / "type-helpers.ts", "export type X = string") + _write( + tmp_path / "tsconfig.json", + '{"compilerOptions":{"paths":{"$lib":["./src/lib"],"$lib/*":["./src/lib/*"]}}}', + ) importer_dir = src / "routes" - importer = _write(importer_dir / "page.ts", - "import type { X } from '$lib/type-helpers'\n") + importer = _write(importer_dir / "page.ts", "import type { X } from '$lib/type-helpers'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( - f"Alias + bare-path resolution failed; " - f"expected {expected}; got {_import_targets(result)}" + f"Alias + bare-path resolution failed; expected {expected}; got {_import_targets(result)}" ) @@ -299,10 +286,8 @@ def test_type_only_import_with_bare_path_resolves(tmp_path): """`import type { X } from './foo'` — type-only imports must go through the same resolution path as regular imports. Common in TS codebases that separate types into their own module.""" - target = _write(tmp_path / "type-helpers.ts", - "export type GetNestedType = T") - importer = _write(tmp_path / "page.ts", - "import type { GetNestedType } from './type-helpers'\n") + target = _write(tmp_path / "type-helpers.ts", "export type GetNestedType = T") + importer = _write(tmp_path / "page.ts", "import type { GetNestedType } from './type-helpers'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -317,8 +302,7 @@ def test_named_imports_emit_symbol_edges_after_resolution(tmp_path): `imports_from`. The symbol-edge target_stem comes from _file_stem(resolved), which depends on resolution succeeding first.""" _write(tmp_path / "utils.ts", "export const foo = 1\nexport const bar = 2") - importer = _write(tmp_path / "page.ts", - "import { foo, bar } from './utils'\n") + importer = _write(tmp_path / "page.ts", "import { foo, bar } from './utils'\n") result = extract_js(importer) sym_edges = [e for e in result["edges"] if e.get("relation") == "imports"] targets = {str(e.get("target") or "") for e in sym_edges} @@ -334,18 +318,16 @@ def test_named_imports_emit_symbol_edges_after_resolution(tmp_path): def test_alias_directory_import_resolves_to_index_ts(tmp_path): """`from '$lib/queue'` where queue/ is a directory under src/lib/.""" src = tmp_path / "src" - target = _write(src / "lib" / "queue" / "index.ts", - "export const enqueue = () => {}") - _write(tmp_path / "tsconfig.json", - '{"compilerOptions":{"paths":{"$lib":["./src/lib"],' - '"$lib/*":["./src/lib/*"]}}}') - importer = _write(src / "routes" / "page.ts", - "import { enqueue } from '$lib/queue'\n") + target = _write(src / "lib" / "queue" / "index.ts", "export const enqueue = () => {}") + _write( + tmp_path / "tsconfig.json", + '{"compilerOptions":{"paths":{"$lib":["./src/lib"],"$lib/*":["./src/lib/*"]}}}', + ) + importer = _write(src / "routes" / "page.ts", "import { enqueue } from '$lib/queue'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( - f"Alias + directory resolution failed; " - f"expected {expected}; got {_import_targets(result)}" + f"Alias + directory resolution failed; expected {expected}; got {_import_targets(result)}" ) @@ -358,9 +340,7 @@ def test_resolve_does_not_match_partial_directory_name(tmp_path): bare = tmp_path / "foo" resolved = _resolve_js_module_path(bare) # Not a real file → nothing matches → returns input unchanged - assert resolved == bare, ( - f"Partial-name match must not happen; got {resolved}" - ) + assert resolved == bare, f"Partial-name match must not happen; got {resolved}" def test_resolve_directory_without_index_returns_unchanged(tmp_path): @@ -369,16 +349,13 @@ def test_resolve_directory_without_index_returns_unchanged(tmp_path): pkg = tmp_path / "pkg" _write(pkg / "not-index.ts", "export const x = 1") resolved = _resolve_js_module_path(pkg) - assert resolved == pkg, ( - f"Directory without index.* must return unchanged; got {resolved}" - ) + assert resolved == pkg, f"Directory without index.* must return unchanged; got {resolved}" def test_resolve_handles_subpath_into_directory_with_index(tmp_path): """`./foo/sub` where ./foo/sub/index.ts exists — nested subpath. Common pattern for sub-modules inside a package.""" - target = _write(tmp_path / "foo" / "sub" / "index.ts", - "export const x = 1") + target = _write(tmp_path / "foo" / "sub" / "index.ts", "export const x = 1") sub = tmp_path / "foo" / "sub" assert _resolve_js_module_path(sub) == target @@ -387,8 +364,7 @@ def test_resolve_does_not_treat_dotfile_as_extension(tmp_path): """Edge case: `.eslintrc` and similar dotfiles. Path('.eslintrc').suffix returns '' on Python 3.x for files starting with `.`. Make sure we don't accidentally treat a real file as bare and try to append .ts.""" - target = _write(tmp_path / ".env-types.ts", - "export const x = 1") + target = _write(tmp_path / ".env-types.ts", "export const x = 1") # Path('.env-types.ts').suffix is '.ts' — not a problem assert _resolve_js_module_path(target) == target @@ -401,8 +377,7 @@ def test_resolve_multi_dot_helper_file(tmp_path): Before this rule, .suffix was '.shared' so neither the bare-path branch nor the .js/.jsx branches matched, and the import dropped to a phantom.""" - target = _write(tmp_path / "tag-action.shared.ts", - "export const apply = () => {}") + target = _write(tmp_path / "tag-action.shared.ts", "export const apply = () => {}") written_as = tmp_path / "tag-action.shared" assert _resolve_js_module_path(written_as) == target @@ -423,15 +398,12 @@ def test_resolve_ambient_d_ts_via_bare_path(tmp_path): def test_end_to_end_multi_dot_import_resolves(tmp_path): """End-to-end sanity for the multi-dot pattern via the import handler.""" - target = _write(tmp_path / "tag-action.shared.ts", - "export const apply = () => {}") - importer = _write(tmp_path / "page.ts", - "import { apply } from './tag-action.shared'\n") + target = _write(tmp_path / "tag-action.shared.ts", "export const apply = () => {}") + importer = _write(tmp_path / "page.ts", "import { apply } from './tag-action.shared'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( - f"Multi-dot import failed end-to-end; " - f"expected {expected}; got {_import_targets(result)}" + f"Multi-dot import failed end-to-end; expected {expected}; got {_import_targets(result)}" ) @@ -440,13 +412,16 @@ def test_resolve_chain_alias_and_extension_compose(tmp_path): compose correctly: tsconfig alias maps `$lib/...` to a real dir, then extension resolution finds the actual file.""" src = tmp_path / "src" - target = _write(src / "lib" / "hooks" / "is-mobile.svelte.ts", - "export const isMobile = () => true") - _write(tmp_path / "tsconfig.json", - '{"compilerOptions":{"paths":{"$lib":["./src/lib"],' - '"$lib/*":["./src/lib/*"]}}}') - importer = _write(src / "routes" / "page.ts", - "import { isMobile } from '$lib/hooks/is-mobile.svelte'\n") + target = _write( + src / "lib" / "hooks" / "is-mobile.svelte.ts", "export const isMobile = () => true" + ) + _write( + tmp_path / "tsconfig.json", + '{"compilerOptions":{"paths":{"$lib":["./src/lib"],"$lib/*":["./src/lib/*"]}}}', + ) + importer = _write( + src / "routes" / "page.ts", "import { isMobile } from '$lib/hooks/is-mobile.svelte'\n" + ) result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -464,21 +439,25 @@ def test_ts_dynamic_import_bare_path_resolves(tmp_path): has its own copy of the resolution logic — distinct from the static-import handler and from the Svelte regex pass — and was missing the bare-path extension append, silently dropping every such edge.""" - target = _write(tmp_path / "profanity.ts", - "export const hasProfanity = (s: string) => false") - importer = _write(tmp_path / "auth-validators.ts", """\ + target = _write(tmp_path / "profanity.ts", "export const hasProfanity = (s: string) => false") + importer = _write( + tmp_path / "auth-validators.ts", + """\ export async function validate(name: string) { const { hasProfanity } = await import('./profanity') return hasProfanity(name) } -""") +""", + ) result = extract_js(importer) expected = _make_id(str(target)) - targets = {str(e.get("target") or "") for e in result["edges"] - if e.get("relation") in ("imports", "imports_from")} + targets = { + str(e.get("target") or "") + for e in result["edges"] + if e.get("relation") in ("imports", "imports_from") + } assert expected in targets, ( - f"Bare-path TS dynamic import failed to resolve; " - f"expected {expected}; got {targets}" + f"Bare-path TS dynamic import failed to resolve; expected {expected}; got {targets}" ) @@ -488,22 +467,28 @@ def test_ts_dynamic_import_alias_with_bare_path_resolves(tmp_path): `$lib/foo.ts` after both alias substitution and extension append.""" src = tmp_path / "src" target = _write(src / "lib" / "lazy-module.ts", "export const x = 1") - _write(tmp_path / "tsconfig.json", - '{"compilerOptions":{"paths":{"$lib":["./src/lib"],' - '"$lib/*":["./src/lib/*"]}}}') - importer = _write(src / "routes" / "page.ts", """\ + _write( + tmp_path / "tsconfig.json", + '{"compilerOptions":{"paths":{"$lib":["./src/lib"],"$lib/*":["./src/lib/*"]}}}', + ) + importer = _write( + src / "routes" / "page.ts", + """\ export async function load() { const m = await import('$lib/lazy-module') return m.x } -""") +""", + ) result = extract_js(importer) expected = _make_id(str(target)) - targets = {str(e.get("target") or "") for e in result["edges"] - if e.get("relation") in ("imports", "imports_from")} + targets = { + str(e.get("target") or "") + for e in result["edges"] + if e.get("relation") in ("imports", "imports_from") + } assert expected in targets, ( - f"Alias + bare-path dynamic import failed to resolve; " - f"expected {expected}; got {targets}" + f"Alias + bare-path dynamic import failed to resolve; expected {expected}; got {targets}" ) @@ -511,16 +496,19 @@ def test_dynamic_import_bare_path_resolves(tmp_path): """The regex pass for `import('...')` in .svelte files must also use the new resolver — otherwise dynamic imports of bare paths still produce phantom edges.""" - target = _write(tmp_path / "Heavy.svelte.ts", - "export const heavy = () => 1") - importer = _write(tmp_path / "page.svelte", """\ + target = _write(tmp_path / "Heavy.svelte.ts", "export const heavy = () => 1") + importer = _write( + tmp_path / "page.svelte", + """\ -""") +""", + ) result = extract_svelte(importer) - dyn_targets = {str(e.get("target") or "") for e in result["edges"] - if e.get("relation") == "dynamic_import"} + dyn_targets = { + str(e.get("target") or "") for e in result["edges"] if e.get("relation") == "dynamic_import" + } expected = _make_id(str(target)) assert expected in dyn_targets, ( f"dynamic_import of .svelte that's actually .svelte.ts must " diff --git a/tests/test_incremental.py b/tests/test_incremental.py index 2e2e902e0..a602cf3fd 100644 --- a/tests/test_incremental.py +++ b/tests/test_incremental.py @@ -1,21 +1,42 @@ """Integration tests for incremental graphify extract behavior.""" + from __future__ import annotations import json +import os import subprocess import sys from pathlib import Path -import pytest +from graphify.llm import BACKENDS, _backend_env_keys + PYTHON = sys.executable +def _clean_env() -> dict: + """Return os.environ with every backend API key stripped out.""" + env = dict(os.environ) + for backend in BACKENDS: + for env_key in _backend_env_keys(backend): + env.pop(env_key, None) + for extra in ( + "AWS_PROFILE", + "AWS_REGION", + "AWS_DEFAULT_REGION", + "OLLAMA_BASE_URL", + "OLLAMA_API_KEY", + ): + env.pop(extra, None) + return env + + def _run(args: list[str], cwd: Path) -> subprocess.CompletedProcess: return subprocess.run( [PYTHON, "-m", "graphify"] + args, cwd=cwd, capture_output=True, text=True, + env=_clean_env(), ) @@ -59,3 +80,172 @@ def test_no_incremental_without_manifest(tmp_path): # which pytest derives from the test name and contains "incremental". assert "incremental update" not in r.stdout.lower() assert "incremental scan" not in r.stdout.lower() + + +# ── PR 7: `graphify update` preserves the multidigraph profile (no silent fallback) ── +# +# watch._rebuild_code inherits the saved graph.json profile: it reads the on-disk +# `multigraph` flag and rebuilds via build_from_json(multigraph=...), re-stamping +# multigraph/directed + graphify_profile on write. So `graphify update` on a +# multidigraph round-trips it as a MultiDiGraph with keyed parallel edges intact — +# never silently collapsed to a simple graph. These tests prove that end-to-end by +# actually running `update` as a subprocess and reloading the rewritten graph.json. + + +def _make_code_corpus(tmp_path: Path) -> Path: + """A tiny real code corpus so `graphify update` has AST-extractable files. + + Includes ``extra()`` so a rebuild ADDS AST nodes the seeded multidigraph + graph.json lacks (file node + login()/helper()/extra()). That guarantees a + real topology change, so `update` hits the graph.json REWRITE path rather + than the no-change early return — the rewrite is what must preserve the + multigraph profile, so the test would be vacuous without forcing it. + """ + corpus = tmp_path / "corpus" + corpus.mkdir() + (corpus / "auth.py").write_text( + "def login():\n return helper()\n\n\ndef helper():\n return 1\n\n\n" + "def extra():\n return login()\n", + encoding="utf-8", + ) + return corpus + + +def _write_multidigraph_graph_json(corpus: Path) -> Path: + """Seed corpus/graphify-out/graph.json as a multidigraph with parallel edges. + + Built and serialized exactly the way Phase A persists it (build_from_json + multigraph=True -> export.to_json), so the saved file carries the top-level + ``multigraph: true`` flag and ``graphify_profile.graph_type == multidigraph``. + The two parallel ``login -> helper`` edges (ids absent from AST output) are + preserved by `_rebuild_code` across the rebuild, proving parallels survive. + """ + import networkx as nx + from graphify.build import build_from_json + from graphify.export import to_json + + nodes = [ + { + "id": n, + "label": n, + "file_type": "code", + "source_file": "auth.py", + "source_location": "L1", + } + for n in ("login", "helper") + ] + # Two parallel edges between the same (login -> helper) pair. + edges = [ + { + "source": "login", + "target": "helper", + "relation": rel, + "confidence": "EXTRACTED", + "source_file": "auth.py", + "source_location": f"L{i}", + } + for i, rel in enumerate(["calls", "imports"]) + ] + G = build_from_json({"nodes": nodes, "edges": edges}, multigraph=True) + assert isinstance(G, nx.MultiDiGraph) + assert G.number_of_edges() == 2 + out = corpus / "graphify-out" + out.mkdir(exist_ok=True) + graph_json = out / "graph.json" + to_json(G, {0: ["login", "helper"]}, str(graph_json), force=True) + # Persist the scan root so `graphify update` (no path arg) can recover it. + (out / ".graphify_root").write_text(str(corpus), encoding="utf-8") + return graph_json + + +def _parallel_login_helper_edges(graph_data: dict) -> list[dict]: + """Return the parallel ``login -> helper`` edge records from a graph.json dict.""" + links = graph_data.get("links", graph_data.get("edges", [])) + return [e for e in links if e.get("source") == "login" and e.get("target") == "helper"] + + +def test_update_preserves_multigraph_profile(tmp_path): + """`graphify update` on a multidigraph graph.json preserves the profile and + its parallel edges end-to-end: the rewritten file stays multigraph=true / + graph_type=multidigraph and reloads via load_graph as a MultiDiGraph with the + parallel edges intact.""" + from graphify.graph_loader import load_graph + + corpus = _make_code_corpus(tmp_path) + graph_json = _write_multidigraph_graph_json(corpus) + + before = json.loads(graph_json.read_text(encoding="utf-8")) + assert before.get("multigraph") is True + assert len(_parallel_login_helper_edges(before)) == 2 # both parallel edges present + + r = _run(["update", str(corpus)], tmp_path) + assert r.returncode == 0, f"update on multidigraph should succeed, got: {r.stderr}" + assert "multidigraph" not in r.stderr # no refusal message + + after = json.loads(graph_json.read_text(encoding="utf-8")) + # Profile preserved (no silent collapse to simple). + assert after.get("multigraph") is True, "multigraph flag must be preserved" + assert after.get("graph", {}).get("graphify_profile", {}).get("graph_type") == "multidigraph" + # Prove the REWRITE path ran (rebuild added AST nodes the seed lacked), not a + # no-change early return that would trivially leave the seed file untouched. + assert any(n.get("label") == "extra()" for n in after.get("nodes", [])), ( + "rebuild should have added AST nodes — rewrite path must have executed" + ) + # Parallel edges survive the rewrite. + par = _parallel_login_helper_edges(after) + assert len(par) == 2, "keyed parallel edges must be preserved across update" + assert sorted(e["relation"] for e in par) == ["calls", "imports"] + # Reloads as a MultiDiGraph with the parallels intact. + G2 = load_graph(after) + assert G2.is_multigraph(), "rewritten graph.json must reload as a MultiDiGraph" + assert G2.number_of_edges("login", "helper") == 2 + + +def test_update_simple_graph_unchanged_regression(tmp_path): + """A simple graph.json updated in simple mode behaves exactly as before: + `graphify update` succeeds and the graph stays a simple graph.""" + corpus = _make_code_corpus(tmp_path) + + # First run on a fresh corpus builds the simple graph via the normal path. + r1 = _run(["update", str(corpus)], tmp_path) + assert r1.returncode == 0, f"initial simple update failed: {r1.stderr}" + graph_json = corpus / "graphify-out" / "graph.json" + assert graph_json.exists() + data1 = json.loads(graph_json.read_text(encoding="utf-8")) + assert data1.get("multigraph") is False + assert data1.get("graph", {}).get("graphify_profile", {}).get("graph_type") == "simple" + assert any(n.get("label") == "login()" for n in data1.get("nodes", [])) + + # Re-running update on the now-simple graph must still succeed (no refusal, + # no profile change) — the pre-PR7 behavior is preserved. + r2 = _run(["update", str(corpus)], tmp_path) + assert r2.returncode == 0, f"re-run simple update failed: {r2.stderr}" + assert "multidigraph" not in r2.stderr + data2 = json.loads(graph_json.read_text(encoding="utf-8")) + assert data2.get("multigraph") is False + assert data2.get("graph", {}).get("graphify_profile", {}).get("graph_type") == "simple" + + +def test_update_profile_mismatch_no_silent_fallback(tmp_path): + """Go/no-go gate: `graphify update` on a multidigraph must NOT silently fall + back to simple-graph behavior. The gate is satisfied by PRESERVATION — the + result is still a multidigraph with parallel edges, never a collapsed simple + graph (and never a spurious refusal now that the pipeline preserves).""" + corpus = _make_code_corpus(tmp_path) + graph_json = _write_multidigraph_graph_json(corpus) + + r = _run(["update", str(corpus)], tmp_path) + after = json.loads(graph_json.read_text(encoding="utf-8")) + + # The invariant: never a silent simple-graph result. + assert after.get("multigraph") is True, ( + "no silent fallback: a multidigraph update must remain a multidigraph, " + f"got multigraph={after.get('multigraph')!r}" + ) + assert after.get("graph", {}).get("graphify_profile", {}).get("graph_type") == "multidigraph" + # Parallel edges are not collapsed away. + assert len(_parallel_login_helper_edges(after)) == 2, ( + "parallel edges must survive — collapsing to one edge is a silent fallback" + ) + # Preservation, not refusal: the command succeeds normally. + assert r.returncode == 0, f"update should preserve (succeed), not refuse: {r.stderr}" diff --git a/tests/test_ingest.py b/tests/test_ingest.py index 41128eee2..2d6774258 100644 --- a/tests/test_ingest.py +++ b/tests/test_ingest.py @@ -1,8 +1,6 @@ """Tests for graphify.ingest.save_query_result""" + from __future__ import annotations -import re -from pathlib import Path -import pytest from graphify.ingest import save_query_result @@ -49,7 +47,7 @@ def test_source_nodes_capped_at_10(tmp_path): out = save_query_result("q", "a", mem, source_nodes=nodes) content = out.read_text() # Only first 10 should appear in frontmatter source_nodes line - fm_line = [l for l in content.splitlines() if l.startswith("source_nodes:")][0] + fm_line = [label for label in content.splitlines() if label.startswith("source_nodes:")][0] assert fm_line.count('"Node') == 10 diff --git a/tests/test_install.py b/tests/test_install.py index 5b464e8d9..94616255a 100644 --- a/tests/test_install.py +++ b/tests/test_install.py @@ -1,4 +1,5 @@ """Tests for graphify install --platform routing.""" + import os from pathlib import Path import sys @@ -20,6 +21,7 @@ def _install(tmp_path, platform): from graphify.__main__ import install + old_cwd = Path.cwd() try: os.chdir(tmp_path) @@ -46,6 +48,7 @@ def test_install_opencode(tmp_path): def test_install_positional_platform_opencode(tmp_path, monkeypatch): from graphify.__main__ import main + monkeypatch.chdir(tmp_path) monkeypatch.setattr(sys, "argv", ["graphify", "install", "opencode"]) with patch("graphify.__main__.Path.home", return_value=tmp_path): @@ -56,6 +59,7 @@ def test_install_positional_platform_opencode(tmp_path, monkeypatch): def test_install_project_claude_writes_project_scope(tmp_path, monkeypatch, capsys): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -67,12 +71,15 @@ def test_install_project_claude_writes_project_scope(tmp_path, monkeypatch, caps assert (project / ".claude" / "CLAUDE.md").exists() assert not (home / ".claude" / "skills" / "graphify" / "SKILL.md").exists() assert ".claude/skills/graphify/SKILL.md" in (project / ".claude" / "CLAUDE.md").read_text() - assert "~/.claude/skills/graphify/SKILL.md" not in (project / ".claude" / "CLAUDE.md").read_text() + assert ( + "~/.claude/skills/graphify/SKILL.md" not in (project / ".claude" / "CLAUDE.md").read_text() + ) assert "git add .claude/" in capsys.readouterr().out def test_install_project_codex_writes_skill_and_agents(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -88,6 +95,7 @@ def test_install_project_codex_writes_skill_and_agents(tmp_path, monkeypatch): def test_claude_subcommand_project_install_and_uninstall_are_project_scoped(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -114,6 +122,7 @@ def test_claude_subcommand_project_install_and_uninstall_are_project_scoped(tmp_ def test_codex_subcommand_project_install_and_uninstall_are_project_scoped(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -142,6 +151,7 @@ def test_codex_subcommand_project_install_and_uninstall_are_project_scoped(tmp_p def test_antigravity_install_project_writes_project_skill(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -155,6 +165,7 @@ def test_antigravity_install_project_writes_project_skill(tmp_path, monkeypatch) def test_install_help_does_not_install_default(tmp_path, monkeypatch, capsys): from graphify.__main__ import main + monkeypatch.chdir(tmp_path) monkeypatch.setattr(sys, "argv", ["graphify", "install", "opencode", "--help"]) with patch("graphify.__main__.Path.home", return_value=tmp_path): @@ -199,6 +210,7 @@ def test_install_unknown_platform_exits(tmp_path): def test_codex_skill_contains_spawn_agent(): """Codex skill file must reference spawn_agent.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-codex.md").read_text() assert "spawn_agent" in skill @@ -206,6 +218,7 @@ def test_codex_skill_contains_spawn_agent(): def test_codex_skill_uses_graphify_with_dirty_graph_output(): """Codex skill must keep graph-first orientation even when graph output is dirty.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-codex.md").read_text() assert "Dirty `graphify-out/` artifacts are expected" in skill assert "not a reason to skip Graphify" in skill @@ -224,6 +237,7 @@ def test_codex_agents_install_mentions_dirty_graph_output(tmp_path): def test_opencode_skill_contains_mention(): """OpenCode skill file must reference @mention.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-opencode.md").read_text() assert "@mention" in skill @@ -231,6 +245,7 @@ def test_opencode_skill_contains_mention(): def test_opencode_skill_uses_opencode_agent_guidance(): """OpenCode skill must not reference Codex/Claude agent type names.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-opencode.md").read_text() assert "general-purpose" not in skill assert 'subagent_type="general-purpose"' not in skill @@ -245,6 +260,7 @@ def test_opencode_skill_uses_opencode_agent_guidance(): def test_claw_skill_is_sequential(): """OpenClaw skill file must describe sequential extraction.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-claw.md").read_text() assert "sequential" in skill.lower() assert "spawn_agent" not in skill @@ -254,8 +270,17 @@ def test_claw_skill_is_sequential(): def test_all_skill_files_exist_in_package(): """All installable platform skill files must be present in the installed package.""" import graphify + pkg = Path(graphify.__file__).parent - for name in ("skill.md", "skill-codex.md", "skill-opencode.md", "skill-claw.md", "skill-windows.md", "skill-droid.md", "skill-trae.md"): + for name in ( + "skill.md", + "skill-codex.md", + "skill-opencode.md", + "skill-claw.md", + "skill-windows.md", + "skill-droid.md", + "skill-trae.md", + ): assert (pkg / name).exists(), f"Missing: {name}" @@ -272,6 +297,7 @@ def test_codex_install_does_not_write_claude_md(tmp_path): def test_uninstall_project_removes_project_skill_only(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -280,9 +306,13 @@ def test_uninstall_project_removes_project_skill_only(tmp_path, monkeypatch): user_skill.write_text("user skill") monkeypatch.chdir(project) with patch("graphify.__main__.Path.home", return_value=home): - monkeypatch.setattr(sys, "argv", ["graphify", "install", "--project", "--platform", "codex"]) + monkeypatch.setattr( + sys, "argv", ["graphify", "install", "--project", "--platform", "codex"] + ) main() - monkeypatch.setattr(sys, "argv", ["graphify", "uninstall", "--project", "--platform", "codex"]) + monkeypatch.setattr( + sys, "argv", ["graphify", "uninstall", "--project", "--platform", "codex"] + ) main() assert user_skill.exists() assert not (project / ".agents" / "skills" / "graphify" / "SKILL.md").exists() @@ -291,6 +321,7 @@ def test_uninstall_project_removes_project_skill_only(tmp_path, monkeypatch): def test_uninstall_project_without_platform_removes_project_installs(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -310,6 +341,7 @@ def test_uninstall_project_without_platform_removes_project_installs(tmp_path, m def test_antigravity_uninstall_project_removes_project_skill_only(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -330,6 +362,7 @@ def test_antigravity_uninstall_project_removes_project_skill_only(tmp_path, monk def test_antigravity_global_install_writes_gemini_config_skills(tmp_path, monkeypatch): """Global `graphify antigravity install` must write to ~/.gemini/config/skills/ (#1079).""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -349,6 +382,7 @@ def test_antigravity_global_install_writes_gemini_config_skills(tmp_path, monkey def test_antigravity_global_uninstall_removes_gemini_config_skill(tmp_path, monkeypatch): """Global `graphify antigravity uninstall` must remove from ~/.gemini/config/skills/ (#1079).""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -368,13 +402,16 @@ def test_antigravity_global_uninstall_removes_gemini_config_skill(tmp_path, monk # --- always-on AGENTS.md install/uninstall tests --- + def _agents_install(tmp_path, platform): from graphify.__main__ import _agents_install as _install_fn + _install_fn(tmp_path, platform) def _agents_uninstall(tmp_path, platform=""): from graphify.__main__ import _agents_uninstall as _uninstall_fn + _uninstall_fn(tmp_path, platform=platform) @@ -442,6 +479,7 @@ def test_agents_uninstall_no_op_when_not_installed(tmp_path, capsys): # --- OpenCode plugin tests --- + def test_opencode_agents_install_writes_plugin(tmp_path): """opencode install writes .opencode/plugins/graphify.js.""" _agents_install(tmp_path, "opencode") @@ -456,6 +494,7 @@ def test_opencode_agents_install_registers_plugin_in_config(tmp_path): config_file = tmp_path / ".opencode" / "opencode.json" assert config_file.exists() import json as _json + config = _json.loads(config_file.read_text()) assert any("graphify.js" in p for p in config.get("plugin", [])) @@ -463,6 +502,7 @@ def test_opencode_agents_install_registers_plugin_in_config(tmp_path): def test_opencode_agents_install_merges_existing_config(tmp_path): """opencode install preserves existing .opencode/opencode.json keys.""" import json as _json + config_file = tmp_path / ".opencode" / "opencode.json" config_file.parent.mkdir(parents=True, exist_ok=True) config_file.write_text(_json.dumps({"model": "claude-opus-4-5", "plugin": []})) @@ -475,6 +515,7 @@ def test_opencode_agents_install_merges_existing_config(tmp_path): def test_opencode_agents_uninstall_removes_plugin(tmp_path): """opencode uninstall removes the plugin file and deregisters from opencode.json.""" import json as _json + _agents_install(tmp_path, "opencode") _agents_uninstall(tmp_path, platform="opencode") plugin = tmp_path / ".opencode" / "plugins" / "graphify.js" @@ -487,9 +528,11 @@ def test_opencode_agents_uninstall_removes_plugin(tmp_path): # ── Cursor ──────────────────────────────────────────────────────────────────── + def test_cursor_install_writes_rule(tmp_path): """cursor install writes .cursor/rules/graphify.mdc.""" from graphify.__main__ import _cursor_install + _cursor_install(tmp_path) rule = tmp_path / ".cursor" / "rules" / "graphify.mdc" assert rule.exists() @@ -501,6 +544,7 @@ def test_cursor_install_writes_rule(tmp_path): def test_cursor_install_idempotent(tmp_path): """cursor install does not overwrite an existing rule file.""" from graphify.__main__ import _cursor_install + _cursor_install(tmp_path) rule = tmp_path / ".cursor" / "rules" / "graphify.mdc" original = rule.read_text() @@ -511,6 +555,7 @@ def test_cursor_install_idempotent(tmp_path): def test_cursor_uninstall_removes_rule(tmp_path): """cursor uninstall removes the rule file.""" from graphify.__main__ import _cursor_install, _cursor_uninstall + _cursor_install(tmp_path) _cursor_uninstall(tmp_path) rule = tmp_path / ".cursor" / "rules" / "graphify.mdc" @@ -520,51 +565,64 @@ def test_cursor_uninstall_removes_rule(tmp_path): def test_cursor_uninstall_noop_if_not_installed(tmp_path): """cursor uninstall does nothing if rule was never written.""" from graphify.__main__ import _cursor_uninstall + _cursor_uninstall(tmp_path) # should not raise # ── Gemini CLI ──────────────────────────────────────────────────────────────── + def test_gemini_install_writes_gemini_md(tmp_path): from graphify.__main__ import gemini_install + gemini_install(tmp_path) md = tmp_path / "GEMINI.md" assert md.exists() assert "graphify-out/GRAPH_REPORT.md" in md.read_text() + def test_gemini_install_writes_hook(tmp_path): import json as _json from graphify.__main__ import gemini_install + gemini_install(tmp_path) settings = _json.loads((tmp_path / ".gemini" / "settings.json").read_text()) hooks = settings["hooks"]["BeforeTool"] assert any("graphify" in str(h) for h in hooks) + def test_gemini_install_idempotent(tmp_path): from graphify.__main__ import gemini_install + gemini_install(tmp_path) gemini_install(tmp_path) md = tmp_path / "GEMINI.md" assert md.read_text().count("## graphify") == 1 + def test_gemini_install_merges_existing_gemini_md(tmp_path): from graphify.__main__ import gemini_install + (tmp_path / "GEMINI.md").write_text("# My project rules\n") gemini_install(tmp_path) content = (tmp_path / "GEMINI.md").read_text() assert "# My project rules" in content assert "graphify-out/GRAPH_REPORT.md" in content + def test_gemini_uninstall_removes_section(tmp_path): from graphify.__main__ import gemini_install, gemini_uninstall + gemini_install(tmp_path) gemini_uninstall(tmp_path) md = tmp_path / "GEMINI.md" assert not md.exists() + def test_gemini_uninstall_removes_hook(tmp_path): import json as _json from graphify.__main__ import gemini_install, gemini_uninstall + gemini_install(tmp_path) gemini_uninstall(tmp_path) settings_path = tmp_path / ".gemini" / "settings.json" @@ -573,6 +631,8 @@ def test_gemini_uninstall_removes_hook(tmp_path): hooks = settings.get("hooks", {}).get("BeforeTool", []) assert not any("graphify" in str(h) for h in hooks) + def test_gemini_uninstall_noop_if_not_installed(tmp_path): from graphify.__main__ import gemini_uninstall + gemini_uninstall(tmp_path) # should not raise diff --git a/tests/test_install_strings.py b/tests/test_install_strings.py index 5e0037089..93cb2a7d8 100644 --- a/tests/test_install_strings.py +++ b/tests/test_install_strings.py @@ -8,6 +8,7 @@ (issue #580). This file locks in the query-first policy so a future revert or partial change is caught by CI. """ + from __future__ import annotations import json @@ -73,6 +74,7 @@ def test_no_install_surface_demands_reading_the_full_report_first(): are legitimate platform metadata, not the bug. """ import re + banned = [ # "read ... GRAPH_REPORT.md ... before" re.compile(r"read[^.\n]{0,80}GRAPH_REPORT\.md[^.\n]{0,80}before", re.IGNORECASE), @@ -87,10 +89,7 @@ def test_no_install_surface_demands_reading_the_full_report_first(): m = pattern.search(text) if m: hits.append((name, m.group(0))) - assert not hits, ( - f"banned report-first phrasing reappeared: {hits}. " - f"This regresses issue #580." - ) + assert not hits, f"banned report-first phrasing reappeared: {hits}. This regresses issue #580." def test_report_is_still_referenced_as_fallback(): @@ -127,6 +126,7 @@ def test_agents_section_does_not_skip_dirty_graph_output(): def test_how_it_works_clarifies_code_only_semantic_extraction(): from pathlib import Path + doc = (Path(__file__).parent.parent / "docs" / "how-it-works.md").read_text(encoding="utf-8") assert "Code files are not sent to the LLM semantic extractor" in doc assert "code files, Pass 3 is skipped entirely" in doc diff --git a/tests/test_install_upgrade.py b/tests/test_install_upgrade.py index 09ee3d81e..fa085dbdc 100644 --- a/tests/test_install_upgrade.py +++ b/tests/test_install_upgrade.py @@ -9,6 +9,7 @@ section, run the installer, and assert that the on-disk file now contains the new query-first wording and does not contain the old report-first text. """ + from __future__ import annotations import json from pathlib import Path @@ -85,9 +86,7 @@ def _assert_no_report_first(text: str, ctx: str) -> None: def _assert_query_first(text: str, ctx: str) -> None: - assert "graphify query" in text, ( - f"{ctx}: new 'graphify query' guidance missing after upgrade" - ) + assert "graphify query" in text, f"{ctx}: new 'graphify query' guidance missing after upgrade" def test_claude_install_upgrades_stale_section(tmp_path, monkeypatch): @@ -95,7 +94,9 @@ def test_claude_install_upgrades_stale_section(tmp_path, monkeypatch): `graphify claude install` again after upgrading to a fixed package.""" monkeypatch.chdir(tmp_path) claude_md = tmp_path / "CLAUDE.md" - claude_md.write_text("# My Project\n\nSome description.\n\n" + _OLD_CLAUDE_SECTION, encoding="utf-8") + claude_md.write_text( + "# My Project\n\nSome description.\n\n" + _OLD_CLAUDE_SECTION, encoding="utf-8" + ) monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) mainmod.claude_install(tmp_path) @@ -125,11 +126,7 @@ def test_claude_install_upgrades_stale_hook_payload(tmp_path, monkeypatch): "hooks": [ { "type": "command", - "command": ( - "case x in *) " - + _OLD_HOOK_PAYLOAD_SNIPPET - + " esac" - ), + "command": ("case x in *) " + _OLD_HOOK_PAYLOAD_SNIPPET + " esac"), } ], } @@ -142,9 +139,7 @@ def test_claude_install_upgrades_stale_hook_payload(tmp_path, monkeypatch): mainmod.claude_install(tmp_path) new_settings_text = settings.read_text(encoding="utf-8") - assert _OLD_HOOK_PAYLOAD_SNIPPET not in new_settings_text, ( - "stale hook payload survived upgrade" - ) + assert _OLD_HOOK_PAYLOAD_SNIPPET not in new_settings_text, "stale hook payload survived upgrade" assert "graphify query" in new_settings_text, ( "new hook payload should route to `graphify query`" ) diff --git a/tests/test_js_import_resolution.py b/tests/test_js_import_resolution.py index 414cf8d87..a1ac2ff7e 100644 --- a/tests/test_js_import_resolution.py +++ b/tests/test_js_import_resolution.py @@ -18,10 +18,7 @@ def _extract_for(paths: list[Path], root: Path): def _has_edge(result: dict, source: str, target: str, relation: str = "imports_from") -> bool: expected = (_make_id(source), _make_id(target), relation) - actual = { - (edge["source"], edge["target"], edge["relation"]) - for edge in result["edges"] - } + actual = {(edge["source"], edge["target"], edge["relation"]) for edge in result["edges"]} return expected in actual @@ -33,10 +30,7 @@ def _has_symbol_edge( relation: str = "imports", ) -> bool: expected = (_make_id(source), _make_id(_file_stem(Path(target_file)), symbol), relation) - actual = { - (edge["source"], edge["target"], edge["relation"]) - for edge in result["edges"] - } + actual = {(edge["source"], edge["target"], edge["relation"]) for edge in result["edges"]} return expected in actual @@ -53,10 +47,7 @@ def _has_symbol_to_symbol_edge( _make_id(_file_stem(Path(target_file)), target_symbol), relation, ) - actual = { - (edge["source"], edge["target"], edge["relation"]) - for edge in result["edges"] - } + actual = {(edge["source"], edge["target"], edge["relation"]) for edge in result["edges"]} return expected in actual @@ -198,7 +189,9 @@ def test_ts_reexported_type_alias_resolves_imported_symbol_to_origin(tmp_path: P def test_ts_reexported_abstract_class_resolves_imported_symbol_to_origin(tmp_path: Path): - target = _write(tmp_path / "src/lib/foo.ts", "export abstract class Foo { abstract run(): void }\n") + target = _write( + tmp_path / "src/lib/foo.ts", "export abstract class Foo { abstract run(): void }\n" + ) barrel = _write(tmp_path / "src/lib/index.ts", "export { Foo } from './foo'\n") consumer = _write( tmp_path / "src/routes/page.ts", @@ -228,7 +221,9 @@ def test_ts_const_alias_reexport_resolves_imported_symbol_to_origin(tmp_path: Pa assert _has_symbol_edge(result, "src/routes/page.ts", "src/lib/foo.ts", "Foo") -def test_ts_local_const_alias_then_named_reexport_resolves_imported_symbol_to_origin(tmp_path: Path): +def test_ts_local_const_alias_then_named_reexport_resolves_imported_symbol_to_origin( + tmp_path: Path, +): target = _write(tmp_path / "src/lib/foo.ts", "export function makeFoo() { return {} }\n") barrel = _write( tmp_path / "src/lib/index.ts", @@ -307,7 +302,9 @@ def test_ts_import_alias_call_from_same_named_local_symbol_targets_origin(tmp_pa def test_svelte_rune_import_resolves_svelte_ts_file(tmp_path: Path): - target = _write(tmp_path / "src/lib/hooks/is-mobile.svelte.ts", "export const isMobile = true\n") + target = _write( + tmp_path / "src/lib/hooks/is-mobile.svelte.ts", "export const isMobile = true\n" + ) importer = _write( tmp_path / "src/routes/page.ts", "import { isMobile } from '../lib/hooks/is-mobile.svelte'\nconsole.log(isMobile)\n", @@ -482,8 +479,12 @@ def _norm(label: str) -> str: if edge.get("relation") == "references" } - assert _has_symbol_to_symbol_edge(result, "src/lib/impl.ts", "DataProcessor", "src/lib/base.ts", "BaseProcessor", "inherits") - assert _has_symbol_to_symbol_edge(result, "src/lib/impl.ts", "DataProcessor", "src/lib/base.ts", "IProcessor", "implements") + assert _has_symbol_to_symbol_edge( + result, "src/lib/impl.ts", "DataProcessor", "src/lib/base.ts", "BaseProcessor", "inherits" + ) + assert _has_symbol_to_symbol_edge( + result, "src/lib/impl.ts", "DataProcessor", "src/lib/base.ts", "IProcessor", "implements" + ) assert ("run", "Payload", "parameter_type") in reference_contexts assert ("run", "Result", "return_type") in reference_contexts assert ("run", "Payload", "generic_arg") in reference_contexts diff --git a/tests/test_languages.py b/tests/test_languages.py index fefc13807..72d6aab4d 100644 --- a/tests/test_languages.py +++ b/tests/test_languages.py @@ -1,14 +1,30 @@ """Tests for language extractors: Java, C, C++, Ruby, C#, Kotlin, Scala, PHP, Swift, Go, Julia, Fortran, JS/TS, .NET project files.""" + from __future__ import annotations from pathlib import Path -import pytest from graphify.extract import ( - extract_java, extract_c, extract_cpp, extract_ruby, - extract_csharp, extract_kotlin, extract_scala, extract_php, - extract_swift, extract_go, extract_julia, extract_js, extract_fortran, - extract_groovy, extract_sln, extract_csproj, extract_razor, - extract_dm, extract_dmi, extract_dmm, extract_dmf, + extract_java, + extract_c, + extract_cpp, + extract_csproj, + extract_dm, + extract_dmf, + extract_dmi, + extract_dmm, + extract_fortran, + extract_go, + extract_groovy, + extract_js, + extract_julia, + extract_kotlin, + extract_php, extract_powershell, + extract_razor, + extract_ruby, + extract_csharp, + extract_scala, + extract_sln, + extract_swift, ) FIXTURES = Path(__file__).parent / "fixtures" @@ -17,14 +33,17 @@ def _labels(r): return [n["label"] for n in r["nodes"]] + def _relations(r): return {e["relation"] for e in r["edges"]} + def _calls(r): node_by_id = {n["id"]: n["label"] for n in r["nodes"]} return { (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "calls" + for e in r["edges"] + if e["relation"] == "calls" } @@ -36,7 +55,8 @@ def _references(r): node_by_id.get(e["target"], e["target"]), e, ) - for e in r["edges"] if e["relation"] == "references" + for e in r["edges"] + if e["relation"] == "references" ] @@ -63,29 +83,36 @@ def _edge_labels(result: dict, relation: str, context: str | None = None) -> set continue if context is not None and edge.get("context") != context: continue - pairs.add((labels.get(edge["source"], edge["source"]), labels.get(edge["target"], edge["target"]))) + pairs.add( + (labels.get(edge["source"], edge["source"]), labels.get(edge["target"], edge["target"])) + ) return pairs # ── Java ────────────────────────────────────────────────────────────────────── + def test_java_no_error(): r = extract_java(FIXTURES / "sample.java") assert "error" not in r + def test_java_finds_class(): r = extract_java(FIXTURES / "sample.java") - assert any("DataProcessor" in l for l in _labels(r)) + assert any("DataProcessor" in label for label in _labels(r)) + def test_java_finds_interface(): r = extract_java(FIXTURES / "sample.java") - assert any("Processor" in l for l in _labels(r)) + assert any("Processor" in label for label in _labels(r)) + def test_java_finds_methods(): r = extract_java(FIXTURES / "sample.java") labels = _labels(r) - assert any("addItem" in l for l in labels) - assert any("process" in l for l in labels) + assert any("addItem" in label for label in labels) + assert any("process" in label for label in labels) + def test_java_finds_imports(): r = extract_java(FIXTURES / "sample.java") @@ -98,6 +125,7 @@ def test_java_import_edges_have_import_context(): assert import_edges assert all(e.get("context") == "import" for e in import_edges) + def test_java_no_dangling_edges(): r = extract_java(FIXTURES / "sample.java") node_ids = {n["id"] for n in r["nodes"]} @@ -107,24 +135,29 @@ def test_java_no_dangling_edges(): # ── C ──────────────────────────────────────────────────────────────────────── + def test_c_no_error(): r = extract_c(FIXTURES / "sample.c") assert "error" not in r + def test_c_finds_functions(): r = extract_c(FIXTURES / "sample.c") labels = _labels(r) - assert any("process" in l for l in labels) - assert any("main" in l for l in labels) + assert any("process" in label for label in labels) + assert any("main" in label for label in labels) + def test_c_finds_includes(): r = extract_c(FIXTURES / "sample.c") assert "imports" in _relations(r) + def test_c_emits_calls(): r = extract_c(FIXTURES / "sample.c") assert any(e["relation"] == "calls" for e in r["edges"]) + def test_c_calls_are_extracted(): r = extract_c(FIXTURES / "sample.c") for e in r["edges"]: @@ -154,19 +187,23 @@ def test_c_call_edges_have_call_context(): # ── C++ ─────────────────────────────────────────────────────────────────────── + def test_cpp_no_error(): r = extract_cpp(FIXTURES / "sample.cpp") assert "error" not in r + def test_cpp_finds_class(): r = extract_cpp(FIXTURES / "sample.cpp") - assert any("HttpClient" in l for l in _labels(r)) + assert any("HttpClient" in label for label in _labels(r)) + def test_cpp_finds_methods(): r = extract_cpp(FIXTURES / "sample.cpp") labels = _labels(r) # C++ extractor captures the constructor and public-visible methods - assert any("HttpClient" in l for l in labels) + assert any("HttpClient" in label for label in labels) + def test_cpp_finds_includes(): r = extract_cpp(FIXTURES / "sample.cpp") @@ -200,7 +237,8 @@ def test_cpp_class_inherits_edge(): found = any( "AuthedHttpClient" in node_by_id.get(e["source"], "") and "HttpClient" in node_by_id.get(e["target"], "") - for e in r["edges"] if e["relation"] == "inherits" + for e in r["edges"] + if e["relation"] == "inherits" ) assert found, "AuthedHttpClient should have inherits edge to HttpClient" @@ -212,67 +250,80 @@ def test_cpp_struct_inherits_edge(): found = any( "RetryingHttpClient" in node_by_id.get(e["source"], "") and "HttpClient" in node_by_id.get(e["target"], "") - for e in r["edges"] if e["relation"] == "inherits" + for e in r["edges"] + if e["relation"] == "inherits" ) assert found, "RetryingHttpClient (struct) should have inherits edge to HttpClient" # ── Ruby ───────────────────────────────────────────────────────────────────── + def test_ruby_no_error(): r = extract_ruby(FIXTURES / "sample.rb") assert "error" not in r + def test_ruby_finds_class(): r = extract_ruby(FIXTURES / "sample.rb") - assert any("ApiClient" in l for l in _labels(r)) + assert any("ApiClient" in label for label in _labels(r)) + def test_ruby_finds_methods(): r = extract_ruby(FIXTURES / "sample.rb") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) + def test_ruby_finds_function(): r = extract_ruby(FIXTURES / "sample.rb") - assert any("parse_response" in l for l in _labels(r)) + assert any("parse_response" in label for label in _labels(r)) # ── C# ─────────────────────────────────────────────────────────────────────── + def test_csharp_no_error(): r = extract_csharp(FIXTURES / "sample.cs") assert "error" not in r + def test_csharp_finds_class(): r = extract_csharp(FIXTURES / "sample.cs") - assert any("DataProcessor" in l for l in _labels(r)) + assert any("DataProcessor" in label for label in _labels(r)) + def test_csharp_finds_interface(): r = extract_csharp(FIXTURES / "sample.cs") - assert any("IProcessor" in l for l in _labels(r)) + assert any("IProcessor" in label for label in _labels(r)) + def test_csharp_finds_methods(): r = extract_csharp(FIXTURES / "sample.cs") labels = _labels(r) - assert any("Process" in l for l in labels) + assert any("Process" in label for label in labels) + def test_csharp_finds_usings(): r = extract_csharp(FIXTURES / "sample.cs") assert "imports" in _relations(r) + def test_csharp_inherits_edge(): r = extract_csharp(FIXTURES / "sample.cs") inherits = [e for e in r["edges"] if e["relation"] == "inherits"] assert len(inherits) >= 1 + def test_csharp_implements_iprocessor(): r = extract_csharp(FIXTURES / "sample.cs") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} found = any( - "DataProcessor" in node_by_id.get(e["source"], "") and - "IProcessor" in node_by_id.get(e["target"], "") - for e in r["edges"] if e["relation"] == "implements" + "DataProcessor" in node_by_id.get(e["source"], "") + and "IProcessor" in node_by_id.get(e["target"], "") + for e in r["edges"] + if e["relation"] == "implements" ) assert found, "DataProcessor should have implements edge to IProcessor" @@ -320,7 +371,8 @@ def test_csharp_call_edges_have_call_context(): "Process" in node_by_id.get(e["source"], "") and "Validate" in node_by_id.get(e["target"], "") and e.get("context") == "call" - for e in r["edges"] if e["relation"] == "calls" + for e in r["edges"] + if e["relation"] == "calls" ), "C# call edges should retain call context" @@ -333,27 +385,33 @@ def test_csharp_import_edges_have_import_context(): # ── Kotlin ─────────────────────────────────────────────────────────────────── + def test_kotlin_no_error(): r = extract_kotlin(FIXTURES / "sample.kt") assert "error" not in r + def test_kotlin_finds_class(): r = extract_kotlin(FIXTURES / "sample.kt") - assert any("HttpClient" in l for l in _labels(r)) + assert any("HttpClient" in label for label in _labels(r)) + def test_kotlin_finds_data_class(): r = extract_kotlin(FIXTURES / "sample.kt") - assert any("Config" in l for l in _labels(r)) + assert any("Config" in label for label in _labels(r)) + def test_kotlin_finds_methods(): r = extract_kotlin(FIXTURES / "sample.kt") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) + def test_kotlin_finds_function(): r = extract_kotlin(FIXTURES / "sample.kt") - assert any("createClient" in l for l in _labels(r)) + assert any("createClient" in label for label in _labels(r)) + def test_kotlin_emits_in_file_calls(): """Regression test for the call-walker `simple_identifier` / @@ -384,23 +442,27 @@ def test_kotlin_parameter_return_generic_and_field_contexts(): # ── Scala ───────────────────────────────────────────────────────────────────── + def test_scala_no_error(): r = extract_scala(FIXTURES / "sample.scala") assert "error" not in r + def test_scala_finds_class(): r = extract_scala(FIXTURES / "sample.scala") - assert any("HttpClient" in l for l in _labels(r)) + assert any("HttpClient" in label for label in _labels(r)) + def test_scala_finds_object(): r = extract_scala(FIXTURES / "sample.scala") - assert any("HttpClientFactory" in l for l in _labels(r)) + assert any("HttpClientFactory" in label for label in _labels(r)) + def test_scala_finds_methods(): r = extract_scala(FIXTURES / "sample.scala") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) def test_scala_import_edges_have_import_context(): @@ -440,23 +502,28 @@ def test_scala_call_edges_have_call_context(): # ── PHP ─────────────────────────────────────────────────────────────────────── + def test_php_no_error(): r = extract_php(FIXTURES / "sample.php") assert "error" not in r + def test_php_finds_class(): r = extract_php(FIXTURES / "sample.php") - assert any("ApiClient" in l for l in _labels(r)) + assert any("ApiClient" in label for label in _labels(r)) + def test_php_finds_methods(): r = extract_php(FIXTURES / "sample.php") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) + def test_php_finds_function(): r = extract_php(FIXTURES / "sample.php") - assert any("parseResponse" in l for l in _labels(r)) + assert any("parseResponse" in label for label in _labels(r)) + def test_php_finds_imports(): r = extract_php(FIXTURES / "sample.php") @@ -476,55 +543,67 @@ def test_php_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_php_finds_static_property_access(): r = extract_php(FIXTURES / "sample_php_static_prop.php") assert "uses_static_prop" in _relations(r) + def test_php_static_prop_target_is_holding_class(): r = extract_php(FIXTURES / "sample_php_static_prop.php") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} uses_prop = [ (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "uses_static_prop" + for e in r["edges"] + if e["relation"] == "uses_static_prop" ] assert any("DefaultPalette" in tgt for _, tgt in uses_prop) + def test_php_finds_config_helper_call(): r = extract_php(FIXTURES / "sample_php_config.php") assert "uses_config" in _relations(r) + def test_php_config_helper_target_matches_first_segment(): r = extract_php(FIXTURES / "sample_php_config.php") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} uses_cfg = [ (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "uses_config" + for e in r["edges"] + if e["relation"] == "uses_config" ] assert any("Throttle" in tgt for _, tgt in uses_cfg) + def test_php_finds_container_bind(): r = extract_php(FIXTURES / "sample_php_container.php") assert "bound_to" in _relations(r) + def test_php_container_bind_links_contract_to_implementation(): r = extract_php(FIXTURES / "sample_php_container.php") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} bound = [ (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "bound_to" + for e in r["edges"] + if e["relation"] == "bound_to" ] assert any("PaymentGateway" in src and "StripeGateway" in tgt for src, tgt in bound) + def test_php_finds_event_listeners(): r = extract_php(FIXTURES / "sample_php_listen.php") assert "listened_by" in _relations(r) + def test_php_event_listener_links_event_to_listener(): r = extract_php(FIXTURES / "sample_php_listen.php") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} listened = [ (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "listened_by" + for e in r["edges"] + if e["relation"] == "listened_by" ] assert any("UserRegistered" in src and "SendWelcomeEmail" in tgt for src, tgt in listened) @@ -545,31 +624,38 @@ def test_php_property_parameter_and_return_contexts(): # ── Swift ──────────────────────────────────────────────────────────────────── + def test_swift_no_error(): r = extract_swift(FIXTURES / "sample.swift") assert "error" not in r + def test_swift_finds_class(): r = extract_swift(FIXTURES / "sample.swift") - assert any("DataProcessor" in l for l in _labels(r)) + assert any("DataProcessor" in label for label in _labels(r)) + def test_swift_finds_protocol(): r = extract_swift(FIXTURES / "sample.swift") - assert any("Processor" in l for l in _labels(r)) + assert any("Processor" in label for label in _labels(r)) + def test_swift_finds_struct(): r = extract_swift(FIXTURES / "sample.swift") - assert any("Config" in l for l in _labels(r)) + assert any("Config" in label for label in _labels(r)) + def test_swift_finds_methods(): r = extract_swift(FIXTURES / "sample.swift") labels = _labels(r) - assert any("addItem" in l for l in labels) - assert any("process" in l for l in labels) + assert any("addItem" in label for label in labels) + assert any("process" in label for label in labels) + def test_swift_finds_function(): r = extract_swift(FIXTURES / "sample.swift") - assert any("createProcessor" in l for l in _labels(r)) + assert any("createProcessor" in label for label in _labels(r)) + def test_swift_finds_imports(): r = extract_swift(FIXTURES / "sample.swift") @@ -582,42 +668,51 @@ def test_swift_import_edges_have_import_context(): assert import_edges assert all(e.get("context") == "import" for e in import_edges) + def test_swift_no_dangling_edges(): r = extract_swift(FIXTURES / "sample.swift") node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: assert e["source"] in node_ids + def test_swift_finds_actor(): r = extract_swift(FIXTURES / "sample.swift") - assert any("CacheManager" in l for l in _labels(r)) + assert any("CacheManager" in label for label in _labels(r)) + def test_swift_finds_enum(): r = extract_swift(FIXTURES / "sample.swift") - assert any("NetworkError" in l for l in _labels(r)) + assert any("NetworkError" in label for label in _labels(r)) + def test_swift_finds_enum_methods(): r = extract_swift(FIXTURES / "sample.swift") - assert any("describe" in l for l in _labels(r)) + assert any("describe" in label for label in _labels(r)) + def test_swift_finds_enum_cases(): r = extract_swift(FIXTURES / "sample.swift") labels = _labels(r) - assert any("timeout" in l for l in labels) - assert any("connectionFailed" in l for l in labels) + assert any("timeout" in label for label in labels) + assert any("connectionFailed" in label for label in labels) + def test_swift_enum_cases_have_case_of_edge(): r = extract_swift(FIXTURES / "sample.swift") case_edges = [e for e in r["edges"] if e["relation"] == "case_of"] assert len(case_edges) >= 2 + def test_swift_finds_deinit(): r = extract_swift(FIXTURES / "sample.swift") - assert any("deinit" in l for l in _labels(r)) + assert any("deinit" in label for label in _labels(r)) + def test_swift_finds_subscript(): r = extract_swift(FIXTURES / "sample.swift") - assert any("subscript" in l for l in _labels(r)) + assert any("subscript" in label for label in _labels(r)) + def test_swift_extension_methods_attach_to_type(): r = extract_swift(FIXTURES / "sample.swift") @@ -632,11 +727,13 @@ def test_swift_extension_methods_attach_to_type(): break assert found, "extension method isValid should attach to Config" + def test_swift_extension_does_not_duplicate_type_node(): r = extract_swift(FIXTURES / "sample.swift") config_nodes = [n for n in r["nodes"] if n["label"] == "Config"] assert len(config_nodes) == 1, f"Config should appear once, got {len(config_nodes)}" + def test_swift_protocol_conformance_emits_implements(): r = extract_swift(FIXTURES / "sample.swift") assert ("DataProcessor", "Processor") in _edge_labels(r, "implements") @@ -660,11 +757,13 @@ def test_swift_parameter_return_generic_and_field_contexts(): assert ("run", "DataProcessor") in _edge_labels(r, "references", "generic_arg") assert ("DataProcessor", "Result") in _edge_labels(r, "references", "field") + def test_swift_emits_calls(): r = extract_swift(FIXTURES / "sample.swift") calls = _calls(r) assert any("process" in src and "validate" in tgt for src, tgt in calls) + def test_swift_call_edges_have_call_context(): r = extract_swift(FIXTURES / "sample.swift") call_edges = _edges_with_relation(r, "calls") @@ -678,36 +777,45 @@ def test_swift_extension_across_files_merges_into_canonical_type(): node ids carry the file stem, so without a corpus-level merge each file would emit its own Foo.""" from graphify.extract import extract + paths = sorted((FIXTURES / "swift_cross_file").glob("*.swift")) r = extract(paths, cache_root=Path("/tmp/graphify-test-no-cache")) foo_nodes = [n for n in r["nodes"] if n["label"] == "Foo"] - assert len(foo_nodes) == 1, f"Foo should appear once, got {len(foo_nodes)}: {[n['id'] for n in foo_nodes]}" + assert len(foo_nodes) == 1, ( + f"Foo should appear once, got {len(foo_nodes)}: {[n['id'] for n in foo_nodes]}" + ) foo_id = foo_nodes[0]["id"] method_targets = { - e["target"] for e in r["edges"] - if e["relation"] == "method" and e["source"] == foo_id + e["target"] for e in r["edges"] if e["relation"] == "method" and e["source"] == foo_id } method_labels = {n["label"] for n in r["nodes"] if n["id"] in method_targets} - assert any("one" in l for l in method_labels), f"one() should attach to Foo, got {method_labels}" - assert any("two" in l for l in method_labels), f"extension method two() should attach to Foo, got {method_labels}" + assert any("one" in label for label in method_labels), ( + f"one() should attach to Foo, got {method_labels}" + ) + assert any("two" in label for label in method_labels), ( + f"extension method two() should attach to Foo, got {method_labels}" + ) # ── Elixir ──────────────────────────────────────────────────────────────────── -from graphify.extract import extract_elixir +from graphify.extract import extract_elixir # noqa: E402 + def test_elixir_finds_module(): r = extract_elixir(FIXTURES / "sample.ex") assert "error" not in r labels = [n["label"] for n in r["nodes"]] - assert any("MyApp.Accounts.User" in l for l in labels) + assert any("MyApp.Accounts.User" in label for label in labels) + def test_elixir_finds_functions(): r = extract_elixir(FIXTURES / "sample.ex") labels = [n["label"] for n in r["nodes"]] - assert any("create" in l for l in labels) - assert any("find" in l for l in labels) - assert any("validate" in l for l in labels) + assert any("create" in label for label in labels) + assert any("find" in label for label in labels) + assert any("validate" in label for label in labels) + def test_elixir_finds_imports(): r = extract_elixir(FIXTURES / "sample.ex") @@ -721,11 +829,14 @@ def test_elixir_import_edges_have_import_context(): assert import_edges assert all(e.get("context") == "import" for e in import_edges) + def test_elixir_finds_calls(): r = extract_elixir(FIXTURES / "sample.ex") calls = {(e["source"], e["target"]) for e in r["edges"] if e["relation"] == "calls"} labels = {n["id"]: n["label"] for n in r["nodes"]} - assert any("create" in labels.get(src, "") and "validate" in labels.get(tgt, "") for src, tgt in calls) + assert any( + "create" in labels.get(src, "") and "validate" in labels.get(tgt, "") for src, tgt in calls + ) def test_elixir_call_edges_have_call_context(): @@ -734,6 +845,7 @@ def test_elixir_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_elixir_method_edges(): r = extract_elixir(FIXTURES / "sample.ex") methods = [e for e in r["edges"] if e["relation"] == "method"] @@ -741,7 +853,7 @@ def test_elixir_method_edges(): # ── Objective-C ────────────────────────────────────────────────────────────── -from graphify.extract import extract_objc +from graphify.extract import extract_objc # noqa: E402 def test_objc_finds_interface(): @@ -759,7 +871,7 @@ def test_objc_finds_subclass(): def test_objc_finds_methods(): r = extract_objc(FIXTURES / "sample.m") labels = [n["label"] for n in r["nodes"]] - assert any("speak" in l or "fetch" in l or "initWithName" in l for l in labels) + assert any("speak" in label or "fetch" in label or "initWithName" in label for label in labels) def test_objc_finds_imports(): @@ -804,6 +916,7 @@ def test_objc_no_dangling_edges(): # Go # --------------------------------------------------------------------------- + def test_go_receiver_methods_share_type_node(): """Methods on the same receiver type must share one canonical type node.""" r = extract_go(FIXTURES / "sample.go") @@ -811,6 +924,7 @@ def test_go_receiver_methods_share_type_node(): # Both Start() and Stop() are on *Server — should produce exactly one Server node assert len(server_nodes) == 1 + def test_go_receiver_uses_pkg_scope(): """Type node id should be scoped to directory, not file stem.""" r = extract_go(FIXTURES / "sample.go") @@ -824,6 +938,7 @@ def test_go_receiver_uses_pkg_scope(): # Julia # --------------------------------------------------------------------------- + def test_julia_finds_module(): r = extract_julia(FIXTURES / "sample.jl") labels = [n["label"] for n in r["nodes"]] @@ -846,14 +961,14 @@ def test_julia_finds_abstract_type(): def test_julia_finds_functions(): r = extract_julia(FIXTURES / "sample.jl") labels = [n["label"] for n in r["nodes"]] - assert any("area" in l for l in labels) - assert any("distance" in l for l in labels) + assert any("area" in label for label in labels) + assert any("distance" in label for label in labels) def test_julia_finds_short_function(): r = extract_julia(FIXTURES / "sample.jl") labels = [n["label"] for n in r["nodes"]] - assert any("perimeter" in l for l in labels) + assert any("perimeter" in label for label in labels) def test_julia_finds_imports(): @@ -910,6 +1025,7 @@ def test_julia_no_dangling_edges(): # ── Fortran extractor ──────────────────────────────────────────────────────── + def test_fortran_finds_module(): r = extract_fortran(FIXTURES / "sample.f90") assert "error" not in r @@ -920,14 +1036,14 @@ def test_fortran_finds_module(): def test_fortran_finds_subroutines(): r = extract_fortran(FIXTURES / "sample.f90") labels = [n["label"] for n in r["nodes"]] - assert any("circle_area" in l for l in labels) - assert any("print_area" in l for l in labels) + assert any("circle_area" in label for label in labels) + assert any("print_area" in label for label in labels) def test_fortran_finds_function(): r = extract_fortran(FIXTURES / "sample.f90") labels = [n["label"] for n in r["nodes"]] - assert any("distance" in l for l in labels) + assert any("distance" in label for label in labels) def test_fortran_finds_program(): @@ -957,7 +1073,11 @@ def test_fortran_finds_calls(): def test_fortran_case_insensitive_names(): r = extract_fortran(FIXTURES / "sample.f90") labels = [n["label"] for n in r["nodes"]] - assert all(l == l.lower() or "(" in l for l in labels if l.endswith(("()", "")) and not "." in l) + assert all( + label == label.lower() or "(" in label + for label in labels + if label.endswith(("()", "")) and "." not in label + ) assert "geometry" in labels assert "main" in labels @@ -986,11 +1106,12 @@ def test_fortran_capital_F_parses_preprocessed(): assert "error" not in r labels = [n["label"] for n in r["nodes"]] assert "shapes" in labels - assert any("compute_volume" in l for l in labels) + assert any("compute_volume" in label for label in labels) # ── PowerShell ─────────────────────────────────────────────────────────────── + def test_powershell_no_error(): r = extract_powershell(FIXTURES / "sample.ps1") assert "error" not in r @@ -1000,7 +1121,7 @@ def test_powershell_finds_class_and_method(): r = extract_powershell(FIXTURES / "sample.ps1") labels = [n["label"] for n in r["nodes"]] assert "DataProcessor" in labels - assert any("Transform" in l for l in labels) + assert any("Transform" in label for label in labels) def test_powershell_property_field_type_context(): @@ -1017,10 +1138,12 @@ def test_powershell_method_parameter_and_return_type_contexts(): # ── TypeScript dynamic imports ─────────────────────────────────────────────── + def test_ts_dynamic_import_no_error(): r = extract_js(FIXTURES / "dynamic_import.ts") assert "error" not in r + def test_ts_dynamic_import_extracts_edges(): """Dynamic import() calls inside functions should produce imports_from edges.""" r = extract_js(FIXTURES / "dynamic_import.ts") @@ -1028,56 +1151,71 @@ def test_ts_dynamic_import_extracts_edges(): targets = {e["target"] for e in dyn_edges} # Should find: static ./logger, dynamic ./mayaEngine.js, dynamic ./queue.js assert any("logger" in t for t in targets), f"Missing static import of logger: {targets}" - assert any("mayaengine" in t.lower() for t in targets), f"Missing dynamic import of mayaEngine: {targets}" + assert any("mayaengine" in t.lower() for t in targets), ( + f"Missing dynamic import of mayaEngine: {targets}" + ) assert any("queue" in t.lower() for t in targets), f"Missing dynamic import of queue: {targets}" + def test_ts_dynamic_import_confidence(): """Dynamic imports should have EXTRACTED confidence (they are deterministic string literals).""" r = extract_js(FIXTURES / "dynamic_import.ts") - dyn_edges = [e for e in r["edges"] - if e["relation"] == "imports_from" - and "mayaengine" in e["target"].lower()] + dyn_edges = [ + e + for e in r["edges"] + if e["relation"] == "imports_from" and "mayaengine" in e["target"].lower() + ] assert len(dyn_edges) >= 1 assert dyn_edges[0]["confidence"] == "EXTRACTED" + def test_ts_dynamic_import_source_is_function(): """Dynamic import edge source should be the enclosing function, not the file.""" r = extract_js(FIXTURES / "dynamic_import.ts") node_labels = {n["id"]: n["label"] for n in r["nodes"]} - dyn_edges = [e for e in r["edges"] - if e["relation"] == "imports_from" - and "mayaengine" in e["target"].lower()] + dyn_edges = [ + e + for e in r["edges"] + if e["relation"] == "imports_from" and "mayaengine" in e["target"].lower() + ] assert len(dyn_edges) >= 1 src_label = node_labels.get(dyn_edges[0]["source"], "") assert "processInbound" in src_label, f"Expected processInbound as source, got {src_label}" + def test_ts_no_dynamic_import_in_sync_fn(): """Functions without dynamic imports should not get spurious imports_from edges.""" r = extract_js(FIXTURES / "dynamic_import.ts") node_ids = {n["label"]: n["id"] for n in r["nodes"]} sync_nid = node_ids.get("syncOnly()") if sync_nid: - sync_imports = [e for e in r["edges"] - if e["source"] == sync_nid and e["relation"] == "imports_from"] + sync_imports = [ + e for e in r["edges"] if e["source"] == sync_nid and e["relation"] == "imports_from" + ] assert len(sync_imports) == 0 + def test_ts_dynamic_template_literal_skipped(): """Dynamic template literals (with ${}) must not produce an imports_from edge.""" r = extract_js(FIXTURES / "dynamic_import.ts") targets = {e["target"] for e in r["edges"] if e["relation"] == "imports_from"} # loadHandler uses `./handlers/${handlerName}` — no static path, must be absent - assert not any("handler" in t.lower() and "$" in t for t in targets), \ + assert not any("handler" in t.lower() and "$" in t for t in targets), ( f"Garbage edge from dynamic template literal found: {targets}" + ) # More robust: no target should contain a brace character - assert not any("{" in t or "}" in t for t in targets), \ + assert not any("{" in t or "}" in t for t in targets), ( f"Target contains unresolved template expression: {targets}" + ) + def test_ts_static_template_literal_resolved(): """Static template literals (no ${}) should resolve the same as a plain string.""" r = extract_js(FIXTURES / "dynamic_import.ts") targets = {e["target"] for e in r["edges"] if e["relation"] == "imports_from"} - assert any("statichelper" in t.lower() for t in targets), \ + assert any("statichelper" in t.lower() for t in targets), ( f"Static template literal import not resolved: {targets}" + ) def test_js_local_const_does_not_emit_phantom_node(tmp_path): @@ -1115,19 +1253,16 @@ def test_js_module_level_arrow_produces_node_and_call_edges(tmp_path): The scope guard must not accidentally suppress top-level arrow functions. """ - src = ( - "function helper() { return 1; }\n" - "const handler = () => {\n" - " helper();\n" - "};\n" - ) + src = "function helper() { return 1; }\nconst handler = () => {\n helper();\n};\n" f = tmp_path / "arrows.js" f.write_text(src) r = extract_js(f) labels = _labels(r) relations = _relations(r) - assert any("handler" in l for l in labels), f"module-level arrow 'handler' missing: {labels}" + assert any("handler" in label for label in labels), ( + f"module-level arrow 'handler' missing: {labels}" + ) assert "calls" in relations, f"expected 'calls' edge from handler->helper: {relations}" @@ -1151,25 +1286,29 @@ def test_ts_local_const_does_not_emit_phantom_node(tmp_path): # ── Markdown ───────────────────────────────────────────────────────────────── -from graphify.extract import extract_markdown +from graphify.extract import extract_markdown # noqa: E402 + def test_markdown_no_error(): r = extract_markdown(FIXTURES / "deploy_guide.md") assert "error" not in r + def test_markdown_finds_headings(): r = extract_markdown(FIXTURES / "deploy_guide.md") labels = _labels(r) - assert any("Deploy Guide" in l for l in labels) - assert any("Prerequisites" in l for l in labels) - assert any("Full Deploy" in l for l in labels) - assert any("Rollback" in l for l in labels) + assert any("Deploy Guide" in label for label in labels) + assert any("Prerequisites" in label for label in labels) + assert any("Full Deploy" in label for label in labels) + assert any("Rollback" in label for label in labels) + def test_markdown_finds_nested_heading(): """### Database Migration is nested under ## Full Deploy.""" r = extract_markdown(FIXTURES / "deploy_guide.md") labels = _labels(r) - assert any("Database Migration" in l for l in labels) + assert any("Database Migration" in label for label in labels) + def test_markdown_skips_fenced_code_blocks(): """Fenced code blocks should NOT emit nodes (#1077). @@ -1181,8 +1320,10 @@ def test_markdown_skips_fenced_code_blocks(): """ r = extract_markdown(FIXTURES / "deploy_guide.md") labels = _labels(r) - assert not any(l.startswith("code:") for l in labels), \ - f"Expected no code:* nodes after #1077 fix, got: {[l for l in labels if l.startswith('code:')]}" + assert not any(label.startswith("code:") for label in labels), ( + f"Expected no code:* nodes after #1077 fix, got: {[label for label in labels if label.startswith('code:')]}" + ) + def test_markdown_contains_edges(): """Headings should be connected via 'contains' edges (file->h, h->h).""" @@ -1200,16 +1341,11 @@ def test_markdown_fenced_heading_not_parsed(): The fence-toggle skips over fenced contents so interior markdown syntax is not misread as document structure. """ - import tempfile, os + import os + import tempfile + src = ( - "# Real Heading\n" - "\n" - "```bash\n" - "## Not A Heading\n" - "echo hello\n" - "```\n" - "\n" - "## Another Real Heading\n" + "# Real Heading\n\n```bash\n## Not A Heading\necho hello\n```\n\n## Another Real Heading\n" ) with tempfile.NamedTemporaryFile(suffix=".md", mode="w", delete=False) as fh: fh.write(src) @@ -1220,10 +1356,14 @@ def test_markdown_fenced_heading_not_parsed(): finally: os.unlink(fpath) - assert any("Real Heading" in l for l in labels), f"'Real Heading' missing: {labels}" - assert any("Another Real Heading" in l for l in labels), f"'Another Real Heading' missing: {labels}" - assert not any("Not A Heading" in l for l in labels), \ + assert any("Real Heading" in label for label in labels), f"'Real Heading' missing: {labels}" + assert any("Another Real Heading" in label for label in labels), ( + f"'Another Real Heading' missing: {labels}" + ) + assert not any("Not A Heading" in label for label in labels), ( f"fenced '## Not A Heading' was incorrectly parsed as a node: {labels}" + ) + def test_markdown_no_dangling_edges(): r = extract_markdown(FIXTURES / "deploy_guide.md") @@ -1242,14 +1382,14 @@ def test_groovy_no_error(): def test_groovy_finds_class(): r = extract_groovy(FIXTURES / "sample.groovy") - assert any("SampleService" in l for l in _labels(r)) + assert any("SampleService" in label for label in _labels(r)) def test_groovy_finds_methods(): r = extract_groovy(FIXTURES / "sample.groovy") labels = _labels(r) - assert any("process" in l for l in labels) - assert any("reset" in l for l in labels) + assert any("process" in label for label in labels) + assert any("reset" in label for label in labels) def test_groovy_finds_imports(): @@ -1273,18 +1413,18 @@ def test_groovy_no_dangling_edges(): def test_groovy_spock_finds_class(): r = extract_groovy(FIXTURES / "sample_spock.groovy") - assert any("SampleSpec" in l for l in _labels(r)) + assert any("SampleSpec" in label for label in _labels(r)) def test_groovy_spock_finds_feature_methods(): r = extract_groovy(FIXTURES / "sample_spock.groovy") - feature_labels = [l for l in _labels(r) if l.startswith('"')] + feature_labels = [label for label in _labels(r) if label.startswith('"')] assert len(feature_labels) >= 2 def test_groovy_spock_finds_method_with_apostrophe(): r = extract_groovy(FIXTURES / "sample_spock.groovy") - assert any("it's" in l for l in _labels(r)) + assert any("it's" in label for label in _labels(r)) def test_groovy_spock_preserves_import_edges(): @@ -1301,15 +1441,18 @@ def test_groovy_spock_no_dangling_edges(): # ── DM (BYOND DreamMaker) ──────────────────────────────────────────────────── + def test_dm_no_error(): r = extract_dm(FIXTURES / "sample.dm") assert "error" not in r + def test_dm_finds_global_proc(): r = extract_dm(FIXTURES / "sample.dm") labels = _labels(r) - assert any(l == "log_event()" for l in labels) - assert any(l == "RunTest()" for l in labels) + assert any(label == "log_event()" for label in labels) + assert any(label == "RunTest()" for label in labels) + def test_dm_finds_type_definition(): r = extract_dm(FIXTURES / "sample.dm") @@ -1317,22 +1460,26 @@ def test_dm_finds_type_definition(): assert "/datum/weapon" in labels assert "/datum/weapon/sword" in labels + def test_dm_qualifies_proc_with_type_path(): r = extract_dm(FIXTURES / "sample.dm") labels = _labels(r) assert "/datum/weapon/attack()" in labels assert "/datum/weapon/sword/attack()" in labels + def test_dm_finds_path_form_proc_definition(): r = extract_dm(FIXTURES / "sample.dm") assert "/datum/weapon/sword/sharpen()" in _labels(r) + def test_dm_emits_include_edge(): r = extract_dm(FIXTURES / "sample.dm") import_edges = _edges_with_relation(r, "imports", "imports_from") assert import_edges assert all(e.get("context") == "import" for e in import_edges) + def test_dm_unresolved_include_flagged_external(): r = extract_dm(FIXTURES / "sample.dm") import_edges = _edges_with_relation(r, "imports", "imports_from") @@ -1340,39 +1487,47 @@ def test_dm_unresolved_include_flagged_external(): assert helpers assert all(e.get("external") is True for e in helpers) + def test_dm_resolves_in_file_calls(): r = extract_dm(FIXTURES / "sample.dm") calls = _calls(r) assert any(callee == "log_event()" for _, callee in calls) assert ("/datum/weapon/sword/attack()", "/datum/weapon/sword/sharpen()") in calls + def test_dm_ambiguous_member_call_left_unresolved(): r = extract_dm(FIXTURES / "sample.dm") calls = _calls(r) - runtest_to_attack = [c for s, c in calls - if s == "RunTest()" and "attack" in c] + runtest_to_attack = [c for s, c in calls if s == "RunTest()" and "attack" in c] assert not runtest_to_attack assert any(rc["callee"] == "attack" for rc in r.get("raw_calls", [])) + def test_dm_emits_new_as_instantiates(): r = extract_dm(FIXTURES / "sample.dm") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} - inst = [(node_by_id.get(e["source"]), node_by_id.get(e["target"])) - for e in r["edges"] if e["relation"] == "instantiates"] + inst = [ + (node_by_id.get(e["source"]), node_by_id.get(e["target"])) + for e in r["edges"] + if e["relation"] == "instantiates" + ] assert ("RunTest()", "/datum/weapon/sword") in inst + def test_dm_call_edges_have_call_context(): r = extract_dm(FIXTURES / "sample.dm") call_edges = _edges_with_relation(r, "calls", "instantiates") assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_dm_no_dangling_edges(): r = extract_dm(FIXTURES / "sample.dm") node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: assert e["source"] in node_ids + def test_dm_super_call_not_emitted(): r = extract_dm(FIXTURES / "sample.dm") calls = _calls(r) @@ -1382,29 +1537,37 @@ def test_dm_super_call_not_emitted(): # ── DMI (BYOND icon sheets) ────────────────────────────────────────────────── + def test_dmi_no_error(): r = extract_dmi(FIXTURES / "sample.dmi") assert "error" not in r + def test_dmi_emits_state_nodes(): r = extract_dmi(FIXTURES / "sample.dmi") labels = _labels(r) - assert any(l == '"mob"' for l in labels) + assert any(label == '"mob"' for label in labels) + def test_dmi_state_contained_by_file(): r = extract_dmi(FIXTURES / "sample.dmi") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} - contains = [(node_by_id.get(e["source"]), node_by_id.get(e["target"])) - for e in r["edges"] if e["relation"] == "contains"] + contains = [ + (node_by_id.get(e["source"]), node_by_id.get(e["target"])) + for e in r["edges"] + if e["relation"] == "contains" + ] assert ("sample.dmi", '"mob"') in contains # ── DMM (BYOND map files) ──────────────────────────────────────────────────── + def test_dmm_no_error(): r = extract_dmm(FIXTURES / "sample.dmm") assert "error" not in r + def test_dmm_extracts_type_paths_as_uses_edges(): r = extract_dmm(FIXTURES / "sample.dmm") targets = {e["target"] for e in r["edges"] if e["relation"] == "uses"} @@ -1412,17 +1575,20 @@ def test_dmm_extracts_type_paths_as_uses_edges(): assert "obj_structure_table" in targets assert "obj_item_weapon_sword" in targets + def test_dmm_strips_var_overrides(): r = extract_dmm(FIXTURES / "sample.dmm") targets = {e["target"] for e in r["edges"] if e["relation"] == "uses"} assert not any("{" in t for t in targets) assert "obj_item_weapon_sword" in targets + def test_dmm_handles_multiline_tile_definition(): r = extract_dmm(FIXTURES / "sample.dmm") targets = {e["target"] for e in r["edges"] if e["relation"] == "uses"} assert "area_station_maintenance" in targets + def test_dmm_skips_grid_section(): r = extract_dmm(FIXTURES / "sample.dmm") targets = {e["target"] for e in r["edges"] if e["relation"] == "uses"} @@ -1431,28 +1597,36 @@ def test_dmm_skips_grid_section(): # ── DMF (BYOND interface forms) ────────────────────────────────────────────── + def test_dmf_no_error(): r = extract_dmf(FIXTURES / "sample.dmf") assert "error" not in r + def test_dmf_extracts_windows(): r = extract_dmf(FIXTURES / "sample.dmf") labels = _labels(r) assert 'window "mapwindow"' in labels assert 'window "infowindow"' in labels + def test_dmf_elem_labels_carry_control_type(): r = extract_dmf(FIXTURES / "sample.dmf") labels = _labels(r) assert 'elem "map" [MAP]' in labels + def test_dmf_elem_under_window(): r = extract_dmf(FIXTURES / "sample.dmf") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} - contains = [(node_by_id.get(e["source"]), node_by_id.get(e["target"])) - for e in r["edges"] if e["relation"] == "contains"] + contains = [ + (node_by_id.get(e["source"]), node_by_id.get(e["target"])) + for e in r["edges"] + if e["relation"] == "contains" + ] assert ('window "mapwindow"', 'elem "map" [MAP]') in contains + def test_dmf_no_dangling_edges(): r = extract_dmf(FIXTURES / "sample.dmf") node_ids = {n["id"] for n in r["nodes"]} @@ -1463,68 +1637,83 @@ def test_dmf_no_dangling_edges(): # -- .NET project files (.sln, .csproj, .razor) ------------------------------- + def test_sln_no_error(): r = extract_sln(FIXTURES / "sample.sln") assert "error" not in r + def test_sln_finds_projects(): r = extract_sln(FIXTURES / "sample.sln") labels = _labels(r) - assert any("WebApi" in l for l in labels) - assert any("Domain" in l for l in labels) + assert any("WebApi" in label for label in labels) + assert any("Domain" in label for label in labels) + def test_sln_contains_edges(): r = extract_sln(FIXTURES / "sample.sln") assert "contains" in _relations(r) + def test_sln_project_dependency_edges(): r = extract_sln(FIXTURES / "sample.sln") assert "imports" in _relations(r) + def test_csproj_no_error(): r = extract_csproj(FIXTURES / "sample.csproj") assert "error" not in r + def test_csproj_finds_packages(): r = extract_csproj(FIXTURES / "sample.csproj") labels = _labels(r) - assert any("MediatR" in l for l in labels) - assert any("FluentValidation" in l for l in labels) + assert any("MediatR" in label for label in labels) + assert any("FluentValidation" in label for label in labels) + def test_csproj_finds_project_references(): r = extract_csproj(FIXTURES / "sample.csproj") labels = _labels(r) - assert any("Domain.csproj" in l for l in labels) + assert any("Domain.csproj" in label for label in labels) + def test_csproj_finds_target_framework(): r = extract_csproj(FIXTURES / "sample.csproj") - assert any("net8.0" in l for l in _labels(r)) + assert any("net8.0" in label for label in _labels(r)) + def test_csproj_finds_sdk(): r = extract_csproj(FIXTURES / "sample.csproj") - assert any("Microsoft.NET.Sdk.Web" in l for l in _labels(r)) + assert any("Microsoft.NET.Sdk.Web" in label for label in _labels(r)) + def test_razor_no_error(): r = extract_razor(FIXTURES / "sample.razor") assert "error" not in r + def test_razor_finds_using_directives(): r = extract_razor(FIXTURES / "sample.razor") assert "imports" in _relations(r) + def test_razor_finds_component_references(): r = extract_razor(FIXTURES / "sample.razor") assert "calls" in _relations(r) + def test_razor_finds_inherits(): r = extract_razor(FIXTURES / "sample.razor") assert "inherits" in _relations(r) + def test_razor_finds_code_block_methods(): r = extract_razor(FIXTURES / "sample.razor") labels = _labels(r) - assert any("IncrementCount" in l for l in labels) - assert any("LoadData" in l for l in labels) + assert any("IncrementCount" in label for label in labels) + assert any("LoadData" in label for label in labels) + def test_razor_no_dangling_edges(): r = extract_razor(FIXTURES / "sample.razor") diff --git a/tests/test_llm_backends.py b/tests/test_llm_backends.py index d2ee058c4..7cd670e83 100644 --- a/tests/test_llm_backends.py +++ b/tests/test_llm_backends.py @@ -224,6 +224,7 @@ def test_response_is_hollow_accepts_real_extraction(): def _fake_openai_response(content, *, finish_reason="stop", prompt_tokens=100, completion_tokens=0): """Build a minimal stand-in for an `openai` SDK ChatCompletion response.""" + class _Usage: def __init__(self): self.prompt_tokens = prompt_tokens @@ -257,11 +258,12 @@ class _FakeOpenAI: def __init__(self, *_, **__): self.chat = self self.completions = self + def create(self, **__): return fake_resp fake_module = types.ModuleType("openai") - fake_module.OpenAI = _FakeOpenAI + setattr(fake_module, "OpenAI", _FakeOpenAI) monkeypatch.setitem(sys.modules, "openai", fake_module) @@ -274,8 +276,13 @@ def test_call_openai_compat_relabels_empty_content_as_length(monkeypatch): _install_fake_openai(monkeypatch, fake_resp) result = llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "user msg", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "user msg", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert result["finish_reason"] == "length", ( "empty content from a 'successful' call must be re-labelled so the " @@ -288,8 +295,13 @@ def test_call_openai_compat_relabels_none_content_as_length(monkeypatch): _install_fake_openai(monkeypatch, fake_resp) result = llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "u", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "u", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert result["finish_reason"] == "length" @@ -298,12 +310,19 @@ def test_call_openai_compat_relabels_unparseable_json_as_length(monkeypatch): # A half-generated response: `{"nodes": [{"id":` parses to {} (empty # fragment) via _parse_llm_json's JSONDecodeError fallback. That is also # hollow and must trigger bisection. - fake_resp = _fake_openai_response('{"nodes": [{"id":', finish_reason="stop", completion_tokens=20) + fake_resp = _fake_openai_response( + '{"nodes": [{"id":', finish_reason="stop", completion_tokens=20 + ) _install_fake_openai(monkeypatch, fake_resp) result = llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "u", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "u", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert result["finish_reason"] == "length" @@ -318,8 +337,13 @@ def test_call_openai_compat_preserves_real_finish_reason(monkeypatch): _install_fake_openai(monkeypatch, fake_resp) result = llm._call_openai_compat( - "http://localhost:11434/v1", "k", "m", - "u", temperature=0, max_completion_tokens=8192, backend="kimi", + "http://localhost:11434/v1", + "k", + "m", + "u", + temperature=0, + max_completion_tokens=8192, + backend="kimi", ) assert result["finish_reason"] == "stop" assert result["nodes"] == [{"id": "a"}] @@ -352,7 +376,7 @@ def create(self, **kwargs): ) fake_module = types.ModuleType("openai") - fake_module.OpenAI = _FakeOpenAI + setattr(fake_module, "OpenAI", _FakeOpenAI) monkeypatch.setitem(sys.modules, "openai", fake_module) return captured @@ -363,8 +387,13 @@ def test_ollama_extra_body_sets_num_ctx_and_keep_alive(monkeypatch): monkeypatch.delenv("GRAPHIFY_OLLAMA_KEEP_ALIVE", raising=False) llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "user msg", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "user msg", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert "extra_body" in captured, "extra_body must be sent to Ollama" @@ -387,8 +416,13 @@ def test_ollama_num_ctx_scales_with_small_token_budget(monkeypatch): small_chunk_msg = "x" * 32_000 llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - small_chunk_msg, temperature=0, max_completion_tokens=16384, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + small_chunk_msg, + temperature=0, + max_completion_tokens=16384, + backend="ollama", ) num_ctx = captured["extra_body"]["options"]["num_ctx"] @@ -407,8 +441,13 @@ def test_ollama_num_ctx_env_override(monkeypatch): monkeypatch.delenv("GRAPHIFY_OLLAMA_KEEP_ALIVE", raising=False) llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "u", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "u", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert captured["extra_body"]["options"]["num_ctx"] == 65536 @@ -418,8 +457,13 @@ def test_non_ollama_backend_gets_no_num_ctx_extra_body(monkeypatch): captured = _install_capturing_openai(monkeypatch) llm._call_openai_compat( - "https://api.openai.com/v1", "sk-test", "gpt-4.1-mini", - "u", temperature=0, max_completion_tokens=8192, backend="openai", + "https://api.openai.com/v1", + "sk-test", + "gpt-4.1-mini", + "u", + temperature=0, + max_completion_tokens=8192, + backend="openai", ) eb = captured.get("extra_body") @@ -445,8 +489,14 @@ def fake_extract(chunk, *_, **__): with patch("graphify.llm.extract_files_direct", side_effect=fake_extract): with patch("graphify.llm.ThreadPoolExecutor") as mock_pool: result = llm.extract_corpus_parallel( - files, backend="ollama", api_key="ollama", model="qwen2.5-coder:7b", - root=tmp_path, token_budget=None, chunk_size=2, max_concurrency=4, + files, + backend="ollama", + api_key="ollama", + model="qwen2.5-coder:7b", + root=tmp_path, + token_budget=None, + chunk_size=2, + max_concurrency=4, ) mock_pool.assert_not_called() @@ -469,8 +519,14 @@ def test_extract_corpus_parallel_ollama_parallel_env_restores_concurrency(tmp_pa )() try: llm.extract_corpus_parallel( - files, backend="ollama", api_key="ollama", model="m", - root=tmp_path, token_budget=None, chunk_size=2, max_concurrency=4, + files, + backend="ollama", + api_key="ollama", + model="m", + root=tmp_path, + token_budget=None, + chunk_size=2, + max_concurrency=4, ) except Exception: pass # mock scaffolding may not be complete; we only care about the call @@ -496,16 +552,24 @@ def fake_extract(chunk, *_, **__): # Hollow response: looks successful, finish_reason already # rewritten to "length" by _call_openai_compat. return { - "nodes": [], "edges": [], "hyperedges": [], - "input_tokens": 100, "output_tokens": 0, - "model": "m", "finish_reason": "length", + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 100, + "output_tokens": 0, + "model": "m", + "finish_reason": "length", } return _ok(nodes=[{"id": f.stem} for f in chunk]) with patch("graphify.llm.extract_files_direct", side_effect=fake_extract): result = llm._extract_with_adaptive_retry( - files, backend="ollama", api_key="ollama", model="qwen2.5-coder:7b", - root=tmp_path, max_depth=3, + files, + backend="ollama", + api_key="ollama", + model="qwen2.5-coder:7b", + root=tmp_path, + max_depth=3, ) assert len(result["nodes"]) == 4, ( diff --git a/tests/test_llm_parser.py b/tests/test_llm_parser.py index c643807b9..6956a1c84 100644 --- a/tests/test_llm_parser.py +++ b/tests/test_llm_parser.py @@ -6,13 +6,12 @@ - The switch from --append-system-prompt to --system-prompt - The GRAPHIFY_CLAUDE_CLI_MODEL env-var passthrough """ + from __future__ import annotations import json from unittest.mock import patch -import pytest - from graphify import llm @@ -25,10 +24,7 @@ def test_preamble_then_fence_is_parsed(): so any preamble caused json.loads to fail and the chunk to be dropped as a hollow response. The robust parser handles fences anywhere in the text.""" - raw = ( - "Here are the extracted entities:\n\n" - '```json\n{"nodes": [{"id": "a"}], "edges": []}\n```' - ) + raw = 'Here are the extracted entities:\n\n```json\n{"nodes": [{"id": "a"}], "edges": []}\n```' result = llm._parse_llm_json(raw) assert result["nodes"] == [{"id": "a"}] assert result["edges"] == [] @@ -37,10 +33,7 @@ def test_preamble_then_fence_is_parsed(): def test_prose_wrapped_json_without_fence_is_parsed(): """Some models return prose around bare JSON with no markdown fence. The balanced-brace fallback extracts the first complete object.""" - raw = ( - 'The extracted graph is {"nodes": [{"id": "b"}], "edges": []}. ' - "Hope this helps!" - ) + raw = 'The extracted graph is {"nodes": [{"id": "b"}], "edges": []}. Hope this helps!' result = llm._parse_llm_json(raw) assert result["nodes"] == [{"id": "b"}] @@ -87,16 +80,22 @@ def test_empty_response_returns_empty_fragment(): def _make_envelope(result_obj: dict) -> str: - return json.dumps({ - "type": "result", - "subtype": "success", - "is_error": False, - "result": json.dumps(result_obj), - "usage": {"input_tokens": 1, "output_tokens": 1, - "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}, - "modelUsage": {"claude-opus-4-7": {}}, - "stop_reason": "end_turn", - }) + return json.dumps( + { + "type": "result", + "subtype": "success", + "is_error": False, + "result": json.dumps(result_obj), + "usage": { + "input_tokens": 1, + "output_tokens": 1, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + }, + "modelUsage": {"claude-opus-4-7": {}}, + "stop_reason": "end_turn", + } + ) @patch("shutil.which", return_value="/usr/local/bin/claude") diff --git a/tests/test_mcp_ingest.py b/tests/test_mcp_ingest.py index 8e46b170d..2f447ac54 100644 --- a/tests/test_mcp_ingest.py +++ b/tests/test_mcp_ingest.py @@ -1,10 +1,10 @@ """Tests for graphify.mcp_ingest — MCP config file extraction.""" + from __future__ import annotations import json from pathlib import Path -import pytest from graphify.mcp_ingest import ( MCP_CONFIG_FILENAMES, @@ -29,11 +29,7 @@ def _relations(result): def _label_by_kind(result, kind): - return [ - n["label"] - for n in result["nodes"] - if n.get("metadata", {}).get("mcp_kind") == kind - ] + return [n["label"] for n in result["nodes"] if n.get("metadata", {}).get("mcp_kind") == kind] def _write(tmp_path: Path, name: str, payload) -> Path: @@ -167,13 +163,21 @@ def test_every_edge_has_confidence_score(): def test_same_command_collapses_to_one_node_across_configs(tmp_path): # Two configs both use "npx". The mcp_command node should be shared. - config_a = _write(tmp_path, ".mcp.json", { - "mcpServers": {"a": {"command": "npx", "args": ["@scope/server-a"]}}, - }) + config_a = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"a": {"command": "npx", "args": ["@scope/server-a"]}}, + }, + ) (tmp_path / "subdir").mkdir() - config_b = _write(tmp_path / "subdir", "claude_desktop_config.json", { - "mcpServers": {"b": {"command": "npx", "args": ["@scope/server-b"]}}, - }) + config_b = _write( + tmp_path / "subdir", + "claude_desktop_config.json", + { + "mcpServers": {"b": {"command": "npx", "args": ["@scope/server-b"]}}, + }, + ) r_a = extract_mcp_config(config_a) r_b = extract_mcp_config(config_b) cmd_id_a = next(n["id"] for n in r_a["nodes"] if n["metadata"]["mcp_kind"] == "mcp_command") @@ -183,17 +187,25 @@ def test_same_command_collapses_to_one_node_across_configs(tmp_path): def test_same_env_var_collapses_to_one_node_across_configs(tmp_path): # Two configs both require OPENAI_API_KEY. The env_var node ID must be identical. - a = _write(tmp_path, ".mcp.json", { - "mcpServers": { - "x": {"command": "npx", "args": ["@scope/x"], "env": {"OPENAI_API_KEY": "v1"}}, + a = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": { + "x": {"command": "npx", "args": ["@scope/x"], "env": {"OPENAI_API_KEY": "v1"}}, + }, }, - }) + ) (tmp_path / "sub").mkdir() - b = _write(tmp_path / "sub", "claude_desktop_config.json", { - "mcpServers": { - "y": {"command": "uvx", "args": ["mcp-server-y"], "env": {"OPENAI_API_KEY": "v2"}}, + b = _write( + tmp_path / "sub", + "claude_desktop_config.json", + { + "mcpServers": { + "y": {"command": "uvx", "args": ["mcp-server-y"], "env": {"OPENAI_API_KEY": "v2"}}, + }, }, - }) + ) r_a = extract_mcp_config(a) r_b = extract_mcp_config(b) env_id_a = next(n["id"] for n in r_a["nodes"] if n["metadata"]["mcp_kind"] == "env_var") @@ -206,12 +218,20 @@ def test_same_server_name_in_different_dirs_does_not_collide(tmp_path): # The server nodes should NOT collide (stem-scoped via parent dir). (tmp_path / "proj_a").mkdir() (tmp_path / "proj_b").mkdir() - a = _write(tmp_path / "proj_a", ".mcp.json", { - "mcpServers": {"filesystem": {"command": "npx", "args": ["@scope/a"]}}, - }) - b = _write(tmp_path / "proj_b", ".mcp.json", { - "mcpServers": {"filesystem": {"command": "npx", "args": ["@scope/b"]}}, - }) + a = _write( + tmp_path / "proj_a", + ".mcp.json", + { + "mcpServers": {"filesystem": {"command": "npx", "args": ["@scope/a"]}}, + }, + ) + b = _write( + tmp_path / "proj_b", + ".mcp.json", + { + "mcpServers": {"filesystem": {"command": "npx", "args": ["@scope/b"]}}, + }, + ) r_a = extract_mcp_config(a) r_b = extract_mcp_config(b) srv_a = next(n["id"] for n in r_a["nodes"] if n["metadata"]["mcp_kind"] == "mcp_server") @@ -232,9 +252,13 @@ def test_missing_mcp_servers_key(tmp_path): def test_nested_mcp_servers_shape(tmp_path): # Some tools wrap the map: {"mcp": {"servers": {...}}} - p = _write(tmp_path, ".mcp.json", { - "mcp": {"servers": {"x": {"command": "node", "args": ["dist/index.js"]}}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcp": {"servers": {"x": {"command": "node", "args": ["dist/index.js"]}}}, + }, + ) r = extract_mcp_config(p) assert "error" not in r assert "x" in _label_by_kind(r, "mcp_server") @@ -266,12 +290,16 @@ def test_root_not_an_object(tmp_path): def test_non_dict_server_entry_skipped(tmp_path): - p = _write(tmp_path, ".mcp.json", { - "mcpServers": { - "valid": {"command": "npx", "args": ["@scope/pkg"]}, - "broken": ["this", "is", "not", "an", "object"], + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": { + "valid": {"command": "npx", "args": ["@scope/pkg"]}, + "broken": ["this", "is", "not", "an", "object"], + }, }, - }) + ) r = extract_mcp_config(p) server_labels = _label_by_kind(r, "mcp_server") assert "valid" in server_labels @@ -283,26 +311,38 @@ def test_non_dict_server_entry_skipped(tmp_path): def test_package_detection_skips_flags(tmp_path): # First arg is -y (flag); second is the package. Detection should skip the flag. - p = _write(tmp_path, ".mcp.json", { - "mcpServers": {"x": {"command": "npx", "args": ["-y", "@scope/server-x"]}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"x": {"command": "npx", "args": ["-y", "@scope/server-x"]}}, + }, + ) r = extract_mcp_config(p) assert "@scope/server-x" in _label_by_kind(r, "mcp_package") def test_no_package_detected_for_unknown_arg_shape(tmp_path): # Args don't look like any known package pattern => no package node. - p = _write(tmp_path, ".mcp.json", { - "mcpServers": {"x": {"command": "node", "args": ["./local-script.js", "--verbose"]}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"x": {"command": "node", "args": ["./local-script.js", "--verbose"]}}, + }, + ) r = extract_mcp_config(p) assert _label_by_kind(r, "mcp_package") == [] def test_server_without_command_still_emits_server_node(tmp_path): - p = _write(tmp_path, ".mcp.json", { - "mcpServers": {"x": {"args": ["@scope/server-x"]}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"x": {"args": ["@scope/server-x"]}}, + }, + ) r = extract_mcp_config(p) assert "x" in _label_by_kind(r, "mcp_server") assert _label_by_kind(r, "mcp_command") == [] @@ -316,9 +356,13 @@ def test_dispatch_routes_mcp_filename_to_mcp_extractor(tmp_path): # extract_mcp_config, NOT extract_json. from graphify.extract import _get_extractor - p = _write(tmp_path, ".mcp.json", { - "mcpServers": {"x": {"command": "npx", "args": ["@scope/server-x"]}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"x": {"command": "npx", "args": ["@scope/server-x"]}}, + }, + ) extractor = _get_extractor(p) assert extractor is extract_mcp_config diff --git a/tests/test_multigraph_diagnostics.py b/tests/test_multigraph_diagnostics.py index 8c39b8e23..6c49c58cf 100644 --- a/tests/test_multigraph_diagnostics.py +++ b/tests/test_multigraph_diagnostics.py @@ -147,7 +147,10 @@ def test_diagnose_extraction_handles_malformed_shapes_without_crashing() -> None assert summary["missing_endpoint_edges"] == 1 assert summary["dangling_endpoint_edges"] == 2 assert summary["valid_candidate_edges"] == 1 - assert summary["post_build_error"].startswith("TypeError:") + assert summary["post_build_graph_type"] == "DiGraph" + assert summary["post_build_node_count"] == 2 + assert summary["post_build_edge_count"] == 1 + assert summary["post_build_error"] == "" def test_diagnose_extraction_handles_non_list_nodes_and_edges() -> None: @@ -228,7 +231,8 @@ def test_format_diagnostic_report_includes_build_and_suppression_errors( report = format_diagnostic_report(summary) - assert "post_build_error: TypeError:" in report + assert "post_build_error:" not in report + assert "post_build_graph_type: DiGraph" in report assert "producer_suppression_error: file not found" in report diff --git a/tests/test_multilang.py b/tests/test_multilang.py index c30b9e10c..7bdeb8752 100644 --- a/tests/test_multilang.py +++ b/tests/test_multilang.py @@ -1,6 +1,6 @@ """Tests for multi-language AST extraction: JS/TS, Go, Rust, SQL.""" + from __future__ import annotations -import shutil from pathlib import Path import pytest from graphify.extract import extract_js, extract_go, extract_rust, extract, extract_sql @@ -10,16 +10,20 @@ # ── helpers ────────────────────────────────────────────────────────────────── + def _labels(result): return [n["label"] for n in result["nodes"]] + def _call_pairs(result): node_by_id = {n["id"]: n["label"] for n in result["nodes"]} return { (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in result["edges"] if e["relation"] == "calls" + for e in result["edges"] + if e["relation"] == "calls" } + def _confidences(result): return {e["confidence"] for e in result["edges"]} @@ -46,20 +50,24 @@ def _edge_labels(result, relation, context=None): # ── TypeScript ──────────────────────────────────────────────────────────────── + def test_ts_finds_class(): r = extract_js(FIXTURES / "sample.ts") assert "error" not in r assert "HttpClient" in _labels(r) + def test_ts_finds_methods(): r = extract_js(FIXTURES / "sample.ts") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) + def test_ts_finds_function(): r = extract_js(FIXTURES / "sample.ts") - assert any("buildHeaders" in l for l in _labels(r)) + assert any("buildHeaders" in label for label in _labels(r)) + def test_ts_emits_calls(): r = extract_js(FIXTURES / "sample.ts") @@ -67,6 +75,7 @@ def test_ts_emits_calls(): # .post() calls .get() assert any("post" in src and "get" in tgt for src, tgt in calls) + def test_ts_calls_are_extracted(): r = extract_js(FIXTURES / "sample.ts") for e in r["edges"]: @@ -87,6 +96,7 @@ def test_ts_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_ts_no_dangling_edges(): r = extract_js(FIXTURES / "sample.ts") node_ids = {n["id"] for n in r["nodes"]} @@ -97,26 +107,31 @@ def test_ts_no_dangling_edges(): # ── Go ──────────────────────────────────────────────────────────────────────── + def test_go_finds_struct(): r = extract_go(FIXTURES / "sample.go") assert "error" not in r assert "Server" in _labels(r) + def test_go_finds_methods(): r = extract_go(FIXTURES / "sample.go") labels = _labels(r) - assert any("Start" in l for l in labels) - assert any("Stop" in l for l in labels) + assert any("Start" in label for label in labels) + assert any("Stop" in label for label in labels) + def test_go_finds_constructor(): r = extract_go(FIXTURES / "sample.go") - assert any("NewServer" in l for l in _labels(r)) + assert any("NewServer" in label for label in _labels(r)) + def test_go_emits_calls(): r = extract_go(FIXTURES / "sample.go") # main() calls NewServer and Start assert len(_call_pairs(r)) > 0 + def test_go_has_extracted_calls(): r = extract_go(FIXTURES / "sample.go") assert "EXTRACTED" in _confidences(r) @@ -135,6 +150,7 @@ def test_go_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_go_no_dangling_edges(): r = extract_go(FIXTURES / "sample.go") node_ids = {n["id"] for n in r["nodes"]} @@ -181,13 +197,15 @@ def test_go_method_declaration_emits_refs_only_when_name_present(): def _find_branch(root: ast.AST, type_literal: str) -> ast.If | None: """Return the `if t == '':` branch inside the walk function.""" for child in ast.walk(root): - if (isinstance(child, ast.If) - and isinstance(child.test, ast.Compare) - and isinstance(child.test.left, ast.Name) - and child.test.left.id == "t" - and len(child.test.comparators) == 1 - and isinstance(child.test.comparators[0], ast.Constant) - and child.test.comparators[0].value == type_literal): + if ( + isinstance(child, ast.If) + and isinstance(child.test, ast.Compare) + and isinstance(child.test.left, ast.Name) + and child.test.left.id == "t" + and len(child.test.comparators) == 1 + and isinstance(child.test.comparators[0], ast.Constant) + and child.test.comparators[0].value == type_literal + ): return child return None @@ -242,10 +260,12 @@ def _is_guarded(use: ast.AST) -> bool: for stmt, siblings in _stmt_chain(use): parent = parents.get(id(stmt)) # Case 1: lexically nested under `if name_node:` body - if (isinstance(parent, ast.If) - and isinstance(parent.test, ast.Name) - and parent.test.id == "name_node" - and stmt in parent.body): + if ( + isinstance(parent, ast.If) + and isinstance(parent.test, ast.Name) + and parent.test.id == "name_node" + and stmt in parent.body + ): return True # Case 2: a preceding sibling is `if not name_node: return` idx = siblings.index(stmt) @@ -288,26 +308,31 @@ def _is_guarded(use: ast.AST) -> bool: # ── Rust ────────────────────────────────────────────────────────────────────── + def test_rust_finds_struct(): r = extract_rust(FIXTURES / "sample.rs") assert "error" not in r assert "Graph" in _labels(r) + def test_rust_finds_impl_methods(): r = extract_rust(FIXTURES / "sample.rs") labels = _labels(r) - assert any("add_node" in l for l in labels) - assert any("add_edge" in l for l in labels) + assert any("add_node" in label for label in labels) + assert any("add_edge" in label for label in labels) + def test_rust_finds_function(): r = extract_rust(FIXTURES / "sample.rs") - assert any("build_graph" in l for l in _labels(r)) + assert any("build_graph" in label for label in _labels(r)) + def test_rust_emits_calls(): r = extract_rust(FIXTURES / "sample.rs") calls = _call_pairs(r) assert any("build_graph" in src for src, _ in calls) + def test_rust_calls_are_extracted(): r = extract_rust(FIXTURES / "sample.rs") for e in r["edges"]: @@ -328,6 +353,7 @@ def test_rust_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_rust_no_dangling_edges(): r = extract_rust(FIXTURES / "sample.rs") node_ids = {n["id"] for n in r["nodes"]} @@ -363,6 +389,7 @@ def test_rust_no_cross_crate_spurious_edges(): """Scoped calls (Type::method) and blocklisted names must not produce INFERRED cross-crate calls edges (#908).""" from graphify.extract import extract + crate_a = FIXTURES / "crate_a" / "src" / "lib.rs" crate_b = FIXTURES / "crate_b" / "src" / "lib.rs" r = extract([crate_a, crate_b]) @@ -370,18 +397,16 @@ def test_rust_no_cross_crate_spurious_edges(): node_ids_b = {n["id"] for n in r["nodes"] if "crate_b" in (n.get("source_file") or "")} # No calls edge should cross from crate_b into crate_a cross_crate_calls = [ - e for e in r["edges"] - if e["relation"] == "calls" - and e["source"] in node_ids_b - and e["target"] in node_ids_a + e + for e in r["edges"] + if e["relation"] == "calls" and e["source"] in node_ids_b and e["target"] in node_ids_a ] - assert cross_crate_calls == [], ( - f"Spurious cross-crate edges: {cross_crate_calls}" - ) + assert cross_crate_calls == [], f"Spurious cross-crate edges: {cross_crate_calls}" # ── extract() dispatch ──────────────────────────────────────────────────────── + def test_extract_dispatches_all_languages(): files = [ FIXTURES / "sample.py", @@ -400,6 +425,7 @@ def test_extract_dispatches_all_languages(): # ── Cache ───────────────────────────────────────────────────────────────────── + def test_cache_hit_returns_same_result(tmp_path): src = FIXTURES / "sample.py" dst = tmp_path / "sample.py" @@ -410,20 +436,22 @@ def test_cache_hit_returns_same_result(tmp_path): assert len(r1["nodes"]) == len(r2["nodes"]) assert len(r1["edges"]) == len(r2["edges"]) + def test_cache_miss_after_file_change(tmp_path): dst = tmp_path / "a.py" dst.write_text("def foo(): pass\n") - r1 = extract([dst]) + extract([dst]) dst.write_text("def foo(): pass\ndef bar(): pass\n") r2 = extract([dst]) # bar() should appear in the second result labels2 = [n["label"] for n in r2["nodes"]] - assert any("bar" in l for l in labels2) + assert any("bar" in label for label in labels2) # ── SQL ─────────────────────────────────────────────────────────────────────── + def _extract_sql_or_skip(fixture: str = "sample.sql"): pytest.importorskip("tree_sitter_sql") return extract_sql(FIXTURES / fixture) @@ -432,35 +460,41 @@ def _extract_sql_or_skip(fixture: str = "sample.sql"): def test_sql_finds_tables(): r = _extract_sql_or_skip() labels = [n["label"] for n in r["nodes"]] - assert any("users" in l for l in labels) - assert any("organizations" in l for l in labels) + assert any("users" in label for label in labels) + assert any("organizations" in label for label in labels) + def test_sql_finds_view(): r = _extract_sql_or_skip() labels = [n["label"] for n in r["nodes"]] - assert any("active_users" in l for l in labels) + assert any("active_users" in label for label in labels) + def test_sql_finds_function(): r = _extract_sql_or_skip() labels = [n["label"] for n in r["nodes"]] - assert any("get_user" in l for l in labels) + assert any("get_user" in label for label in labels) + def test_sql_emits_foreign_key_edge(): r = _extract_sql_or_skip() relations = {e["relation"] for e in r["edges"]} assert "references" in relations + def test_sql_emits_reads_from_edge(): r = _extract_sql_or_skip() relations = {e["relation"] for e in r["edges"]} assert "reads_from" in relations + def test_sql_no_dangling_edges(): r = _extract_sql_or_skip() node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: assert e["source"] in node_ids, f"dangling source: {e['source']}" + def test_sql_alter_table_fk_edge(): """ALTER TABLE ... FOREIGN KEY ... REFERENCES produces a references edge.""" r = _extract_sql_or_skip("sample_alter_fk.sql") @@ -471,12 +505,14 @@ def test_sql_alter_table_fk_edge(): assert e["source"] in node_ids, f"dangling source: {e['source']}" assert e["target"] in node_ids, f"dangling target: {e['target']}" + def test_sql_schema_qualified_names(): """Schema-qualified table names (Schema.Table) are preserved.""" r = _extract_sql_or_skip("sample_schema_qualified.sql") labels = [n["label"] for n in r["nodes"]] - assert any("Sales.Customer" in l for l in labels) - assert any("Sales.SalesOrder" in l for l in labels) + assert any("Sales.Customer" in label for label in labels) + assert any("Sales.SalesOrder" in label for label in labels) + def test_sql_schema_qualified_alter_fk(): """ALTER TABLE with schema-qualified names produces correct edges.""" diff --git a/tests/test_ollama.py b/tests/test_ollama.py index a7af29a64..e6e412002 100644 --- a/tests/test_ollama.py +++ b/tests/test_ollama.py @@ -1,7 +1,27 @@ """Tests for the Ollama backend additions in graphify/llm.py.""" + from __future__ import annotations -from graphify.llm import detect_backend, BACKENDS +import pytest + +from graphify.llm import detect_backend, BACKENDS, _backend_env_keys + + +@pytest.fixture(autouse=True) +def _isolate_backend_env(monkeypatch): + """Strip every ambient backend API key so detect_backend() tests are hermetic.""" + for backend in BACKENDS: + for env_key in _backend_env_keys(backend): + monkeypatch.delenv(env_key, raising=False) + for extra in ( + "AWS_PROFILE", + "AWS_REGION", + "AWS_DEFAULT_REGION", + "OLLAMA_BASE_URL", + "OLLAMA_API_KEY", + ): + monkeypatch.delenv(extra, raising=False) + yield def test_ollama_in_backends(): @@ -60,6 +80,7 @@ def test_ollama_api_key_sentinel(monkeypatch): } with patch("graphify.llm._call_openai_compat", return_value=fake_result) as mock_call: from graphify.llm import extract_files_direct + with tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False) as f: f.write("x = 1\n") tmp = Path(f.name) @@ -68,7 +89,9 @@ def test_ollama_api_key_sentinel(monkeypatch): # Should have called _call_openai_compat with api_key="ollama" assert mock_call.called call_kwargs = mock_call.call_args - api_key_used = call_kwargs.args[1] if call_kwargs.args else call_kwargs.kwargs.get("api_key", "") + api_key_used = ( + call_kwargs.args[1] if call_kwargs.args else call_kwargs.kwargs.get("api_key", "") + ) assert api_key_used == "ollama" finally: tmp.unlink(missing_ok=True) diff --git a/tests/test_pascal.py b/tests/test_pascal.py index 36c1b8747..50b0c7412 100644 --- a/tests/test_pascal.py +++ b/tests/test_pascal.py @@ -1,4 +1,5 @@ """Tests for the Pascal/Delphi extractor.""" + from __future__ import annotations from pathlib import Path @@ -19,48 +20,55 @@ def _edges_with_relation(r, *relations): def test_pascal_no_error(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") assert "error" not in r def test_pascal_finds_unit(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") - assert any("SampleUnit" in l for l in _labels(r)) + assert any("SampleUnit" in label for label in _labels(r)) def test_pascal_finds_classes(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") labels = _labels(r) - assert any("TBaseProcessor" in l for l in labels) - assert any("TDataProcessor" in l for l in labels) + assert any("TBaseProcessor" in label for label in labels) + assert any("TDataProcessor" in label for label in labels) def test_pascal_finds_interface(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") - assert any("IProcessor" in l for l in _labels(r)) + assert any("IProcessor" in label for label in _labels(r)) def test_pascal_finds_methods(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") labels = _labels(r) - assert any("Process" in l for l in labels) - assert any("Initialize" in l for l in labels) - assert any("GetCount" in l for l in labels) - assert any("Reset" in l for l in labels) + assert any("Process" in label for label in labels) + assert any("Initialize" in label for label in labels) + assert any("GetCount" in label for label in labels) + assert any("Reset" in label for label in labels) def test_pascal_finds_imports(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") assert "imports" in _relations(r) def test_pascal_import_edges_have_import_context(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") import_edges = _edges_with_relation(r, "imports") assert import_edges @@ -69,30 +77,31 @@ def test_pascal_import_edges_have_import_context(): def test_pascal_finds_inherits(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") assert "inherits" in _relations(r) def test_pascal_inherits_from_base(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} inherits = [e for e in r["edges"] if e["relation"] == "inherits"] - found = any( - "TDataProcessor" in node_by_id.get(e["source"], "") - for e in inherits - ) + found = any("TDataProcessor" in node_by_id.get(e["source"], "") for e in inherits) assert found, "TDataProcessor should have at least one inherits edge" def test_pascal_finds_calls(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") assert "calls" in _relations(r) def test_pascal_call_edges_have_call_context(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") call_edges = _edges_with_relation(r, "calls") assert call_edges @@ -101,6 +110,7 @@ def test_pascal_call_edges_have_call_context(): def test_pascal_all_edges_extracted(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") structural = {"contains", "method", "inherits", "imports"} for e in r["edges"]: @@ -110,6 +120,7 @@ def test_pascal_all_edges_extracted(): def test_pascal_no_dangling_edges(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") node_ids = {n["id"] for n in r["nodes"]} # imports edges are cross-file by design; only check within-file edge targets @@ -122,6 +133,7 @@ def test_pascal_no_dangling_edges(): def test_pascal_dispatch_registered(): from graphify.extract import _DISPATCH + assert ".pas" in _DISPATCH assert ".pp" in _DISPATCH assert ".dpr" in _DISPATCH @@ -134,6 +146,7 @@ def test_pascal_dispatch_registered(): def test_pascal_detect_extensions_registered(): from graphify.detect import CODE_EXTENSIONS + assert ".pas" in CODE_EXTENSIONS assert ".pp" in CODE_EXTENSIONS assert ".dpr" in CODE_EXTENSIONS @@ -144,38 +157,44 @@ def test_pascal_detect_extensions_registered(): # ── Lazarus Form (.lfm) ─────────────────────────────────────────────────────── + def test_lfm_no_error(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") assert "error" not in r def test_lfm_finds_root_form_class(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") - assert any("TSampleForm" in l for l in _labels(r)) + assert any("TSampleForm" in label for label in _labels(r)) def test_lfm_finds_component_classes(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") labels = _labels(r) - assert any("TPanel" in l for l in labels) - assert any("TButton" in l for l in labels) - assert any("TLabel" in l for l in labels) - assert any("TTimer" in l for l in labels) + assert any("TPanel" in label for label in labels) + assert any("TButton" in label for label in labels) + assert any("TLabel" in label for label in labels) + assert any("TTimer" in label for label in labels) def test_lfm_finds_event_handlers(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") labels = _labels(r) - assert any("ButtonOKClick" in l for l in labels) - assert any("TimerRefreshTimer" in l for l in labels) + assert any("ButtonOKClick" in label for label in labels) + assert any("TimerRefreshTimer" in label for label in labels) def test_lfm_event_edges_have_event_context(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") ref_edges = [e for e in r["edges"] if e["relation"] == "references"] assert ref_edges @@ -184,12 +203,14 @@ def test_lfm_event_edges_have_event_context(): def test_lfm_contains_edges_form_hierarchy(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") assert "contains" in _relations(r) def test_lfm_no_dangling_edges(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: @@ -198,28 +219,33 @@ def test_lfm_no_dangling_edges(): # ── Lazarus Package (.lpk) ─────────────────────────────────────────────────── + def test_lpk_no_error(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") assert "error" not in r def test_lpk_finds_package_name(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") - assert any("SamplePackage" in l for l in _labels(r)) + assert any("SamplePackage" in label for label in _labels(r)) def test_lpk_finds_required_packages(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") labels = _labels(r) - assert any("FCL" in l for l in labels) - assert any("LCL" in l for l in labels) + assert any("FCL" in label for label in labels) + assert any("LCL" in label for label in labels) def test_lpk_imports_edges_have_import_context(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") import_edges = _edges_with_relation(r, "imports") assert import_edges @@ -228,14 +254,16 @@ def test_lpk_imports_edges_have_import_context(): def test_lpk_contains_listed_units(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") labels = _labels(r) - assert any("sample" in l.lower() for l in labels) - assert any("sampleutils" in l.lower() for l in labels) + assert any("sample" in label.lower() for label in labels) + assert any("sampleutils" in label.lower() for label in labels) def test_lpk_no_dangling_edges(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: @@ -244,38 +272,44 @@ def test_lpk_no_dangling_edges(): # ── Delphi Form (.dfm) ─────────────────────────────────────────────────────── + def test_dfm_no_error(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") assert "error" not in r def test_dfm_finds_root_form_class(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") - assert any("TMainForm" in l for l in _labels(r)) + assert any("TMainForm" in label for label in _labels(r)) def test_dfm_finds_component_classes(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") labels = _labels(r) - assert any("TPanel" in l for l in labels) - assert any("TButton" in l for l in labels) - assert any("TMemo" in l for l in labels) - assert any("TStatusBar" in l for l in labels) + assert any("TPanel" in label for label in labels) + assert any("TButton" in label for label in labels) + assert any("TMemo" in label for label in labels) + assert any("TStatusBar" in label for label in labels) def test_dfm_finds_event_handlers(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") labels = _labels(r) - assert any("FormCreate" in l for l in labels) - assert any("ButtonOKClick" in l for l in labels) + assert any("FormCreate" in label for label in labels) + assert any("ButtonOKClick" in label for label in labels) def test_dfm_event_edges_have_event_context(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") ref_edges = [e for e in r["edges"] if e["relation"] == "references"] assert ref_edges @@ -284,12 +318,14 @@ def test_dfm_event_edges_have_event_context(): def test_dfm_contains_edges_form_hierarchy(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") assert "contains" in _relations(r) def test_dfm_no_dangling_edges(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: @@ -298,7 +334,9 @@ def test_dfm_no_dangling_edges(): def test_dfm_binary_returns_empty_not_crash(): from graphify.extract import extract_delphi_form - import tempfile, pathlib + import tempfile + import pathlib + # Write a fake binary DFM (FF 0A magic header) with tempfile.NamedTemporaryFile(suffix=".dfm", delete=False) as f: f.write(b"\xff\x0a\x00\x00some binary data") @@ -314,9 +352,11 @@ def test_dfm_binary_returns_empty_not_crash(): def test_dfm_dispatch_registered(): from graphify.extract import _DISPATCH + assert ".dfm" in _DISPATCH def test_dfm_detect_extension_registered(): from graphify.detect import CODE_EXTENSIONS + assert ".dfm" in CODE_EXTENSIONS diff --git a/tests/test_path_cli.py b/tests/test_path_cli.py index de7e8837f..f8e7770e5 100644 --- a/tests/test_path_cli.py +++ b/tests/test_path_cli.py @@ -1,23 +1,36 @@ """Regression tests for `graphify path` arrow direction (#849).""" + from __future__ import annotations import json -import networkx as nx -from networkx.readwrite import json_graph import graphify.__main__ as mainmod def _write_graph(tmp_path): graph_data = { - "directed": False, "multigraph": False, "graph": {}, + "directed": False, + "multigraph": False, + "graph": {}, "nodes": [ - {"id": "create_patch", "label": "createPatchHandler()", - "source_file": "server/create-patch-handler.ts", "community": 0}, - {"id": "validate", "label": "validateSanitySession()", - "source_file": "server/sanity-validate-session.ts", "community": 0}, + { + "id": "create_patch", + "label": "createPatchHandler()", + "source_file": "server/create-patch-handler.ts", + "community": 0, + }, + { + "id": "validate", + "label": "validateSanitySession()", + "source_file": "server/sanity-validate-session.ts", + "community": 0, + }, ], "links": [ - {"source": "create_patch", "target": "validate", - "relation": "calls", "confidence": "EXTRACTED"}, + { + "source": "create_patch", + "target": "validate", + "relation": "calls", + "confidence": "EXTRACTED", + }, ], } p = tmp_path / "graph.json" @@ -27,8 +40,9 @@ def _write_graph(tmp_path): def _run(monkeypatch, graph_path, src, tgt, capsys): monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) - monkeypatch.setattr(mainmod.sys, "argv", - ["graphify", "path", src, tgt, "--graph", str(graph_path)]) + monkeypatch.setattr( + mainmod.sys, "argv", ["graphify", "path", src, tgt, "--graph", str(graph_path)] + ) mainmod.main() return capsys.readouterr().out @@ -46,3 +60,128 @@ def test_reverse_arrow(monkeypatch, tmp_path, capsys): assert "Shortest path (1 hops):" in out assert "validateSanitySession() <--calls [EXTRACTED]-- createPatchHandler()" in out assert "validateSanitySession() --calls [EXTRACTED]--> createPatchHandler()" not in out + + +def _write_multigraph(tmp_path): + """A->B with 3 parallel relations, B->C with a single relation.""" + graph_data = { + "directed": True, + "multigraph": True, + "graph": {}, + "nodes": [ + {"id": "a", "label": "alpha()", "source_file": "a.py", "community": 0}, + {"id": "b", "label": "beta()", "source_file": "b.py", "community": 0}, + {"id": "c", "label": "gamma()", "source_file": "c.py", "community": 0}, + ], + "links": [ + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "key": 0, + }, + { + "source": "a", + "target": "b", + "relation": "imports", + "confidence": "EXTRACTED", + "key": 1, + }, + { + "source": "a", + "target": "b", + "relation": "contains", + "confidence": "EXTRACTED", + "key": 2, + }, + { + "source": "b", + "target": "c", + "relation": "returns", + "confidence": "INFERRED", + "key": 0, + }, + ], + } + p = tmp_path / "graph.json" + p.write_text(json.dumps(graph_data)) + return p + + +def test_path_multigraph_hop_shows_all_relations(monkeypatch, tmp_path, capsys): + """PR5 gate: a MultiDiGraph hop bundles all parallel relations, never first-only.""" + p = _write_multigraph(tmp_path) + out = _run(monkeypatch, p, "alpha", "gamma", capsys) + assert "Shortest path (2 hops):" in out + # The A->B hop carries 3 parallel relations: all must appear (sorted, unique). + assert "--calls, contains, imports--> beta()" in out + # First-edge-only regression guard: the lone "calls" hop form must NOT appear. + assert "--calls [EXTRACTED]--> beta()" not in out + # The single-relation B->C hop stays byte-stable. + assert "--returns [INFERRED]--> gamma()" in out + + +def test_path_simple_graph_output_regression(monkeypatch, tmp_path, capsys): + """Simple DiGraph path output is unchanged: single relation per hop.""" + p = _write_graph(tmp_path) + out = _run(monkeypatch, p, "createPatchHandler", "validateSanitySession", capsys) + # Byte-stable single-relation form, matching test_forward_arrow exactly. + assert "createPatchHandler() --calls [EXTRACTED]--> validateSanitySession()" in out + + +def _write_bidirectional_multigraph(tmp_path): + """A->B 'calls', B->A 'imports' (opposite relations), plus B->C so the + shortest A->C path renders the A->B hop in its stored forward direction.""" + graph_data = { + "directed": True, + "multigraph": True, + "graph": {}, + "nodes": [ + {"id": "a", "label": "alpha()", "source_file": "a.py", "community": 0}, + {"id": "b", "label": "beta()", "source_file": "b.py", "community": 0}, + {"id": "c", "label": "gamma()", "source_file": "c.py", "community": 0}, + ], + "links": [ + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "key": 0, + }, + { + "source": "b", + "target": "a", + "relation": "imports", + "confidence": "EXTRACTED", + "key": 0, + }, + { + "source": "b", + "target": "c", + "relation": "returns", + "confidence": "INFERRED", + "key": 0, + }, + ], + } + p = tmp_path / "graph.json" + p.write_text(json.dumps(graph_data)) + return p + + +def test_path_directional_isolation(monkeypatch, tmp_path, capsys): + """A->B hop renders only the forward 'calls' relation, never the reverse 'imports'. + + Regression for the directed_only fix: relationship_envelope merges both + directions by default, which would wrongly bundle B->A 'imports' onto the + A-->B arrow. directed_only=True must isolate the stored hop direction. + """ + p = _write_bidirectional_multigraph(tmp_path) + out = _run(monkeypatch, p, "alpha", "gamma", capsys) + assert "Shortest path (2 hops):" in out + # Forward hop shows ONLY 'calls' (byte-stable single-relation form). + assert "alpha() --calls [EXTRACTED]--> beta()" in out + # The reverse-direction 'imports' must NOT bleed into the forward arrow. + assert "imports" not in out diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index ce6055d8b..de977ad91 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -3,14 +3,13 @@ Uses the existing test fixtures (code + markdown). No LLM calls - AST extraction only. Catches regressions in how modules connect, not just individual module behaviour. """ + import json -import tempfile from pathlib import Path -import pytest from graphify.detect import detect -from graphify.extract import collect_files, extract +from graphify.extract import extract from graphify.build import build_from_json from graphify.cluster import cluster, score_all from graphify.analyze import god_nodes, surprising_connections, suggest_questions @@ -62,7 +61,18 @@ def run_pipeline(tmp_path: Path) -> dict: # Step 6: report tokens = {"input": 0, "output": 0} - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, str(FIXTURES), suggested_questions=questions) + report = generate( + G, + communities, + cohesion, + labels, + gods, + surprises, + detection, + tokens, + str(FIXTURES), + suggested_questions=questions, + ) assert "God Nodes" in report assert "Communities" in report assert len(report) > 100 @@ -85,7 +95,9 @@ def run_pipeline(tmp_path: Path) -> dict: # Step 9: export - Obsidian vault vault_path = tmp_path / "obsidian" - n_notes = to_obsidian(G, communities, str(vault_path), community_labels=labels, cohesion=cohesion) + n_notes = to_obsidian( + G, communities, str(vault_path), community_labels=labels, cohesion=cohesion + ) assert n_notes > 0 assert (vault_path / ".obsidian" / "graph.json").exists() md_files = list(vault_path.glob("*.md")) diff --git a/tests/test_projections.py b/tests/test_projections.py new file mode 100644 index 000000000..305f5627b --- /dev/null +++ b/tests/test_projections.py @@ -0,0 +1,422 @@ +from __future__ import annotations + +import networkx as nx +import pytest +from typing import Any, cast + +from graphify.projections import ( + DEFAULT_RELATIONSHIP_CAP, + distinct_neighbor_degree, + edge_records_between, + edge_summary_between, + format_relationship_envelope, + normalize_to_multidigraph, + project_for_callflow, + project_for_community, + project_for_context, + project_for_path, + relationship_envelope, +) + + +def _parallel_graph() -> nx.MultiDiGraph: + graph = nx.MultiDiGraph() + graph.graph["graphify_profile"] = "test-profile" + graph.add_node("a", label="A") + graph.add_node("b", label="B") + graph.add_node("c", label="C") + graph.add_edge( + "a", + "b", + key="calls-low", + relation="calls", + confidence="INFERRED", + confidence_score=0.4, + source_file="src/a.py", + source_location="L10", + context="code", + ) + graph.add_edge( + "a", + "b", + key="imports-high", + relation="imports", + confidence="EXTRACTED", + confidence_score=0.9, + source_file="src/a.py", + source_location="L2", + context="code", + ) + graph.add_edge( + "b", + "a", + key="returns", + relation="returns", + confidence="AMBIGUOUS", + confidence_score=0.2, + source_file="src/b.py", + source_location="L5", + context="runtime", + ) + graph.add_edge( + "b", + "c", + key="calls-c", + relation="calls", + confidence="EXTRACTED", + confidence_score=1.0, + source_file="src/b.py", + source_location="L7", + context="code", + ) + graph.add_edge("c", "c", key="self", relation="calls", confidence="EXTRACTED") + return graph + + +def test_project_for_community_returns_simple_weighted_copy() -> None: + projected = project_for_community(_parallel_graph(), weight_mode="count") + + assert type(projected) is nx.Graph + assert projected.graph["graphify_profile"] == "test-profile" + assert set(projected.nodes) == {"a", "b", "c"} + assert not projected.has_edge("c", "c") + assert projected["a"]["b"]["weight"] == 3.0 + assert projected["a"]["b"]["parallel_edge_count"] == 3 + assert projected["b"]["c"]["weight"] == 1.0 + + +def test_project_for_community_supports_confidence_and_sum_weight_modes() -> None: + graph = _parallel_graph() + + by_confidence = project_for_community(graph, weight_mode="confidence") + by_sum = project_for_community(graph, weight_mode="sum") + + assert by_confidence["a"]["b"]["weight"] == 0.9 + assert by_confidence["a"]["b"]["relation"] == "imports" + assert by_sum["a"]["b"]["weight"] == pytest.approx(1.5) + with pytest.raises(ValueError, match="weight_mode"): + project_for_community(graph, weight_mode=cast(Any, "unknown")) + + +def test_project_for_path_uses_simple_graph_not_multigraph_view() -> None: + projected = project_for_path(_parallel_graph()) + + assert type(projected) is nx.Graph + assert not projected.is_multigraph() + assert projected.number_of_edges("a", "b") == 1 + assert nx.shortest_path(projected, "a", "c") == ["a", "b", "c"] + + +def test_project_for_callflow_preserves_direction_and_filters_relations() -> None: + projected = project_for_callflow(_parallel_graph(), relations=frozenset({"calls"})) + + assert type(projected) is nx.DiGraph + assert set(projected.edges()) == {("a", "b"), ("b", "c")} + assert projected["a"]["b"]["relation"] == "calls" + + +def test_project_for_callflow_recovers_src_tgt_from_undirected_edges() -> None: + graph = nx.Graph() + graph.add_node("display_a") + graph.add_node("display_b") + graph.add_edge("display_a", "display_b", _src="real_src", _tgt="real_tgt", relation="calls") + + projected = project_for_callflow(graph) + + assert set(projected.edges()) == {("real_src", "real_tgt")} + assert projected["real_src"]["real_tgt"]["relation"] == "calls" + + +def test_project_for_context_preserves_multigraph_type_keys_and_metadata() -> None: + projected = project_for_context(_parallel_graph(), contexts=["code"]) + + assert isinstance(projected, nx.MultiDiGraph) + assert projected.graph["graphify_profile"] == "test-profile" + assert set(projected["a"]["b"]) == {"calls-low", "imports-high"} + assert "returns" not in projected.get_edge_data("b", "a", default={}) + + +def test_project_for_context_none_returns_copy_not_original() -> None: + graph = _parallel_graph() + + projected = project_for_context(graph) + + assert projected is not graph + assert isinstance(projected, nx.MultiDiGraph) + assert projected.number_of_edges() == graph.number_of_edges() + + +def test_project_for_context_empty_filter_is_noop_copy() -> None: + graph = _parallel_graph() + + projected = project_for_context(graph, contexts=[]) + + assert projected is not graph + assert projected.number_of_edges() == graph.number_of_edges() + + +def test_edge_records_between_returns_copies_from_both_directions() -> None: + graph = _parallel_graph() + + records = edge_records_between(graph, "a", "b") + + assert [record["relation"] for record in records] == ["imports", "calls", "returns"] + records[0]["relation"] = "mutated" + assert graph["a"]["b"]["imports-high"]["relation"] == "imports" + + +def test_edge_summary_between_counts_and_picks_representative() -> None: + summary = edge_summary_between(_parallel_graph(), "a", "b") + + assert summary["count"] == 3 + assert summary["relations"] == ["calls", "imports", "returns"] + assert summary["confidences"] == ["AMBIGUOUS", "EXTRACTED", "INFERRED"] + assert summary["representative"]["relation"] == "imports" + + +def test_distinct_neighbor_degree_does_not_count_parallel_edges() -> None: + graph = _parallel_graph() + + assert graph.degree("a") == 3 + assert distinct_neighbor_degree(graph, "a") == 1 + assert distinct_neighbor_degree(graph, "missing") == 0 + + +def test_normalize_to_multidigraph_preserves_parallel_keys_and_simple_edges() -> None: + graph = nx.MultiGraph() + graph.graph["name"] = "mixed" + graph.add_node("a", label="A") + graph.add_node("b", label="B") + graph.add_edge("a", "b", key="one", relation="calls") + graph.add_edge("a", "b", key="two", relation="imports") + + normalized = normalize_to_multidigraph(graph) + + assert isinstance(normalized, nx.MultiDiGraph) + assert normalized.graph["name"] == "mixed" + assert set(normalized["a"]["b"]) == {"one", "two"} + + simple = nx.Graph() + simple.add_edge("x", "y", relation="uses") + simple_normalized = normalize_to_multidigraph(simple) + + assert isinstance(simple_normalized, nx.MultiDiGraph) + assert simple_normalized.number_of_edges("x", "y") == 1 + assert next(iter(simple_normalized["x"]["y"].values()))["relation"] == "uses" + + +# --------------------------------------------------------------------------- +# relationship_envelope / format_relationship_envelope +# --------------------------------------------------------------------------- + + +def _multidigraph_with_parallel_relations( + relations: list[str], *, confidence: str | None = None +) -> nx.MultiDiGraph: + """Build A->B with one parallel edge per supplied relation.""" + graph = nx.MultiDiGraph() + graph.add_node("a", label="A") + graph.add_node("b", label="B") + for index, relation in enumerate(relations): + attrs: dict[str, Any] = {"relation": relation} + if confidence is not None: + attrs["confidence"] = confidence + graph.add_edge("a", "b", key=f"{relation}-{index}", **attrs) + return graph + + +def test_relationship_envelope_single_edge() -> None: + graph = nx.DiGraph() + graph.add_edge("a", "b", relation="calls", confidence="EXTRACTED") + + envelope = relationship_envelope(graph, "a", "b") + + assert envelope["count"] == 1 + assert len(envelope["shown"]) == 1 + assert envelope["shown"][0]["relation"] == "calls" + assert envelope["truncated"] == 0 + assert envelope["relations"] == ["calls"] + assert envelope["confidences"] == ["EXTRACTED"] + + +def test_relationship_envelope_multidigraph_bundles_all() -> None: + graph = _multidigraph_with_parallel_relations(["calls", "imports", "contains"]) + + envelope = relationship_envelope(graph, "a", "b") + + assert envelope["count"] == 3 + assert envelope["relations"] == ["calls", "contains", "imports"] + assert len(envelope["shown"]) == 3 # default cap == 3 fits all + assert envelope["truncated"] == 0 + # shown records mirror edge_records_between ordering + assert envelope["shown"] == edge_records_between(graph, "a", "b") + + +def test_relationship_envelope_caps_shown() -> None: + graph = _multidigraph_with_parallel_relations(["r1", "r2", "r3", "r4", "r5"]) + + envelope = relationship_envelope(graph, "a", "b", cap=3) + + assert envelope["count"] == 5 + assert len(envelope["shown"]) == 3 + assert envelope["truncated"] == 2 + assert envelope["relations"] == ["r1", "r2", "r3", "r4", "r5"] + # shown is the leading slice of the full sorted record list + assert envelope["shown"] == edge_records_between(graph, "a", "b")[:3] + + +def test_relationship_envelope_cap_zero_or_negative() -> None: + graph = _multidigraph_with_parallel_relations( + ["calls", "imports", "contains"], confidence="EXTRACTED" + ) + + zero = relationship_envelope(graph, "a", "b", cap=0) + assert zero["shown"] == [] + assert zero["truncated"] == zero["count"] == 3 + assert zero["relations"] == ["calls", "contains", "imports"] + assert zero["confidences"] == ["EXTRACTED"] + + negative = relationship_envelope(graph, "a", "b", cap=-1) + assert negative["shown"] == [] + assert negative["truncated"] == negative["count"] == 3 + assert negative["relations"] == ["calls", "contains", "imports"] + + +def test_relationship_envelope_directed_both_directions() -> None: + graph = nx.DiGraph() + graph.add_edge("a", "b", relation="calls", confidence="EXTRACTED") + graph.add_edge("b", "a", relation="returns", confidence="INFERRED") + + envelope = relationship_envelope(graph, "a", "b") + + assert envelope["count"] == 2 + assert envelope["relations"] == ["calls", "returns"] + assert envelope["confidences"] == ["EXTRACTED", "INFERRED"] + assert envelope["shown"] == edge_records_between(graph, "a", "b") + + +def test_relationship_envelope_no_edge() -> None: + graph = nx.DiGraph() + graph.add_node("a") + graph.add_node("b") + + envelope = relationship_envelope(graph, "a", "b") + + assert envelope["count"] == 0 + assert envelope["shown"] == [] + assert envelope["truncated"] == 0 + assert envelope["relations"] == [] + assert envelope["confidences"] == [] + + +def test_format_relationship_envelope_single() -> None: + without_confidence = nx.DiGraph() + without_confidence.add_edge("a", "b", relation="calls") + assert format_relationship_envelope(without_confidence, "a", "b") == "calls" + + with_confidence = nx.DiGraph() + with_confidence.add_edge("a", "b", relation="calls", confidence="EXTRACTED") + assert format_relationship_envelope(with_confidence, "a", "b") == "calls (EXTRACTED)" + + +def test_format_relationship_envelope_multiple_within_cap() -> None: + graph = _multidigraph_with_parallel_relations( + ["imports", "calls", "contains"], confidence="EXTRACTED" + ) + + # 3 unique relations within the default cap; confidence omitted for multi-relation lines + assert format_relationship_envelope(graph, "a", "b") == "calls, contains, imports" + + +def test_format_relationship_envelope_capped() -> None: + graph = _multidigraph_with_parallel_relations(["gamma", "alpha", "epsilon", "beta", "delta"]) + + # sorted unique relations: alpha, beta, delta, epsilon, gamma -> first 3 shown + assert ( + format_relationship_envelope(graph, "a", "b", cap=3) + == "alpha, beta, delta (+2 more, 5 total)" + ) + + +def test_format_relationship_envelope_empty() -> None: + graph = nx.DiGraph() + graph.add_node("a") + graph.add_node("b") + + assert format_relationship_envelope(graph, "a", "b") == "" + + +def test_relationship_envelope_simple_graph_regression() -> None: + graph = nx.DiGraph() + graph.add_edge("a", "b", relation="calls") + graph.add_edge("a", "c", relation="imports") + + # Plain DiGraph: no parallel edges, so the envelope between a single pair + # reflects exactly the one edge and shown == all records (cap unreached). + assert DEFAULT_RELATIONSHIP_CAP == 3 + envelope = relationship_envelope(graph, "a", "b") + assert envelope["count"] == graph.number_of_edges("a", "b") == 1 + assert envelope["shown"] == edge_records_between(graph, "a", "b") + assert envelope["truncated"] == 0 + + +def _bidirectional_digraph() -> nx.DiGraph: + """Directed A->B (calls) plus the reverse B->A (imports).""" + graph = nx.DiGraph() + graph.add_edge("a", "b", relation="calls", confidence="EXTRACTED") + graph.add_edge("b", "a", relation="imports", confidence="INFERRED") + return graph + + +def test_edge_records_between_directed_only_excludes_reverse() -> None: + graph = _bidirectional_digraph() + + both = edge_records_between(graph, "a", "b") + assert len(both) == 2 + assert {record["relation"] for record in both} == {"calls", "imports"} + + forward = edge_records_between(graph, "a", "b", directed_only=True) + assert len(forward) == 1 + assert forward[0]["relation"] == "calls" + + +def test_relationship_envelope_directed_only() -> None: + graph = _bidirectional_digraph() + + envelope = relationship_envelope(graph, "a", "b", directed_only=True) + + assert envelope["count"] == 1 + assert envelope["relations"] == ["calls"] + assert "imports" not in envelope["relations"] + assert [record["relation"] for record in envelope["shown"]] == ["calls"] + + +def test_format_relationship_envelope_directed_only() -> None: + graph = _bidirectional_digraph() + + # Single forward relation with confidence present -> "calls (EXTRACTED)". + rendered = format_relationship_envelope(graph, "a", "b", directed_only=True) + assert rendered == "calls (EXTRACTED)" + assert "imports" not in rendered + + # Without confidence the single forward relation renders bare. + plain = nx.DiGraph() + plain.add_edge("a", "b", relation="calls") + plain.add_edge("b", "a", relation="imports") + assert format_relationship_envelope(plain, "a", "b", directed_only=True) == "calls" + + +def test_directed_only_noop_on_undirected() -> None: + graph = nx.Graph() + graph.add_edge("a", "b", relation="calls", confidence="EXTRACTED") + graph.add_edge("a", "b", relation="imports") # simple graph: overwrites attrs, single edge + + assert edge_records_between(graph, "a", "b", directed_only=True) == edge_records_between( + graph, "a", "b" + ) + assert relationship_envelope(graph, "a", "b", directed_only=True) == relationship_envelope( + graph, "a", "b" + ) + assert format_relationship_envelope( + graph, "a", "b", directed_only=True + ) == format_relationship_envelope(graph, "a", "b") diff --git a/tests/test_provider_registry.py b/tests/test_provider_registry.py index bbf082ca8..f2a8b2fe4 100644 --- a/tests/test_provider_registry.py +++ b/tests/test_provider_registry.py @@ -1,6 +1,4 @@ import json -import pytest -from pathlib import Path def test_custom_provider_add_list_show_remove(tmp_path, monkeypatch): @@ -9,18 +7,28 @@ def test_custom_provider_add_list_show_remove(tmp_path, monkeypatch): providers_file.write_text("{}", encoding="utf-8") from graphify import llm - monkeypatch.setattr(llm, "_custom_providers_path", lambda global_=True: providers_file if global_ else tmp_path / "local.json") + + monkeypatch.setattr( + llm, + "_custom_providers_path", + lambda global_=True: providers_file if global_ else tmp_path / "local.json", + ) monkeypatch.setattr(llm, "BACKENDS", {**llm.BACKENDS}) - providers_file.write_text(json.dumps({ - "nvidia": { - "base_url": "https://integrate.api.nvidia.com/v1", - "default_model": "minimaxai/minimax-m2.7", - "env_key": "NVIDIA_API_KEY", - "pricing": {"input": 0.0, "output": 0.0}, - "temperature": 0, - } - }), encoding="utf-8") + providers_file.write_text( + json.dumps( + { + "nvidia": { + "base_url": "https://integrate.api.nvidia.com/v1", + "default_model": "minimaxai/minimax-m2.7", + "env_key": "NVIDIA_API_KEY", + "pricing": {"input": 0.0, "output": 0.0}, + "temperature": 0, + } + } + ), + encoding="utf-8", + ) loaded = llm._load_custom_providers() assert "nvidia" in loaded @@ -30,19 +38,27 @@ def test_custom_provider_add_list_show_remove(tmp_path, monkeypatch): def test_custom_provider_pricing_defaults_to_zero(tmp_path): """Missing pricing field defaults to zero so estimate_cost doesn't blow up.""" providers_file = tmp_path / "providers.json" - providers_file.write_text(json.dumps({ - "mymodel": { - "base_url": "http://localhost:8080/v1", - "default_model": "llama3", - "env_key": "MY_API_KEY", - } - }), encoding="utf-8") + providers_file.write_text( + json.dumps( + { + "mymodel": { + "base_url": "http://localhost:8080/v1", + "default_model": "llama3", + "env_key": "MY_API_KEY", + } + } + ), + encoding="utf-8", + ) from graphify import llm - import importlib from unittest.mock import patch - with patch.object(llm, "_custom_providers_path", side_effect=lambda global_=True: providers_file if global_ else tmp_path / "local.json"): + with patch.object( + llm, + "_custom_providers_path", + side_effect=lambda global_=True: providers_file if global_ else tmp_path / "local.json", + ): loaded = llm._load_custom_providers() assert "mymodel" in loaded @@ -52,18 +68,27 @@ def test_custom_provider_pricing_defaults_to_zero(tmp_path): def test_custom_provider_cannot_shadow_builtin(tmp_path): """Built-in provider names are protected from being overridden.""" providers_file = tmp_path / "providers.json" - providers_file.write_text(json.dumps({ - "claude": { - "base_url": "http://evil.example.com/v1", - "default_model": "evil-model", - "env_key": "EVIL_KEY", - } - }), encoding="utf-8") + providers_file.write_text( + json.dumps( + { + "claude": { + "base_url": "http://evil.example.com/v1", + "default_model": "evil-model", + "env_key": "EVIL_KEY", + } + } + ), + encoding="utf-8", + ) from graphify import llm from unittest.mock import patch - with patch.object(llm, "_custom_providers_path", side_effect=lambda global_=True: providers_file if global_ else tmp_path / "local.json"): + with patch.object( + llm, + "_custom_providers_path", + side_effect=lambda global_=True: providers_file if global_ else tmp_path / "local.json", + ): loaded = llm._load_custom_providers() assert "claude" not in loaded @@ -73,19 +98,30 @@ def test_detect_backend_custom_provider_after_builtins(monkeypatch): """Custom providers appear after all built-ins in detect_backend() priority.""" from graphify import llm - monkeypatch.setattr(llm, "BACKENDS", { - **llm.BACKENDS, - "myprovider": { - "base_url": "http://example.com/v1", - "default_model": "mymodel", - "env_key": "MY_CUSTOM_KEY", - "pricing": {"input": 0.0, "output": 0.0}, - "temperature": 0, - } - }) + monkeypatch.setattr( + llm, + "BACKENDS", + { + **llm.BACKENDS, + "myprovider": { + "base_url": "http://example.com/v1", + "default_model": "mymodel", + "env_key": "MY_CUSTOM_KEY", + "pricing": {"input": 0.0, "output": 0.0}, + "temperature": 0, + }, + }, + ) monkeypatch.setenv("MY_CUSTOM_KEY", "test-key") - for key in ("GEMINI_API_KEY", "GOOGLE_API_KEY", "MOONSHOT_API_KEY", "ANTHROPIC_API_KEY", - "OPENAI_API_KEY", "DEEPSEEK_API_KEY", "OLLAMA_BASE_URL"): + for key in ( + "GEMINI_API_KEY", + "GOOGLE_API_KEY", + "MOONSHOT_API_KEY", + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "DEEPSEEK_API_KEY", + "OLLAMA_BASE_URL", + ): monkeypatch.delenv(key, raising=False) monkeypatch.delenv("AWS_PROFILE", raising=False) monkeypatch.delenv("AWS_REGION", raising=False) diff --git a/tests/test_prs.py b/tests/test_prs.py index 61ebc140e..1f7ef1fba 100644 --- a/tests/test_prs.py +++ b/tests/test_prs.py @@ -1,4 +1,5 @@ """Tests for graphify/prs.py.""" + from __future__ import annotations import subprocess @@ -6,7 +7,6 @@ from unittest.mock import patch, MagicMock import networkx as nx -import pytest from graphify.prs import ( PRInfo, @@ -23,6 +23,7 @@ # ── Helpers ─────────────────────────────────────────────────────────────────── + def make_pr( number: int = 1, title: str = "Test PR", @@ -54,6 +55,7 @@ def make_pr( # ── _classify ───────────────────────────────────────────────────────────────── + class TestClassify: def test_ready(self): pr = make_pr(ci_status="SUCCESS", review_decision="", is_draft=False) @@ -94,6 +96,7 @@ def test_wrong_base(self): # ── _parse_ci ───────────────────────────────────────────────────────────────── + class TestParseCi: def test_empty_rollup_returns_none(self): assert _parse_ci([]) == "NONE" @@ -128,6 +131,7 @@ def test_mixed_success_and_failure_is_failure(self): # ── _path_match ─────────────────────────────────────────────────────────────── + class TestPathMatch: def test_exact_match(self): assert _path_match("src/auth/api.py", "src/auth/api.py") is True @@ -150,6 +154,7 @@ def test_both_directions_work(self): # ── compute_pr_impact ───────────────────────────────────────────────────────── + class TestComputePrImpact: def _make_graph(self) -> nx.Graph: """3 nodes across 2 communities, 2 distinct source files.""" @@ -167,9 +172,7 @@ def test_matching_files_returns_correct_communities_and_count(self): def test_matching_both_files(self): G = self._make_graph() - comms, nodes = compute_pr_impact( - ["src/auth/api.py", "src/utils/helpers.py"], G - ) + comms, nodes = compute_pr_impact(["src/auth/api.py", "src/utils/helpers.py"], G) assert comms == [0, 1] assert nodes == 3 @@ -208,6 +211,7 @@ def test_no_double_counting_same_graph_file_matched_by_two_pr_files(self): # ── fetch_worktrees ─────────────────────────────────────────────────────────── + class TestFetchWorktrees: def test_normal_case_maps_branch_to_path(self): porcelain = ( @@ -279,6 +283,7 @@ def test_subprocess_failure_returns_empty_dict(self): # ── format_prs_text ─────────────────────────────────────────────────────────── + class TestFormatPrsText: def test_contains_pr_metadata_and_count_header(self): prs = [ @@ -330,6 +335,7 @@ def test_empty_pr_list(self): # ── _detect_default_branch ──────────────────────────────────────────────────── + class TestDetectDefaultBranch: def test_gh_returns_main(self): with patch( @@ -342,8 +348,9 @@ def test_falls_back_to_git_symbolic_ref(self): mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = "refs/remotes/origin/develop\n" - with patch("graphify.prs._gh", return_value=None), patch( - "graphify.prs.subprocess.run", return_value=mock_result + with ( + patch("graphify.prs._gh", return_value=None), + patch("graphify.prs.subprocess.run", return_value=mock_result), ): assert _detect_default_branch() == "develop" @@ -351,8 +358,9 @@ def test_both_fail_returns_main(self): mock_result = MagicMock() mock_result.returncode = 1 mock_result.stdout = "" - with patch("graphify.prs._gh", return_value=None), patch( - "graphify.prs.subprocess.run", return_value=mock_result + with ( + patch("graphify.prs._gh", return_value=None), + patch("graphify.prs.subprocess.run", return_value=mock_result), ): assert _detect_default_branch() == "main" @@ -361,27 +369,32 @@ def test_gh_returns_empty_dict_falls_back(self): mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = "refs/remotes/origin/trunk\n" - with patch("graphify.prs._gh", return_value={}), patch( - "graphify.prs.subprocess.run", return_value=mock_result + with ( + patch("graphify.prs._gh", return_value={}), + patch("graphify.prs.subprocess.run", return_value=mock_result), ): assert _detect_default_branch() == "trunk" def test_git_timeout_returns_main(self): - with patch("graphify.prs._gh", return_value=None), patch( - "graphify.prs.subprocess.run", - side_effect=subprocess.TimeoutExpired("git", 5), + with ( + patch("graphify.prs._gh", return_value=None), + patch( + "graphify.prs.subprocess.run", + side_effect=subprocess.TimeoutExpired("git", 5), + ), ): assert _detect_default_branch() == "main" # ── build_community_labels ───────────────────────────────────────────────────── + class TestBuildCommunityLabels: def test_basic_grouping(self): data = { "nodes": [ {"id": "a", "label": "Alpha", "community": 0}, - {"id": "b", "label": "Beta", "community": 0}, + {"id": "b", "label": "Beta", "community": 0}, {"id": "c", "label": "Gamma", "community": 1}, ] } diff --git a/tests/test_python_import_resolution.py b/tests/test_python_import_resolution.py index 2a517aaea..fb333eac4 100644 --- a/tests/test_python_import_resolution.py +++ b/tests/test_python_import_resolution.py @@ -23,9 +23,7 @@ def _node_id(result: dict, label: str, source_file: str) -> str: def _has_edge(result: dict, source: str, target: str, relation: str) -> bool: return any( - edge["source"] == source - and edge["target"] == target - and edge["relation"] == relation + edge["source"] == source and edge["target"] == target and edge["relation"] == relation for edge in result["edges"] ) @@ -35,9 +33,7 @@ def test_python_package_reexport_resolves_import_and_call_to_origin_symbol(tmp_p barrel = _write(tmp_path / "pkg/__init__.py", "from .foo import Foo as PublicFoo\n") consumer = _write( tmp_path / "app.py", - "from pkg import PublicFoo\n\n" - "def X():\n" - " return PublicFoo()\n", + "from pkg import PublicFoo\n\ndef X():\n return PublicFoo()\n", ) result = extract([origin, barrel, consumer], cache_root=tmp_path) @@ -57,10 +53,7 @@ def test_python_parameter_return_and_generic_contexts(tmp_path: Path): model = tmp_path / "pkg" / "model.py" model.parent.mkdir(parents=True) model.write_text( - "class Payload:\n" - " pass\n\n" - "class Result:\n" - " pass\n", + "class Payload:\n pass\n\nclass Result:\n pass\n", encoding="utf-8", ) service = tmp_path / "pkg" / "service.py" @@ -77,7 +70,11 @@ def test_python_parameter_return_and_generic_contexts(tmp_path: Path): labels = {node["id"]: node["label"] for node in result["nodes"]} edges = [edge for edge in result["edges"] if edge.get("relation") == "references"] pairs = { - (labels.get(e["source"], e["source"]), labels.get(e["target"], e["target"]), e.get("context")) + ( + labels.get(e["source"], e["source"]), + labels.get(e["target"], e["target"]), + e.get("context"), + ) for e in edges } diff --git a/tests/test_query_cli.py b/tests/test_query_cli.py index cf8eb6e56..ef3ebbe46 100644 --- a/tests/test_query_cli.py +++ b/tests/test_query_cli.py @@ -1,4 +1,5 @@ """Tests for graphify query CLI context filtering.""" + from __future__ import annotations import json diff --git a/tests/test_rationale.py b/tests/test_rationale.py index b52aa3909..8ca6e463d 100644 --- a/tests/test_rationale.py +++ b/tests/test_rationale.py @@ -1,7 +1,7 @@ """Tests for rationale/docstring extraction in extract.py.""" + import textwrap from pathlib import Path -import pytest from graphify.extract import extract_python from graphify.build import build_from_json @@ -13,10 +13,13 @@ def _write_py(tmp_path: Path, code: str) -> Path: def test_module_docstring_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """This module handles authentication because legacy sessions were insecure.""" def login(): pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert len(rationale) >= 1 @@ -24,45 +27,57 @@ def login(): pass def test_function_docstring_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' def process(): """We use chunked processing here because the full dataset exceeds RAM.""" pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert any("chunked" in n["label"] for n in rationale) def test_class_docstring_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' class Cache: """Chosen over Redis because we need zero external dependencies in the test env.""" pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert any("Redis" in n["label"] for n in rationale) def test_rationale_comment_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + """ def build(): # NOTE: must run before compile() or linker will fail pass - ''') + """, + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert any("NOTE" in n["label"] for n in rationale) def test_rationale_for_edges_present(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """Module docstring explaining the why.""" def foo(): """Function docstring with rationale.""" pass - ''') + ''', + ) result = extract_python(path) rationale_edges = [e for e in result["edges"] if e.get("relation") == "rationale_for"] assert len(rationale_edges) >= 1 @@ -70,28 +85,36 @@ def foo(): def test_short_docstring_ignored(tmp_path): """Trivial docstrings under 20 chars should not become rationale nodes.""" - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' def foo(): """Constructor.""" pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert len(rationale) == 0 def test_rationale_confidence_is_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """This module exists because we needed a standalone parser.""" def parse(): pass - ''') + ''', + ) result = extract_python(path) rationale_edges = [e for e in result["edges"] if e.get("relation") == "rationale_for"] assert all(e.get("confidence") == "EXTRACTED" for e in rationale_edges) def test_alembic_module_docstring_suppressed(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """initial schema Revision ID: 0001abcd @@ -107,7 +130,8 @@ def upgrade(): def downgrade(): pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert not any("Revision ID" in n["label"] for n in rationale) @@ -115,7 +139,9 @@ def downgrade(): def test_alembic_function_docstrings_still_extracted(tmp_path): """Function docstrings inside upgrade/downgrade should still be captured.""" - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """Revision ID: 0002 Revises: 0001""" revision = "0002" down_revision = "0001" @@ -126,7 +152,8 @@ def upgrade(): def downgrade(): pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] # module docstring suppressed @@ -137,39 +164,48 @@ def downgrade(): def test_non_migration_revision_var_not_suppressed(tmp_path): """A file with a `revision` variable but no Alembic markers keeps its docstring.""" - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """This module tracks document revisions because we need audit history.""" revision = 42 def get_revision(): pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert any("audit history" in n["label"] for n in rationale) def test_django_migration_module_docstring_suppressed(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """Add post_priority_config table.""" from django.db import migrations class Migration(migrations.Migration): dependencies = [("myapp", "0001_initial")] operations = [] - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert not any("post_priority" in n["label"] for n in rationale) def test_generated_file_module_docstring_suppressed(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """Generated by the protocol buffer compiler. DO NOT EDIT!""" from google.protobuf import descriptor as _descriptor class UserMessage: pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert not any("protocol buffer" in n["label"].lower() for n in rationale) @@ -182,7 +218,9 @@ def test_decorated_method_node_id_is_class_qualified(tmp_path): docstring's edge target. The mismatch caused ``build_from_json`` to drop the rationale_for edge as dangling, orphaning the docstring node. """ - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' class Bar: @property def baz(self) -> int: @@ -202,13 +240,13 @@ def factory(cls) -> "Bar": def normal(self) -> int: """A normal instance method documented for comparison.""" return 3 - ''') + ''', + ) result = extract_python(path) nodes_by_id = {n["id"]: n for n in result["nodes"]} # The plain method's id is the baseline: stem + class + name. - normal_ids = [nid for nid, n in nodes_by_id.items() - if n.get("label") == ".normal()"] + normal_ids = [nid for nid, n in nodes_by_id.items() if n.get("label") == ".normal()"] assert len(normal_ids) == 1, "expected exactly one ``.normal()`` method node" normal_id = normal_ids[0] assert normal_id.endswith("_bar_normal"), normal_id @@ -216,16 +254,16 @@ def normal(self) -> int: # Each decorated method must share the same class-qualified id shape so the # rationale_for edge target matches the method node id. for decorated_name in ("baz", "helper", "factory"): - matches = [nid for nid, n in nodes_by_id.items() - if n.get("label") == f".{decorated_name}()"] + matches = [ + nid for nid, n in nodes_by_id.items() if n.get("label") == f".{decorated_name}()" + ] assert len(matches) == 1, ( f"expected exactly one ``.{decorated_name}()`` method node, got {matches}" ) method_id = matches[0] assert method_id.endswith(f"_bar_{decorated_name}"), method_id # Unqualified id (the buggy form) must NOT also be present. - unqualified_buggy_id = method_id.replace(f"_bar_{decorated_name}", - f"_{decorated_name}") + unqualified_buggy_id = method_id.replace(f"_bar_{decorated_name}", f"_{decorated_name}") assert unqualified_buggy_id not in nodes_by_id, ( f"buggy unqualified id {unqualified_buggy_id} should not exist alongside " f"the class-qualified id" @@ -245,16 +283,11 @@ def normal(self) -> int: g = build_from_json(result) for decorated_name in ("baz", "helper", "factory", "normal"): method_id = next( - nid for nid, n in nodes_by_id.items() - if n.get("label") == f".{decorated_name}()" + nid for nid, n in nodes_by_id.items() if n.get("label") == f".{decorated_name}()" ) # Find rationale node attached to this method. - attached_rationale = [ - e["source"] for e in rationale_edges if e["target"] == method_id - ] - assert attached_rationale, ( - f"no rationale_for edge found for ``.{decorated_name}()`` method" - ) + attached_rationale = [e["source"] for e in rationale_edges if e["target"] == method_id] + assert attached_rationale, f"no rationale_for edge found for ``.{decorated_name}()`` method" for r_id in attached_rationale: assert r_id in g.nodes, f"rationale node {r_id} missing from graph" assert g.degree(r_id) > 0, ( diff --git a/tests/test_report.py b/tests/test_report.py index a5b3916a1..20265259c 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -1,17 +1,18 @@ import json from pathlib import Path +import networkx as nx from graphify.build import build_from_json -from graphify.cluster import cluster, score_all from graphify.analyze import god_nodes, surprising_connections from graphify.report import generate FIXTURES = Path(__file__).parent / "fixtures" + def make_inputs(): extraction = json.loads((FIXTURES / "extraction.json").read_text()) G = build_from_json(extraction) - communities = cluster(G) - cohesion = score_all(G, communities) + communities = {0: list(G.nodes())} + cohesion = {0: 0.5} labels = {cid: f"Community {cid}" for cid in communities} gods = god_nodes(G) surprises = surprising_connections(G) @@ -19,45 +20,228 @@ def make_inputs(): tokens = {"input": extraction["input_tokens"], "output": extraction["output_tokens"]} return G, communities, cohesion, labels, gods, surprises, detection, tokens + def test_report_contains_header(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "# Graph Report" in report + def test_report_contains_corpus_check(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## Corpus Check" in report + def test_report_contains_god_nodes(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## God Nodes" in report + def test_report_contains_surprising_connections(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## Surprising Connections" in report + def test_report_contains_communities(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## Communities" in report + def test_report_contains_ambiguous_section(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## Ambiguous Edges" in report + def test_report_shows_token_cost(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "Token cost" in report assert "1,200" in report + def test_report_shows_raw_cohesion_scores(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project", min_community_size=1) + report = generate( + G, + communities, + cohesion, + labels, + gods, + surprises, + detection, + tokens, + "./project", + min_community_size=1, + ) assert "Cohesion:" in report assert "✓" not in report assert "⚠" not in report + + +# --- Helpers for edge-count tests --- + + +def _minimal_report(G): + """Generate a report from a graph with minimal scaffolding.""" + communities = {0: list(G.nodes())} + cohesion = {0: 0.5} + labels = {0: "Test Community"} + god_list = [{"id": n, "label": n, "degree": G.degree(n)} for n in list(G.nodes())[:3]] + surprise_list = [] + detection = {"total_files": 1, "total_words": 100, "needs_graph": True, "warning": None} + tokens = {"input": 100, "output": 50} + return generate( + G, + communities, + cohesion, + labels, + god_list, + surprise_list, + detection, + tokens, + "./test", + min_community_size=1, + ) + + +# --- PR 4B: Edge count reporting tests --- + + +def test_report_multigraph_edge_count_distinguishes_pairs(): + """MultiDiGraph with parallel edges: report must show both total and unique pair count.""" + G = nx.MultiDiGraph() + G.add_nodes_from(["A", "B", "C", "D"], label="x", type="entity") + # 3 unique pairs, 8 total edges + G.add_edge("A", "B", relation="calls", confidence="EXTRACTED") + G.add_edge("A", "B", relation="imports", confidence="EXTRACTED") + G.add_edge("A", "B", relation="uses", confidence="EXTRACTED") + G.add_edge("B", "C", relation="calls", confidence="EXTRACTED") + G.add_edge("B", "C", relation="imports", confidence="EXTRACTED") + G.add_edge("B", "C", relation="uses", confidence="EXTRACTED") + G.add_edge("C", "D", relation="calls", confidence="EXTRACTED") + G.add_edge("C", "D", relation="imports", confidence="EXTRACTED") + assert G.number_of_edges() == 8 + report = _minimal_report(G) + assert "8 edges (3 unique pairs)" in report + + +def test_report_simple_graph_edge_count_unchanged(): + """Simple DiGraph: report must show just 'X edges' without unique-pairs qualifier.""" + G = nx.DiGraph() + G.add_nodes_from(["A", "B", "C"], label="x", type="entity") + G.add_edge("A", "B", relation="calls", confidence="EXTRACTED") + G.add_edge("B", "C", relation="calls", confidence="EXTRACTED") + G.add_edge("A", "C", relation="calls", confidence="EXTRACTED") + report = _minimal_report(G) + assert "3 edges" in report + assert "unique pairs" not in report + + +def test_report_multigraph_no_parallel_just_shows_total(): + """MultiDiGraph with no actual parallel edges: show just 'X edges', no redundant qualifier.""" + G = nx.MultiDiGraph() + G.add_nodes_from(["A", "B", "C"], label="x", type="entity") + G.add_edge("A", "B", relation="calls", confidence="EXTRACTED") + G.add_edge("B", "C", relation="calls", confidence="EXTRACTED") + G.add_edge("A", "C", relation="calls", confidence="EXTRACTED") + assert G.number_of_edges() == 3 + report = _minimal_report(G) + assert "3 edges" in report + assert "unique pairs" not in report + + +def test_report_god_node_degree_not_inflated(): + """God-node degree should reflect unique neighbors, not parallel edge count. + + analyze.god_nodes() already uses distinct_neighbor_degree(), so the degree + value in the report should equal the neighbor count, not the edge count. + """ + G = nx.MultiDiGraph() + # Nodes need source_file with an extension to avoid being filtered as concept nodes + attrs = {"label": "hub", "type": "entity", "source_file": "test.py"} + G.add_node("hub", **attrs) + for name in ["A", "B", "C"]: + G.add_node(name, label=name, type="entity", source_file="test.py") + # hub -> A: 4 parallel edges, hub -> B: 3, hub -> C: 3 = 10 total, 3 unique neighbors + for _ in range(4): + G.add_edge("hub", "A", relation="calls", confidence="EXTRACTED") + for _ in range(3): + G.add_edge("hub", "B", relation="calls", confidence="EXTRACTED") + for _ in range(3): + G.add_edge("hub", "C", relation="calls", confidence="EXTRACTED") + assert G.number_of_edges() == 10 + gods = god_nodes(G) + hub_entry = next(g for g in gods if g["label"] == "hub") + assert hub_entry["degree"] == 3, f"Expected 3 unique neighbors, got {hub_entry['degree']}" + + +# --- PR 6: parallel-edge preservation in per-edge report surfaces --- + + +def test_report_preserves_parallel_inferred_edges(): + """MultiDiGraph with parallel edges between one pair: every per-edge report + surface must preserve ALL parallel edges, not collapse to one (PR 6). + + report.py iterates ``G.edges(data=True)`` for confidence stats, the INFERRED + count/avg, and the ambiguous-edges list — on a MultiDiGraph that yields every + parallel edge. This confirms the no-collapse contract: + - INFERRED count reflects all 3 parallel inferred edges (not 1) + - INFERRED avg confidence averages all 3 distinct scores (0.5/0.7/0.9 -> 0.7) + - the ambiguous-edges section emits one line per parallel ambiguous edge + """ + G = nx.MultiDiGraph() + G.add_node("A", label="alpha", type="entity", source_file="a.py") + G.add_node("B", label="beta", type="entity", source_file="b.py") + # 3 parallel INFERRED edges between the SAME pair, distinct scores. + G.add_edge("A", "B", relation="calls", confidence="INFERRED", confidence_score=0.5) + G.add_edge("A", "B", relation="imports", confidence="INFERRED", confidence_score=0.7) + G.add_edge("A", "B", relation="contains", confidence="INFERRED", confidence_score=0.9) + # 2 parallel AMBIGUOUS edges between the SAME pair, distinct relations. + G.add_edge("A", "B", relation="maybe_uses", confidence="AMBIGUOUS") + G.add_edge("A", "B", relation="maybe_refs", confidence="AMBIGUOUS") + assert G.number_of_edges() == 5 + report = _minimal_report(G) + # All 3 parallel inferred edges are counted (collapse would show "1 edges"). + assert "INFERRED: 3 edges" in report + # Average over all 3 parallel scores, proving each edge contributed. + assert "avg confidence: 0.7" in report + # The ambiguous-edges list preserves one line per parallel ambiguous edge. + assert "relation: maybe_uses" in report + assert "relation: maybe_refs" in report + + +def test_report_multigraph_edge_count_unchanged_semantics(): + """PR4B total-vs-unique-pairs edge-count line still renders correctly on a + multigraph with parallel edges (regression for PR 6).""" + G = nx.MultiDiGraph() + G.add_node("A", label="alpha", type="entity", source_file="a.py") + G.add_node("B", label="beta", type="entity", source_file="b.py") + G.add_node("C", label="gamma", type="entity", source_file="c.py") + # 2 unique pairs, 5 total edges (3 parallel A->B, 2 parallel B->C). + G.add_edge("A", "B", relation="calls", confidence="EXTRACTED") + G.add_edge("A", "B", relation="imports", confidence="EXTRACTED") + G.add_edge("A", "B", relation="contains", confidence="EXTRACTED") + G.add_edge("B", "C", relation="calls", confidence="EXTRACTED") + G.add_edge("B", "C", relation="imports", confidence="EXTRACTED") + assert G.number_of_edges() == 5 + report = _minimal_report(G) + assert "5 edges (2 unique pairs)" in report diff --git a/tests/test_security.py b/tests/test_security.py index c547ab842..32e160865 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,7 +1,7 @@ """Tests for graphify/security.py - URL validation, safe fetch, path guards, label sanitisation.""" + from __future__ import annotations -import json import urllib.error from pathlib import Path from typing import Any @@ -17,9 +17,7 @@ safe_fetch_text, validate_graph_path, validate_url, - _MAX_FETCH_BYTES, _MAX_GRAPH_FILE_BYTES, - _MAX_TEXT_BYTES, _METADATA_MAX_LIST_ITEMS, _METADATA_MAX_VALUE_LEN, _sanitize_metadata_string, @@ -31,24 +29,30 @@ # validate_url # --------------------------------------------------------------------------- + def test_validate_url_accepts_http(): assert validate_url("http://example.com/page") == "http://example.com/page" + def test_validate_url_accepts_https(): assert validate_url("https://arxiv.org/abs/1706.03762") == "https://arxiv.org/abs/1706.03762" + def test_validate_url_rejects_file(): with pytest.raises(ValueError, match="file"): validate_url("file:///etc/passwd") + def test_validate_url_rejects_ftp(): with pytest.raises(ValueError, match="ftp"): validate_url("ftp://files.example.com/data.zip") + def test_validate_url_rejects_data(): with pytest.raises(ValueError, match="data"): validate_url("data:text/html,") + def test_validate_url_rejects_empty_scheme(): with pytest.raises(ValueError): validate_url("//no-scheme.example.com") @@ -58,13 +62,14 @@ def test_validate_url_rejects_empty_scheme(): # safe_fetch - scheme and redirect guards (mocked network) # --------------------------------------------------------------------------- + def _make_mock_response(content: bytes, status: int = 200): mock = MagicMock() mock.__enter__ = lambda s: s mock.__exit__ = MagicMock(return_value=False) mock.status = status mock.code = status - chunks = [content[i:i+65536] for i in range(0, len(content), 65536)] + [b""] + chunks = [content[i : i + 65536] for i in range(0, len(content), 65536)] + [b""] mock.read.side_effect = chunks return mock @@ -73,10 +78,12 @@ def test_safe_fetch_rejects_file_url(): with pytest.raises(ValueError, match="file"): safe_fetch("file:///etc/passwd") + def test_safe_fetch_rejects_ftp_url(): with pytest.raises(ValueError, match="ftp"): safe_fetch("ftp://example.com/file.zip") + def test_safe_fetch_returns_bytes(tmp_path): mock_resp = _make_mock_response(b"hello world") with patch("graphify.security._build_opener") as mock_opener_fn: @@ -86,6 +93,7 @@ def test_safe_fetch_returns_bytes(tmp_path): result = safe_fetch("https://example.com/") assert result == b"hello world" + def test_safe_fetch_raises_on_non_2xx(): mock_resp = _make_mock_response(b"Not Found", status=404) with patch("graphify.security._build_opener") as mock_opener_fn: @@ -95,6 +103,7 @@ def test_safe_fetch_raises_on_non_2xx(): with pytest.raises(urllib.error.HTTPError): safe_fetch("https://example.com/missing") + def test_safe_fetch_raises_on_size_exceeded(): # Build a response larger than max_bytes big_chunk = b"x" * 65_537 @@ -118,6 +127,7 @@ def test_safe_fetch_raises_on_size_exceeded(): # safe_fetch_text # --------------------------------------------------------------------------- + def test_safe_fetch_text_decodes_utf8(): content = "héllo wörld".encode("utf-8") mock_resp = _make_mock_response(content) @@ -128,6 +138,7 @@ def test_safe_fetch_text_decodes_utf8(): result = safe_fetch_text("https://example.com/") assert result == "héllo wörld" + def test_safe_fetch_text_replaces_bad_bytes(): bad = b"hello \xff world" mock_resp = _make_mock_response(bad) @@ -145,6 +156,7 @@ def test_safe_fetch_text_replaces_bad_bytes(): # validate_graph_path # --------------------------------------------------------------------------- + def test_validate_graph_path_allows_inside_base(tmp_path): base = tmp_path / "graphify-out" base.mkdir() @@ -153,6 +165,7 @@ def test_validate_graph_path_allows_inside_base(tmp_path): result = validate_graph_path(str(graph), base=base) assert result == graph.resolve() + def test_validate_graph_path_blocks_traversal(tmp_path): base = tmp_path / "graphify-out" base.mkdir() @@ -160,11 +173,13 @@ def test_validate_graph_path_blocks_traversal(tmp_path): with pytest.raises(ValueError, match="escapes"): validate_graph_path(str(evil), base=base) + def test_validate_graph_path_requires_base_exists(tmp_path): base = tmp_path / "graphify-out" # not created with pytest.raises(ValueError, match="does not exist"): validate_graph_path(str(base / "graph.json"), base=base) + def test_validate_graph_path_raises_if_file_missing(tmp_path): base = tmp_path / "graphify-out" base.mkdir() @@ -176,22 +191,26 @@ def test_validate_graph_path_raises_if_file_missing(tmp_path): # sanitize_label # --------------------------------------------------------------------------- + def test_sanitize_label_passthrough_html_chars(): # sanitize_label does NOT HTML-escape — callers that inject into HTML must # wrap with html.escape() themselves (e.g. the title in to_html()) assert sanitize_label("