From 13bfa12e50121170289f99bd934188d89e4ba674 Mon Sep 17 00:00:00 2001 From: hypnwtykvmpr Date: Fri, 22 May 2026 13:21:09 -0500 Subject: [PATCH 01/21] feat(multigraph): projections + schema-aware loader + stable edge identity + internal keyed build path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Internal, opt-in MultiDiGraph foundation. No default user-visible behavior changes. Part 1: Projection API (graphify/projections.py) Explicit projection helpers so consumers can declare what graph semantics they want and what they intentionally lose: - project_for_community (configurable weight mode: confidence or count) - project_for_path (simple undirected topology for shortest-path) - project_for_callflow (directed projection with optional relation filter) - project_for_context (filter by context value) - edge_records_between (raw edge record iteration) - edge_summary_between (bundle/summary formatting) - distinct_neighbor_degree (for god-node and hub thresholds) - normalize_to_multidigraph (simple-to-multidigraph lift) Tests in tests/test_projections.py cover MultiDiGraph fixtures plus property-style invariants (bundle counts equal total edge records; weighted projection weight equals multiplicity). Part 2: Schema-aware loader + stable edge identity Centralizes graph loading and reserves the edge "key" field as schema, not as an ordinary attribute. - graphify/edge_identity.py: SCHEMA_KEY_FIELD constant, make_stable_key(relation, source_file, source_location), strip_schema_key(attrs) - graphify/graph_loader.py: load_graph and load_graph_file handling legacy simple JSON with "links", legacy simple JSON with "edges", valid multigraph node-link JSON (multigraph: true) with keyed parallel edges, malformed multigraph JSON with missing/non-string keys (deterministic repair via full-attribute payload hash, never silent downgrade), conflicting schema markers. - Profile metadata stored in G.graph["graphify_profile"] (preserved by node-link serialization). - Multigraph loads gated behind require_multigraph_capabilities() from PR #956. Tests in tests/test_graph_loader.py and tests/test_edge_identity.py cover all seven loader scenarios plus the schema-key reservation contract. Part 3: Internal keyed MultiDiGraph build path Opt-in nx.MultiDiGraph build support in build_from_json and build. - multigraph: bool = False parameter - Stable edge keys generated after dedup/remap and source normalization. - Serialized edge attrs cannot pass duplicate key= kwargs into G.add_edge. - Exact duplicates collapse only with diagnostics; non-exact key collisions fire deterministic bounded repair (full-payload hash, not identity-field-only). - Node-link JSON written with explicit edges="links" compatibility. - Default simple-graph output is unchanged. Adversarial-input resilience (verified against malformed extraction inputs): - Hashable non-string node IDs and edge endpoints are preserved. - Unhashable node IDs and endpoints do not crash validation or build. - Non-dict node entries and nodes missing "id" are skipped safely after validation warns. - Non-dict edge entries are skipped safely after validation warns. - Explicit empty-string schema keys are preserved. - Collision-repair keys are deterministic and do not overwrite explicit keys. - Exact duplicate detection remains O(n) within a (source, target, key) group. Out of scope: - No public --multigraph CLI flag (planned for a later slice; only programmatic activation here). - No watch/cache/global-graph/MCP/export surface changes. - No producer widening. - No dedup/remap MultiDiGraph contract changes (separate concern, separate review). Test coverage: pytest tests/test_projections.py tests/test_graph_loader.py tests/test_edge_identity.py \ tests/test_multigraph_diagnostics.py tests/test_build.py tests/test_validate.py → 130 passed. This is a collapse of an earlier 12-commit stack on wave3-pr3-internal-build into a single commit so that every commit in origin history passes Copilot review individually. The pre-collapse stack is preserved as the tag archive/2026-05-22-wave3-pr3-internal-build. --- graphify/build.py | 428 ++++++++++++-- graphify/edge_identity.py | 58 ++ graphify/graph_loader.py | 301 ++++++++++ graphify/projections.py | 214 +++++++ graphify/symbol_resolution.py | 4 +- graphify/validate.py | 44 +- tests/test_build.py | 838 ++++++++++++++++++++++++++- tests/test_edge_identity.py | 85 +++ tests/test_graph_loader.py | 557 ++++++++++++++++++ tests/test_multigraph_diagnostics.py | 8 +- tests/test_projections.py | 202 +++++++ tests/test_validate.py | 95 ++- 12 files changed, 2727 insertions(+), 107 deletions(-) create mode 100644 graphify/edge_identity.py create mode 100644 graphify/graph_loader.py create mode 100644 graphify/projections.py create mode 100644 tests/test_edge_identity.py create mode 100644 tests/test_graph_loader.py create mode 100644 tests/test_projections.py diff --git a/graphify/build.py b/graphify/build.py index 07fbb0340..386560005 100644 --- a/graphify/build.py +++ b/graphify/build.py @@ -22,18 +22,32 @@ # from __future__ import annotations import json +import hashlib import os import re import sys import unicodedata +from collections.abc import Hashable from pathlib import Path import networkx as nx -from .validate import validate_extraction +from .edge_identity import make_stable_key, strip_schema_key +from .validate import is_hashable, validate_extraction # Synonym mapper for known invalid file_type values that LLM subagents commonly # emit. Keeps semantic intent close (markdown→document, tool→code) and falls # back to "concept" for any other invalid value (see #840). +_LANG_FAMILY: dict[str, str] = { + ".py": "py", ".pyi": "py", + ".js": "js", ".mjs": "js", ".cjs": "js", ".jsx": "js", + ".ts": "js", ".tsx": "js", + ".go": "go", ".rs": "rs", + ".java": "jvm", ".kt": "jvm", ".scala": "jvm", ".groovy": "jvm", + ".c": "c", ".h": "c", ".cc": "cpp", ".cpp": "cpp", ".hpp": "cpp", + ".rb": "rb", ".php": "php", ".cs": "cs", ".swift": "swift", ".lua": "lua", +} + + _FILE_TYPE_SYNONYMS = { "markdown": "document", "text": "document", @@ -83,7 +97,56 @@ def _norm_source_file(p: str | None, root: str | None = None) -> str | None: return p -def edge_data(G: nx.Graph, u: str, v: str) -> dict: +def _stable_identity_component(value: object) -> str | None: + """Normalize malformed edge identity values before stable-key hashing.""" + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, os.PathLike): + # os.fspath can return bytes for bytes-flavored PathLike; coerce to str + # so downstream json.dumps / hashing always sees text. + fs_value = os.fspath(value) + return fs_value.decode("utf-8", errors="replace") if isinstance(fs_value, bytes) else fs_value + if isinstance(value, (set, frozenset)): + return json.dumps(sorted(str(item) for item in value), ensure_ascii=False) + try: + return json.dumps(value, sort_keys=True, ensure_ascii=False, default=str) + except (TypeError, ValueError): + return str(value) + + +def _make_collision_key(base_key: str, attrs: dict, *, salt: int = 0) -> str: + payload = { + "base_key": base_key, + "attrs": attrs, + } + if salt: + payload["salt"] = salt + repair_payload = json.dumps(payload, sort_keys=True, ensure_ascii=False, default=str) + repair_digest = hashlib.sha256(repair_payload.encode()).hexdigest() + return f"{base_key}:alt:{repair_digest}" + + +def _list_field(data: dict, key: str) -> list: + """Return ``data[key]`` if it is a list; otherwise warn to stderr and return ``[]``. + + Extraction dicts come from LLM subagents and can contain malformed shapes; + matching the rest of build_from_json's skip+warn policy keeps a single bad + field from crashing the whole build. + """ + value = data.get(key, []) + if isinstance(value, list): + return value + print( + f"[graphify] WARNING: extraction field '{key}' must be a list, " + f"got {type(value).__name__}; treating as empty.", + file=sys.stderr, + ) + return [] + + +def edge_data(G: nx.Graph, u: Hashable, v: Hashable) -> dict: """Return one edge attribute dict for (u, v), tolerating MultiGraph. For MultiGraph/MultiDiGraph there can be multiple parallel edges; @@ -96,7 +159,7 @@ def edge_data(G: nx.Graph, u: str, v: str) -> dict: return raw -def edge_datas(G: nx.Graph, u: str, v: str) -> list[dict]: +def edge_datas(G: nx.Graph, u: Hashable, v: Hashable) -> list[dict]: """Return every edge attribute dict for (u, v); always a list.""" raw = G[u][v] if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)): @@ -104,29 +167,47 @@ def edge_datas(G: nx.Graph, u: str, v: str) -> list[dict]: return [raw] -def build_from_json(extraction: dict, *, directed: bool = False, root: str | Path | None = None) -> nx.Graph: +def build_from_json( + extraction: dict, + *, + directed: bool = False, + root: str | Path | None = None, + multigraph: bool = False, +) -> nx.Graph | nx.DiGraph | nx.MultiDiGraph: """Build a NetworkX graph from an extraction dict. directed=True produces a DiGraph that preserves edge direction (source→target). directed=False (default) produces an undirected Graph for backward compatibility. + multigraph=True produces a directed MultiDiGraph with keyed parallel edges for + internal tests/callers; public CLI exposure is intentionally deferred. + In this mode, directed is ignored because MultiDiGraph is always directed. root: if given, absolute source_file paths from semantic subagents are made relative to root so all nodes share a consistent path key (#932). """ + if not isinstance(extraction, dict): + raise TypeError("extraction must be a JSON object") + _root = str(Path(root).resolve()) if root else None # NetworkX <= 3.1 serialised edges as "links"; remap to "edges" for compatibility. if "edges" not in extraction and "links" in extraction: extraction = dict(extraction, edges=extraction["links"]) + nodes = _list_field(extraction, "nodes") + edges = _list_field(extraction, "edges") + extraction = dict(extraction, nodes=nodes, edges=edges) + # Canonicalize legacy node/edge schema before validation. - for node in extraction.get("nodes", []): + for node in nodes: if not isinstance(node, dict): continue if "source" in node and "source_file" not in node: # Count edges that reference this node so the warning is actionable (#479) node_id = node.get("id", "?") affected_edges = sum( - 1 for e in extraction.get("edges", []) - if e.get("source") == node_id or e.get("target") == node_id + 1 + for e in edges + if isinstance(e, dict) + and (e.get("source") == node_id or e.get("target") == node_id) ) print( f"[graphify] WARNING: node '{node_id}' uses field 'source' instead of " @@ -149,29 +230,78 @@ def build_from_json(extraction: dict, *, directed: bool = False, root: str | Pat # Dangling edges (stdlib/external imports) are expected - only warn about real schema errors. real_errors = [e for e in errors if "does not match any node id" not in e] if real_errors: - print(f"[graphify] Extraction warning ({len(real_errors)} issues): {real_errors[0]}", file=sys.stderr) - G: nx.Graph = nx.DiGraph() if directed else nx.Graph() - for node in extraction.get("nodes", []): + print( + f"[graphify] Extraction warning ({len(real_errors)} issues): {real_errors[0]}", + file=sys.stderr, + ) + if multigraph: + from .multigraph_compat import require_multigraph_capabilities + + require_multigraph_capabilities() + G: nx.Graph = nx.MultiDiGraph() if multigraph else nx.DiGraph() if directed else nx.Graph() + for node in nodes: + if not isinstance(node, dict) or "id" not in node: + continue + node_id = node["id"] + if not is_hashable(node_id): + continue if "source_file" in node: - node["source_file"] = _norm_source_file(node["source_file"], _root) - G.add_node(node["id"], **{k: v for k, v in node.items() if k != "id"}) + node["source_file"] = _norm_source_file( + _stable_identity_component(node["source_file"]), _root + ) + node_attrs = {k: v for k, v in node.items() if k != "id"} + # Reject node ids that JSON-serialize but won't round-trip to the same + # hashable type. Tuples serialize as JSON arrays and come back as lists + # (unhashable), so they cannot be used as NetworkX node ids after a + # save/load cycle even though json.dumps would accept them. + if isinstance(node_id, (list, tuple, set, frozenset, dict)): + print( + f"[graphify] WARNING: node id {node_id!r} ({type(node_id).__name__}) " + f"would not round-trip through JSON as the same hashable type; skipping.", + file=sys.stderr, + ) + continue + # Check id AND attrs are JSON-serializable. NetworkX allows hashable but + # non-JSON-safe ids (e.g., custom objects); accepting them here would + # break later node_link_data + json.dump. + try: + json.dumps({"id": node_id, **node_attrs}, ensure_ascii=False) + except (TypeError, ValueError) as exc: + print( + f"[graphify] WARNING: node {node_id!r} has non-JSON-serializable " + f"id or attrs ({exc}); skipping.", + file=sys.stderr, + ) + continue + G.add_node(node_id, **node_attrs) node_set = set(G.nodes()) # Normalized ID map: lets edges survive when the LLM generates IDs with # slightly different casing or punctuation than the AST extractor. # e.g. "Session_ValidateToken" maps to "session_validatetoken". - norm_to_id: dict[str, str] = {_normalize_id(nid): nid for nid in node_set} + norm_to_id: dict[str, Hashable] = { + _normalize_id(nid): nid for nid in node_set if isinstance(nid, str) + } + multigraph_groups: dict[tuple[Hashable, Hashable, str], list[dict]] = {} + multigraph_explicit_keys: set[tuple[Hashable, Hashable, str]] = set() + multigraph_diagnostics = {"exact_duplicate_edges": 0, "key_collision_edges": 0} # Iterate edges in a deterministic order. The graph is undirected and stores # direction in _src/_tgt; when two edges collapse onto the same node pair the # last write wins, so an unstable iteration order flips _src/_tgt run-to-run - # and makes the serialized graph churn. Sorting fixes the last-write outcome. - for edge in sorted( - extraction.get("edges", []), - key=lambda e: ( - str(e.get("source", e.get("from", ""))), - str(e.get("target", e.get("to", ""))), - str(e.get("relation", "")), - ), - ): + # and makes the serialized graph churn. Sorting also stabilizes multigraph + # key-collision grouping before keyed emission. + def _edge_sort_key(edge: object) -> tuple[str, str, str, str]: + if not isinstance(edge, dict): + return ("", "", "", repr(edge)) + return ( + str(edge.get("source", edge.get("from", ""))), + str(edge.get("target", edge.get("to", ""))), + str(edge.get("relation", "")), + json.dumps(edge, sort_keys=True, ensure_ascii=False, default=str), + ) + + for edge in sorted(edges, key=_edge_sort_key): + if not isinstance(edge, dict): + continue if "source" not in edge and "from" in edge: edge["source"] = edge["from"] if "target" not in edge and "to" in edge: @@ -179,29 +309,38 @@ def build_from_json(extraction: dict, *, directed: bool = False, root: str | Pat if "source" not in edge or "target" not in edge: continue src, tgt = edge["source"], edge["target"] + srcis_hashable = is_hashable(src) + tgtis_hashable = is_hashable(tgt) + if not srcis_hashable or not tgtis_hashable: + endpoint = "source" if not srcis_hashable else "target" + endpoint_value = src if not srcis_hashable else tgt + print( + "[graphify] WARNING: skipped edge with unhashable " + f"{endpoint} endpoint ({type(endpoint_value).__name__})", + file=sys.stderr, + ) + continue # Remap mismatched IDs via normalization before dropping the edge. - if src not in node_set: + if isinstance(src, str) and src not in node_set: src = norm_to_id.get(_normalize_id(src), src) - if tgt not in node_set: + if isinstance(tgt, str) and tgt not in node_set: tgt = norm_to_id.get(_normalize_id(tgt), tgt) if src not in node_set or tgt not in node_set: continue # skip edges to external/stdlib nodes - expected, not an error - attrs = {k: v for k, v in edge.items() if k not in ("source", "target")} + # Exclude legacy from/to alongside source/target so they don't survive + # as ordinary edge attrs after legacy-shape remap above. + base_attrs = { + k: v for k, v in edge.items() if k not in ("source", "target", "from", "to") + } + raw_key, attrs = strip_schema_key(base_attrs) if "source_file" in attrs: - attrs["source_file"] = _norm_source_file(attrs["source_file"], _root) + attrs["source_file"] = _norm_source_file( + _stable_identity_component(attrs["source_file"]), _root + ) # Drop cross-language INFERRED `calls` edges — same short names (render, # parse, etc.) appear across language boundaries in multi-language chunks, # producing phantom edges that don't represent real call relationships. if attrs.get("relation") == "calls" and attrs.get("confidence") == "INFERRED": - _LANG_FAMILY: dict[str, str] = { - ".py": "py", ".pyi": "py", - ".js": "js", ".mjs": "js", ".cjs": "js", ".jsx": "js", - ".ts": "js", ".tsx": "js", - ".go": "go", ".rs": "rs", - ".java": "jvm", ".kt": "jvm", ".scala": "jvm", ".groovy": "jvm", - ".c": "c", ".h": "c", ".cc": "cpp", ".cpp": "cpp", ".hpp": "cpp", - ".rb": "rb", ".php": "php", ".cs": "cs", ".swift": "swift", ".lua": "lua", - } src_ext = Path(G.nodes[src].get("source_file") or "").suffix.lower() tgt_ext = Path(G.nodes[tgt].get("source_file") or "").suffix.lower() if src_ext and tgt_ext and _LANG_FAMILY.get(src_ext) != _LANG_FAMILY.get(tgt_ext): @@ -210,23 +349,131 @@ def build_from_json(extraction: dict, *, directed: bool = False, root: str | Pat # causing display functions to show edges backwards. attrs["_src"] = src attrs["_tgt"] = tgt - # When the graph is undirected and the same node pair appears twice with - # the same relation but opposite directions (e.g. a `calls` b and b `calls` a), - # nx.Graph collapses them into one edge. The deterministic sort above means - # the lexicographically-later direction would systematically overwrite the - # earlier one's _src/_tgt, silently flipping the surviving edge's caller - # and callee. First-seen direction wins instead — drop the redundant - # reverse-direction duplicate so the original direction is preserved (#1061). - if not G.is_directed() and G.has_edge(src, tgt): - existing = edge_data(G, src, tgt) - if existing.get("relation") == attrs.get("relation") and ( - existing.get("_src") == tgt and existing.get("_tgt") == src - ): - continue - G.add_edge(src, tgt, **attrs) + # Refuse to store any edge whose attrs cannot round-trip through JSON. + # Mutating attrs in place would silently change the user's stored value; + # skipping with a warning matches the rest of the build's defensive policy + # and prevents later json.dump crashes during export, identically in + # simple-graph and multigraph modes. + try: + json.dumps(attrs, ensure_ascii=False) + except (TypeError, ValueError) as exc: + print( + f"[graphify] WARNING: edge ({src}->{tgt}) has non-JSON-serializable " + f"attrs ({exc}); skipping.", + file=sys.stderr, + ) + continue + if multigraph: + if raw_key is not None and not isinstance(raw_key, str): + raise TypeError( + f"multigraph edge 'key' must be a string, got " + f"{type(raw_key).__name__} ({raw_key!r})" + ) + base_key = ( + raw_key + if raw_key is not None + else make_stable_key( + _stable_identity_component(attrs.get("relation")), + _stable_identity_component(attrs.get("source_file")), + _stable_identity_component(attrs.get("source_location")), + ) + ) + if raw_key is not None: + multigraph_explicit_keys.add((src, tgt, base_key)) + multigraph_groups.setdefault((src, tgt, base_key), []).append(dict(attrs)) + else: + # When the graph is undirected and the same node pair appears twice with + # the same relation but opposite directions (e.g. a `calls` b and b `calls` a), + # nx.Graph collapses them into one edge. The deterministic sort above means + # the lexicographically-later direction would systematically overwrite the + # earlier one's _src/_tgt, silently flipping the surviving edge's caller + # and callee. First-seen direction wins instead — drop the redundant + # reverse-direction duplicate so the original direction is preserved (#1061). + if not G.is_directed() and G.has_edge(src, tgt): + existing = edge_data(G, src, tgt) + if existing.get("relation") == attrs.get("relation") and ( + existing.get("_src") == tgt and existing.get("_tgt") == src + ): + continue + G.add_edge(src, tgt, **attrs) + if multigraph: + singleton_groups: list[tuple[Hashable, Hashable, str, dict]] = [] + multi_groups: list[tuple[Hashable, Hashable, str, list[dict]]] = [] + used_keys_by_pair: dict[tuple[Hashable, Hashable], set[str]] = {} + for (src, tgt, base_key), group_attrs in multigraph_groups.items(): + unique_attrs: list[dict] = [] + seen_attr_fingerprints: set[str] = set() + for attrs in group_attrs: + attr_fingerprint = json.dumps( + attrs, sort_keys=True, ensure_ascii=False, default=str + ) + if attr_fingerprint in seen_attr_fingerprints: + multigraph_diagnostics["exact_duplicate_edges"] += 1 + else: + seen_attr_fingerprints.add(attr_fingerprint) + unique_attrs.append(attrs) + if len(unique_attrs) > 1: + multigraph_diagnostics["key_collision_edges"] += 1 + unique_attrs.sort( + key=lambda attrs: json.dumps( + attrs, sort_keys=True, ensure_ascii=False, default=str + ) + ) + multi_groups.append((src, tgt, base_key, unique_attrs)) + elif unique_attrs: + # Reserve the singleton's base_key so any later multi-attr + # collision-repair on the same (src, tgt) avoids it. + used_keys_by_pair.setdefault((src, tgt), set()).add(base_key) + singleton_groups.append((src, tgt, base_key, unique_attrs[0])) + # Sort both lists deterministically. + singleton_groups.sort( + key=lambda item: ( + repr(item[0]), + repr(item[1]), + item[2], + json.dumps(item[3], sort_keys=True, ensure_ascii=False, default=str), + ) + ) + multi_groups.sort( + key=lambda item: ( + repr(item[0]), + repr(item[1]), + item[2], + json.dumps(item[3], sort_keys=True, ensure_ascii=False, default=str), + ) + ) + # Emit singletons first: they use base_key directly and were reserved + # in the pre-loop above, so collision-repair from multi groups will + # see those reservations and salt around them. + for src, tgt, base_key, attrs in singleton_groups: + G.add_edge(src, tgt, key=base_key, **attrs) + # Then emit multi-attr groups with collision-repair salting against + # both reserved singleton base_keys and earlier multi-group repair + # keys on the same (src, tgt) pair. + for src, tgt, base_key, unique_attrs in multi_groups: + used_keys = used_keys_by_pair.setdefault((src, tgt), set()) + preserve_explicit = (src, tgt, base_key) in multigraph_explicit_keys + for index, attrs in enumerate(unique_attrs): + # When the user passed an explicit `key` shared across multiple + # distinct edges, preserve it on the first emit so at least one + # edge per group keeps the canonical user-supplied key. + # Derived base_keys (from make_stable_key) always go through + # collision-repair so emission stays order-independent. + if preserve_explicit and index == 0 and base_key not in used_keys: + key = base_key + else: + key = _make_collision_key(base_key, attrs) + salt = 0 + while key in used_keys: + salt += 1 + key = _make_collision_key(base_key, attrs, salt=salt) + used_keys.add(key) + G.add_edge(src, tgt, key=key, **attrs) hyperedges = extraction.get("hyperedges", []) if hyperedges: G.graph["hyperedges"] = hyperedges + if multigraph: + G.graph["graphify_multigraph_diagnostics"] = multigraph_diagnostics return G @@ -237,7 +484,8 @@ def build( dedup: bool = True, dedup_llm_backend: str | None = None, root: str | Path | None = None, -) -> nx.Graph: + multigraph: bool = False, +) -> nx.Graph | nx.DiGraph | nx.MultiDiGraph: """Merge multiple extraction results into one graph. directed=True produces a DiGraph that preserves edge direction (source→target). @@ -253,7 +501,14 @@ def build( reverse the order if you prefer AST source_location precision to win. """ from graphify.dedup import deduplicate_entities - combined: dict = {"nodes": [], "edges": [], "hyperedges": [], "input_tokens": 0, "output_tokens": 0} + + combined: dict = { + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } for ext in extractions: combined["nodes"].extend(ext.get("nodes", [])) combined["edges"].extend(ext.get("edges", [])) @@ -262,10 +517,12 @@ def build( combined["output_tokens"] += ext.get("output_tokens", 0) if dedup and combined["nodes"]: combined["nodes"], combined["edges"] = deduplicate_entities( - combined["nodes"], combined["edges"], communities={}, + combined["nodes"], + combined["edges"], + communities={}, dedup_llm_backend=dedup_llm_backend, ) - return build_from_json(combined, directed=directed, root=root) + return build_from_json(combined, directed=directed, root=root, multigraph=multigraph) def _norm_label(label: str) -> str: @@ -282,7 +539,7 @@ def deduplicate_by_label(nodes: list[dict], edges: list[dict]) -> tuple[list[dic """ _CHUNK_SUFFIX = re.compile(r"_c\d+$") canonical: dict[str, dict] = {} # norm_label -> surviving node - remap: dict[str, str] = {} # old_id -> surviving_id + remap: dict[str, str] = {} # old_id -> surviving_id for node in nodes: key = _norm_label(node.get("label", node.get("id", ""))) @@ -325,16 +582,23 @@ def build_merge( graph_path: str | Path = "graphify-out/graph.json", prune_sources: list[str] | None = None, *, - directed: bool = False, + directed: bool | None = None, dedup: bool = True, dedup_llm_backend: str | None = None, root: str | Path | None = None, -) -> nx.Graph: - """Load existing graph.json, merge new chunks into it, and save back. +) -> nx.Graph | nx.DiGraph: + """Load existing graph.json, merge new chunks into it, and return the merged graph. + + Persistence is the caller's responsibility (e.g., via ``export.to_json``); + this function does not write back to disk. Never replaces - only grows (or prunes deleted-file nodes via prune_sources). Safe to call repeatedly: existing nodes and edges are preserved. root: if given, absolute source_file paths in new_chunks are made relative (#932). + + ``directed`` defaults to inheriting the saved graph's flag when an + existing graph.json is present, so updating a directed simple graph with + default args no longer silently downgrades it to undirected. """ graph_path = Path(graph_path) if graph_path.exists(): @@ -346,18 +610,62 @@ def build_merge( # attrs are popped before saving in export.py, so going through the # NetworkX round-trip loses direction permanently (#760). from graphify.security import check_graph_file_size_cap + check_graph_file_size_cap(graph_path) data = json.loads(graph_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise TypeError( + f"saved graph.json at {graph_path} must be a JSON object, " + f"got {type(data).__name__}" + ) + # Refuse to silently collapse a saved multigraph. build() runs in + # simple mode here, which would drop parallel edges; stateful + # multigraph update paths are out of scope for the internal keyed + # build path (watch/cache/global-graph land in later slices). + saved_multigraph = data.get("multigraph", False) + if saved_multigraph is True: + raise NotImplementedError( + f"build_merge cannot update a multigraph graph.json. " + f"Found multigraph=true in {graph_path}. Rebuild from extraction " + f"or use a simple-graph build target." + ) + if saved_multigraph is not False: + raise TypeError( + f"'multigraph' in {graph_path} must be a boolean, " + f"got {type(saved_multigraph).__name__} ({saved_multigraph!r})" + ) + # Honor the saved graph's `directed` flag unless the caller explicitly + # overrides. Without this, an update with default args on a directed + # graph silently downgrades it and loses edge direction on next export. + saved_directed_raw = data.get("directed", False) + if saved_directed_raw is not True and saved_directed_raw is not False: + raise TypeError( + f"'directed' in {graph_path} must be a boolean, " + f"got {type(saved_directed_raw).__name__} ({saved_directed_raw!r})" + ) + saved_directed = saved_directed_raw + if directed is None: + directed = saved_directed + elif directed != saved_directed: + print( + f"[graphify] WARNING: build_merge directed={directed} overrides " + f"saved graph.json directed={saved_directed}", + file=sys.stderr, + ) links_key = "links" if "links" in data else "edges" existing_nodes = list(data.get("nodes", [])) existing_edges = list(data.get(links_key, [])) base = [{"nodes": existing_nodes, "edges": existing_edges}] else: + if directed is None: + directed = False existing_nodes = [] base = [] all_chunks = base + list(new_chunks) - G = build(all_chunks, directed=directed, dedup=dedup, dedup_llm_backend=dedup_llm_backend, root=root) + G = build( + all_chunks, directed=directed, dedup=dedup, dedup_llm_backend=dedup_llm_backend, root=root + ) # Prune nodes and edges from deleted source files if prune_sources: @@ -390,8 +698,7 @@ def build_merge( ) edges_to_remove = [ - (u, v) for u, v, d in G.edges(data=True) - if d.get("source_file") in prune_set + (u, v) for u, v, d in G.edges(data=True) if d.get("source_file") in prune_set ] if edges_to_remove: G.remove_edges_from(edges_to_remove) @@ -418,6 +725,7 @@ def build_merge( f"Pass prune_sources explicitly if you intend to remove nodes." ) + # No write to graph_path here; persistence is the caller's responsibility. return G diff --git a/graphify/edge_identity.py b/graphify/edge_identity.py new file mode 100644 index 000000000..f1802bca4 --- /dev/null +++ b/graphify/edge_identity.py @@ -0,0 +1,58 @@ +"""Stable edge identity helpers and schema constants for MultiDiGraph support. + +The node-link ``"key"`` field is reserved schema — it identifies a parallel edge +and must never be stored as an ordinary edge attribute. All callers that build or +load graphs should use :func:`strip_schema_key` before passing attrs to +``G.add_edge`` so the ``key`` kwarg is never duplicated. +""" + +from __future__ import annotations + +import hashlib +import json as _json + +SCHEMA_KEY_FIELD = "key" + + +def make_stable_key( + relation: str | None, + source_file: str | None, + source_location: str | None, +) -> str: + """Return a collision-safe deterministic edge key from semantic identity fields. + + Uses SHA-256 over a canonical JSON payload with explicit field names so that + delimiter characters in field values cannot produce false collisions. The key + format is ``"edge:v1:"``. + + Two edges with the same relation, file, and location always produce the same + key; any difference in those three fields produces a different key. + """ + payload = _json.dumps( + { + "relation": relation, + "source_file": source_file, + "source_location": source_location, + }, + sort_keys=True, + ) + digest = hashlib.sha256(payload.encode()).hexdigest() + return f"edge:v1:{digest}" + + +def strip_schema_key(attrs: dict) -> tuple[object | None, dict]: + """Separate the ``"key"`` schema field from edge attribute kwargs. + + Returns ``(key_value, cleaned_attrs)`` where ``cleaned_attrs`` is a new dict + with ``SCHEMA_KEY_FIELD`` removed. The original *attrs* dict is not mutated. + + The return type is ``object | None`` rather than ``str | None`` because the + field may carry any JSON-decodable value at this layer; callers narrow to + ``str`` after explicit validation (see the multigraph loader/build paths). + + Use before ``G.add_edge(u, v, key=key_value, **cleaned_attrs)`` to avoid + passing ``key`` twice (once as the positional schema arg and once inside attrs). + """ + key_val = attrs.get(SCHEMA_KEY_FIELD) + cleaned = {k: v for k, v in attrs.items() if k != SCHEMA_KEY_FIELD} + return key_val, cleaned diff --git a/graphify/graph_loader.py b/graphify/graph_loader.py new file mode 100644 index 000000000..437c4024a --- /dev/null +++ b/graphify/graph_loader.py @@ -0,0 +1,301 @@ +"""Schema-aware graph loader for saved graphify node-link JSON. + +This module loads *serialized graph files* (graph.json / node-link format). +It is distinct from :func:`graphify.build.build_from_json`, which assembles +graphs from raw extraction dicts produced by AST and semantic passes. + +The two are complementary: + - extraction dict → ``build_from_json`` + - saved graph.json → ``load_graph`` / ``load_graph_file`` +""" + +from __future__ import annotations + +import hashlib +import json +import sys +from pathlib import Path + +import networkx as nx + +from .edge_identity import strip_schema_key +from .multigraph_compat import require_multigraph_capabilities +from .validate import is_hashable + +GRAPHIFY_PROFILE_KEY = "graphify_profile" + + +def load_graph( + data: object, + *, + require_capabilities: bool = True, +) -> nx.Graph | nx.DiGraph | nx.MultiDiGraph: + """Load a serialized node-link graph dict into the appropriate NetworkX type. + + Detects graph type from ``multigraph`` and ``directed`` flags in *data*: + + - ``multigraph: true`` → :class:`nx.MultiDiGraph` + - ``multigraph: false, directed: true`` → :class:`nx.DiGraph` + - ``multigraph: false, directed: false`` → :class:`nx.Graph` + + All paths set ``G.graph[GRAPHIFY_PROFILE_KEY]`` with at minimum + ``{"graph_type": "simple" | "digraph" | "multidigraph"}``. + + ``require_capabilities`` (default ``True``) gates multigraph loading behind + :func:`~graphify.multigraph_compat.require_multigraph_capabilities`. Pass + ``False`` to skip the probe entirely — used in unit tests and when the + caller has already verified capabilities externally. + """ + if not isinstance(data, dict): + raise TypeError("serialized graph data must be a JSON object") + + multigraph_flag = _require_bool_field(data, "multigraph", default=False) + directed_flag = _require_bool_field(data, "directed", default=False, allow_none=True) + directed_present = "directed" in data + + if multigraph_flag is True: + # Only warn when ``directed`` was *explicitly* set to false; an omitted + # flag does not contradict ``multigraph: true``. + if directed_present and directed_flag is False: + print( + "[graphify] WARNING: multigraph=true but directed=false; " + "normalizing to MultiDiGraph (graphify uses directed graphs).", + file=sys.stderr, + ) + if require_capabilities: + require_multigraph_capabilities() + return _load_multigraph(data) + if directed_flag is True: + return _load_directed_simple(data) + return _load_simple(data) + + +def _require_bool_field( + data: dict, field: str, *, default: bool, allow_none: bool = False +) -> bool | None: + """Read a strict-boolean field from serialized graph JSON. + + Rejects non-boolean values (e.g., the string ``"false"``) so corrupted JSON + cannot be misclassified by Python's truthiness rules. + """ + if field not in data: + return default + value = data[field] + if value is True or value is False: + return value + if allow_none and value is None: + return None + raise TypeError( + f"'{field}' must be a boolean, got {type(value).__name__} ({value!r})" + ) + + +def load_graph_file( + path: str | Path, + *, + require_capabilities: bool = True, +) -> nx.Graph | nx.DiGraph | nx.MultiDiGraph: + """Load a graph.json file produced by graphify. + + Applies the 512 MiB size cap before parsing. + """ + from .security import check_graph_file_size_cap + + path = Path(path) + check_graph_file_size_cap(path) + data = json.loads(path.read_text(encoding="utf-8")) + return load_graph(data, require_capabilities=require_capabilities) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _get_edges(data: dict) -> list[dict]: + """Return the edge list, accepting both ``"edges"`` and legacy ``"links"``.""" + for key in ("edges", "links"): + if key in data: + val = data[key] + if not isinstance(val, list): + raise TypeError(f"'{key}' must be a list, got {type(val).__name__}") + return [e for e in val if isinstance(e, dict)] + return [] + + +def _set_graph_profile(G: nx.Graph, data: dict, *, graph_type: str) -> None: + """Store Graphify profile metadata in ``G.graph[GRAPHIFY_PROFILE_KEY]``. + + NetworkX ``node_link_data`` serializes ``G.graph[...]`` attributes under + ``data["graph"]``; some graphify writers also promote ``graphify_profile`` + to the top level. Read both so round-trips do not silently drop metadata. + """ + nested = data.get("graph", {}) + if isinstance(nested, dict): + for key, value in nested.items(): + G.graph[key] = value + # Prefer the top-level profile when it is a usable dict; fall through to + # the nested copy when the top-level value is absent OR malformed. + raw = data.get(GRAPHIFY_PROFILE_KEY) + if not isinstance(raw, dict) and isinstance(nested, dict): + raw = nested.get(GRAPHIFY_PROFILE_KEY) + profile = dict(raw) if isinstance(raw, dict) else {} + # Overwrite graph_type with the value derived from the multigraph/directed + # flags on this load; a stale graph_type in a serialized profile must not + # mislabel the actual NetworkX type we just constructed. + profile["graph_type"] = graph_type + G.graph[GRAPHIFY_PROFILE_KEY] = profile + + +def _add_nodes(G: nx.Graph, data: dict) -> set: + """Add valid nodes from *data* to *G*; return the resulting node ID set.""" + nodes = data.get("nodes", []) + if not isinstance(nodes, list): + raise TypeError(f"'nodes' must be a list, got {type(nodes).__name__}") + + skipped_unhashable = 0 + for node in nodes: + if not isinstance(node, dict) or "id" not in node: + continue + node_id = node["id"] + if not is_hashable(node_id): + skipped_unhashable += 1 + continue + G.add_node(node_id, **{k: v for k, v in node.items() if k != "id"}) + if skipped_unhashable: + print( + f"[graphify] WARNING: skipped {skipped_unhashable} node(s) with unhashable id", + file=sys.stderr, + ) + return set(G.nodes()) + + +def _load_simple(data: dict) -> nx.Graph: + """Build an undirected :class:`nx.Graph` from node-link data.""" + G = nx.Graph() + _set_graph_profile(G, data, graph_type="simple") + node_set = _add_nodes(G, data) + for edge in _get_edges(data): + if not isinstance(edge, dict): + continue + src = edge["source"] if "source" in edge else edge.get("from") + tgt = edge["target"] if "target" in edge else edge.get("to") + # `is None` (not falsy) so valid hashable IDs like 0 or False survive; + # the unhashable guard prevents `in node_set` from raising TypeError on + # corrupt input like {"source": ["bad"]}. + if src is None or tgt is None: + continue + if not is_hashable(src) or not is_hashable(tgt): + continue + if src not in node_set or tgt not in node_set: + continue + attrs = {k: v for k, v in edge.items() if k not in ("source", "target", "from", "to")} + _, attrs = strip_schema_key(attrs) + G.add_edge(src, tgt, **attrs) + return G + + +def _load_directed_simple(data: dict) -> nx.DiGraph: + """Build a directed :class:`nx.DiGraph` from node-link data.""" + G = nx.DiGraph() + _set_graph_profile(G, data, graph_type="digraph") + node_set = _add_nodes(G, data) + for edge in _get_edges(data): + if not isinstance(edge, dict): + continue + src = edge["source"] if "source" in edge else edge.get("from") + tgt = edge["target"] if "target" in edge else edge.get("to") + # `is None` (not falsy) so valid hashable IDs like 0 or False survive; + # the unhashable guard prevents `in node_set` from raising TypeError on + # corrupt input like {"source": ["bad"]}. + if src is None or tgt is None: + continue + if not is_hashable(src) or not is_hashable(tgt): + continue + if src not in node_set or tgt not in node_set: + continue + attrs = {k: v for k, v in edge.items() if k not in ("source", "target", "from", "to")} + _, attrs = strip_schema_key(attrs) + G.add_edge(src, tgt, **attrs) + return G + + +def _load_multigraph(data: dict) -> nx.MultiDiGraph: + """Build a :class:`nx.MultiDiGraph` with preserved edge keys. + + Missing-key repair: when a serialized edge has no ``"key"`` field, a + deterministic repair key is generated from the full edge attribute payload + (not just the 3 identity fields) so parallel edges with different metadata + are never silently overwritten. + """ + G = nx.MultiDiGraph() + _set_graph_profile(G, data, graph_type="multidigraph") + node_set = _add_nodes(G, data) + missing_key_count = 0 + duplicate_key_count = 0 + used_keys_by_pair: dict[tuple[object, object], set[str]] = {} + # Sort edges by a stable fingerprint so duplicate-key repair is + # input-order-independent: the same malformed graph.json with edges in any + # order produces the same final (src, tgt, key) layout. + sorted_edges = sorted( + _get_edges(data), + key=lambda e: json.dumps(e, sort_keys=True, default=str), + ) + for edge in sorted_edges: + if not isinstance(edge, dict): + continue + src = edge["source"] if "source" in edge else edge.get("from") + tgt = edge["target"] if "target" in edge else edge.get("to") + # `is None` (not falsy) so valid hashable IDs like 0 or False survive; + # the unhashable guard prevents `in node_set` from raising TypeError on + # corrupt input like {"source": ["bad"]}. + if src is None or tgt is None: + continue + if not is_hashable(src) or not is_hashable(tgt): + continue + if src not in node_set or tgt not in node_set: + continue + attrs = {k: v for k, v in edge.items() if k not in ("source", "target", "from", "to")} + key, attrs = strip_schema_key(attrs) + if key is not None and not isinstance(key, str): + raise TypeError( + f"multigraph edge 'key' must be a string, got " + f"{type(key).__name__} ({key!r})" + ) + if key is None: + missing_key_count += 1 + # Hash the full payload so edges with different metadata get different + # keys and both survive (identity-field-only hashing collapses distinct + # parallel edges that share relation/source_file/source_location). + repair_payload = json.dumps(attrs, sort_keys=True, default=str) + repair_digest = hashlib.sha256(repair_payload.encode()).hexdigest() + key = f"edge:v1:{repair_digest}" + # Detect duplicate (src, tgt, key) tuples. add_edge would otherwise + # silently overwrite a previously loaded parallel edge. + used = used_keys_by_pair.setdefault((src, tgt), set()) + if key in used: + duplicate_key_count += 1 + repair_payload = json.dumps(attrs, sort_keys=True, default=str) + salt = 0 + candidate = f"{key}:dup:{hashlib.sha256(repair_payload.encode()).hexdigest()}" + while candidate in used: + salt += 1 + candidate = ( + f"{key}:dup:{hashlib.sha256((repair_payload + str(salt)).encode()).hexdigest()}" + ) + key = candidate + used.add(key) + G.add_edge(src, tgt, key=key, **attrs) + if missing_key_count: + print( + f"[graphify] WARNING: {missing_key_count} multigraph edge(s) were missing " + f"'key' — generated repair keys from full edge payload.", + file=sys.stderr, + ) + if duplicate_key_count: + print( + f"[graphify] WARNING: {duplicate_key_count} multigraph edge(s) had duplicate " + f"(source, target, key) tuples — generated repair keys to preserve all edges.", + file=sys.stderr, + ) + return G diff --git a/graphify/projections.py b/graphify/projections.py new file mode 100644 index 000000000..a591538bd --- /dev/null +++ b/graphify/projections.py @@ -0,0 +1,214 @@ +"""Projection helpers for graph consumers that need explicit edge semantics.""" + +from __future__ import annotations + +from collections.abc import Hashable, Iterable +from typing import Any, Literal, cast + +import networkx as nx + +WeightMode = Literal["confidence", "count", "sum"] + +_CONFIDENCE_SCORE = { + "EXTRACTED": 1.0, + "INFERRED": 0.5, + "AMBIGUOUS": 0.2, +} + + +def _confidence_score(data: dict[str, Any]) -> float: + raw_score = data.get("confidence_score") + if isinstance(raw_score, int | float) and not isinstance(raw_score, bool): # Python 3.10+ + return float(raw_score) + raw_confidence = data.get("confidence") + if isinstance(raw_confidence, str): + return _CONFIDENCE_SCORE.get(raw_confidence.upper(), 0.0) + return 0.0 + + +def _edge_sort_key(data: dict[str, Any]) -> tuple: + return ( + -_confidence_score(data), + str(data.get("relation", "")), + str(data.get("source_file", "")), + str(data.get("source_location", "")), + str(data.get("context", "")), + repr(sorted((str(key), repr(value)) for key, value in data.items())), + ) + + +def _iter_edge_data(G: nx.Graph) -> Iterable[tuple[Any, Any, Any, dict[str, Any]]]: + if isinstance(G, nx.MultiGraph | nx.MultiDiGraph): # Python 3.10+ + yield from G.edges(keys=True, data=True) + return + for u, v, data in G.edges(data=True): + yield u, v, None, data + + +def _copy_graph_skeleton(G: nx.Graph, graph_type: type[nx.Graph]) -> nx.Graph: + H = graph_type() + H.graph.update(G.graph) + H.add_nodes_from((node, attrs.copy()) for node, attrs in G.nodes(data=True)) + return H + + +def _unordered_pair(u: Any, v: Any) -> tuple[Any, Any]: + if repr(u) <= repr(v): + return u, v + return v, u + + +def _merged_edge_attrs(records: list[dict[str, Any]], weight_mode: WeightMode) -> dict[str, Any]: + if weight_mode not in ("confidence", "count", "sum"): + raise ValueError("weight_mode must be one of: confidence, count, sum") + sorted_records = sorted(records, key=_edge_sort_key) + representative = sorted_records[0].copy() + scores = [_confidence_score(record) for record in records] + if weight_mode == "confidence": + weight = max(scores, default=0.0) + elif weight_mode == "count": + weight = float(len(records)) + else: + weight = float(sum(scores)) + representative["weight"] = weight + representative["parallel_edge_count"] = len(records) + return representative + + +def project_for_community(G: nx.Graph, *, weight_mode: WeightMode = "confidence") -> nx.Graph: + """Return a simple undirected projection for clustering and community metrics.""" + groups: dict[tuple[Any, Any], list[dict[str, Any]]] = {} + for u, v, _key, data in _iter_edge_data(G): + if u == v: + continue + pair = _unordered_pair(u, v) + groups.setdefault(pair, []).append(dict(data)) + + H = _copy_graph_skeleton(G, nx.Graph) + for (u, v), records in sorted( + groups.items(), key=lambda item: (repr(item[0][0]), repr(item[0][1])) + ): + H.add_edge(u, v, **_merged_edge_attrs(records, weight_mode)) + return H + + +def project_for_path(G: nx.Graph) -> nx.Graph: + """Return a simple undirected topology projection for path search.""" + return project_for_community(G, weight_mode="count") + + +def project_for_callflow( + G: nx.Graph, + *, + relations: frozenset[str] | set[str] | None = None, +) -> nx.DiGraph: + """Return a simple directed projection for callflow-style consumers.""" + relation_filter = set(relations) if relations is not None else None + groups: dict[tuple[Any, Any], list[dict[str, Any]]] = {} + for u, v, _key, data in _iter_edge_data(G): + relation = data.get("relation") + # Guard against non-string `relation`; relation_filter is set[str], and + # an unhashable relation would TypeError on the `in` membership test. + if relation_filter is not None and ( + not isinstance(relation, str) or relation not in relation_filter + ): + continue + src = data.get("_src", u) + tgt = data.get("_tgt", v) + if src == tgt: + continue + groups.setdefault((src, tgt), []).append(dict(data)) + + H = cast(nx.DiGraph, _copy_graph_skeleton(G, nx.DiGraph)) + for (src, tgt), records in sorted( + groups.items(), key=lambda item: (repr(item[0][0]), repr(item[0][1])) + ): + if src not in H: + H.add_node(src) + if tgt not in H: + H.add_node(tgt) + H.add_edge(src, tgt, **_merged_edge_attrs(records, "confidence")) + return H + + +def _normalize_contexts(contexts: Iterable[str] | str | None) -> set[str] | None: + if contexts is None: + return None + raw_contexts = [contexts] if isinstance(contexts, str) else contexts + normalized = {str(context).strip().lower() for context in raw_contexts if str(context).strip()} + return normalized or None + + +def project_for_context(G: nx.Graph, *, contexts: Iterable[str] | str | None = None) -> nx.Graph: + """Return a graph copy containing only edges whose context matches the filter.""" + filters = _normalize_contexts(contexts) + H = _copy_graph_skeleton(G, G.__class__) + for u, v, key, data in _iter_edge_data(G): + if filters is not None and str(data.get("context", "")).strip().lower() not in filters: + continue + if isinstance(H, nx.MultiGraph | nx.MultiDiGraph): # Python 3.10+ + H.add_edge(u, v, key=key, **data) + else: + H.add_edge(u, v, **data) + return H + + +def edge_records_between(G: nx.Graph, u: Hashable, v: Hashable) -> list[dict[str, Any]]: + """Return shallow copies of all edge records connecting two nodes.""" + records: list[dict[str, Any]] = [] + + def collect(src: Hashable, tgt: Hashable) -> None: + if not G.has_edge(src, tgt): + return + raw = G.get_edge_data(src, tgt) + if not isinstance(raw, dict): + return + if isinstance(G, nx.MultiGraph | nx.MultiDiGraph): # Python 3.10+ + records.extend(dict(data) for data in raw.values() if isinstance(data, dict)) + else: + records.append(dict(raw)) + + collect(u, v) + if G.is_directed() and u != v: + collect(v, u) + return sorted(records, key=_edge_sort_key) + + +def edge_summary_between(G: nx.Graph, u: Hashable, v: Hashable) -> dict[str, Any]: + """Summarize all relationships between two nodes for display consumers.""" + records = edge_records_between(G, u, v) + representative = records[0].copy() if records else {} + return { + "count": len(records), + "relations": sorted( + {str(record.get("relation")) for record in records if record.get("relation")} + ), + "confidences": sorted( + {str(record.get("confidence")) for record in records if record.get("confidence")} + ), + "representative": representative, + } + + +def distinct_neighbor_degree(G: nx.Graph, node: Hashable) -> int: + """Count unique adjacent nodes without inflating parallel edges.""" + if node not in G: + return 0 + if G.is_directed(): + directed = cast(nx.DiGraph, G) + return len(set(directed.predecessors(node)) | set(directed.successors(node))) + return len(set(G.neighbors(node))) + + +def normalize_to_multidigraph(G: nx.Graph) -> nx.MultiDiGraph: + """Return a MultiDiGraph copy, preserving parallel keys when present.""" + H = nx.MultiDiGraph() + H.graph.update(G.graph) + H.add_nodes_from((node, attrs.copy()) for node, attrs in G.nodes(data=True)) + if isinstance(G, nx.MultiGraph | nx.MultiDiGraph): # Python 3.10+ + for u, v, key, data in G.edges(keys=True, data=True): + H.add_edge(u, v, key=key, **data) + else: + for u, v, data in G.edges(data=True): + H.add_edge(u, v, **data) + return H diff --git a/graphify/symbol_resolution.py b/graphify/symbol_resolution.py index 7bc68093a..5cb0dad15 100644 --- a/graphify/symbol_resolution.py +++ b/graphify/symbol_resolution.py @@ -243,7 +243,7 @@ def resolve_python_import_guided_calls( if path.suffix != ".py": continue slot: Any = per_file[index] if index < len(per_file) else None - result_by_file[str(path)] = slot if isinstance(slot, dict) else {"nodes": [], "edges": []} + result_by_file[str(path)] = slot if isinstance(slot, dict) else {"nodes": [], "edges": []} # empty fragment for missing/non-dict slots resolved_edges: list[dict[str, Any]] = [] for path in paths: @@ -256,7 +256,7 @@ def resolve_python_import_guided_calls( file_result = result_by_file.get(source_file, {"raw_calls": []}) raw_calls = file_result.get("raw_calls", []) if not isinstance(raw_calls, list): - continue + continue # raw_calls must be a list; skip malformed fragments for raw_call in raw_calls: if not isinstance(raw_call, dict): continue diff --git a/graphify/validate.py b/graphify/validate.py index 5f6bad364..4b63d1af4 100644 --- a/graphify/validate.py +++ b/graphify/validate.py @@ -7,6 +7,14 @@ REQUIRED_EDGE_FIELDS = {"source", "target", "relation", "confidence", "source_file"} +def is_hashable(value: object) -> bool: + try: + hash(value) + except TypeError: + return False + return True + + def validate_extraction(data: dict) -> list[str]: """ Validate an extraction JSON dict against the graphify schema. @@ -29,7 +37,11 @@ def validate_extraction(data: dict) -> list[str]: continue for field in REQUIRED_NODE_FIELDS: if field not in node: - errors.append(f"Node {i} (id={node.get('id', '?')!r}) missing required field '{field}'") + errors.append( + f"Node {i} (id={node.get('id', '?')!r}) missing required field '{field}'" + ) + if "id" in node and not is_hashable(node["id"]): + errors.append(f"Node {i} id is unhashable and cannot be used as a node id") if "file_type" in node and node["file_type"] not in VALID_FILE_TYPES: errors.append( f"Node {i} (id={node.get('id', '?')!r}) has invalid file_type " @@ -43,7 +55,17 @@ def validate_extraction(data: dict) -> list[str]: elif not isinstance(edge_list, list): errors.append("'edges' must be a list") else: - node_ids = {n["id"] for n in data.get("nodes", []) if isinstance(n, dict) and "id" in n} + # Guard against non-list `nodes` (the earlier branch only records the + # error and falls through to here); iterating a non-list would otherwise + # raise an incidental TypeError instead of yielding a clean validation + # message. + raw_nodes = data.get("nodes", []) + nodes_iter = raw_nodes if isinstance(raw_nodes, list) else [] + node_ids = { + n["id"] + for n in nodes_iter + if isinstance(n, dict) and "id" in n and is_hashable(n["id"]) + } for i, edge in enumerate(edge_list): if not isinstance(edge, dict): errors.append(f"Edge {i} must be an object") @@ -56,10 +78,16 @@ def validate_extraction(data: dict) -> list[str]: f"Edge {i} has invalid confidence '{edge['confidence']}' " f"- must be one of {sorted(VALID_CONFIDENCES)}" ) - if "source" in edge and node_ids and edge["source"] not in node_ids: - errors.append(f"Edge {i} source '{edge['source']}' does not match any node id") - if "target" in edge and node_ids and edge["target"] not in node_ids: - errors.append(f"Edge {i} target '{edge['target']}' does not match any node id") + if "source" in edge: + if not is_hashable(edge["source"]): + errors.append(f"Edge {i} source is unhashable and cannot match any node id") + elif node_ids and edge["source"] not in node_ids: + errors.append(f"Edge {i} source '{edge['source']}' does not match any node id") + if "target" in edge: + if not is_hashable(edge["target"]): + errors.append(f"Edge {i} target is unhashable and cannot match any node id") + elif node_ids and edge["target"] not in node_ids: + errors.append(f"Edge {i} target '{edge['target']}' does not match any node id") return errors @@ -68,5 +96,7 @@ def assert_valid(data: dict) -> None: """Raise ValueError with all errors if extraction is invalid.""" errors = validate_extraction(data) if errors: - msg = f"Extraction JSON has {len(errors)} error(s):\n" + "\n".join(f" • {e}" for e in errors) + msg = f"Extraction JSON has {len(errors)} error(s):\n" + "\n".join( + f" • {e}" for e in errors + ) raise ValueError(msg) diff --git a/tests/test_build.py b/tests/test_build.py index 9be6c1289..19bd65199 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,40 +1,61 @@ import json from pathlib import Path +from typing import cast import networkx as nx +import pytest from networkx.readwrite import json_graph -from graphify.build import build_from_json, build, build_merge, edge_data, edge_datas +from graphify.build import ( + _make_collision_key, + build_from_json, + build, + build_merge, + edge_data, + edge_datas, +) +from graphify.edge_identity import make_stable_key FIXTURES = Path(__file__).parent / "fixtures" + def load_extraction(): return json.loads((FIXTURES / "extraction.json").read_text()) + def test_build_from_json_node_count(): G = build_from_json(load_extraction()) assert G.number_of_nodes() == 4 + def test_build_from_json_edge_count(): G = build_from_json(load_extraction()) assert G.number_of_edges() == 4 + def test_nodes_have_label(): G = build_from_json(load_extraction()) assert G.nodes["n_transformer"]["label"] == "Transformer" + def test_edges_have_confidence(): - G = build_from_json(load_extraction()) + G = cast(nx.Graph, build_from_json(load_extraction())) data = G.edges["n_attention", "n_concept_attn"] assert data["confidence"] == "INFERRED" + def test_ambiguous_edge_preserved(): - G = build_from_json(load_extraction()) + G = cast(nx.Graph, build_from_json(load_extraction())) data = G.edges["n_layernorm", "n_concept_attn"] assert data["confidence"] == "AMBIGUOUS" + def test_legacy_node_source_canonicalized(): """Legacy 'source' key on nodes is renamed to 'source_file' before graph build.""" - ext = {"nodes": [{"id": "n1", "label": "A", "file_type": "code", "source": "a.py"}], - "edges": [], "input_tokens": 0, "output_tokens": 0} + ext = { + "nodes": [{"id": "n1", "label": "A", "file_type": "code", "source": "a.py"}], + "edges": [], + "input_tokens": 0, + "output_tokens": 0, + } G = build_from_json(ext) assert "source_file" in G.nodes["n1"] assert G.nodes["n1"]["source_file"] == "a.py" @@ -43,11 +64,24 @@ def test_legacy_node_source_canonicalized(): def test_legacy_edge_from_to_canonicalized(): """Legacy 'from'/'to' keys on edges are accepted alongside 'source'/'target'.""" - ext = {"nodes": [{"id": "n1", "label": "A", "file_type": "code", "source_file": "a.py"}, - {"id": "n2", "label": "B", "file_type": "code", "source_file": "b.py"}], - "edges": [{"from": "n1", "to": "n2", "relation": "calls", - "confidence": "EXTRACTED", "source_file": "a.py", "weight": 1.0}], - "input_tokens": 0, "output_tokens": 0} + ext = { + "nodes": [ + {"id": "n1", "label": "A", "file_type": "code", "source_file": "a.py"}, + {"id": "n2", "label": "B", "file_type": "code", "source_file": "b.py"}, + ], + "edges": [ + { + "from": "n1", + "to": "n2", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + "weight": 1.0, + } + ], + "input_tokens": 0, + "output_tokens": 0, + } G = build_from_json(ext) assert G.number_of_edges() == 1 @@ -56,11 +90,22 @@ def test_source_file_backslash_normalized(): """Windows backslash paths and POSIX paths for the same file must produce one node.""" extraction = { "nodes": [ - {"id": "n1", "label": "A", "file_type": "code", "source_file": "src\\middleware\\auth.py"}, - {"id": "n2", "label": "B", "file_type": "code", "source_file": "src/middleware/auth.py"}, + { + "id": "n1", + "label": "A", + "file_type": "code", + "source_file": "src\\middleware\\auth.py", + }, + { + "id": "n2", + "label": "B", + "file_type": "code", + "source_file": "src/middleware/auth.py", + }, ], "edges": [], - "input_tokens": 0, "output_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, } G = build_from_json(extraction) sources = {G.nodes[n]["source_file"] for n in G.nodes()} @@ -68,12 +113,27 @@ def test_source_file_backslash_normalized(): def test_build_merges_multiple_extractions(): - ext1 = {"nodes": [{"id": "n1", "label": "A", "file_type": "code", "source_file": "a.py"}], - "edges": [], "input_tokens": 0, "output_tokens": 0} - ext2 = {"nodes": [{"id": "n2", "label": "B", "file_type": "document", "source_file": "b.md"}], - "edges": [{"source": "n1", "target": "n2", "relation": "references", - "confidence": "INFERRED", "source_file": "b.md", "weight": 1.0}], - "input_tokens": 0, "output_tokens": 0} + ext1 = { + "nodes": [{"id": "n1", "label": "A", "file_type": "code", "source_file": "a.py"}], + "edges": [], + "input_tokens": 0, + "output_tokens": 0, + } + ext2 = { + "nodes": [{"id": "n2", "label": "B", "file_type": "document", "source_file": "b.md"}], + "edges": [ + { + "source": "n1", + "target": "n2", + "relation": "references", + "confidence": "INFERRED", + "source_file": "b.md", + "weight": 1.0, + } + ], + "input_tokens": 0, + "output_tokens": 0, + } G = build([ext1, ext2]) assert G.number_of_nodes() == 2 assert G.number_of_edges() == 1 @@ -191,8 +251,9 @@ def test_build_merge_preserves_call_edge_direction(tmp_path): # Verify direction is correct in the freshly written JSON. saved = json.loads(graph_path.read_text()) - saved_calls = [e for e in saved.get("links", saved.get("edges", [])) - if e.get("relation") == "calls"] + saved_calls = [ + e for e in saved.get("links", saved.get("edges", [])) if e.get("relation") == "calls" + ] assert len(saved_calls) == 1 assert saved_calls[0]["source"] == truth_src assert saved_calls[0]["target"] == truth_tgt @@ -203,8 +264,9 @@ def test_build_merge_preserves_call_edge_direction(tmp_path): # The calls edge must still go a -> b, not b -> a. reloaded = json.loads(graph_path.read_text()) - reloaded_calls = [e for e in reloaded.get("links", reloaded.get("edges", [])) - if e.get("relation") == "calls"] + reloaded_calls = [ + e for e in reloaded.get("links", reloaded.get("edges", [])) if e.get("relation") == "calls" + ] assert len(reloaded_calls) == 1 assert reloaded_calls[0]["source"] == truth_src, ( f"calls edge source flipped after build_merge round-trip: " @@ -280,6 +342,7 @@ def test_build_from_json_preserves_first_direction_on_bidirectional_pair(tmp_pat # whenever the loaded JSON has multigraph: true. Plain G.edges[u, v] crashes # on those with `ValueError: not enough values to unpack (expected 3, got 2)`. + def test_edge_data_simple_graph(): G = nx.Graph() G.add_edge("a", "b", relation="calls", confidence="EXTRACTED") @@ -367,12 +430,22 @@ def test_build_from_json_relativizes_absolute_source_file(tmp_path): abs_path = str(root / "docs" / "overview.md") extraction = { "nodes": [ - {"id": "overview_intro", "label": "Intro", "source_file": abs_path, "file_type": "document"}, + { + "id": "overview_intro", + "label": "Intro", + "source_file": abs_path, + "file_type": "document", + }, ], "edges": [ - {"source": "overview_intro", "target": "overview_intro", - "relation": "self", "confidence": "EXTRACTED", "confidence_score": 1.0, - "source_file": abs_path}, + { + "source": "overview_intro", + "target": "overview_intro", + "relation": "self", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": abs_path, + }, ], } G = build_from_json(extraction, root=root) @@ -398,7 +471,9 @@ def test_build_relativizes_absolute_source_file(tmp_path): def test_build_from_json_relative_source_file_unchanged(tmp_path): """Already-relative source_file paths must not be modified.""" extraction = { - "nodes": [{"id": "foo_bar", "label": "bar", "source_file": "src/foo.py", "file_type": "code"}], + "nodes": [ + {"id": "foo_bar", "label": "bar", "source_file": "src/foo.py", "file_type": "code"} + ], "edges": [], } G = build_from_json(extraction, root=tmp_path) @@ -468,3 +543,710 @@ def test_build_merge_rejects_oversized_existing_graph(monkeypatch, tmp_path): monkeypatch.setattr("graphify.security._MAX_GRAPH_FILE_BYTES", 8) with pytest.raises(ValueError, match="exceeds"): build_merge([], graph_path, dedup=False) + + +def _parallel_edge_extraction() -> dict: + return { + "nodes": [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "src/a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "src/b.py"}, + ], + "edges": [ + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + }, + { + "source": "a", + "target": "b", + "relation": "imports", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L2", + }, + ], + } + + +def test_default_build_stays_simple_when_parallel_edges_exist(): + G = build_from_json(_parallel_edge_extraction()) + + assert type(G) is nx.Graph + assert not G.is_multigraph() + assert G.number_of_edges("a", "b") == 1 + + +def test_multigraph_build_preserves_same_endpoint_different_relations(): + G = build_from_json(_parallel_edge_extraction(), multigraph=True) + + assert type(G) is nx.MultiDiGraph + assert G.number_of_edges("a", "b") == 2 + edge_records = list(G["a"]["b"].items()) + assert {data["relation"] for _key, data in edge_records} == {"calls", "imports"} + assert all(str(key).startswith("edge:v1:") for key, _data in edge_records) + assert all("key" not in data for _key, data in edge_records) + + +def test_multigraph_build_preserves_same_identity_except_source_location(): + extraction = _parallel_edge_extraction() + extraction["edges"][1].update( + { + "relation": "calls", + "source_location": "L20", + } + ) + + G = build_from_json(extraction, multigraph=True) + + assert G.number_of_edges("a", "b") == 2 + assert {data["source_location"] for data in G["a"]["b"].values()} == {"L10", "L20"} + + +def test_multigraph_build_collapses_exact_duplicates_with_diagnostic(): + extraction = _parallel_edge_extraction() + extraction["edges"].append(dict(extraction["edges"][0])) + + G = build_from_json(extraction, multigraph=True) + + assert G.number_of_edges("a", "b") == 2 + assert G.graph["graphify_multigraph_diagnostics"]["exact_duplicate_edges"] == 1 + + +def test_multigraph_build_preserves_non_exact_key_collisions_with_diagnostic(): + extraction = { + "nodes": [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "src/a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "src/b.py"}, + ], + "edges": [ + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "static", + }, + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "runtime", + }, + ], + } + + G = build_from_json(extraction, multigraph=True) + + assert G.number_of_edges("a", "b") == 2 + assert {data["context"] for data in G["a"]["b"].values()} == { + "static", + "runtime", + } + assert G.graph["graphify_multigraph_diagnostics"]["key_collision_edges"] == 1 + + +def test_multigraph_build_collapses_duplicates_after_collision_repair(): + extraction = { + "nodes": [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "src/a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "src/b.py"}, + ], + "edges": [ + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "static", + }, + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "runtime", + }, + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "runtime", + }, + ], + } + + G = build_from_json(extraction, multigraph=True) + + assert G.number_of_edges("a", "b") == 2 + assert {data["context"] for data in G["a"]["b"].values()} == { + "static", + "runtime", + } + assert G.graph["graphify_multigraph_diagnostics"] == { + "exact_duplicate_edges": 1, + "key_collision_edges": 1, + } + + +def test_multigraph_build_preserves_empty_string_schema_key(): + extraction = _parallel_edge_extraction() + extraction["edges"] = [dict(extraction["edges"][0], key="")] + + G = build_from_json(extraction, multigraph=True) + + assert list(G["a"]["b"].keys()) == [""] + + +def test_multigraph_build_normalizes_path_identity_fields_for_stable_key(tmp_path): + """Path objects survive coercion via the JSON 'default=str' path of json.dumps.""" + extraction = _parallel_edge_extraction() + absolute_source = tmp_path / "src" / "a.py" + extraction["edges"] = [ + { + **extraction["edges"][0], + "source_file": absolute_source, + "source_location": {"line": 10}, + } + ] + + G = build_from_json(extraction, root=tmp_path, multigraph=True) + + assert G.number_of_edges("a", "b") == 1 + assert next(iter(G["a"]["b"].keys())).startswith("edge:v1:") + assert next(iter(G["a"]["b"].values()))["source_file"] == "src/a.py" + + +def test_multigraph_build_skips_edge_with_non_json_serializable_attrs(capsys): + """Edges whose attrs cannot round-trip through JSON are skipped with a warning. + + Mutating attrs in place would silently change the user's stored value; + skipping with a warning preserves data integrity for surviving edges and + prevents later json.dump crashes during export. + """ + extraction = _parallel_edge_extraction() + extraction["edges"] = [ + { + **extraction["edges"][0], + "relation": {"calls", "uses"}, + } + ] + + G = build_from_json(extraction, multigraph=True) + + assert G.number_of_edges("a", "b") == 0 + captured = capsys.readouterr() + assert "non-JSON-serializable" in captured.err + + +@pytest.mark.parametrize("field", ["nodes", "edges"]) +def test_build_from_json_treats_non_list_node_or_edge_field_as_empty(field, capsys): + extraction = _parallel_edge_extraction() + extraction[field] = 123 + + G = build_from_json(extraction, multigraph=True) + + assert G.number_of_edges() == 0 + captured = capsys.readouterr() + assert f"extraction field '{field}' must be a list" in captured.err + + +def test_multigraph_collision_repair_keys_do_not_depend_on_edge_order(): + base_edges = [ + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "static", + }, + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "runtime", + }, + ] + + def keys_by_context(edges: list[dict]) -> dict[str, str]: + extraction = { + "nodes": [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "src/a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "src/b.py"}, + ], + "edges": edges, + } + G = build_from_json(extraction, multigraph=True) + return {data["context"]: key for key, data in G["a"]["b"].items()} + + forward = keys_by_context(base_edges) + reverse = keys_by_context(list(reversed(base_edges))) + + assert forward == reverse + assert all(":alt:" in key for key in forward.values()) + + +def test_multigraph_collision_repair_does_not_overwrite_explicit_key(): + runtime_attrs = { + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "runtime", + "_src": "a", + "_tgt": "b", + } + base_key = make_stable_key("calls", "src/a.py", "L10") + explicit_conflict_key = _make_collision_key(base_key, runtime_attrs) + edges = [ + { + "source": "a", + "target": "b", + "key": explicit_conflict_key, + "relation": "imports", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L2", + "context": "explicit", + }, + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "static", + }, + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + "context": "runtime", + }, + ] + + def contexts_by_key(edge_order: list[dict]) -> dict[str, str]: + extraction = { + "nodes": [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "src/a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "src/b.py"}, + ], + "edges": edge_order, + } + G = build_from_json(extraction, multigraph=True) + assert G.number_of_edges("a", "b") == 3 + return {key: data["context"] for key, data in G["a"]["b"].items()} + + forward = contexts_by_key(edges) + reverse = contexts_by_key(list(reversed(edges))) + + assert forward == reverse + assert forward[explicit_conflict_key] == "explicit" + runtime_keys = [key for key, context in forward.items() if context == "runtime"] + assert len(runtime_keys) == 1 + assert runtime_keys[0] != explicit_conflict_key + + +def test_multigraph_build_roundtrips_through_json_loader(tmp_path): + from graphify.export import to_json + from graphify.graph_loader import load_graph_file + + G = build_from_json(_parallel_edge_extraction(), multigraph=True) + graph_path = tmp_path / "graph.json" + + assert to_json(G, {}, str(graph_path), force=True) + data = json.loads(graph_path.read_text()) + loaded = load_graph_file(graph_path) + + assert data["multigraph"] is True + assert data["directed"] is True + assert len(data["links"]) == 2 + assert all("key" in link for link in data["links"]) + assert type(loaded) is nx.MultiDiGraph + assert loaded.number_of_edges("a", "b") == 2 + assert set(loaded["a"]["b"]) == {link["key"] for link in data["links"]} + + +def test_build_multigraph_merges_extractions_without_collapsing_parallel_edges(): + extraction = _parallel_edge_extraction() + + G = build( + [ + {"nodes": extraction["nodes"], "edges": [extraction["edges"][0]]}, + {"nodes": [], "edges": [extraction["edges"][1]]}, + ], + dedup=False, + multigraph=True, + ) + + assert type(G) is nx.MultiDiGraph + assert G.number_of_edges("a", "b") == 2 + + +def test_build_preserves_hashable_non_string_edge_endpoints(): + extraction = { + "nodes": [ + {"id": 1, "label": "A", "file_type": "code", "source_file": "src/a.py"}, + {"id": 2, "label": "B", "file_type": "code", "source_file": "src/b.py"}, + ], + "edges": [ + { + "source": 1, + "target": 2, + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + }, + ], + } + + G = build_from_json(extraction) + + assert G.has_edge(1, 2) + + +def test_build_skips_unhashable_edge_endpoints_without_crashing(capsys): + extraction = { + "nodes": [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "src/a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "src/b.py"}, + ], + "edges": [ + { + "source": "a", + "target": {"bad": "target"}, + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + }, + ], + } + + G = build_from_json(extraction) + captured = capsys.readouterr() + + assert G.number_of_edges() == 0 + assert "unhashable" in captured.err + + +def test_build_skips_unhashable_node_ids_without_crashing(capsys): + extraction = { + "nodes": [ + {"id": ["bad"], "label": "Bad", "file_type": "code", "source_file": "src/bad.py"}, + {"id": "ok", "label": "OK", "file_type": "code", "source_file": "src/ok.py"}, + ], + "edges": [], + } + + G = build_from_json(extraction) + captured = capsys.readouterr() + + assert list(G.nodes()) == ["ok"] + assert "Node 0 id is unhashable" in captured.err + + +def test_build_skips_malformed_nodes_without_crashing(capsys): + extraction = { + "nodes": [ + "bad-node", + {"label": "Missing ID", "file_type": "code", "source_file": "src/missing.py"}, + {"id": "ok", "label": "OK", "file_type": "code", "source_file": "src/ok.py"}, + ], + "edges": [], + } + + G = build_from_json(extraction) + captured = capsys.readouterr() + + assert list(G.nodes()) == ["ok"] + assert "Node 0 must be an object" in captured.err + + +def test_build_warns_when_skipping_unhashable_endpoint_without_node_ids(capsys): + extraction = { + "nodes": [], + "edges": [ + { + "source": ["bad"], + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + }, + ], + } + + G = build_from_json(extraction) + captured = capsys.readouterr() + + assert G.number_of_edges() == 0 + assert "skipped edge with unhashable source endpoint" in captured.err + + +def test_build_skips_malformed_edges_without_crashing(capsys): + extraction = { + "nodes": [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "src/a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "src/b.py"}, + ], + "edges": [ + 7, + { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "src/a.py", + "source_location": "L10", + }, + ], + } + + G = build_from_json(extraction) + captured = capsys.readouterr() + + assert G.number_of_edges() == 1 + assert "Edge 0 must be an object" in captured.err + + +def test_build_merge_rejects_multigraph_graph_json(tmp_path): + """build_merge must refuse a multigraph input rather than silently collapse parallel edges.""" + import json as _json + + graph_path = tmp_path / "graph.json" + graph_path.write_text( + _json.dumps( + { + "directed": True, + "multigraph": True, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [ + {"source": "a", "target": "b", "key": "k1", "relation": "calls"}, + {"source": "a", "target": "b", "key": "k2", "relation": "imports"}, + ], + } + ) + ) + + with pytest.raises(NotImplementedError, match="multigraph"): + build_merge([], graph_path=graph_path) + + +def test_build_merge_inherits_directed_from_saved_graph_json(tmp_path): + """build_merge with default args must preserve direction of a directed saved graph.""" + import json as _json + + graph_path = tmp_path / "graph.json" + graph_path.write_text( + _json.dumps( + { + "directed": True, + "multigraph": False, + "nodes": [ + {"id": "caller", "file_type": "code", "source_file": "a.py"}, + {"id": "callee", "file_type": "code", "source_file": "b.py"}, + ], + "links": [ + { + "source": "caller", + "target": "callee", + "relation": "calls", + "source_file": "a.py", + "_src": "caller", + "_tgt": "callee", + } + ], + } + ) + ) + + # No `directed=` arg passed — must inherit True from the saved file. + G = build_merge([], graph_path=graph_path) + assert G.is_directed(), "build_merge default-args must inherit directed=True from saved graph" + assert G.has_edge("caller", "callee") + assert not G.has_edge("callee", "caller") + + +def test_build_merge_directed_override_warns(tmp_path, capsys): + """Explicit directed=False against a directed saved graph emits a warning.""" + import json as _json + + graph_path = tmp_path / "graph.json" + graph_path.write_text( + _json.dumps( + { + "directed": True, + "multigraph": False, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [{"source": "a", "target": "b", "relation": "calls"}], + } + ) + ) + + G = build_merge([], graph_path=graph_path, directed=False) + captured = capsys.readouterr() + assert "overrides saved" in captured.err.lower() + assert not G.is_directed() + + +def test_build_merge_rejects_non_bool_multigraph_in_saved_graph(tmp_path): + """A saved graph.json with a non-bool 'multigraph' value must be rejected.""" + import json as _json + graph_path = tmp_path / "graph.json" + graph_path.write_text(_json.dumps({ + "directed": True, "multigraph": "false", + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [{"source": "a", "target": "b", "relation": "calls"}], + })) + with pytest.raises(TypeError, match="'multigraph' in .* must be a boolean"): + build_merge([], graph_path=graph_path) + + +def test_build_merge_rejects_non_bool_directed_in_saved_graph(tmp_path): + import json as _json + graph_path = tmp_path / "graph.json" + graph_path.write_text(_json.dumps({ + "directed": "true", "multigraph": False, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [{"source": "a", "target": "b", "relation": "calls"}], + })) + with pytest.raises(TypeError, match="'directed' in .* must be a boolean"): + build_merge([], graph_path=graph_path) + + +def test_simple_build_skips_edge_with_non_json_serializable_attrs(capsys): + """Same skip-and-warn policy applies to simple-graph builds.""" + extraction = _parallel_edge_extraction() + extraction["edges"] = [ + { + **extraction["edges"][0], + "relation": {"calls", "uses"}, + } + ] + G = build_from_json(extraction, multigraph=False) + assert G.number_of_edges("a", "b") == 0 + captured = capsys.readouterr() + assert "non-JSON-serializable" in captured.err + + +def test_build_skips_node_with_non_json_serializable_attrs(capsys): + """Nodes with non-JSON-serializable attrs are skipped with a warning.""" + extraction = { + "nodes": [ + {"id": "ok", "label": "OK", "file_type": "code", "source_file": "a.py"}, + { + "id": "bad", + "label": "Bad", + "file_type": "code", + "source_file": "b.py", + "tags": {"unhashable", "set"}, + }, + ], + "edges": [], + "input_tokens": 0, + "output_tokens": 0, + } + G = build_from_json(extraction) + assert "ok" in G.nodes + assert "bad" not in G.nodes + captured = capsys.readouterr() + assert "non-JSON-serializable" in captured.err + + +def test_build_strips_legacy_from_to_from_edge_attrs(): + """Legacy from/to keys must not survive into stored edge attrs after remap.""" + ext = { + "nodes": [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "b.py"}, + ], + "edges": [ + { + "from": "a", + "to": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + "weight": 1.0, + } + ], + "input_tokens": 0, + "output_tokens": 0, + } + G = cast(nx.Graph, build_from_json(ext)) + data = G.edges["a", "b"] + assert "from" not in data + assert "to" not in data + + +def test_multigraph_preserves_first_explicit_key_in_collision_group(): + """When multiple edges share an explicit user key, the first one preserves it.""" + extraction = { + "nodes": [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "b.py"}, + ], + "edges": [ + { + "source": "a", "target": "b", + "key": "user-key", + "relation": "calls", "confidence": "EXTRACTED", + "source_file": "a.py", "context": "first", + }, + { + "source": "a", "target": "b", + "key": "user-key", + "relation": "calls", "confidence": "EXTRACTED", + "source_file": "a.py", "context": "second", + }, + ], + "input_tokens": 0, "output_tokens": 0, + } + G = build_from_json(extraction, multigraph=True) + keys = set(G["a"]["b"].keys()) + assert "user-key" in keys, "First edge must retain the explicit user-supplied key" + assert len(keys) == 2, "Both edges must survive; second gets a repair key" diff --git a/tests/test_edge_identity.py b/tests/test_edge_identity.py new file mode 100644 index 000000000..f80efb345 --- /dev/null +++ b/tests/test_edge_identity.py @@ -0,0 +1,85 @@ +"""Tests for graphify.edge_identity — schema constants and stable key helpers.""" + +from __future__ import annotations + +from graphify.edge_identity import SCHEMA_KEY_FIELD, make_stable_key, strip_schema_key + + +def test_schema_key_field_constant(): + assert SCHEMA_KEY_FIELD == "key" + + +def test_make_stable_key_deterministic(): + k1 = make_stable_key("calls", "src/a.py", "L10") + k2 = make_stable_key("calls", "src/a.py", "L10") + assert k1 == k2 + assert isinstance(k1, str) + assert k1 # non-empty + + +def test_make_stable_key_all_none(): + k = make_stable_key(None, None, None) + assert isinstance(k, str) + assert k # non-empty — never crashes or returns empty string + + +def test_make_stable_key_differs_by_source_location(): + k1 = make_stable_key("calls", "src/a.py", "L10") + k2 = make_stable_key("calls", "src/a.py", "L20") + assert k1 != k2 + + +def test_make_stable_key_identical_fields_match(): + k1 = make_stable_key("imports", "graphify/build.py", "L42") + k2 = make_stable_key("imports", "graphify/build.py", "L42") + assert k1 == k2 + + +def test_strip_schema_key_removes_key_field(): + attrs = {"key": "calls:a.py:L1", "relation": "calls", "confidence": "EXTRACTED"} + key_val, cleaned = strip_schema_key(attrs) + assert key_val == "calls:a.py:L1" + assert "key" not in cleaned + assert cleaned["relation"] == "calls" + assert cleaned["confidence"] == "EXTRACTED" + + +def test_strip_schema_key_no_key_present(): + attrs = {"relation": "imports", "confidence_score": 1.0} + key_val, cleaned = strip_schema_key(attrs) + assert key_val is None + assert cleaned == attrs + assert "key" not in cleaned + + +def test_strip_schema_key_does_not_mutate_input(): + attrs = {"key": "k1", "relation": "calls"} + original = dict(attrs) + strip_schema_key(attrs) + assert attrs == original + + +# --------------------------------------------------------------------------- +# Blocker 1: delimiter-collision safety +# --------------------------------------------------------------------------- + + +def test_make_stable_key_no_delimiter_collision(): + # "a:b","c","d" must not hash the same as "a","b:c","d" + k1 = make_stable_key("a:b", "c", "d") + k2 = make_stable_key("a", "b:c", "d") + assert k1 != k2 + + +def test_make_stable_key_format_is_versioned(): + k = make_stable_key("calls", "a.py", "L1") + assert k.startswith("edge:v1:") + + +def test_make_stable_key_none_differs_from_empty_and_unknown(): + # make_stable_key(None, None, None) must not collide with + # make_stable_key("unknown", "", "") — None must serialize as JSON null, + # not be normalised to "unknown"/"" before hashing. + k_none = make_stable_key(None, None, None) + k_defaults = make_stable_key("unknown", "", "") + assert k_none != k_defaults diff --git a/tests/test_graph_loader.py b/tests/test_graph_loader.py new file mode 100644 index 000000000..4fcd3ce45 --- /dev/null +++ b/tests/test_graph_loader.py @@ -0,0 +1,557 @@ +"""Tests for graphify.graph_loader — schema-aware graph loading. + +Seven required PR 2 scenarios from the Wave 3 handoff guardrails. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import networkx as nx +import pytest + +from graphify.graph_loader import GRAPHIFY_PROFILE_KEY, load_graph + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +_NODES = [ + {"id": "a", "label": "A", "file_type": "code", "source_file": "a.py"}, + {"id": "b", "label": "B", "file_type": "code", "source_file": "b.py"}, +] + +_SIMPLE_EDGE = { + "source": "a", + "target": "b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "a.py", + "weight": 1.0, +} + +_KEYED_EDGE = {**_SIMPLE_EDGE, "key": "calls:a.py:L1"} + +_KEYED_EDGE_2 = { + "source": "a", + "target": "b", + "relation": "imports", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "a.py", + "key": "imports:a.py:L5", + "weight": 1.0, +} + + +def _simple_links() -> dict: + """Legacy simple JSON using 'links' key.""" + return {"nodes": _NODES, "links": [_SIMPLE_EDGE]} + + +def _simple_edges() -> dict: + """Modern simple JSON using 'edges' key.""" + return {"nodes": _NODES, "edges": [_SIMPLE_EDGE]} + + +def _multigraph_data() -> dict: + """Valid multigraph node-link JSON with two keyed parallel edges.""" + return { + "multigraph": True, + "nodes": _NODES, + "links": [_KEYED_EDGE, _KEYED_EDGE_2], + } + + +def _multigraph_missing_keys() -> dict: + """Multigraph JSON where edges lack 'key' fields.""" + edge_no_key = {k: v for k, v in _SIMPLE_EDGE.items() if k != "key"} + return {"multigraph": True, "nodes": _NODES, "links": [edge_no_key]} + + +# --------------------------------------------------------------------------- +# Scenario 1: legacy 'links' loads as nx.Graph +# --------------------------------------------------------------------------- + + +def test_load_graph_rejects_non_object_payload(): + with pytest.raises(TypeError, match="serialized graph data must be a JSON object"): + load_graph([]) + + +def test_load_graph_rejects_non_list_nodes(): + data = {**_simple_links(), "nodes": 123} + + with pytest.raises(TypeError, match="'nodes' must be a list, got int"): + load_graph(data) + + +def test_legacy_links_loads_as_simple_graph(): + G = load_graph(_simple_links()) + assert type(G) is nx.Graph + assert not G.is_multigraph() + assert G.number_of_nodes() == 2 + assert G.number_of_edges() == 1 + + +# --------------------------------------------------------------------------- +# Scenario 2: modern 'edges' loads as nx.Graph +# --------------------------------------------------------------------------- + + +def test_modern_edges_loads_as_simple_graph(): + G = load_graph(_simple_edges()) + assert type(G) is nx.Graph + assert not G.is_multigraph() + assert G.number_of_edges() == 1 + + +# --------------------------------------------------------------------------- +# Scenario 3: valid multigraph JSON with keyed parallel edges → nx.MultiDiGraph +# --------------------------------------------------------------------------- + + +def test_valid_multigraph_loads_as_multidigraph(): + G = load_graph(_multigraph_data()) + assert type(G) is nx.MultiDiGraph + assert G.is_multigraph() + assert G.number_of_nodes() == 2 + assert G.number_of_edges() == 2 # both parallel edges preserved + + +# --------------------------------------------------------------------------- +# Scenario 4: malformed multigraph (missing keys) repairs explicitly, not silently +# --------------------------------------------------------------------------- + + +def test_malformed_multigraph_missing_keys_repairs_explicitly(capsys): + G = load_graph(_multigraph_missing_keys()) + # Must produce a MultiDiGraph (not silently fall back to simple) + assert type(G) is nx.MultiDiGraph + assert G.number_of_edges() == 1 + # Must warn to stderr + captured = capsys.readouterr() + assert "missing" in captured.err.lower() or "key" in captured.err.lower() + + +# --------------------------------------------------------------------------- +# Scenario 5: edge 'key' is stripped from attrs — not stored as an edge attribute +# --------------------------------------------------------------------------- + + +def test_schema_key_stripped_from_edge_attrs(): + G = load_graph(_multigraph_data()) + assert isinstance(G, nx.MultiDiGraph) + for u, v, k, data in G.edges(keys=True, data=True): + assert "key" not in data, ( + f"Edge ({u},{v},key={k!r}) must not store 'key' inside its attrs dict" + ) + + +# --------------------------------------------------------------------------- +# Scenario 6: G.graph["graphify_profile"] is present after load +# --------------------------------------------------------------------------- + + +def test_graph_profile_metadata_round_trips(): + G = load_graph(_simple_links()) + assert GRAPHIFY_PROFILE_KEY in G.graph + profile = G.graph[GRAPHIFY_PROFILE_KEY] + assert isinstance(profile, dict) + assert "graph_type" in profile + + +def test_graph_profile_type_for_multidigraph(): + G = load_graph(_multigraph_data()) + assert G.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == "multidigraph" + + +def test_graph_profile_type_for_simple(): + G = load_graph(_simple_links()) + assert G.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == "simple" + + +# --------------------------------------------------------------------------- +# Scenario 7: capability probe failure raises clearly; simple loading unaffected +# --------------------------------------------------------------------------- + + +def test_capability_probe_failure_raises_clear_error(): + with patch( + "graphify.graph_loader.require_multigraph_capabilities", + side_effect=RuntimeError("MultiDiGraph not supported: simulated failure"), + ): + with pytest.raises(RuntimeError, match="MultiDiGraph not supported"): + load_graph(_multigraph_data(), require_capabilities=True) + + +def test_capability_probe_failure_does_not_affect_simple_load(): + with patch( + "graphify.graph_loader.require_multigraph_capabilities", + side_effect=RuntimeError("should not be called"), + ): + # Simple JSON must not trigger the capability probe at all + G = load_graph(_simple_links(), require_capabilities=True) + assert type(G) is nx.Graph + + +# --------------------------------------------------------------------------- +# Blocker 2: missing-key repair must preserve distinct parallel edges +# --------------------------------------------------------------------------- + + +def _two_missing_key_parallel_edges() -> dict: + """Multigraph with two missing-key edges sharing relation/file but different attrs.""" + return { + "multigraph": True, + "nodes": _NODES, + "links": [ + { + "source": "a", + "target": "b", + "relation": "calls", + "source_file": "a.py", + "confidence": "EXTRACTED", + "weight": 1.0, + "context": "one", + }, + { + "source": "a", + "target": "b", + "relation": "calls", + "source_file": "a.py", + "confidence": "EXTRACTED", + "weight": 1.0, + "context": "two", + }, + ], + } + + +def test_missing_key_repair_preserves_distinct_parallel_edges(capsys): + G = load_graph(_two_missing_key_parallel_edges()) + assert type(G) is nx.MultiDiGraph + assert G.number_of_edges() == 2, ( + f"Both missing-key parallel edges must survive repair; got {G.number_of_edges()}" + ) + captured = capsys.readouterr() + assert "missing" in captured.err.lower() or "key" in captured.err.lower() + + +# --------------------------------------------------------------------------- +# Blocker 3: simple loader must respect serialized directedness +# --------------------------------------------------------------------------- + + +def test_directed_true_loads_as_digraph(): + data = { + "directed": True, + "multigraph": False, + "nodes": _NODES, + "edges": [_SIMPLE_EDGE], + } + G = load_graph(data) + assert type(G) is nx.DiGraph + + +def test_directed_false_explicitly_loads_as_graph(): + data = { + "directed": False, + "multigraph": False, + "nodes": _NODES, + "edges": [_SIMPLE_EDGE], + } + G = load_graph(data) + assert type(G) is nx.Graph + + +def test_directed_true_profile_graph_type(): + data = { + "directed": True, + "multigraph": False, + "nodes": _NODES, + "edges": [_SIMPLE_EDGE], + } + G = load_graph(data) + assert G.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == "digraph" + + +# --------------------------------------------------------------------------- +# Blocker 4: malformed JSON must fail cleanly or skip under documented policy +# --------------------------------------------------------------------------- + + +def test_non_dict_edge_entries_are_skipped(): + data = {"nodes": _NODES, "edges": ["not-a-dict", None, 42]} + G = load_graph(data) + assert G.number_of_edges() == 0 + + +def test_edges_value_not_a_list_raises(): + data = {"nodes": _NODES, "edges": "not-a-list"} + with pytest.raises((TypeError, ValueError)): + load_graph(data) + + +def test_non_dict_graphify_profile_is_ignored(): + data = { + "nodes": _NODES, + "edges": [_SIMPLE_EDGE], + GRAPHIFY_PROFILE_KEY: "bad-profile", + } + G = load_graph(data) + assert isinstance(G.graph[GRAPHIFY_PROFILE_KEY], dict) + assert "graph_type" in G.graph[GRAPHIFY_PROFILE_KEY] + + +def test_edge_missing_source_or_target_skipped(): + data = { + "nodes": _NODES, + "edges": [ + {"target": "b", "relation": "calls"}, + {"source": "a", "relation": "calls"}, + ], + } + G = load_graph(data) + assert G.number_of_edges() == 0 + + +# --------------------------------------------------------------------------- +# Non-string multigraph key values must raise before NetworkX sees them +# --------------------------------------------------------------------------- + + +def _multigraph_with_key(key_value: object) -> dict: + return { + "multigraph": True, + "nodes": _NODES, + "links": [{**_SIMPLE_EDGE, "key": key_value}], + } + + +def test_multigraph_list_key_raises(): + with pytest.raises((TypeError, ValueError)): + load_graph(_multigraph_with_key(["bad"])) + + +def test_multigraph_dict_key_raises(): + with pytest.raises((TypeError, ValueError)): + load_graph(_multigraph_with_key({"bad": 1})) + + +def test_multigraph_int_key_raises(): + with pytest.raises((TypeError, ValueError)): + load_graph(_multigraph_with_key(123)) + + +def test_load_simple_edge_with_empty_string_source_not_shadowed_by_from(): + # An edge with source="" AND from="a" must not silently use "from" as the + # source — an explicitly-set empty source means the edge is invalid. + data = { + "nodes": _NODES, + "links": [{"source": "", "from": "a", "target": "b", "relation": "calls"}], + } + G = load_graph(data) + assert G.number_of_edges() == 0 + + +def test_load_simple_edge_with_from_key_loaded(): + # Edges using legacy "from"/"to" keys should load correctly as long as + # the IDs are non-empty and present in the node set. + data = { + "nodes": _NODES, + "links": [{"from": "a", "to": "b", "relation": "calls"}], + } + G = load_graph(data) + assert G.number_of_edges() == 1 + + +def test_load_simple_preserves_falsy_hashable_ids(): + """Falsy-but-hashable node IDs like 0 or False must survive the loader.""" + data = { + "directed": False, + "multigraph": False, + "nodes": [{"id": 0}, {"id": ""}, {"id": "x"}], + "links": [ + {"source": 0, "target": "", "relation": "calls"}, + {"source": 0, "target": "x", "relation": "imports"}, + ], + } + G = load_graph(data) + assert G.number_of_nodes() == 3 + assert G.number_of_edges() == 2 + assert G.has_edge(0, "") + assert G.has_edge(0, "x") + + +def test_load_directed_preserves_falsy_hashable_ids(): + data = { + "directed": True, + "multigraph": False, + "nodes": [{"id": 0}, {"id": "y"}], + "links": [{"source": 0, "target": "y", "relation": "calls"}], + } + G = load_graph(data) + assert G.number_of_edges() == 1 + assert G.has_edge(0, "y") + + +def test_load_multigraph_preserves_falsy_hashable_ids(): + data = { + "directed": True, + "multigraph": True, + "nodes": [{"id": 0}, {"id": 1}], + "links": [ + {"source": 0, "target": 1, "key": "k1", "relation": "calls"}, + {"source": 0, "target": 1, "key": "k2", "relation": "imports"}, + ], + } + G = load_graph(data) + assert G.number_of_edges() == 2 + assert G.has_edge(0, 1) + + +def test_graph_attributes_round_trip_through_node_link_data(): + """G.graph[...] attrs must survive node_link_data → load_graph round-trip. + + NetworkX serializes graph-level metadata under data["graph"]; the loader + must read from there, not only from top-level keys. + """ + import networkx as nx + from networkx.readwrite import json_graph + + G_out = nx.DiGraph() + G_out.add_node("a") + G_out.add_node("b") + G_out.add_edge("a", "b", relation="calls") + G_out.graph["graphify_profile"] = {"graph_type": "digraph", "extra": "value"} + G_out.graph["hyperedges"] = [{"members": ["a", "b"]}] + G_out.graph["graphify_multigraph_diagnostics"] = {"collapsed": 0} + + data = json_graph.node_link_data(G_out, edges="links") + G_in = load_graph(data) + + assert G_in.graph["graphify_profile"]["extra"] == "value" + assert G_in.graph["hyperedges"] == [{"members": ["a", "b"]}] + assert G_in.graph["graphify_multigraph_diagnostics"] == {"collapsed": 0} + + +def test_graph_attributes_round_trip_through_multigraph_node_link_data(): + """Same round-trip guarantee for multigraph exports.""" + import networkx as nx + from networkx.readwrite import json_graph + + G_out = nx.MultiDiGraph() + G_out.add_node("a") + G_out.add_node("b") + G_out.add_edge("a", "b", key="k1", relation="calls") + G_out.add_edge("a", "b", key="k2", relation="imports") + G_out.graph["graphify_profile"] = {"graph_type": "multidigraph"} + G_out.graph["graphify_multigraph_diagnostics"] = {"exact_duplicates": 0} + + data = json_graph.node_link_data(G_out, edges="links") + G_in = load_graph(data, require_capabilities=False) + + assert G_in.graph["graphify_profile"]["graph_type"] == "multidigraph" + assert G_in.graph["graphify_multigraph_diagnostics"] == {"exact_duplicates": 0} + assert G_in.number_of_edges() == 2 + + +def test_load_skips_unhashable_node_ids(capsys): + """Corrupted graph.json with unhashable node ids must not crash; skip + warn.""" + data = { + "directed": True, + "multigraph": False, + "nodes": [{"id": "ok"}, {"id": ["unhashable"]}, {"id": {"also": "unhashable"}}], + "links": [{"source": "ok", "target": "ok", "relation": "self"}], + } + G = load_graph(data) + captured = capsys.readouterr() + assert G.number_of_nodes() == 1 + assert "unhashable" in captured.err.lower() + + +def test_load_skips_edges_with_unhashable_endpoints(): + """Edges with unhashable source/target must be skipped, not raise TypeError.""" + data = { + "directed": True, + "multigraph": False, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [ + {"source": "a", "target": "b", "relation": "calls"}, + {"source": ["unhashable"], "target": "b", "relation": "bogus"}, + {"source": "a", "target": {"also": "unhashable"}, "relation": "bogus"}, + ], + } + G = load_graph(data) + assert G.number_of_edges() == 1 + + +def test_load_multigraph_skips_unhashable_endpoints(): + """Same protection in the multigraph loader.""" + data = { + "directed": True, + "multigraph": True, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [ + {"source": "a", "target": "b", "key": "k1", "relation": "calls"}, + {"source": ["bad"], "target": "b", "key": "k2", "relation": "calls"}, + ], + } + G = load_graph(data, require_capabilities=False) + assert G.number_of_edges() == 1 + + +def test_load_multigraph_duplicate_keys_repaired_not_overwritten(capsys): + """Two parallel edges with same (src, tgt, key) but different attrs must both survive.""" + data = { + "directed": True, + "multigraph": True, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [ + {"source": "a", "target": "b", "key": "same", "relation": "calls", "context": "one"}, + {"source": "a", "target": "b", "key": "same", "relation": "calls", "context": "two"}, + ], + } + G = load_graph(data, require_capabilities=False) + assert G.number_of_edges() == 2, "duplicate-key edges must both be preserved via repair keys" + captured = capsys.readouterr() + assert "duplicate" in captured.err.lower() + + +def test_load_graph_rejects_non_bool_multigraph_field(): + """String 'false' or other non-bool 'multigraph' must be rejected, not coerced.""" + data = {**_simple_links(), "multigraph": "false"} + with pytest.raises(TypeError, match="'multigraph' must be a boolean"): + load_graph(data) + + +def test_load_graph_rejects_non_bool_directed_field(): + data = {**_simple_links(), "directed": "true"} + with pytest.raises(TypeError, match="'directed' must be a boolean"): + load_graph(data) + + +def test_load_multigraph_with_omitted_directed_does_not_warn(capsys): + """Missing 'directed' alongside 'multigraph: true' must not trigger the false warning.""" + data = { + "multigraph": True, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [{"source": "a", "target": "b", "key": "k", "relation": "calls"}], + } + load_graph(data, require_capabilities=False) + captured = capsys.readouterr() + assert "multigraph=true but directed=false" not in captured.err + + +def test_load_graph_overwrites_stale_graph_type_in_profile(): + """Stale graph_type from serialized profile must not survive when loading.""" + data = { + "multigraph": True, + "nodes": [{"id": "a"}, {"id": "b"}], + "links": [{"source": "a", "target": "b", "key": "k", "relation": "calls"}], + "graph": {"graphify_profile": {"graph_type": "simple"}}, + } + G = load_graph(data, require_capabilities=False) + assert G.graph[GRAPHIFY_PROFILE_KEY]["graph_type"] == "multidigraph" diff --git a/tests/test_multigraph_diagnostics.py b/tests/test_multigraph_diagnostics.py index 8c39b8e23..6c49c58cf 100644 --- a/tests/test_multigraph_diagnostics.py +++ b/tests/test_multigraph_diagnostics.py @@ -147,7 +147,10 @@ def test_diagnose_extraction_handles_malformed_shapes_without_crashing() -> None assert summary["missing_endpoint_edges"] == 1 assert summary["dangling_endpoint_edges"] == 2 assert summary["valid_candidate_edges"] == 1 - assert summary["post_build_error"].startswith("TypeError:") + assert summary["post_build_graph_type"] == "DiGraph" + assert summary["post_build_node_count"] == 2 + assert summary["post_build_edge_count"] == 1 + assert summary["post_build_error"] == "" def test_diagnose_extraction_handles_non_list_nodes_and_edges() -> None: @@ -228,7 +231,8 @@ def test_format_diagnostic_report_includes_build_and_suppression_errors( report = format_diagnostic_report(summary) - assert "post_build_error: TypeError:" in report + assert "post_build_error:" not in report + assert "post_build_graph_type: DiGraph" in report assert "producer_suppression_error: file not found" in report diff --git a/tests/test_projections.py b/tests/test_projections.py new file mode 100644 index 000000000..140274eb5 --- /dev/null +++ b/tests/test_projections.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import networkx as nx +import pytest +from typing import Any, cast + +from graphify.projections import ( + distinct_neighbor_degree, + edge_records_between, + edge_summary_between, + normalize_to_multidigraph, + project_for_callflow, + project_for_community, + project_for_context, + project_for_path, +) + + +def _parallel_graph() -> nx.MultiDiGraph: + graph = nx.MultiDiGraph() + graph.graph["graphify_profile"] = "test-profile" + graph.add_node("a", label="A") + graph.add_node("b", label="B") + graph.add_node("c", label="C") + graph.add_edge( + "a", + "b", + key="calls-low", + relation="calls", + confidence="INFERRED", + confidence_score=0.4, + source_file="src/a.py", + source_location="L10", + context="code", + ) + graph.add_edge( + "a", + "b", + key="imports-high", + relation="imports", + confidence="EXTRACTED", + confidence_score=0.9, + source_file="src/a.py", + source_location="L2", + context="code", + ) + graph.add_edge( + "b", + "a", + key="returns", + relation="returns", + confidence="AMBIGUOUS", + confidence_score=0.2, + source_file="src/b.py", + source_location="L5", + context="runtime", + ) + graph.add_edge( + "b", + "c", + key="calls-c", + relation="calls", + confidence="EXTRACTED", + confidence_score=1.0, + source_file="src/b.py", + source_location="L7", + context="code", + ) + graph.add_edge("c", "c", key="self", relation="calls", confidence="EXTRACTED") + return graph + + +def test_project_for_community_returns_simple_weighted_copy() -> None: + projected = project_for_community(_parallel_graph(), weight_mode="count") + + assert type(projected) is nx.Graph + assert projected.graph["graphify_profile"] == "test-profile" + assert set(projected.nodes) == {"a", "b", "c"} + assert not projected.has_edge("c", "c") + assert projected["a"]["b"]["weight"] == 3.0 + assert projected["a"]["b"]["parallel_edge_count"] == 3 + assert projected["b"]["c"]["weight"] == 1.0 + + +def test_project_for_community_supports_confidence_and_sum_weight_modes() -> None: + graph = _parallel_graph() + + by_confidence = project_for_community(graph, weight_mode="confidence") + by_sum = project_for_community(graph, weight_mode="sum") + + assert by_confidence["a"]["b"]["weight"] == 0.9 + assert by_confidence["a"]["b"]["relation"] == "imports" + assert by_sum["a"]["b"]["weight"] == pytest.approx(1.5) + with pytest.raises(ValueError, match="weight_mode"): + project_for_community(graph, weight_mode=cast(Any, "unknown")) + + +def test_project_for_path_uses_simple_graph_not_multigraph_view() -> None: + projected = project_for_path(_parallel_graph()) + + assert type(projected) is nx.Graph + assert not projected.is_multigraph() + assert projected.number_of_edges("a", "b") == 1 + assert nx.shortest_path(projected, "a", "c") == ["a", "b", "c"] + + +def test_project_for_callflow_preserves_direction_and_filters_relations() -> None: + projected = project_for_callflow(_parallel_graph(), relations=frozenset({"calls"})) + + assert type(projected) is nx.DiGraph + assert set(projected.edges()) == {("a", "b"), ("b", "c")} + assert projected["a"]["b"]["relation"] == "calls" + + +def test_project_for_callflow_recovers_src_tgt_from_undirected_edges() -> None: + graph = nx.Graph() + graph.add_node("display_a") + graph.add_node("display_b") + graph.add_edge("display_a", "display_b", _src="real_src", _tgt="real_tgt", relation="calls") + + projected = project_for_callflow(graph) + + assert set(projected.edges()) == {("real_src", "real_tgt")} + assert projected["real_src"]["real_tgt"]["relation"] == "calls" + + +def test_project_for_context_preserves_multigraph_type_keys_and_metadata() -> None: + projected = project_for_context(_parallel_graph(), contexts=["code"]) + + assert isinstance(projected, nx.MultiDiGraph) + assert projected.graph["graphify_profile"] == "test-profile" + assert set(projected["a"]["b"]) == {"calls-low", "imports-high"} + assert "returns" not in projected.get_edge_data("b", "a", default={}) + + +def test_project_for_context_none_returns_copy_not_original() -> None: + graph = _parallel_graph() + + projected = project_for_context(graph) + + assert projected is not graph + assert isinstance(projected, nx.MultiDiGraph) + assert projected.number_of_edges() == graph.number_of_edges() + + +def test_project_for_context_empty_filter_is_noop_copy() -> None: + graph = _parallel_graph() + + projected = project_for_context(graph, contexts=[]) + + assert projected is not graph + assert projected.number_of_edges() == graph.number_of_edges() + + +def test_edge_records_between_returns_copies_from_both_directions() -> None: + graph = _parallel_graph() + + records = edge_records_between(graph, "a", "b") + + assert [record["relation"] for record in records] == ["imports", "calls", "returns"] + records[0]["relation"] = "mutated" + assert graph["a"]["b"]["imports-high"]["relation"] == "imports" + + +def test_edge_summary_between_counts_and_picks_representative() -> None: + summary = edge_summary_between(_parallel_graph(), "a", "b") + + assert summary["count"] == 3 + assert summary["relations"] == ["calls", "imports", "returns"] + assert summary["confidences"] == ["AMBIGUOUS", "EXTRACTED", "INFERRED"] + assert summary["representative"]["relation"] == "imports" + + +def test_distinct_neighbor_degree_does_not_count_parallel_edges() -> None: + graph = _parallel_graph() + + assert graph.degree("a") == 3 + assert distinct_neighbor_degree(graph, "a") == 1 + assert distinct_neighbor_degree(graph, "missing") == 0 + + +def test_normalize_to_multidigraph_preserves_parallel_keys_and_simple_edges() -> None: + graph = nx.MultiGraph() + graph.graph["name"] = "mixed" + graph.add_node("a", label="A") + graph.add_node("b", label="B") + graph.add_edge("a", "b", key="one", relation="calls") + graph.add_edge("a", "b", key="two", relation="imports") + + normalized = normalize_to_multidigraph(graph) + + assert isinstance(normalized, nx.MultiDiGraph) + assert normalized.graph["name"] == "mixed" + assert set(normalized["a"]["b"]) == {"one", "two"} + + simple = nx.Graph() + simple.add_edge("x", "y", relation="uses") + simple_normalized = normalize_to_multidigraph(simple) + + assert isinstance(simple_normalized, nx.MultiDiGraph) + assert simple_normalized.number_of_edges("x", "y") == 1 + assert next(iter(simple_normalized["x"]["y"].values()))["relation"] == "uses" diff --git a/tests/test_validate.py b/tests/test_validate.py index 396e90c8c..e5f9cd50f 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -7,26 +7,37 @@ {"id": "n2", "label": "Bar", "file_type": "document", "source_file": "bar.md"}, ], "edges": [ - {"source": "n1", "target": "n2", "relation": "references", - "confidence": "EXTRACTED", "source_file": "foo.py", "weight": 1.0}, + { + "source": "n1", + "target": "n2", + "relation": "references", + "confidence": "EXTRACTED", + "source_file": "foo.py", + "weight": 1.0, + }, ], } + def test_valid_passes(): assert validate_extraction(VALID) == [] + def test_missing_nodes_key(): errors = validate_extraction({"edges": []}) assert any("nodes" in e for e in errors) + def test_missing_edges_key(): errors = validate_extraction({"nodes": []}) assert any("edges" in e for e in errors) + def test_not_a_dict(): errors = validate_extraction([]) assert len(errors) == 1 + def test_invalid_file_type(): data = { "nodes": [{"id": "n1", "label": "X", "file_type": "video", "source_file": "x.mp4"}], @@ -35,6 +46,7 @@ def test_invalid_file_type(): errors = validate_extraction(data) assert any("file_type" in e for e in errors) + def test_invalid_confidence(): data = { "nodes": [ @@ -42,35 +54,87 @@ def test_invalid_confidence(): {"id": "n2", "label": "B", "file_type": "code", "source_file": "b.py"}, ], "edges": [ - {"source": "n1", "target": "n2", "relation": "calls", - "confidence": "CERTAIN", "source_file": "a.py"}, + { + "source": "n1", + "target": "n2", + "relation": "calls", + "confidence": "CERTAIN", + "source_file": "a.py", + }, ], } errors = validate_extraction(data) assert any("confidence" in e for e in errors) + def test_dangling_edge_source(): data = { "nodes": [{"id": "n1", "label": "A", "file_type": "code", "source_file": "a.py"}], "edges": [ - {"source": "missing_id", "target": "n1", "relation": "calls", - "confidence": "EXTRACTED", "source_file": "a.py"}, + { + "source": "missing_id", + "target": "n1", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + }, ], } errors = validate_extraction(data) assert any("source" in e and "missing_id" in e for e in errors) + def test_dangling_edge_target(): data = { "nodes": [{"id": "n1", "label": "A", "file_type": "code", "source_file": "a.py"}], "edges": [ - {"source": "n1", "target": "ghost", "relation": "calls", - "confidence": "EXTRACTED", "source_file": "a.py"}, + { + "source": "n1", + "target": "ghost", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + }, ], } errors = validate_extraction(data) assert any("target" in e and "ghost" in e for e in errors) + +def test_unhashable_edge_source_reported_without_node_ids(): + data = { + "nodes": [], + "edges": [ + { + "source": ["bad"], + "target": "n1", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + }, + ], + } + errors = validate_extraction(data) + assert any("source is unhashable" in e for e in errors) + + +def test_unhashable_edge_target_reported_without_node_ids(): + data = { + "nodes": [], + "edges": [ + { + "source": "n1", + "target": {"bad": "target"}, + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "a.py", + }, + ], + } + errors = validate_extraction(data) + assert any("target is unhashable" in e for e in errors) + + def test_missing_node_field(): data = { "nodes": [{"id": "n1", "label": "A", "source_file": "a.py"}], # missing file_type @@ -79,9 +143,24 @@ def test_missing_node_field(): errors = validate_extraction(data) assert any("file_type" in e for e in errors) + def test_assert_valid_raises_on_errors(): with pytest.raises(ValueError, match="error"): assert_valid({"nodes": [], "edges": [], "oops": True, **{"nodes": "bad"}}) + def test_assert_valid_passes_silently(): assert_valid(VALID) # should not raise + + +def test_validate_extraction_does_not_typeerror_on_non_list_nodes(): + """validate_extraction must report 'nodes must be a list' without raising TypeError.""" + from graphify.validate import validate_extraction + errors = validate_extraction({"nodes": 123, "edges": []}) + assert any("'nodes' must be a list" in e for e in errors) + + +def test_validate_extraction_does_not_typeerror_on_non_list_edges(): + from graphify.validate import validate_extraction + errors = validate_extraction({"nodes": [], "edges": 42}) + assert any("'edges' must be a list" in e for e in errors) From b7e05f75257b3178e9eeaf51e17cc6b50c0d08c1 Mon Sep 17 00:00:00 2001 From: hypnwtykvmpr Date: Wed, 27 May 2026 15:33:28 -0500 Subject: [PATCH 02/21] chore: no-waiver lint/type/security cleanup + upstream v8 rebase Rebased onto upstream/v8 (740382a). Conflict in graphify/extract.py resolved preserving both upstream (TypeScript abstract class, C# base-list, Java interface inheritance, defusedxml) and local multigraph behavior. Full-repo ruff/pyright/security pass: 0 errors, 0 warnings, all .AUDIT gates clean. Added --no-viz support to graphify update. Raised AUDIT_COPILOT_MAX_DIFF_BYTES default from 120KB to 2MB. Updated AGENTS.md (no-waiver rule, conflict rule, removed stale memory block) and added CLAUDE.md with durable project policy. 1507 passed, ruff clean, pyright clean. gost --- .AUDIT/copilot-local-review.sh | 252 ++ CLAUDE.md | 32 + graphify/__init__.py | 1 + graphify/__main__.py | 957 ++++-- graphify/affected.py | 7 +- graphify/analyze.py | 297 +- graphify/benchmark.py | 30 +- graphify/cache.py | 10 +- graphify/callflow_html.py | 616 +++- graphify/cluster.py | 17 +- graphify/detect.py | 389 ++- graphify/export.py | 323 +- graphify/extract.py | 3591 ++++++++++++++------- graphify/global_graph.py | 7 +- graphify/google_workspace.py | 27 +- graphify/graph_loader.py | 7 +- graphify/hooks.py | 32 +- graphify/ingest.py | 55 +- graphify/llm.py | 76 +- graphify/mcp_ingest.py | 63 +- graphify/prs.py | 271 +- graphify/report.py | 61 +- graphify/security.py | 48 +- graphify/semantic_cleanup.py | 4 +- graphify/serve.py | 281 +- graphify/symbol_resolution.py | 25 +- graphify/transcribe.py | 33 +- graphify/tree_html.py | 43 +- graphify/watch.py | 208 +- graphify/wiki.py | 26 +- pyproject.toml | 4 +- tests/bench_extract.py | 12 +- tests/test_affected_cli.py | 16 +- tests/test_analyze.py | 473 ++- tests/test_astro_extraction.py | 1 + tests/test_benchmark.py | 54 +- tests/test_build.py | 95 +- tests/test_cache.py | 11 +- tests/test_callflow_html.py | 105 +- tests/test_charmap_encoding.py | 122 +- tests/test_chunking.py | 46 +- tests/test_claude_cli_backend.py | 63 +- tests/test_claude_md.py | 11 +- tests/test_cli_export.py | 41 +- tests/test_cluster.py | 9 +- tests/test_confidence.py | 69 +- tests/test_dedup.py | 19 +- tests/test_detect.py | 100 +- tests/test_devin.py | 30 +- tests/test_dotnet.py | 16 +- tests/test_explain_cli.py | 66 +- tests/test_export.py | 38 +- tests/test_extract.py | 237 +- tests/test_extract_cli.py | 26 +- tests/test_global_graph.py | 89 +- tests/test_google_workspace.py | 1 - tests/test_hooks.py | 7 +- tests/test_hypergraph.py | 68 +- tests/test_import_extension_resolution.py | 202 +- tests/test_incremental.py | 2 +- tests/test_ingest.py | 6 +- tests/test_install.py | 66 +- tests/test_install_strings.py | 8 +- tests/test_install_upgrade.py | 19 +- tests/test_js_import_resolution.py | 35 +- tests/test_languages.py | 392 ++- tests/test_llm_backends.py | 120 +- tests/test_llm_parser.py | 2 - tests/test_mcp_ingest.py | 136 +- tests/test_multilang.py | 84 +- tests/test_ollama.py | 6 +- tests/test_pascal.py | 104 +- tests/test_path_cli.py | 36 +- tests/test_pipeline.py | 22 +- tests/test_prs.py | 41 +- tests/test_python_import_resolution.py | 19 +- tests/test_query_cli.py | 1 + tests/test_rationale.py | 86 +- tests/test_report.py | 50 +- tests/test_security.py | 34 +- tests/test_semantic_similarity.py | 84 +- tests/test_serve.py | 74 +- tests/test_transcribe.py | 10 +- tests/test_validate.py | 4 +- tests/test_watch.py | 65 +- tests/test_wiki.py | 23 +- uv.lock | 10 +- 87 files changed, 8051 insertions(+), 3308 deletions(-) create mode 100755 .AUDIT/copilot-local-review.sh create mode 100644 CLAUDE.md diff --git a/.AUDIT/copilot-local-review.sh b/.AUDIT/copilot-local-review.sh new file mode 100755 index 000000000..46613fbac --- /dev/null +++ b/.AUDIT/copilot-local-review.sh @@ -0,0 +1,252 @@ +#!/usr/bin/env bash +set -uo pipefail +# Owner: Codex +# +# Private local gate: run GitHub Copilot CLI against the staged diff before a +# local commit. This is an early-warning review, not a replacement for the +# origin/upstream PR review gate. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +cd "$ROOT" || exit 2 + +if [[ -f "$ROOT/.venv/bin/activate" ]]; then + # shellcheck disable=SC1091 + source "$ROOT/.venv/bin/activate" +fi + +if [[ "${AUDIT_PRIVATE_GUARD:-1}" != "0" && -x "$SCRIPT_DIR/private-guard.sh" ]]; then + "$SCRIPT_DIR/private-guard.sh" --quiet || exit "$?" +fi + +usage() { + cat <<'EOF' +Usage: .AUDIT/copilot-local-review.sh [--cached|--worktree|--base ] [--advisory] [--max-diff-bytes ] + +Runs GitHub Copilot CLI against a local diff and blocks when Copilot reports +actionable findings. + +Modes: + --cached review staged changes; default and intended for pre-commit + --worktree review unstaged worktree changes + --base review changes from merge-base(, HEAD) to HEAD + --advisory always exit 0 after saving/reporting review output + +Environment: + AUDIT_COPILOT_MAX_DIFF_BYTES default 120000; local review blocks above this + AUDIT_COPILOT_REPORT_DIR default .AUDIT/reports + +Output contract: + Copilot must emit exactly one decision line: + LOCAL_COPILOT_REVIEW_DECISION: PASS + LOCAL_COPILOT_REVIEW_DECISION: BLOCK + +PASS means no actionable correctness/security/regression/test issue was found. +BLOCK or ambiguous output stops the commit gate. +EOF +} + +mode="cached" +base_ref="" +advisory=0 +max_diff_bytes="${AUDIT_COPILOT_MAX_DIFF_BYTES:-2000000}" + +while [[ $# -gt 0 ]]; do + case "$1" in + --cached) + mode="cached" + ;; + --worktree) + mode="worktree" + ;; + --base) + if [[ $# -lt 2 ]]; then + echo "[copilot-local-review] USAGE_ERROR: --base requires a ref" >&2 + exit 2 + fi + mode="base" + base_ref="$2" + shift + ;; + --base=*) + mode="base" + base_ref="${1#--base=}" + ;; + --advisory) + advisory=1 + ;; + --max-diff-bytes) + if [[ $# -lt 2 ]]; then + echo "[copilot-local-review] USAGE_ERROR: --max-diff-bytes requires a number" >&2 + exit 2 + fi + max_diff_bytes="$2" + shift + ;; + --max-diff-bytes=*) + max_diff_bytes="${1#--max-diff-bytes=}" + ;; + --help|-h) + usage + exit 0 + ;; + *) + echo "[copilot-local-review] USAGE_ERROR: unknown option '$1'" >&2 + usage >&2 + exit 2 + ;; + esac + shift +done + +case "$max_diff_bytes" in + ''|*[!0-9]*) + echo "[copilot-local-review] USAGE_ERROR: max diff bytes must be a non-negative integer" >&2 + exit 2 + ;; +esac + +if ! command -v copilot >/dev/null 2>&1; then + echo "[copilot-local-review] BLOCKED: GitHub Copilot CLI is not installed or not on PATH" >&2 + echo "[copilot-local-review] Install/authenticate Copilot CLI or set AUDIT_SKIP_LOCAL_COPILOT=1 for an explicit bypass." >&2 + exit 1 +fi + +diff_file="$(mktemp)" +stat_file="$(mktemp)" +trap 'rm -f "$diff_file" "$stat_file"' EXIT + +if [[ "$mode" == "cached" ]]; then + git diff --cached --stat >"$stat_file" + git diff --cached --no-ext-diff --binary --unified=80 >"$diff_file" +elif [[ "$mode" == "worktree" ]]; then + git diff --stat >"$stat_file" + git diff --no-ext-diff --binary --unified=80 >"$diff_file" +else + merge_base="$(git merge-base HEAD "$base_ref")" || { + echo "[copilot-local-review] GIT_CONTEXT_ERROR: could not merge-base HEAD and $base_ref" >&2 + exit 2 + } + git diff --stat "$merge_base..HEAD" >"$stat_file" + git diff --no-ext-diff --binary --unified=80 "$merge_base..HEAD" >"$diff_file" +fi + +if [[ ! -s "$diff_file" ]]; then + echo "[copilot-local-review] clean: no diff to review for mode=$mode" + exit 0 +fi + +if grep -Eq '^(Binary files |GIT binary patch)' "$diff_file"; then + echo "[copilot-local-review] BLOCKED: binary diff present; Copilot local text review cannot inspect it reliably" >&2 + exit 1 +fi + +diff_bytes="$(wc -c <"$diff_file" | tr -d '[:space:]')" +if (( max_diff_bytes > 0 && diff_bytes > max_diff_bytes )); then + echo "[copilot-local-review] BLOCKED: diff is ${diff_bytes} bytes, above local review limit ${max_diff_bytes}" >&2 + echo "[copilot-local-review] Split the commit or use origin PR review as the authoritative review surface." >&2 + exit 1 +fi + +report_dir="${AUDIT_COPILOT_REPORT_DIR:-$SCRIPT_DIR/reports}" +mkdir -p "$report_dir" +report_file="$report_dir/$(date +%Y%m%d-%H%M%S)-copilot-local-review.md" + +prompt_file="$(mktemp)" +trap 'rm -f "$diff_file" "$stat_file" "$prompt_file"' EXIT + +{ + cat <<'EOF' +You are GitHub Copilot reviewing a local staged diff before commit. + +Review only the supplied diff. Do not edit files. Do not run tools. Do not ask +questions. Focus on correctness, security, data loss, regression risk, broken +tests, missing tests for changed behavior, and user-visible behavior. Ignore +pure style unless it can create functional risk. + +Your response MUST include exactly one decision line: + +LOCAL_COPILOT_REVIEW_DECISION: PASS + +or: + +LOCAL_COPILOT_REVIEW_DECISION: BLOCK + +Use PASS only if you find no actionable issue. Use BLOCK if there is any +actionable issue or if the diff is too incomplete to review safely. + +After the decision line, provide concise findings with file/path references +when blocking. If passing, provide a brief explanation of the risk areas you +checked. + +Diff stat: +EOF + cat "$stat_file" + printf '\nDiff:\n```diff\n' + cat "$diff_file" + printf '\n```\n' +} >"$prompt_file" + +echo "[copilot-local-review] invoking Copilot CLI mode=$mode diff_bytes=$diff_bytes" +set +e +copilot_output="$( + copilot \ + -p "$(cat "$prompt_file")" \ + --disable-builtin-mcps \ + --disallow-temp-dir \ + --no-color \ + --output-format text 2>&1 +)" +copilot_rc=$? +set -e + +{ + echo "# Local Copilot Review" + echo + echo "- Mode: \`$mode\`" + [[ -n "$base_ref" ]] && echo "- Base: \`$base_ref\`" + echo "- Diff bytes: \`$diff_bytes\`" + echo "- Copilot exit code: \`$copilot_rc\`" + echo "- Started: \`$(date -u +%Y-%m-%dT%H:%M:%SZ)\`" + echo + echo "## Diff Stat" + echo + echo '```text' + cat "$stat_file" + echo '```' + echo + echo "## Copilot Output" + echo + echo '```text' + printf '%s\n' "$copilot_output" + echo '```' +} >"$report_file" + +printf '%s\n' "$copilot_output" +echo "[copilot-local-review] report=$report_file" + +if (( advisory == 1 )); then + echo "[copilot-local-review] advisory mode: not blocking" + exit 0 +fi + +if (( copilot_rc != 0 )); then + echo "[copilot-local-review] BLOCKED: Copilot CLI exited $copilot_rc" >&2 + exit 1 +fi + +pass_count="$(grep -c '^LOCAL_COPILOT_REVIEW_DECISION: PASS$' "$report_file" || true)" +block_count="$(grep -c '^LOCAL_COPILOT_REVIEW_DECISION: BLOCK$' "$report_file" || true)" + +if [[ "$pass_count" == "1" && "$block_count" == "0" ]]; then + echo "[copilot-local-review] clean: Copilot returned PASS" + exit 0 +fi + +if [[ "$block_count" != "0" ]]; then + echo "[copilot-local-review] BLOCKED: Copilot returned BLOCK" >&2 + exit 1 +fi + +echo "[copilot-local-review] BLOCKED: Copilot output did not contain the required PASS decision line" >&2 +exit 1 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..6f0ce17fe --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,32 @@ +## graphify + +This project has a graphify knowledge graph at graphify-out/. + +Rules: +- Before answering architecture or codebase questions, read graphify-out/GRAPH_REPORT.md for god nodes and community structure +- If graphify-out/wiki/index.md exists, navigate it instead of reading raw files +- After modifying code files in this session, run `graphify update .` to keep the graph current (AST-only, no API cost) + +## Remote and Rebase Policy + +Rules: +- `origin` is the active development target unless the user explicitly says otherwise. +- `upstream` is read-only/reference by default. Do not open, update, or describe work as ready for upstream contribution unless the user explicitly reopens that path. +- Still follow upstream Graphify closely: before planning implementation slices or rebases, fetch `origin` and `upstream`, verify current branch/HEAD/ahead-behind state, and compare live upstream changes that touch the same files. +- Prefer rebasing/pulling useful upstream changes into the local/origin development branch when doing so preserves the Graphify direction and reduces future drift. +- Upstream synchronization is preauthorized only for this Graphify/vampyre checkout (`/Users/jonathandirks/Development/vampyre`). Do not generalize this permission to any other repo or project. In this checkout, do not ask for human approval before stashing local work, fetching `origin`/`upstream`, comparing upstream changes, rebasing onto useful upstream updates, resolving conflicts by comparison, reapplying the stash, and rerunning verification. +- If upstream changes are harmful, irrelevant, or incompatible with the local Graphify direction, skip them only after comparing the change and recording the reason in the handoff. +- Helper scripts for stash/fetch/rebase workflows may be created only in local-only paths that are excluded from git, such as `.agent-local/`; do not add those helper scripts to GitHub-bound history. +- Absolute conflict rule: never resolve merge, rebase, cherry-pick, or generated-artifact conflicts by blindly choosing ours/theirs or by assuming the local branch is always correct. Always inspect and compare both sides, identify the intended behavior on each side, preserve useful upstream behavior unless it is deliberately incompatible with the local plan, and document the resolution in the handoff or commit notes. +- If conflict behavior is unclear after comparing both sides and the relevant tests/docs, stop and ask before resolving. +- For this checkout only, publishing a verified local/origin development branch is part of closing an upstream sync. Before pushing, run the full local verification stack, including local Copilot review, tests, lint/type/security/warning gates, pre-commit/pre-push gates, and graph refresh when applicable. +- After verification passes, fetch `origin` and compare. If local `HEAD` contains `origin` or deliberately supersedes it because origin-only commits are already integrated, duplicated by rebase, or intentionally rejected with the reason recorded, update `origin` with a normal push for fast-forwards or `--force-with-lease` for rewritten history. +- Do not leave `origin` behind a verified local stack merely because the local branch was rebased. If the lease fails, or if origin contains new valuable or unclear work that is not in local, stop, fetch, compare, and integrate or ask before publishing. +- This does not authorize publishing to `upstream`, changing remotes, or generalizing the Graphify/vampyre publication rule to any other repo. + +## Verification and Failure Policy + +Rules: +- Do not waive test failures, skips, warnings, linter findings, gate failures, graphify update failures, or audit findings as "pre-existing." If an issue is reproducible in the current workspace, it becomes the current agent's active task until resolved or until the user explicitly redirects the work. +- Do not mark a slice, PR, branch, or handoff as ready while any reproduced failure, skip, warning, or gate finding remains unresolved. +- If another agent reports an issue as pre-existing, independently reproduce it, root-cause it, fix it when it is in scope, and record the invalid waiver in the handoff or audit notes. diff --git a/graphify/__init__.py b/graphify/__init__.py index e34c938ef..e0f698b6e 100644 --- a/graphify/__init__.py +++ b/graphify/__init__.py @@ -22,6 +22,7 @@ def __getattr__(name): } if name in _map: import importlib + mod_name, attr = _map[name] mod = importlib.import_module(mod_name) return getattr(mod, attr) diff --git a/graphify/__main__.py b/graphify/__main__.py index 380572890..2d916b911 100644 --- a/graphify/__main__.py +++ b/graphify/__main__.py @@ -1,4 +1,5 @@ """graphify CLI - `graphify install` sets up the Claude Code skill.""" + from __future__ import annotations import json import os @@ -10,6 +11,7 @@ try: from importlib.metadata import version as _pkg_version + __version__ = _pkg_version("graphifyy") except Exception: __version__ = "unknown" @@ -35,6 +37,7 @@ def _enforce_graph_size_cap_or_exit(gp: Path) -> None: and let the ``ValueError`` propagate. """ from graphify.security import check_graph_file_size_cap + try: check_graph_file_size_cap(gp) except ValueError as exc: @@ -48,11 +51,16 @@ def _check_skill_version(skill_dst: Path) -> None: if not version_file.exists(): return if not skill_dst.exists(): - print(" warning: skill dir exists but SKILL.md is missing. Run 'graphify install' to repair.") + print( + " warning: skill dir exists but SKILL.md is missing. Run 'graphify install' to repair." + ) return installed = version_file.read_text(encoding="utf-8").strip() if installed != __version__: - print(f" warning: skill is from graphify {installed}, package is {__version__}. Run 'graphify install' to update.", file=sys.stderr) + print( + f" warning: skill is from graphify {installed}, package is {__version__}. Run 'graphify install' to update.", + file=sys.stderr, + ) def _refresh_all_version_stamps() -> None: @@ -68,7 +76,9 @@ def _refresh_all_version_stamps() -> None: vf.write_text(__version__, encoding="utf-8") -def _platform_skill_destination(platform_name: str, *, project: bool = False, project_dir: Path | None = None) -> Path: +def _platform_skill_destination( + platform_name: str, *, project: bool = False, project_dir: Path | None = None +) -> Path: """Return the skill destination for a platform and scope.""" if platform_name == "gemini": if project: @@ -102,9 +112,13 @@ def _platform_skill_destination(platform_name: str, *, project: bool = False, pr return Path.home() / cfg["skill_dst"] -def _copy_skill_file(platform_name: str, *, project: bool = False, project_dir: Path | None = None) -> Path: +def _copy_skill_file( + platform_name: str, *, project: bool = False, project_dir: Path | None = None +) -> Path: """Copy a packaged skill file and write its version stamp.""" - skill_file = "skill.md" if platform_name == "gemini" else _PLATFORM_CONFIG[platform_name]["skill_file"] + skill_file = ( + "skill.md" if platform_name == "gemini" else _PLATFORM_CONFIG[platform_name]["skill_file"] + ) skill_src = Path(__file__).parent / skill_file if not skill_src.exists(): print(f"error: {skill_file} not found in package - reinstall graphify", file=sys.stderr) @@ -127,7 +141,9 @@ def _copy_skill_file(platform_name: str, *, project: bool = False, project_dir: return skill_dst -def _remove_skill_file(platform_name: str, *, project: bool = False, project_dir: Path | None = None) -> bool: +def _remove_skill_file( + platform_name: str, *, project: bool = False, project_dir: Path | None = None +) -> bool: """Remove a platform skill file and its version stamp without touching other scopes.""" skill_dst = _platform_skill_destination(platform_name, project=project, project_dir=project_dir) removed = False @@ -187,6 +203,7 @@ def _print_project_git_add_hint(paths: list[Path]) -> None: print("Project-scoped install. Add to version control:") print(f" git add {' '.join(unique)}") + _SETTINGS_HOOK = { # Claude Code v2.1.117+ removed dedicated Grep/Glob tools; searches now go through Bash. # We match on Bash and inspect the command string to avoid firing on every shell call. @@ -195,10 +212,10 @@ def _print_project_git_add_hint(paths: list[Path]) -> None: { "type": "command", "command": ( - "CMD=$(python3 -c \"" + 'CMD=$(python3 -c "' "import json,sys; d=json.load(sys.stdin); " "print(d.get('tool_input',d).get('command',''))\" 2>/dev/null || true); " - "case \"$CMD\" in " + 'case "$CMD" in ' r"*grep*|*rg\ *|*ripgrep*|*find\ *|*fd\ *|*ack\ *|*ag\ *) " " [ -f graphify-out/graph.json ] && " r""" echo '{"hookSpecificOutput":{"hookEventName":"PreToolUse","additionalContext":"graphify: knowledge graph at graphify-out/. For focused questions, run `graphify query \"\"` (scoped subgraph, usually much smaller than GRAPH_REPORT.md) instead of grepping raw files. Read GRAPH_REPORT.md only for broad architecture context."}}' """ @@ -209,13 +226,14 @@ def _print_project_git_add_hint(paths: list[Path]) -> None: ], } + def _skill_registration(skill_path: str = "~/.claude/skills/graphify/SKILL.md") -> str: return ( "\n# graphify\n" f"- **graphify** (`{skill_path}`) " "- any input to knowledge graph. Trigger: `/graphify`\n" "When the user types `/graphify`, invoke the Skill tool " - "with `skill: \"graphify\"` before doing anything else.\n" + 'with `skill: "graphify"` before doing anything else.\n' ) @@ -360,7 +378,9 @@ def _replace_or_append_section(content: str, marker: str, new_section: str) -> s return out -def install(platform: str = "claude", *, project: bool = False, project_dir: Path | None = None) -> None: +def install( + platform: str = "claude", *, project: bool = False, project_dir: Path | None = None +) -> None: if platform == "gemini": gemini_install(project_dir=project_dir, project=project) return @@ -383,12 +403,18 @@ def install(platform: str = "claude", *, project: bool = False, project_dir: Pat if cfg["claude_md"]: # Register in the matching Claude Code scope. - claude_md = (project_dir / ".claude" / "CLAUDE.md") if project else Path.home() / ".claude" / "CLAUDE.md" - registration = _skill_registration(".claude/skills/graphify/SKILL.md" if project else "~/.claude/skills/graphify/SKILL.md") + claude_md = ( + (project_dir / ".claude" / "CLAUDE.md") + if project + else Path.home() / ".claude" / "CLAUDE.md" + ) + registration = _skill_registration( + ".claude/skills/graphify/SKILL.md" if project else "~/.claude/skills/graphify/SKILL.md" + ) if claude_md.exists(): content = claude_md.read_text(encoding="utf-8") if "graphify" in content: - print(f" CLAUDE.md -> already registered (no change)") + print(" CLAUDE.md -> already registered (no change)") else: claude_md.write_text(content.rstrip() + registration, encoding="utf-8") print(f" CLAUDE.md -> skill registered in {claude_md}") @@ -495,9 +521,7 @@ def gemini_install(project_dir: Path | None = None, *, project: bool = False) -> if target.exists(): content = target.read_text(encoding="utf-8") - new_content = _replace_or_append_section( - content, _GEMINI_MD_MARKER, _GEMINI_MD_SECTION - ) + new_content = _replace_or_append_section(content, _GEMINI_MD_MARKER, _GEMINI_MD_SECTION) else: new_content = _GEMINI_MD_SECTION @@ -511,7 +535,13 @@ def gemini_install(project_dir: Path | None = None, *, project: bool = False) -> # wording) is replaced on upgrade. _install_gemini_hook(project_dir) if project: - _print_project_git_add_hint([_project_scope_root(skill_dst, project_dir), project_dir / "GEMINI.md", project_dir / ".gemini"]) + _print_project_git_add_hint( + [ + _project_scope_root(skill_dst, project_dir), + project_dir / "GEMINI.md", + project_dir / ".gemini", + ] + ) print() print("Gemini CLI will now check the knowledge graph before answering") print("codebase questions and rebuild it after code changes.") @@ -521,7 +551,9 @@ def _install_gemini_hook(project_dir: Path) -> None: settings_path = project_dir / ".gemini" / "settings.json" settings_path.parent.mkdir(parents=True, exist_ok=True) try: - settings = json.loads(settings_path.read_text(encoding="utf-8")) if settings_path.exists() else {} + settings = ( + json.loads(settings_path.read_text(encoding="utf-8")) if settings_path.exists() else {} + ) except json.JSONDecodeError: settings = {} before_tool = settings.setdefault("hooks", {}).setdefault("BeforeTool", []) @@ -615,7 +647,9 @@ def vscode_install(project_dir: Path | None = None) -> None: print(f" {instructions} -> already configured (no change)") else: instructions.write_text(new_content, encoding="utf-8") - print(f" {instructions} -> graphify section {'updated' if _VSCODE_INSTRUCTIONS_MARKER in content else 'added'}") + print( + f" {instructions} -> graphify section {'updated' if _VSCODE_INSTRUCTIONS_MARKER in content else 'added'}" + ) else: instructions.write_text(_VSCODE_INSTRUCTIONS_SECTION, encoding="utf-8") print(f" {instructions} -> created") @@ -720,7 +754,7 @@ def _kiro_install(project_dir: Path) -> None: steering_dir.mkdir(parents=True, exist_ok=True) steering_dst = steering_dir / "graphify.md" if steering_dst.exists() and steering_dst.read_text(encoding="utf-8") == _KIRO_STEERING: - print(f" .kiro/steering/graphify.md -> already configured (no change)") + print(" .kiro/steering/graphify.md -> already configured (no change)") else: # File is wholly graphify-owned. Overwrite on upgrade so older # report-first wording does not silently linger (issue #580). @@ -801,11 +835,15 @@ def _antigravity_install(project_dir: Path) -> None: print("Antigravity will now check the knowledge graph before answering") print("codebase questions. Run /graphify first to build the graph.") print() - print("To enable full MCP architecture navigation, add this to ~/.gemini/antigravity/mcp_config.json:") + print( + "To enable full MCP architecture navigation, add this to ~/.gemini/antigravity/mcp_config.json:" + ) print(' "graphify": {') print(' "command": "uv",') - print(' "args": ["run", "--with", "graphifyy", "--with", "mcp", "-m", "graphify.serve", "${workspace.path}/graphify-out/graph.json"]') - print(' }') + print( + ' "args": ["run", "--with", "graphifyy", "--with", "mcp", "-m", "graphify.serve", "${workspace.path}/graphify-out/graph.json"]' + ) + print(" }") def _antigravity_uninstall(project_dir: Path, *, project: bool = False) -> None: @@ -1029,6 +1067,7 @@ def _resolve_graphify_exe() -> str: not on PATH (e.g. VS Code Codex extension on Windows). """ import shutil + found = shutil.which("graphify") if found: return found @@ -1086,7 +1125,7 @@ def _uninstall_codex_hook(project_dir: Path) -> None: filtered = [h for h in pre_tool if "graphify" not in str(h)] existing["hooks"]["PreToolUse"] = filtered hooks_path.write_text(json.dumps(existing, indent=2), encoding="utf-8") - print(f" .codex/hooks.json -> PreToolUse hook removed") + print(" .codex/hooks.json -> PreToolUse hook removed") def _agents_install(project_dir: Path, platform: str) -> None: @@ -1095,9 +1134,7 @@ def _agents_install(project_dir: Path, platform: str) -> None: if target.exists(): content = target.read_text(encoding="utf-8") - new_content = _replace_or_append_section( - content, _AGENTS_MD_MARKER, _AGENTS_MD_SECTION - ) + new_content = _replace_or_append_section(content, _AGENTS_MD_MARKER, _AGENTS_MD_SECTION) else: new_content = _AGENTS_MD_SECTION @@ -1136,7 +1173,17 @@ def _project_install(platform_name: str, project_dir: Path | None = None) -> Non elif platform_name == "kiro": _kiro_install(project_dir) _print_project_git_add_hint([project_dir / ".kiro"]) - elif platform_name in ("aider", "amp", "codex", "opencode", "claw", "droid", "trae", "trae-cn", "hermes"): + elif platform_name in ( + "aider", + "amp", + "codex", + "opencode", + "claw", + "droid", + "trae", + "trae-cn", + "hermes", + ): skill_dst = _copy_skill_file(platform_name, project=True, project_dir=project_dir) _agents_install(project_dir, platform_name) hint_paths = [_project_scope_root(skill_dst, project_dir), project_dir / "AGENTS.md"] @@ -1148,7 +1195,9 @@ def _project_install(platform_name: str, project_dir: Path | None = None) -> Non elif platform_name == "devin": skill_dst = _copy_skill_file("devin", project=True, project_dir=project_dir) _devin_rules_install(project_dir) - _print_project_git_add_hint([_project_scope_root(skill_dst, project_dir), project_dir / ".windsurf"]) + _print_project_git_add_hint( + [_project_scope_root(skill_dst, project_dir), project_dir / ".windsurf"] + ) elif platform_name in ("copilot", "pi", "antigravity", "kimi"): skill_dst = _copy_skill_file(platform_name, project=True, project_dir=project_dir) _print_project_git_add_hint([_project_scope_root(skill_dst, project_dir)]) @@ -1169,7 +1218,17 @@ def _project_uninstall(platform_name: str, project_dir: Path | None = None) -> N _cursor_uninstall(project_dir) elif platform_name == "kiro": _kiro_uninstall(project_dir) - elif platform_name in ("aider", "amp", "codex", "opencode", "claw", "droid", "trae", "trae-cn", "hermes"): + elif platform_name in ( + "aider", + "amp", + "codex", + "opencode", + "claw", + "droid", + "trae", + "trae-cn", + "hermes", + ): _remove_skill_file(platform_name, project=True, project_dir=project_dir) _agents_uninstall(project_dir, platform=platform_name) if platform_name == "codex": @@ -1236,9 +1295,7 @@ def claude_install(project_dir: Path | None = None) -> None: if target.exists(): content = target.read_text(encoding="utf-8") - new_content = _replace_or_append_section( - content, _CLAUDE_MD_MARKER, _CLAUDE_MD_SECTION - ) + new_content = _replace_or_append_section(content, _CLAUDE_MD_MARKER, _CLAUDE_MD_SECTION) else: new_content = _CLAUDE_MD_SECTION @@ -1273,10 +1330,14 @@ def _install_claude_hook(project_dir: Path) -> None: hooks = settings.setdefault("hooks", {}) pre_tool = hooks.setdefault("PreToolUse", []) - hooks["PreToolUse"] = [h for h in pre_tool if not (h.get("matcher") in ("Glob|Grep", "Bash") and "graphify" in str(h))] + hooks["PreToolUse"] = [ + h + for h in pre_tool + if not (h.get("matcher") in ("Glob|Grep", "Bash") and "graphify" in str(h)) + ] hooks["PreToolUse"].append(_SETTINGS_HOOK) settings_path.write_text(json.dumps(settings, indent=2), encoding="utf-8") - print(f" .claude/settings.json -> PreToolUse hook registered") + print(" .claude/settings.json -> PreToolUse hook registered") def _uninstall_claude_hook(project_dir: Path) -> None: @@ -1289,12 +1350,16 @@ def _uninstall_claude_hook(project_dir: Path) -> None: except json.JSONDecodeError: return pre_tool = settings.get("hooks", {}).get("PreToolUse", []) - filtered = [h for h in pre_tool if not (h.get("matcher") in ("Glob|Grep", "Bash") and "graphify" in str(h))] + filtered = [ + h + for h in pre_tool + if not (h.get("matcher") in ("Glob|Grep", "Bash") and "graphify" in str(h)) + ] if len(filtered) == len(pre_tool): return settings["hooks"]["PreToolUse"] = filtered settings_path.write_text(json.dumps(settings, indent=2), encoding="utf-8") - print(f" .claude/settings.json -> PreToolUse hook removed") + print(" .claude/settings.json -> PreToolUse hook removed") def uninstall_all(project_dir: Path | None = None, purge: bool = False) -> None: @@ -1317,18 +1382,20 @@ def uninstall_all(project_dir: Path | None = None, purge: bool = False) -> None: # Git hook try: from graphify.hooks import uninstall as hook_uninstall + result = hook_uninstall(pd) if result: print(result) - except Exception: - pass + except Exception as exc: + print(f"[graphify] warning: could not uninstall git hook: {exc}", file=sys.stderr) if purge: import shutil as _shutil + out = pd / "graphify-out" if out.exists(): _shutil.rmtree(out) - print(f"\n graphify-out/ -> deleted (--purge)") + print("\n graphify-out/ -> deleted (--purge)") else: print("\n graphify-out/ -> not found (nothing to purge)") @@ -1403,7 +1470,7 @@ def _clone_repo(url: str, branch: str | None = None, out_dir: Path | None = None cmd = ["git", "-C", str(dest), "pull"] if branch: cmd += ["origin", "--", branch] - result = _sp.run(cmd, capture_output=True, text=True) + result = _sp.run(cmd, capture_output=True, text=True) # nosec B603 if result.returncode != 0: print(f"warning: git pull failed:\n{result.stderr}", file=sys.stderr) else: @@ -1413,7 +1480,7 @@ def _clone_repo(url: str, branch: str | None = None, out_dir: Path | None = None if branch: cmd += ["--branch", branch] cmd += ["--", git_url, str(dest)] - result = _sp.run(cmd, capture_output=True, text=True) + result = _sp.run(cmd, capture_output=True, text=True) # nosec B603 if result.returncode != 0: print(f"error: git clone failed:\n{result.stderr}", file=sys.stderr) sys.exit(1) @@ -1446,12 +1513,14 @@ def main() -> None: print("Usage: graphify ") print() print("Commands:") - print(" install [--platform P] copy skill to platform config dir (claude|windows|codex|opencode|aider|claw|droid|trae|trae-cn|gemini|cursor|antigravity|hermes|kiro|pi|devin)") + print( + " install [--platform P] copy skill to platform config dir (claude|windows|codex|opencode|aider|claw|droid|trae|trae-cn|gemini|cursor|antigravity|hermes|kiro|pi|devin)" + ) print(" uninstall remove graphify from all detected platforms in one shot") print(" --purge also delete graphify-out/ directory") - print(" path \"A\" \"B\" shortest path between two nodes in graph.json") + print(' path "A" "B" shortest path between two nodes in graph.json') print(" --graph path to graph.json (default graphify-out/graph.json)") - print(" explain \"X\" plain-language explanation of a node and its neighbors") + print(' explain "X" plain-language explanation of a node and its neighbors') print(" --graph path to graph.json (default graphify-out/graph.json)") print(" diagnose multigraph report same-endpoint edge collapse risk in graph.json") print(" --graph path to graph/extraction JSON") @@ -1463,40 +1532,64 @@ def main() -> None: print(" (default follows JSON directed flag;") print(" raw extraction with no flag defaults directed)") print(" --extract-path PATH extractor source for suppression scan") - print(" clone clone a GitHub repo locally and print its path for /graphify") - print(" merge-driver git merge driver: union-merge two graph.json files (set up via hook install)") - print(" merge-graphs merge two or more graph.json files into one cross-repo graph") + print( + " clone clone a GitHub repo locally and print its path for /graphify" + ) + print( + " merge-driver git merge driver: union-merge two graph.json files (set up via hook install)" + ) + print( + " merge-graphs merge two or more graph.json files into one cross-repo graph" + ) print(" --out output path (default: graphify-out/merged-graph.json)") print(" --branch checkout a specific branch (default: repo default)") - print(" --out clone to a custom directory (default: ~/.graphify/repos//)") + print( + " --out clone to a custom directory (default: ~/.graphify/repos//)" + ) print(" add fetch a URL and save it to ./raw, then update the graph") - print(" --author \"Name\" tag the author of the content") - print(" --contributor \"Name\" tag who added it to the corpus") + print(' --author "Name" tag the author of the content') + print(' --contributor "Name" tag who added it to the corpus') print(" --dir target directory (default: ./raw)") print(" watch watch a folder and rebuild the graph on code changes") - print(" update re-extract code files and update the graph (no LLM needed)") - print(" --force overwrite graph.json even if the rebuild has fewer nodes") - print(" (also: GRAPHIFY_FORCE=1 env var; use after refactors that delete code)") + print( + " update re-extract code files and update the graph (no LLM needed)" + ) + print( + " --force overwrite graph.json even if the rebuild has fewer nodes" + ) + print( + " (also: GRAPHIFY_FORCE=1 env var; use after refactors that delete code)" + ) print(" --no-cluster skip clustering, write raw extraction only") - print(" cluster-only rerun clustering on an existing graph.json and regenerate report") - print(" --no-viz skip graph.html generation (useful for >5000 node graphs / CI)") - print(" --graph path to graph.json (default /graphify-out/graph.json)") - print(" query \"\" BFS traversal of graph.json for a question") + print( + " cluster-only rerun clustering on an existing graph.json and regenerate report" + ) + print( + " --no-viz skip graph.html generation (useful for >5000 node graphs / CI)" + ) + print( + " --graph path to graph.json (default /graphify-out/graph.json)" + ) + print(' query "" BFS traversal of graph.json for a question') print(" --dfs use depth-first instead of breadth-first") print(" --context C explicit edge-context filter (repeatable)") print(" --budget N cap output at N tokens (default 2000)") print(" --graph path to graph.json (default graphify-out/graph.json)") - print(" affected \"X\" reverse traversal to find nodes impacted by X") + print(' affected "X" reverse traversal to find nodes impacted by X') print(" --relation R edge relation to traverse in reverse (repeatable)") print(" --depth N reverse traversal depth (default 2)") print(" --graph path to graph.json (default graphify-out/graph.json)") - print(" save-result save a Q&A result to graphify-out/memory/ for graph feedback loop") + print( + " save-result save a Q&A result to graphify-out/memory/ for graph feedback loop" + ) print(" --question Q the question asked") print(" --answer A the answer to save") print(" --type T query type: query|path_query|explain (default: query)") print(" --nodes N1 N2 ... source node labels cited in the answer") print(" --memory-dir DIR memory directory (default: graphify-out/memory)") - print(" check-update check needs_update flag and notify if semantic re-extraction is pending (cron-safe)") + print( + " check-update check needs_update flag and notify if semantic re-extraction is pending (cron-safe)" + ) print(" tree emit a D3 v7 collapsible-tree HTML for graph.json") print(" --graph PATH path to graph.json (default graphify-out/graph.json)") print(" --output HTML output path (default graphify-out/GRAPH_TREE.html)") @@ -1504,44 +1597,70 @@ def main() -> None: print(" --max-children N cap children per node (default 200)") print(" --top-k-edges N per-symbol outbound edges in inspector (default 12)") print(" --label NAME project label in header") - print(" extract headless full extraction (AST + semantic LLM) for CI/scripts") - print(" --backend B gemini|kimi|claude|openai|deepseek|ollama (default: whichever API key is set)") + print( + " extract headless full extraction (AST + semantic LLM) for CI/scripts" + ) + print( + " --backend B gemini|kimi|claude|openai|deepseek|ollama (default: whichever API key is set)" + ) print(" --model M override backend default model") print(" --mode deep aggressive INFERRED-edge semantic extraction") print(" --max-workers N AST extraction subprocess count (default: cpu_count)") - print(" --token-budget N per-chunk token cap for semantic extraction (default: 60000)") - print(" --max-concurrency N parallel semantic chunks in flight (default: 4; set 1 for local LLMs)") - print(" --api-timeout S per-request timeout in seconds for the LLM client (default: 600)") - print(" --out DIR output dir (default: ); writes /graphify-out/") - print(" --google-workspace export .gdoc/.gsheet/.gslides shortcuts via gws before extraction") + print( + " --token-budget N per-chunk token cap for semantic extraction (default: 60000)" + ) + print( + " --max-concurrency N parallel semantic chunks in flight (default: 4; set 1 for local LLMs)" + ) + print( + " --api-timeout S per-request timeout in seconds for the LLM client (default: 600)" + ) + print( + " --out DIR output dir (default: ); writes /graphify-out/" + ) + print( + " --google-workspace export .gdoc/.gsheet/.gslides shortcuts via gws before extraction" + ) print(" --no-cluster skip clustering, write raw extraction only") print(" --global also merge the resulting graph into the global graph") print(" --as repo tag for --global (default: target directory name)") - print(" global add add/update a project graph in the global graph (~/.graphify/global-graph.json)") + print( + " global add add/update a project graph in the global graph (~/.graphify/global-graph.json)" + ) print(" --as repo tag (default: parent directory name)") print(" global remove remove a repo's nodes from the global graph") print(" global list list repos in the global graph") print(" global path print path to the global graph file") print(" benchmark [graph.json] measure token reduction vs naive full-corpus approach") print(" export callflow-html emit Mermaid-based architecture/call-flow HTML") - print(" hook install install post-commit/post-checkout git hooks (all platforms)") + print( + " hook install install post-commit/post-checkout git hooks (all platforms)" + ) print(" hook uninstall remove git hooks") print(" hook status check if git hooks are installed") print(" gemini install write GEMINI.md section + BeforeTool hook (Gemini CLI)") print(" gemini uninstall remove GEMINI.md section + BeforeTool hook") print(" cursor install write .cursor/rules/graphify.mdc (Cursor)") print(" cursor uninstall remove .cursor/rules/graphify.mdc") - print(" claude install write graphify section to CLAUDE.md + PreToolUse hook (Claude Code)") + print( + " claude install write graphify section to CLAUDE.md + PreToolUse hook (Claude Code)" + ) print(" claude uninstall remove graphify section from CLAUDE.md + PreToolUse hook") print(" codex install write graphify section to AGENTS.md (Codex)") print(" codex uninstall remove graphify section from AGENTS.md") - print(" opencode install write graphify section to AGENTS.md + tool.execute.before plugin (OpenCode)") + print( + " opencode install write graphify section to AGENTS.md + tool.execute.before plugin (OpenCode)" + ) print(" opencode uninstall remove graphify section from AGENTS.md + plugin") print(" aider install write graphify section to AGENTS.md (Aider)") print(" aider uninstall remove graphify section from AGENTS.md") - print(" copilot install copy graphify skill to ~/.copilot/skills (GitHub Copilot CLI)") + print( + " copilot install copy graphify skill to ~/.copilot/skills (GitHub Copilot CLI)" + ) print(" copilot uninstall remove graphify skill from ~/.copilot/skills") - print(" vscode install configure VS Code Copilot Chat (skill + .github/copilot-instructions.md)") + print( + " vscode install configure VS Code Copilot Chat (skill + .github/copilot-instructions.md)" + ) print(" vscode uninstall remove VS Code Copilot Chat configuration") print(" claw install write graphify section to AGENTS.md (OpenClaw)") print(" claw uninstall remove graphify section from AGENTS.md") @@ -1551,15 +1670,23 @@ def main() -> None: print(" trae uninstall remove graphify section from AGENTS.md") print(" trae-cn install write graphify section to AGENTS.md (Trae CN)") print(" trae-cn uninstall remove graphify section from AGENTS.md") - print(" antigravity install write .agents/rules + .agents/workflows + skill (Google Antigravity)") + print( + " antigravity install write .agents/rules + .agents/workflows + skill (Google Antigravity)" + ) print(" antigravity uninstall remove .agents/rules, .agents/workflows, and skill") print(" hermes install write skill to ~/.hermes/skills/graphify/ (Hermes)") print(" hermes uninstall remove skill from ~/.hermes/skills/graphify/") - print(" kiro install write skill to .kiro/skills/graphify/ + steering file (Kiro IDE/CLI)") + print( + " kiro install write skill to .kiro/skills/graphify/ + steering file (Kiro IDE/CLI)" + ) print(" kiro uninstall remove skill + steering file") - print(" pi install write skill to ~/.pi/agent/skills/graphify/ (Pi coding agent)") + print( + " pi install write skill to ~/.pi/agent/skills/graphify/ (Pi coding agent)" + ) print(" pi uninstall remove skill from ~/.pi/agent/skills/graphify/") - print(" devin install write skill to ~/.config/devin/skills/graphify/ (Devin CLI)") + print( + " devin install write skill to ~/.config/devin/skills/graphify/ (Devin CLI)" + ) print(" devin uninstall remove skill from ~/.config/devin/skills/graphify/") print() return @@ -1573,7 +1700,7 @@ def main() -> None: # "install"/"uninstall" which have their own per-subcommand help handlers. _FREE_TEXT_CMDS = {"query", "explain", "path", "save-result", "install", "uninstall"} if cmd not in _FREE_TEXT_CMDS and any(a in {"-h", "--help", "-?"} for a in sys.argv[2:]): - print(f"Run 'graphify --help' for full usage.") + print("Run 'graphify --help' for full usage.") return if cmd == "install": @@ -1899,9 +2026,15 @@ def main() -> None: sys.exit(1) elif cmd == "prs": from graphify.prs import cmd_prs + cmd_prs(sys.argv[2:]) elif cmd == "hook": - from graphify.hooks import install as hook_install, uninstall as hook_uninstall, status as hook_status + from graphify.hooks import ( + install as hook_install, + uninstall as hook_uninstall, + status as hook_status, + ) + subcmd = sys.argv[2] if len(sys.argv) > 2 else "" if subcmd == "install": print(hook_install(Path("."))) @@ -1914,11 +2047,14 @@ def main() -> None: sys.exit(1) elif cmd == "query": if len(sys.argv) < 3: - print("Usage: graphify query \"\" [--dfs] [--context C] [--budget N] [--graph path]", file=sys.stderr) + print( + 'Usage: graphify query "" [--dfs] [--context C] [--budget N] [--graph path]', + file=sys.stderr, + ) sys.exit(1) from graphify.serve import _query_graph_text - from graphify.security import sanitize_label from networkx.readwrite import json_graph + question = sys.argv[2] use_dfs = "--dfs" in sys.argv budget = 2000 @@ -1931,14 +2067,14 @@ def main() -> None: try: budget = int(args[i + 1]) except ValueError: - print(f"error: --budget must be an integer", file=sys.stderr) + print("error: --budget must be an integer", file=sys.stderr) sys.exit(1) i += 2 elif args[i].startswith("--budget="): try: budget = int(args[i].split("=", 1)[1]) except ValueError: - print(f"error: --budget must be an integer", file=sys.stderr) + print("error: --budget must be an integer", file=sys.stderr) sys.exit(1) i += 1 elif args[i] == "--context" and i + 1 < len(args): @@ -1948,7 +2084,8 @@ def main() -> None: context_filters.append(args[i].split("=", 1)[1]) i += 1 elif args[i] == "--graph" and i + 1 < len(args): - graph_path = args[i + 1]; i += 2 + graph_path = args[i + 1] + i += 2 else: i += 1 gp = Path(graph_path).resolve() @@ -1956,12 +2093,13 @@ def main() -> None: print(f"error: graph file not found: {gp}", file=sys.stderr) sys.exit(1) if not gp.suffix == ".json": - print(f"error: graph file must be a .json file", file=sys.stderr) + print("error: graph file must be a .json file", file=sys.stderr) sys.exit(1) _enforce_graph_size_cap_or_exit(gp) try: import json as _json import networkx as _nx + _raw = _json.loads(gp.read_text(encoding="utf-8")) if "links" not in _raw and "edges" in _raw: _raw = dict(_raw, links=_raw["edges"]) @@ -1984,9 +2122,13 @@ def main() -> None: ) elif cmd == "affected": if len(sys.argv) < 3: - print("Usage: graphify affected \"\" [--relation R] [--depth N] [--graph path]", file=sys.stderr) + print( + 'Usage: graphify affected "" [--relation R] [--depth N] [--graph path]', + file=sys.stderr, + ) sys.exit(1) from graphify.affected import DEFAULT_AFFECTED_RELATIONS, format_affected, load_graph + query = sys.argv[2] graph_path = "graphify-out/graph.json" depth = 2 @@ -2045,6 +2187,7 @@ def main() -> None: elif cmd == "save-result": # graphify save-result --question Q --answer A --type T [--nodes N1 N2 ...] import argparse as _ap + p = _ap.ArgumentParser(prog="graphify save-result") p.add_argument("--question", required=True) p.add_argument("--answer", required=True) @@ -2053,6 +2196,7 @@ def main() -> None: p.add_argument("--memory-dir", default="graphify-out/memory") opts = p.parse_args(sys.argv[2:]) from graphify.ingest import save_query_result as _sqr + out = _sqr( question=opts.question, answer=opts.answer, @@ -2063,11 +2207,12 @@ def main() -> None: print(f"Saved to {out}") elif cmd == "path": if len(sys.argv) < 4: - print("Usage: graphify path \"\" \"\" [--graph path]", file=sys.stderr) + print('Usage: graphify path "" "" [--graph path]', file=sys.stderr) sys.exit(1) from graphify.serve import _score_nodes from networkx.readwrite import json_graph import networkx as _nx + source_label = sys.argv[2] target_label = sys.argv[3] graph_path = _default_graph_path() @@ -2125,6 +2270,7 @@ def main() -> None: hops = len(path_nodes) - 1 segments = [] from graphify.build import edge_data + for i in range(len(path_nodes) - 1): u, v = path_nodes[i], path_nodes[i + 1] # Check which direction the stored edge points. @@ -2147,10 +2293,11 @@ def main() -> None: elif cmd == "explain": if len(sys.argv) < 3: - print("Usage: graphify explain \"\" [--graph path]", file=sys.stderr) + print('Usage: graphify explain "" [--graph path]', file=sys.stderr) sys.exit(1) from graphify.serve import _find_node from networkx.readwrite import json_graph + label = sys.argv[2] graph_path = _default_graph_path() args = sys.argv[3:] @@ -2184,6 +2331,7 @@ def main() -> None: print(f" Community: {d.get('community', '')}") print(f" Degree: {G.degree(nid)}") from graphify.build import edge_data + connections: list[tuple[str, str, dict]] = [] # (direction, neighbor_id, edge_data) for nb in G.successors(nid): connections.append(("out", nb, edge_data(G, nid, nb))) @@ -2296,9 +2444,13 @@ def main() -> None: elif cmd == "add": if len(sys.argv) < 3: - print("Usage: graphify add [--author Name] [--contributor Name] [--dir ./raw]", file=sys.stderr) + print( + "Usage: graphify add [--author Name] [--contributor Name] [--dir ./raw]", + file=sys.stderr, + ) sys.exit(1) from graphify.ingest import ingest as _ingest + url = sys.argv[2] author: str | None = None contributor: str | None = None @@ -2307,11 +2459,14 @@ def main() -> None: i = 0 while i < len(args): if args[i] == "--author" and i + 1 < len(args): - author = args[i + 1]; i += 2 + author = args[i + 1] + i += 2 elif args[i] == "--contributor" and i + 1 < len(args): - contributor = args[i + 1]; i += 2 + contributor = args[i + 1] + i += 2 elif args[i] == "--dir" and i + 1 < len(args): - target_dir = Path(args[i + 1]); i += 2 + target_dir = Path(args[i + 1]) + i += 2 else: i += 1 try: @@ -2328,6 +2483,7 @@ def main() -> None: print(f"error: path not found: {watch_path}", file=sys.stderr) sys.exit(1) from graphify.watch import watch as _watch + try: _watch(watch_path) except ImportError as exc: @@ -2349,26 +2505,36 @@ def main() -> None: while i_arg < len(args): a = args[i_arg] if a == "--graph" and i_arg + 1 < len(args): - graph_override = Path(args[i_arg + 1]); i_arg += 2 + graph_override = Path(args[i_arg + 1]) + i_arg += 2 elif a == "--resolution" and i_arg + 1 < len(args): - co_resolution = float(args[i_arg + 1]); i_arg += 2 + co_resolution = float(args[i_arg + 1]) + i_arg += 2 elif a.startswith("--resolution="): - co_resolution = float(a.split("=", 1)[1]); i_arg += 1 + co_resolution = float(a.split("=", 1)[1]) + i_arg += 1 elif a == "--exclude-hubs" and i_arg + 1 < len(args): - co_exclude_hubs = float(args[i_arg + 1]); i_arg += 2 + co_exclude_hubs = float(args[i_arg + 1]) + i_arg += 2 elif a.startswith("--exclude-hubs="): - co_exclude_hubs = float(a.split("=", 1)[1]); i_arg += 1 + co_exclude_hubs = float(a.split("=", 1)[1]) + i_arg += 1 elif a == "--no-viz" or a.startswith("--min-community-size="): i_arg += 1 elif a.startswith("--"): i_arg += 1 elif watch_path is None: - watch_path = Path(a); i_arg += 1 + watch_path = Path(a) + i_arg += 1 else: i_arg += 1 if watch_path is None: watch_path = Path(".") - graph_json = graph_override if graph_override is not None else watch_path / "graphify-out" / "graph.json" + graph_json = ( + graph_override + if graph_override is not None + else watch_path / "graphify-out" / "graph.json" + ) if not graph_json.exists(): print(f"error: no graph found at {graph_json} - run /graphify first", file=sys.stderr) sys.exit(1) @@ -2378,6 +2544,7 @@ def main() -> None: from graphify.analyze import god_nodes, surprising_connections, suggest_questions from graphify.report import generate from graphify.export import to_json, to_html + print("Loading existing graph...") _enforce_graph_size_cap_or_exit(graph_json) _raw = json.loads(graph_json.read_text(encoding="utf-8")) @@ -2406,7 +2573,10 @@ def main() -> None: labels_path = out / ".graphify_labels.json" if labels_path.exists(): try: - labels = {int(k): v for k, v in json.loads(labels_path.read_text(encoding="utf-8")).items()} + labels = { + int(k): v + for k, v in json.loads(labels_path.read_text(encoding="utf-8")).items() + } except Exception: labels = {cid: f"Community {cid}" for cid in communities} else: @@ -2414,16 +2584,30 @@ def main() -> None: questions = suggest_questions(G, communities, labels) tokens = {"input": 0, "output": 0} from graphify.export import _git_head as _gh + _commit = _gh() - report = generate(G, communities, cohesion, labels, gods, surprises, - {"warning": "cluster-only mode — file stats not available"}, - tokens, str(watch_path), suggested_questions=questions, - min_community_size=min_community_size, built_at_commit=_commit) + report = generate( + G, + communities, + cohesion, + labels, + gods, + surprises, + {"warning": "cluster-only mode — file stats not available"}, + tokens, + str(watch_path), + suggested_questions=questions, + min_community_size=min_community_size, + built_at_commit=_commit, + ) (out / "GRAPH_REPORT.md").write_text(report, encoding="utf-8") from graphify.export import backup_if_protected as _backup + _backup(out) to_json(G, communities, str(out / "graph.json")) - labels_path.write_text(json.dumps({str(k): v for k, v in labels.items()}, ensure_ascii=False), encoding="utf-8") + labels_path.write_text( + json.dumps({str(k): v for k, v in labels.items()}, ensure_ascii=False), encoding="utf-8" + ) # Mirror watch.py pattern: gate to_html so core outputs (graph.json + # GRAPH_REPORT.md) always land. Honor --no-viz explicitly; otherwise @@ -2433,20 +2617,27 @@ def main() -> None: if no_viz: if html_target.exists(): html_target.unlink() - print(f"Done - {len(communities)} communities. GRAPH_REPORT.md and graph.json updated (--no-viz; graph.html removed).") + print( + f"Done - {len(communities)} communities. GRAPH_REPORT.md and graph.json updated (--no-viz; graph.html removed)." + ) else: try: to_html(G, communities, str(html_target), community_labels=labels or None) - print(f"Done - {len(communities)} communities. GRAPH_REPORT.md, graph.json and graph.html updated.") + print( + f"Done - {len(communities)} communities. GRAPH_REPORT.md, graph.json and graph.html updated." + ) except ValueError as viz_err: if html_target.exists(): html_target.unlink() print(f"Skipped graph.html: {viz_err}") - print(f"Done - {len(communities)} communities. GRAPH_REPORT.md and graph.json updated.") + print( + f"Done - {len(communities)} communities. GRAPH_REPORT.md and graph.json updated." + ) elif cmd == "update": force = os.environ.get("GRAPHIFY_FORCE", "").lower() in ("1", "true", "yes") no_cluster = False + no_viz = False args = sys.argv[2:] watch_arg: str | None = None for a in args: @@ -2456,6 +2647,9 @@ def main() -> None: if a == "--no-cluster": no_cluster = True continue + if a == "--no-viz": + no_viz = True + continue if a.startswith("-"): print(f"error: unknown update option: {a}", file=sys.stderr) sys.exit(2) @@ -2477,13 +2671,22 @@ def main() -> None: print(f"error: path not found: {watch_path}", file=sys.stderr) sys.exit(1) from graphify.watch import _rebuild_code + print(f"Re-extracting code files in {watch_path} (no LLM needed)...") # Interactive CLI: block on the per-repo lock rather than skip, so the # user sees their explicit `graphify update` complete instead of # exiting silently when a hook-driven rebuild happens to be running. - ok = _rebuild_code(watch_path, force=force, no_cluster=no_cluster, block_on_lock=True) + ok = _rebuild_code( + watch_path, + force=force, + no_cluster=no_cluster, + no_viz=no_viz, + block_on_lock=True, + ) if ok: - print("Code graph updated. For doc/paper/image changes run /graphify --update in your AI assistant.") + print( + "Code graph updated. For doc/paper/image changes run /graphify --update in your AI assistant." + ) if not ( os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") @@ -2491,7 +2694,9 @@ def main() -> None: or os.environ.get("DEEPSEEK_API_KEY") or os.environ.get("GRAPHIFY_NO_TIPS") ): - print("Tip: set GEMINI_API_KEY or GOOGLE_API_KEY to use Gemini for semantic extraction.") + print( + "Tip: set GEMINI_API_KEY or GOOGLE_API_KEY to use Gemini for semantic extraction." + ) else: print("Nothing to update or rebuild failed - check output above.", file=sys.stderr) sys.exit(1) @@ -2506,6 +2711,7 @@ def main() -> None: print("Usage: graphify check-update ", file=sys.stderr) sys.exit(1) from graphify.watch import check_update + check_update(Path(sys.argv[2]).resolve()) sys.exit(0) elif cmd == "tree": @@ -2516,6 +2722,7 @@ def main() -> None: # showing top-K outbound edges per symbol. from typing import Optional as _Opt from graphify.tree_html import write_tree_html, DEFAULT_MAX_CHILDREN + graph_path = Path(_GRAPHIFY_OUT) / "graph.json" output_path: "_Opt[Path]" = None root: "_Opt[str]" = None @@ -2527,24 +2734,34 @@ def main() -> None: while i_arg < len(args): a = args[i_arg] if a == "--graph" and i_arg + 1 < len(args): - graph_path = Path(args[i_arg + 1]); i_arg += 2 + graph_path = Path(args[i_arg + 1]) + i_arg += 2 elif a == "--output" and i_arg + 1 < len(args): - output_path = Path(args[i_arg + 1]); i_arg += 2 + output_path = Path(args[i_arg + 1]) + i_arg += 2 elif a == "--root" and i_arg + 1 < len(args): - root = args[i_arg + 1]; i_arg += 2 + root = args[i_arg + 1] + i_arg += 2 elif a == "--max-children" and i_arg + 1 < len(args): - max_children = int(args[i_arg + 1]); i_arg += 2 + max_children = int(args[i_arg + 1]) + i_arg += 2 elif a == "--top-k-edges" and i_arg + 1 < len(args): - top_k_edges = int(args[i_arg + 1]); i_arg += 2 + top_k_edges = int(args[i_arg + 1]) + i_arg += 2 elif a == "--label" and i_arg + 1 < len(args): - project_label = args[i_arg + 1]; i_arg += 2 + project_label = args[i_arg + 1] + i_arg += 2 elif a in ("-h", "--help"): print("Usage: graphify tree [--graph PATH] [--output HTML]") print(" --graph PATH path to graph.json (default graphify-out/graph.json)") print(" --output HTML output path (default graphify-out/GRAPH_TREE.html)") - print(" --root PATH filesystem root (default: longest common dir of all source_files)") + print( + " --root PATH filesystem root (default: longest common dir of all source_files)" + ) print(" --max-children N cap visible children per node (default 200)") - print(" --top-k-edges N pre-compute top-K outbound edges per symbol (default 12)") + print( + " --top-k-edges N pre-compute top-K outbound edges per symbol (default 12)" + ) print(" --label NAME project label shown in the page header") return else: @@ -2556,9 +2773,12 @@ def main() -> None: if output_path is None: output_path = graph_path.parent / "GRAPH_TREE.html" out = write_tree_html( - graph_path=graph_path, output_path=output_path, - root=root, max_children=max_children, - top_k_edges=top_k_edges, project_label=project_label, + graph_path=graph_path, + output_path=output_path, + root=root, + max_children=max_children, + top_k_edges=top_k_edges, + project_label=project_label, ) size_kb = out.stat().st_size / 1024 print(f"wrote {out} ({size_kb:.1f} KB)") @@ -2583,6 +2803,7 @@ def main() -> None: _MERGE_MAX_NODES = 100_000 import networkx as _nx from networkx.readwrite import json_graph as _jg + def _load_graph(p: str): path_obj = Path(p) try: @@ -2598,24 +2819,25 @@ def _load_graph(p: str): return _jg.node_link_graph(data, edges="links"), data except TypeError: return _jg.node_link_graph(data), data + try: - G_cur, _ = _load_graph(_current_path) - G_oth, _ = _load_graph(_other_path) + current_graph, _ = _load_graph(_current_path) + other_graph, _ = _load_graph(_other_path) except Exception as exc: print(f"[graphify merge-driver] error loading graphs: {exc}", file=sys.stderr) sys.exit(1) # surface the conflict so git doesn't accept a corrupt merge - merged = _nx.compose(G_cur, G_oth) - if merged.number_of_nodes() > _MERGE_MAX_NODES: + merged_graph = _nx.compose(current_graph, other_graph) + if merged_graph.number_of_nodes() > _MERGE_MAX_NODES: print( - f"[graphify merge-driver] merged graph has {merged.number_of_nodes()} nodes, " + f"[graphify merge-driver] merged graph has {merged_graph.number_of_nodes()} nodes, " f"exceeds {_MERGE_MAX_NODES}-node cap; aborting merge.", file=sys.stderr, ) sys.exit(1) try: - out_data = _jg.node_link_data(merged, edges="links") + out_data = _jg.node_link_data(merged_graph, edges="links") except TypeError: - out_data = _jg.node_link_data(merged) + out_data = _jg.node_link_data(merged_graph) Path(_current_path).write_text(json.dumps(out_data, indent=2), encoding="utf-8") sys.exit(0) @@ -2627,16 +2849,22 @@ def _load_graph(p: str): i = 0 while i < len(args): if args[i] == "--out" and i + 1 < len(args): - out_path = Path(args[i + 1]); i += 2 + out_path = Path(args[i + 1]) + i += 2 else: - graph_paths.append(Path(args[i])); i += 1 + graph_paths.append(Path(args[i])) + i += 1 if len(graph_paths) < 2: - print("Usage: graphify merge-graphs [...] [--out merged.json]", file=sys.stderr) + print( + "Usage: graphify merge-graphs [...] [--out merged.json]", + file=sys.stderr, + ) sys.exit(1) import networkx as _nx from networkx.readwrite import json_graph as _jg from graphify.build import prefix_graph_for_global as _prefix - graphs = [] + + loaded_graphs = [] for gp in graph_paths: if not gp.exists(): print(f"error: not found: {gp}", file=sys.stderr) @@ -2648,27 +2876,32 @@ def _load_graph(p: str): if "links" not in data and "edges" in data: data = dict(data, links=data["edges"]) try: - G = _jg.node_link_graph(data, edges="links") + input_graph = _jg.node_link_graph(data, edges="links") except TypeError: - G = _jg.node_link_graph(data) - graphs.append(G) - merged = _nx.Graph() - for G, gp in zip(graphs, graph_paths): + input_graph = _jg.node_link_graph(data) + loaded_graphs.append(input_graph) + merged_graph = _nx.Graph() + for input_graph, gp in zip(loaded_graphs, graph_paths): repo_tag = gp.parent.parent.name # graphify-out/../ → repo dir name - prefixed = _prefix(G, repo_tag) - merged = _nx.compose(merged, prefixed) + prefixed = _prefix(input_graph, repo_tag) + merged_graph = _nx.compose(merged_graph, prefixed) try: - out_data = _jg.node_link_data(merged, edges="links") + out_data = _jg.node_link_data(merged_graph, edges="links") except TypeError: - out_data = _jg.node_link_data(merged) + out_data = _jg.node_link_data(merged_graph) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(out_data, indent=2), encoding="utf-8") - print(f"Merged {len(graphs)} graphs -> {merged.number_of_nodes()} nodes, {merged.number_of_edges()} edges") + print( + f"Merged {len(loaded_graphs)} graphs -> {merged_graph.number_of_nodes()} nodes, {merged_graph.number_of_edges()} edges" + ) print(f"Written to: {out_path}") elif cmd == "clone": if len(sys.argv) < 3: - print("Usage: graphify clone [--branch ] [--out ]", file=sys.stderr) + print( + "Usage: graphify clone [--branch ] [--out ]", + file=sys.stderr, + ) sys.exit(1) url = sys.argv[2] branch: str | None = None @@ -2677,9 +2910,11 @@ def _load_graph(p: str): i = 0 while i < len(args): if args[i] == "--branch" and i + 1 < len(args): - branch = args[i + 1]; i += 2 + branch = args[i + 1] + i += 2 elif args[i] == "--out" and i + 1 < len(args): - out_dir = Path(args[i + 1]); i += 2 + out_dir = Path(args[i + 1]) + i += 2 else: i += 1 local_path = _clone_repo(url, branch=branch, out_dir=out_dir) @@ -2689,15 +2924,29 @@ def _load_graph(p: str): subcmd = sys.argv[2] if len(sys.argv) > 2 else "" if subcmd not in ("html", "callflow-html", "obsidian", "wiki", "svg", "graphml", "neo4j"): print("Usage: graphify export ", file=sys.stderr) - print(" html [--graph PATH] [--labels PATH] [--node-limit N] [--no-viz]", file=sys.stderr) - print(" callflow-html [GRAPH|DIR] [--graph PATH] [--labels PATH] [--report PATH] [--sections PATH] [--output HTML]", file=sys.stderr) - print(" [--lang auto|zh-CN|en] [--max-sections N] [--diagram-scale N]", file=sys.stderr) + print( + " html [--graph PATH] [--labels PATH] [--node-limit N] [--no-viz]", + file=sys.stderr, + ) + print( + " callflow-html [GRAPH|DIR] [--graph PATH] [--labels PATH] [--report PATH] [--sections PATH] [--output HTML]", + file=sys.stderr, + ) + print( + " [--lang auto|zh-CN|en] [--max-sections N] [--diagram-scale N]", + file=sys.stderr, + ) print(" obsidian [--graph PATH] [--labels PATH] [--dir PATH]", file=sys.stderr) print(" wiki [--graph PATH] [--labels PATH]", file=sys.stderr) print(" svg [--graph PATH] [--labels PATH]", file=sys.stderr) print(" graphml [--graph PATH]", file=sys.stderr) - print(" neo4j [--graph PATH] [--push URI] [--user U] [--password P]", file=sys.stderr) - print(" (or set NEO4J_PASSWORD instead of --password to keep it off argv)", file=sys.stderr) + print( + " neo4j [--graph PATH] [--push URI] [--user U] [--password P]", file=sys.stderr + ) + print( + " (or set NEO4J_PASSWORD instead of --password to keep it off argv)", + file=sys.stderr, + ) sys.exit(1) # Parse shared args @@ -2741,27 +2990,37 @@ def _load_graph(p: str): report_path_explicit = True i += 2 elif a == "--sections" and i + 1 < len(args): - sections_path = Path(args[i + 1]); i += 2 + sections_path = Path(args[i + 1]) + i += 2 elif a == "--output" and i + 1 < len(args): callflow_output = Path(args[i + 1]).expanduser() if not callflow_output.is_absolute(): callflow_output = Path.cwd() / callflow_output i += 2 elif a == "--lang" and i + 1 < len(args): - callflow_lang = args[i + 1]; i += 2 + callflow_lang = args[i + 1] + i += 2 elif a == "--max-sections" and i + 1 < len(args): - callflow_max_sections = int(args[i + 1]); i += 2 + callflow_max_sections = int(args[i + 1]) + i += 2 elif a == "--diagram-scale" and i + 1 < len(args): - callflow_diagram_scale = float(args[i + 1]); i += 2 + callflow_diagram_scale = float(args[i + 1]) + i += 2 elif a == "--max-diagram-nodes" and i + 1 < len(args): - callflow_max_diagram_nodes = int(args[i + 1]); i += 2 + callflow_max_diagram_nodes = int(args[i + 1]) + i += 2 elif a == "--max-diagram-edges" and i + 1 < len(args): - callflow_max_diagram_edges = int(args[i + 1]); i += 2 + callflow_max_diagram_edges = int(args[i + 1]) + i += 2 elif a in ("-h", "--help") and subcmd == "callflow-html": - print("Usage: graphify export callflow-html [GRAPH|DIR] [--graph PATH] [--labels PATH]") + print( + "Usage: graphify export callflow-html [GRAPH|DIR] [--graph PATH] [--labels PATH]" + ) print(" --report PATH path to GRAPH_REPORT.md") print(" --sections PATH JSON section definitions") - print(" --output HTML output path (default graphify-out/-callflow.html)") + print( + " --output HTML output path (default graphify-out/-callflow.html)" + ) print(" --lang LANG auto, zh-CN, en, etc. (default auto)") print(" --max-sections N maximum auto-derived sections (default 15)") print(" --diagram-scale N Mermaid diagram scale (default 1.0)") @@ -2769,17 +3028,23 @@ def _load_graph(p: str): print(" --max-diagram-edges N representative edges per section (default 24)") sys.exit(0) elif a == "--node-limit" and i + 1 < len(args): - node_limit = int(args[i + 1]); i += 2 + node_limit = int(args[i + 1]) + i += 2 elif a == "--no-viz": - no_viz = True; i += 1 + no_viz = True + i += 1 elif a == "--dir" and i + 1 < len(args): - obsidian_dir = Path(args[i + 1]); i += 2 + obsidian_dir = Path(args[i + 1]) + i += 2 elif a == "--push" and i + 1 < len(args): - neo4j_uri = args[i + 1]; i += 2 + neo4j_uri = args[i + 1] + i += 2 elif a == "--user" and i + 1 < len(args): - neo4j_user = args[i + 1]; i += 2 + neo4j_user = args[i + 1] + i += 2 elif a == "--password" and i + 1 < len(args): - neo4j_password = args[i + 1]; i += 2 + neo4j_password = args[i + 1] + i += 2 elif subcmd == "callflow-html" and not a.startswith("-") and not graph_path_explicit: candidate = Path(a) if candidate.name == "graph.json" or candidate.suffix.lower() == ".json": @@ -2804,11 +3069,15 @@ def _load_graph(p: str): report_path = report_path.expanduser() if not graph_path.exists(): - print(f"error: graph not found: {graph_path}. Run /graphify first.", file=sys.stderr) + print( + f"error: graph not found: {graph_path}. Run /graphify first.", + file=sys.stderr, + ) sys.exit(1) if subcmd == "callflow-html": from graphify.callflow_html import write_callflow_html as _write_callflow_html + out = _write_callflow_html( graph=graph_path, report=report_path, @@ -2826,7 +3095,6 @@ def _load_graph(p: str): sys.exit(0) from networkx.readwrite import json_graph as _jg - from graphify.build import build_from_json as _bfj _enforce_graph_size_cap_or_exit(graph_path) _raw = json.loads(graph_path.read_text(encoding="utf-8")) @@ -2873,36 +3141,52 @@ def _load_graph(p: str): labels: dict[int, str] = {} if labels_path.exists(): - labels = {int(k): v for k, v in json.loads(labels_path.read_text(encoding="utf-8")).items()} + labels = { + int(k): v for k, v in json.loads(labels_path.read_text(encoding="utf-8")).items() + } out_dir = graph_path.parent if subcmd == "html": from graphify.export import to_html as _to_html + if no_viz: html_target = out_dir / "graph.html" if html_target.exists(): html_target.unlink() print("--no-viz: skipped graph.html") else: - _to_html(G, communities, str(out_dir / "graph.html"), - community_labels=labels or None, node_limit=node_limit) + _to_html( + G, + communities, + str(out_dir / "graph.html"), + community_labels=labels or None, + node_limit=node_limit, + ) if G.number_of_nodes() <= node_limit: - print(f"graph.html written - open in any browser, no server needed") + print("graph.html written - open in any browser, no server needed") elif subcmd == "obsidian": from graphify.export import to_obsidian as _to_obsidian, to_canvas as _to_canvas - n = _to_obsidian(G, communities, str(obsidian_dir), - community_labels=labels or None, cohesion=cohesion or None) + + n = _to_obsidian( + G, + communities, + str(obsidian_dir), + community_labels=labels or None, + cohesion=cohesion or None, + ) print(f"Obsidian vault: {n} notes in {obsidian_dir}/") - _to_canvas(G, communities, str(obsidian_dir / "graph.canvas"), - community_labels=labels or None) + _to_canvas( + G, communities, str(obsidian_dir / "graph.canvas"), community_labels=labels or None + ) print(f"Canvas: {obsidian_dir}/graph.canvas") print(f"Open {obsidian_dir}/ as a vault in Obsidian.") elif subcmd == "wiki": from graphify.wiki import to_wiki as _to_wiki from graphify.analyze import god_nodes as _god_nodes + if not communities: print( "error: .graphify_analysis.json is missing or empty — refusing to export wiki to prevent data loss.\n" @@ -2912,39 +3196,53 @@ def _load_graph(p: str): sys.exit(1) if not gods_data: gods_data = _god_nodes(G) - n = _to_wiki(G, communities, str(out_dir / "wiki"), - community_labels=labels or None, cohesion=cohesion or None, - god_nodes_data=gods_data) + n = _to_wiki( + G, + communities, + str(out_dir / "wiki"), + community_labels=labels or None, + cohesion=cohesion or None, + god_nodes_data=gods_data, + ) print(f"Wiki: {n} articles written to {out_dir}/wiki/") print(f" {out_dir}/wiki/index.md -> agent entry point") elif subcmd == "svg": from graphify.export import to_svg as _to_svg - _to_svg(G, communities, str(out_dir / "graph.svg"), - community_labels=labels or None) - print(f"graph.svg written - embeds in Obsidian, Notion, GitHub READMEs") + + _to_svg(G, communities, str(out_dir / "graph.svg"), community_labels=labels or None) + print("graph.svg written - embeds in Obsidian, Notion, GitHub READMEs") elif subcmd == "graphml": from graphify.export import to_graphml as _to_graphml + _to_graphml(G, communities, str(out_dir / "graph.graphml")) - print(f"graph.graphml written - open in Gephi, yEd, or any GraphML tool") + print("graph.graphml written - open in Gephi, yEd, or any GraphML tool") elif subcmd == "neo4j": if neo4j_uri: from graphify.export import push_to_neo4j as _push + if neo4j_password is None: print("error: --password required for --push", file=sys.stderr) sys.exit(1) - result = _push(G, uri=neo4j_uri, user=neo4j_user, - password=neo4j_password, communities=communities) + result = _push( + G, + uri=neo4j_uri, + user=neo4j_user, + password=neo4j_password, + communities=communities, + ) print(f"Pushed to Neo4j: {result['nodes']} nodes, {result['edges']} edges") else: from graphify.export import to_cypher as _to_cypher + _to_cypher(G, str(out_dir / "cypher.txt")) print(f"cypher.txt written - import with: cypher-shell < {out_dir}/cypher.txt") elif cmd == "benchmark": from graphify.benchmark import run_benchmark, print_benchmark + graph_path = sys.argv[2] if len(sys.argv) > 2 else "graphify-out/graph.json" _enforce_graph_size_cap_or_exit(Path(graph_path)) # Try to load corpus_words from detect output @@ -2954,8 +3252,11 @@ def _load_graph(p: str): try: detect_data = json.loads(detect_path.read_text(encoding="utf-8")) corpus_words = detect_data.get("total_words") - except Exception: - pass + except Exception as exc: + print( + f"[graphify] warning: could not read .graphify_detect.json: {exc}", + file=sys.stderr, + ) result = run_benchmark(graph_path, corpus_words=corpus_words) print_benchmark(result) @@ -2967,6 +3268,7 @@ def _load_graph(p: str): global_list as _global_list, global_path as _global_path, ) + if subcmd == "add": # graphify global add [--as ] args = sys.argv[3:] @@ -2975,9 +3277,11 @@ def _load_graph(p: str): i = 0 while i < len(args): if args[i] == "--as" and i + 1 < len(args): - tag = args[i + 1]; i += 2 + tag = args[i + 1] + i += 2 elif not source: - source = Path(args[i]); i += 1 + source = Path(args[i]) + i += 1 else: i += 1 if not source: @@ -2989,19 +3293,24 @@ def _load_graph(p: str): if result["skipped"]: print(f"'{tag}' unchanged since last add - global graph not modified.") else: - print(f"Added '{tag}' to global graph: +{result['nodes_added']} nodes, " - f"-{result['nodes_removed']} pruned. Global: {_global_path()}") + print( + f"Added '{tag}' to global graph: +{result['nodes_added']} nodes, " + f"-{result['nodes_removed']} pruned. Global: {_global_path()}" + ) except Exception as exc: - print(f"error: {exc}", file=sys.stderr); sys.exit(1) + print(f"error: {exc}", file=sys.stderr) + sys.exit(1) elif subcmd == "remove": tag = sys.argv[3] if len(sys.argv) > 3 else "" if not tag: - print("Usage: graphify global remove ", file=sys.stderr); sys.exit(1) + print("Usage: graphify global remove ", file=sys.stderr) + sys.exit(1) try: removed = _global_remove(tag) print(f"Removed '{tag}' from global graph ({removed} nodes pruned).") except KeyError as exc: - print(f"error: {exc}", file=sys.stderr); sys.exit(1) + print(f"error: {exc}", file=sys.stderr) + sys.exit(1) elif subcmd == "list": repos = _global_list() if not repos: @@ -3009,12 +3318,14 @@ def _load_graph(p: str): else: print(f"Global graph: {_global_path()}") for tag, info in repos.items(): - print(f" {tag}: {info.get('node_count', '?')} nodes, added {info.get('added_at', '?')[:10]}") + print( + f" {tag}: {info.get('node_count', '?')} nodes, added {info.get('added_at', '?')[:10]}" + ) elif subcmd == "path": print(_global_path()) else: - print("Usage: graphify global [add|remove|list|path]", file=sys.stderr); sys.exit(1) - + print("Usage: graphify global [add|remove|list|path]", file=sys.stderr) + sys.exit(1) elif cmd == "extract": # Headless full-pipeline extraction for CI / scripts (#698). # Runs detect -> AST extraction on code -> semantic LLM extraction on @@ -3083,59 +3394,86 @@ def _parse_float(name: str, raw: str) -> float: while i < len(args): a = args[i] if a == "--backend" and i + 1 < len(args): - backend = args[i + 1]; i += 2 + backend = args[i + 1] + i += 2 elif a.startswith("--backend="): - backend = a.split("=", 1)[1]; i += 1 + backend = a.split("=", 1)[1] + i += 1 elif a == "--model" and i + 1 < len(args): - model = args[i + 1]; i += 2 + model = args[i + 1] + i += 2 elif a.startswith("--model="): - model = a.split("=", 1)[1]; i += 1 + model = a.split("=", 1)[1] + i += 1 elif a == "--mode" and i + 1 < len(args): - extract_mode = args[i + 1]; i += 2 + extract_mode = args[i + 1] + i += 2 elif a.startswith("--mode="): - extract_mode = a.split("=", 1)[1]; i += 1 + extract_mode = a.split("=", 1)[1] + i += 1 elif a == "--out" and i + 1 < len(args): - out_dir = Path(args[i + 1]); i += 2 + out_dir = Path(args[i + 1]) + i += 2 elif a.startswith("--out="): - out_dir = Path(a.split("=", 1)[1]); i += 1 + out_dir = Path(a.split("=", 1)[1]) + i += 1 elif a == "--no-cluster": - no_cluster = True; i += 1 + no_cluster = True + i += 1 elif a == "--dedup-llm": - dedup_llm = True; i += 1 + dedup_llm = True + i += 1 elif a == "--google-workspace": - google_workspace = True; i += 1 + google_workspace = True + i += 1 elif a == "--global": - global_merge = True; i += 1 + global_merge = True + i += 1 elif a == "--as" and i + 1 < len(args): - global_repo_tag = args[i + 1]; i += 2 + global_repo_tag = args[i + 1] + i += 2 elif a == "--max-workers" and i + 1 < len(args): - cli_max_workers = _parse_int("--max-workers", args[i + 1]); i += 2 + cli_max_workers = _parse_int("--max-workers", args[i + 1]) + i += 2 elif a.startswith("--max-workers="): - cli_max_workers = _parse_int("--max-workers", a.split("=", 1)[1]); i += 1 + cli_max_workers = _parse_int("--max-workers", a.split("=", 1)[1]) + i += 1 elif a == "--token-budget" and i + 1 < len(args): - cli_token_budget = _parse_int("--token-budget", args[i + 1]); i += 2 + cli_token_budget = _parse_int("--token-budget", args[i + 1]) + i += 2 elif a.startswith("--token-budget="): - cli_token_budget = _parse_int("--token-budget", a.split("=", 1)[1]); i += 1 + cli_token_budget = _parse_int("--token-budget", a.split("=", 1)[1]) + i += 1 elif a == "--max-concurrency" and i + 1 < len(args): - cli_max_concurrency = _parse_int("--max-concurrency", args[i + 1]); i += 2 + cli_max_concurrency = _parse_int("--max-concurrency", args[i + 1]) + i += 2 elif a.startswith("--max-concurrency="): - cli_max_concurrency = _parse_int("--max-concurrency", a.split("=", 1)[1]); i += 1 + cli_max_concurrency = _parse_int("--max-concurrency", a.split("=", 1)[1]) + i += 1 elif a == "--api-timeout" and i + 1 < len(args): - cli_api_timeout = _parse_float("--api-timeout", args[i + 1]); i += 2 + cli_api_timeout = _parse_float("--api-timeout", args[i + 1]) + i += 2 elif a.startswith("--api-timeout="): - cli_api_timeout = _parse_float("--api-timeout", a.split("=", 1)[1]); i += 1 + cli_api_timeout = _parse_float("--api-timeout", a.split("=", 1)[1]) + i += 1 elif a == "--resolution" and i + 1 < len(args): - cli_resolution = _parse_float("--resolution", args[i + 1]); i += 2 + cli_resolution = _parse_float("--resolution", args[i + 1]) + i += 2 elif a.startswith("--resolution="): - cli_resolution = _parse_float("--resolution", a.split("=", 1)[1]); i += 1 + cli_resolution = _parse_float("--resolution", a.split("=", 1)[1]) + i += 1 elif a == "--exclude-hubs" and i + 1 < len(args): - cli_exclude_hubs = float(args[i + 1]); i += 2 + cli_exclude_hubs = float(args[i + 1]) + i += 2 elif a.startswith("--exclude-hubs="): - cli_exclude_hubs = float(a.split("=", 1)[1]); i += 1 + cli_exclude_hubs = float(a.split("=", 1)[1]) + i += 1 elif a == "--exclude" and i + 1 < len(args): - cli_excludes.append(args[i + 1]); i += 2 + cli_excludes.append(args[i + 1]) + i += 2 elif a.startswith("--exclude="): - cli_excludes.append(a.split("=", 1)[1]); i += 1 + cli_excludes.append(a.split("=", 1)[1]) + i += 1 else: i += 1 @@ -3170,6 +3508,7 @@ def _parse_float(name: str, raw: str) -> float: _format_backend_env_keys, _get_backend_api_key, ) + if backend is None: backend = _detect_backend() if backend is None: @@ -3183,8 +3522,7 @@ def _parse_float(name: str, raw: str) -> float: sys.exit(1) if backend not in _BACKENDS: print( - f"error: unknown backend '{backend}'. " - f"Available: {', '.join(sorted(_BACKENDS))}", + f"error: unknown backend '{backend}'. Available: {', '.join(sorted(_BACKENDS))}", file=sys.stderr, ) sys.exit(1) @@ -3196,18 +3534,18 @@ def _parse_float(name: str, raw: str) -> float: allow_no_key = False if backend == "ollama": from urllib.parse import urlparse - ollama_url = os.environ.get( - "OLLAMA_BASE_URL", - _BACKENDS["ollama"].get("base_url", ""), + + ollama_url = str( + os.environ.get( + "OLLAMA_BASE_URL", + _BACKENDS["ollama"].get("base_url", ""), + ) ) try: host = (urlparse(ollama_url).hostname or "").lower() except Exception: host = "" - allow_no_key = ( - host in ("localhost", "127.0.0.1", "::1") - or host.startswith("127.") - ) + allow_no_key = host in ("localhost", "127.0.0.1", "::1") or host.startswith("127.") elif backend == "bedrock": allow_no_key = bool( os.environ.get("AWS_PROFILE") @@ -3217,6 +3555,7 @@ def _parse_float(name: str, raw: str) -> float: ) elif backend == "claude-cli": import shutil as _shutil + allow_no_key = _shutil.which("claude") is not None if not allow_no_key: print( @@ -3235,7 +3574,7 @@ def _parse_float(name: str, raw: str) -> float: # Resolve output dir. The user-facing contract is "/graphify-out/" # so a fresh checkout writes graphify-out/ at the project root, matching # the skill.md pipeline. - out_root = (out_dir.resolve() if out_dir else target) + out_root = out_dir.resolve() if out_dir else target graphify_out = out_root / "graphify-out" graphify_out.mkdir(parents=True, exist_ok=True) @@ -3244,6 +3583,7 @@ def _parse_float(name: str, raw: str) -> float: detect_incremental as _detect_incremental, save_manifest as _save_manifest, ) + manifest_path = graphify_out / "manifest.json" existing_graph_path = graphify_out / "graph.json" incremental_mode = manifest_path.exists() and existing_graph_path.exists() @@ -3258,7 +3598,11 @@ def _parse_float(name: str, raw: str) -> float: ) else: print(f"[graphify extract] scanning {target}") - detection = _detect(target, google_workspace=google_workspace or None, extra_excludes=cli_excludes or None) + detection = _detect( + target, + google_workspace=google_workspace or None, + extra_excludes=cli_excludes or None, + ) files_by_type = detection.get("files", {}) if incremental_mode: @@ -3296,6 +3640,7 @@ def _parse_float(name: str, raw: str) -> float: ast_result: dict = {"nodes": [], "edges": [], "input_tokens": 0, "output_tokens": 0} if code_files: from graphify.extract import extract as _ast_extract + ast_kwargs: dict = {"cache_root": target} if cli_max_workers is not None: ast_kwargs["max_workers"] = cli_max_workers @@ -3311,16 +3656,20 @@ def _parse_float(name: str, raw: str) -> float: check_semantic_cache as _check_semantic_cache, save_semantic_cache as _save_semantic_cache, ) + sem_result: dict = { - "nodes": [], "edges": [], "hyperedges": [], - "input_tokens": 0, "output_tokens": 0, + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, } sem_cache_hits = 0 sem_cache_misses = 0 if semantic_files: sem_paths_str = [str(p) for p in semantic_files] - cached_nodes, cached_edges, cached_hyperedges, uncached_paths = ( - _check_semantic_cache(sem_paths_str, root=target) + cached_nodes, cached_edges, cached_hyperedges, uncached_paths = _check_semantic_cache( + sem_paths_str, root=target ) sem_cache_hits = len(semantic_files) - len(uncached_paths) sem_cache_misses = len(uncached_paths) @@ -3328,10 +3677,14 @@ def _parse_float(name: str, raw: str) -> float: sem_result["edges"].extend(cached_edges) sem_result["hyperedges"].extend(cached_hyperedges) if sem_cache_hits: - print(f"[graphify extract] semantic cache: {sem_cache_hits} hit / {sem_cache_misses} miss") + print( + f"[graphify extract] semantic cache: {sem_cache_hits} hit / {sem_cache_misses} miss" + ) if uncached_paths: - print(f"[graphify extract] semantic extraction on {len(uncached_paths)} files via {backend}...") + print( + f"[graphify extract] semantic extraction on {len(uncached_paths)} files via {backend}..." + ) corpus_kwargs: dict = { "backend": backend, "model": model, @@ -3349,6 +3702,7 @@ def _parse_float(name: str, raw: str) -> float: # Also track per-chunk success so we can fail loudly when # every chunk errors (e.g. missing backend SDK package). _chunk_stats = {"total": 0, "succeeded": 0} + def _progress(idx: int, total: int, _result: dict) -> None: _chunk_stats["total"] = total _chunk_stats["succeeded"] += 1 @@ -3356,6 +3710,7 @@ def _progress(idx: int, total: int, _result: dict) -> None: f"[graphify extract] chunk {idx + 1}/{total} done", flush=True, ) + corpus_kwargs["on_chunk_done"] = _progress try: @@ -3371,7 +3726,13 @@ def _progress(idx: int, total: int, _result: dict) -> None: f"[graphify extract] semantic extraction failed: {exc}", file=sys.stderr, ) - fresh = {"nodes": [], "edges": [], "hyperedges": [], "input_tokens": 0, "output_tokens": 0} + fresh = { + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } # on_chunk_done only fires after a chunk succeeds. If fresh # semantic extraction was requested and no chunks completed, @@ -3393,7 +3754,10 @@ def _progress(idx: int, total: int, _result: dict) -> None: root=target, ) except Exception as exc: - print(f"[graphify extract] warning: could not write semantic cache: {exc}", file=sys.stderr) + print( + f"[graphify extract] warning: could not write semantic cache: {exc}", + file=sys.stderr, + ) sem_result["nodes"].extend(fresh.get("nodes", [])) sem_result["edges"].extend(fresh.get("edges", [])) sem_result["hyperedges"].extend(fresh.get("hyperedges", [])) @@ -3409,7 +3773,8 @@ def _progress(idx: int, total: int, _result: dict) -> None: "edges": list(ast_result.get("edges", [])) + list(sem_result.get("edges", [])), "hyperedges": list(sem_result.get("hyperedges", [])), "input_tokens": ast_result.get("input_tokens", 0) + sem_result.get("input_tokens", 0), - "output_tokens": ast_result.get("output_tokens", 0) + sem_result.get("output_tokens", 0), + "output_tokens": ast_result.get("output_tokens", 0) + + sem_result.get("output_tokens", 0), } graph_json_path = graphify_out / "graph.json" @@ -3421,9 +3786,7 @@ def _progress(idx: int, total: int, _result: dict) -> None: # their semantic_hash empty so detect_incremental re-queues them (#933). _sem_extracted: set[str] = { n.get("source_file", "") for n in sem_result.get("nodes", []) - } | { - e.get("source_file", "") for e in sem_result.get("edges", []) - } + } | {e.get("source_file", "") for e in sem_result.get("edges", [])} _sem_extracted.discard("") _sem_types = {"document", "paper", "image"} _manifest_files = { @@ -3435,13 +3798,10 @@ def _progress(idx: int, total: int, _result: dict) -> None: # --no-cluster: dump the raw merged extraction as graph.json. # No NetworkX, no community detection, no analysis sidecar. from graphify.export import backup_if_protected as _backup + _backup(graphify_out) - graph_json_path.write_text( - json.dumps(merged, indent=2), encoding="utf-8" - ) - cost = _estimate_cost( - backend, merged["input_tokens"], merged["output_tokens"] - ) + graph_json_path.write_text(json.dumps(merged, indent=2), encoding="utf-8") + cost = _estimate_cost(backend, merged["input_tokens"], merged["output_tokens"]) print( f"[graphify extract] wrote {graph_json_path} — " f"{len(merged['nodes'])} nodes, {len(merged['edges'])} edges " @@ -3457,30 +3817,38 @@ def _progress(idx: int, total: int, _result: dict) -> None: try: _save_manifest(_manifest_files, manifest_path=str(manifest_path), kind="both") except Exception as exc: - print(f"[graphify extract] warning: could not write manifest: {exc}", file=sys.stderr) + print( + f"[graphify extract] warning: could not write manifest: {exc}", file=sys.stderr + ) if global_merge: from graphify.global_graph import global_add as _global_add + _tag = global_repo_tag or target.name try: result = _global_add(graphify_out / "graph.json", _tag) if result["skipped"]: print(f"[graphify global] '{_tag}' unchanged since last add - skipped.") else: - print(f"[graphify global] '{_tag}' merged into global graph " - f"(+{result['nodes_added']} nodes, -{result['nodes_removed']} pruned).") + print( + f"[graphify global] '{_tag}' merged into global graph " + f"(+{result['nodes_added']} nodes, -{result['nodes_removed']} pruned)." + ) except Exception as exc: - print(f"[graphify global] warning: failed to merge into global graph: {exc}", file=sys.stderr) + print( + f"[graphify global] warning: failed to merge into global graph: {exc}", + file=sys.stderr, + ) sys.exit(0) # Build graph + cluster + score + write. from graphify.build import ( build as _build, - build_from_json as _build_from_json, build_merge as _build_merge, ) from graphify.cluster import cluster as _cluster, score_all as _score_all from graphify.export import to_json as _to_json from graphify.analyze import god_nodes as _god_nodes, surprising_connections as _surprising + dedup_backend = backend if dedup_llm else None if incremental_mode: G = _build_merge( @@ -3502,7 +3870,9 @@ def _progress(idx: int, total: int, _result: dict) -> None: ) sys.exit(1) - communities = _cluster(G, resolution=cli_resolution, exclude_hubs_percentile=cli_exclude_hubs) + communities = _cluster( + G, resolution=cli_resolution, exclude_hubs_percentile=cli_exclude_hubs + ) cohesion = _score_all(G, communities) try: gods = _god_nodes(G) @@ -3514,6 +3884,7 @@ def _progress(idx: int, total: int, _result: dict) -> None: surprises = [] from graphify.export import backup_if_protected as _backup + _backup(graphify_out) _to_json(G, communities, str(graph_json_path), force=True) if merged.get("output_tokens", 0) > 0: @@ -3522,16 +3893,22 @@ def _progress(idx: int, total: int, _result: dict) -> None: ) if global_merge: from graphify.global_graph import global_add as _global_add + _tag = global_repo_tag or target.name try: result = _global_add(graphify_out / "graph.json", _tag) if result["skipped"]: print(f"[graphify global] '{_tag}' unchanged since last add - skipped.") else: - print(f"[graphify global] '{_tag}' merged into global graph " - f"(+{result['nodes_added']} nodes, -{result['nodes_removed']} pruned).") + print( + f"[graphify global] '{_tag}' merged into global graph " + f"(+{result['nodes_added']} nodes, -{result['nodes_removed']} pruned)." + ) except Exception as exc: - print(f"[graphify global] warning: failed to merge into global graph: {exc}", file=sys.stderr) + print( + f"[graphify global] warning: failed to merge into global graph: {exc}", + file=sys.stderr, + ) analysis = { "communities": {str(k): v for k, v in communities.items()}, "cohesion": {str(k): v for k, v in cohesion.items()}, @@ -3563,7 +3940,9 @@ def _progress(idx: int, total: int, _result: dict) -> None: f"{len(deleted_files)} deleted" ) elif sem_cache_hits: - print(f"[graphify extract] semantic cache: {sem_cache_hits} cached, {sem_cache_misses} re-extracted") + print( + f"[graphify extract] semantic cache: {sem_cache_hits} cached, {sem_cache_misses} re-extracted" + ) if merged["input_tokens"] or merged["output_tokens"]: print( f"[graphify extract] tokens: " @@ -3580,26 +3959,31 @@ def _progress(idx: int, total: int, _result: dict) -> None: # graphify-out/.graphify_uncached.txt — paths that need extraction # Stdout: "Cache: N hit, M miss" from graphify.cache import check_semantic_cache + if len(sys.argv) < 3: print("Usage: graphify cache-check [--root ]", file=sys.stderr) sys.exit(1) files_from = Path(sys.argv[2]) - root = Path(".") + cache_root = Path(".") i = 3 while i < len(sys.argv): if sys.argv[i] == "--root" and i + 1 < len(sys.argv): - root = Path(sys.argv[i + 1]) + cache_root = Path(sys.argv[i + 1]) i += 2 else: i += 1 files = [f for f in files_from.read_text(encoding="utf-8").splitlines() if f.strip()] - cached_nodes, cached_edges, cached_hyperedges, uncached = check_semantic_cache(files, root) - out = root / "graphify-out" + cached_nodes, cached_edges, cached_hyperedges, uncached = check_semantic_cache( + files, cache_root + ) + out = cache_root / "graphify-out" out.mkdir(parents=True, exist_ok=True) if cached_nodes or cached_edges or cached_hyperedges: (out / ".graphify_cached.json").write_text( - json.dumps({"nodes": cached_nodes, "edges": cached_edges, "hyperedges": cached_hyperedges}, - ensure_ascii=False), + json.dumps( + {"nodes": cached_nodes, "edges": cached_edges, "hyperedges": cached_hyperedges}, + ensure_ascii=False, + ), encoding="utf-8", ) (out / ".graphify_uncached.txt").write_text("\n".join(uncached), encoding="utf-8") @@ -3610,6 +3994,7 @@ def _progress(idx: int, total: int, _result: dict) -> None: # Concatenates .graphify_chunk_*.json files written by semantic subagents. # Deduplicates nodes by id (first writer wins). Sums token counts. import glob as _glob + if len(sys.argv) < 3: print("Usage: graphify merge-chunks --out ", file=sys.stderr) sys.exit(1) @@ -3630,7 +4015,13 @@ def _progress(idx: int, total: int, _result: dict) -> None: for arg in chunk_args: expanded = _glob.glob(arg) chunk_files.extend(sorted(expanded) if expanded else [arg]) - merged: dict = {"nodes": [], "edges": [], "hyperedges": [], "input_tokens": 0, "output_tokens": 0} + merged: dict = { + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } seen_ids: set[str] = set() for cf in chunk_files: try: @@ -3658,7 +4049,10 @@ def _progress(idx: int, total: int, _result: dict) -> None: # Merges cached semantic results with freshly-extracted chunk results. # Deduplicates nodes by id (cached entries take priority over new ones). if len(sys.argv) < 3: - print("Usage: graphify merge-semantic --cached --new --out ", file=sys.stderr) + print( + "Usage: graphify merge-semantic --cached --new --out ", + file=sys.stderr, + ) sys.exit(1) cached_path: Path | None = None new_path: Path | None = None @@ -3666,19 +4060,30 @@ def _progress(idx: int, total: int, _result: dict) -> None: i = 2 while i < len(sys.argv): if sys.argv[i] == "--cached" and i + 1 < len(sys.argv): - cached_path = Path(sys.argv[i + 1]); i += 2 + cached_path = Path(sys.argv[i + 1]) + i += 2 elif sys.argv[i] == "--new" and i + 1 < len(sys.argv): - new_path = Path(sys.argv[i + 1]); i += 2 + new_path = Path(sys.argv[i + 1]) + i += 2 elif sys.argv[i] == "--out" and i + 1 < len(sys.argv): - out_path2 = Path(sys.argv[i + 1]); i += 2 + out_path2 = Path(sys.argv[i + 1]) + i += 2 else: i += 1 if not out_path2: print("error: --out required", file=sys.stderr) sys.exit(1) empty: dict = {"nodes": [], "edges": [], "hyperedges": []} - cached_data = json.loads(cached_path.read_text(encoding="utf-8")) if cached_path and cached_path.exists() else empty - new_data = json.loads(new_path.read_text(encoding="utf-8")) if new_path and new_path.exists() else empty + cached_data = ( + json.loads(cached_path.read_text(encoding="utf-8")) + if cached_path and cached_path.exists() + else empty + ) + new_data = ( + json.loads(new_path.read_text(encoding="utf-8")) + if new_path and new_path.exists() + else empty + ) seen_ids2: set[str] = set() all_nodes: list[dict] = [] for n in cached_data.get("nodes", []) + new_data.get("nodes", []): diff --git a/graphify/affected.py b/graphify/affected.py index 109eaa95e..0d81e6eda 100644 --- a/graphify/affected.py +++ b/graphify/affected.py @@ -3,7 +3,7 @@ from collections import deque from dataclasses import dataclass from pathlib import Path -from typing import Iterable +from typing import Any, Iterable, cast import networkx as nx @@ -87,8 +87,9 @@ def affected_nodes( current, current_depth = queue.popleft() if current_depth >= depth: continue - if hasattr(graph, "in_edges"): - incoming = graph.in_edges(current, data=True) + graph_any = cast(Any, graph) + if hasattr(graph_any, "in_edges"): + incoming = graph_any.in_edges(current, data=True) else: incoming = ( (source, target, data) diff --git a/graphify/analyze.py b/graphify/analyze.py index f3e08103d..0812d528d 100644 --- a/graphify/analyze.py +++ b/graphify/analyze.py @@ -1,9 +1,11 @@ """Graph analysis: god nodes (most connected), surprising connections (cross-community), suggested questions.""" + from __future__ import annotations from pathlib import Path import networkx as nx from graphify.build import edge_data +from graphify.detect import CODE_EXTENSIONS, IMAGE_EXTENSIONS, PAPER_EXTENSIONS # Language families — extensions sharing a runtime can legitimately call each other _LANG_FAMILY: dict[str, str] = { @@ -53,6 +55,7 @@ def _is_file_node(G: nx.Graph, node_id: str) -> bool: source_file = attrs.get("source_file", "") if source_file: from pathlib import Path as _Path + if label == _Path(source_file).name: return True # Method stub: AST extractor labels methods as '.method_name()' @@ -65,12 +68,29 @@ def _is_file_node(G: nx.Graph, node_id: str) -> bool: return False -_JSON_NOISE_LABELS: frozenset[str] = frozenset({ - "start", "end", "name", "id", "type", "properties", - "value", "key", "data", "items", "title", "description", "version", - "dependencies", "devdependencies", "peerdependencies", - "optionaldependencies", "bundleddependencies", "bundledependencies", -}) +_JSON_NOISE_LABELS: frozenset[str] = frozenset( + { + "start", + "end", + "name", + "id", + "type", + "properties", + "value", + "key", + "data", + "items", + "title", + "description", + "version", + "dependencies", + "devdependencies", + "peerdependencies", + "optionaldependencies", + "bundleddependencies", + "bundledependencies", + } +) def _is_json_key_node(G: nx.Graph, node_id: str) -> bool: @@ -92,13 +112,19 @@ def god_nodes(G: nx.Graph, top_n: int = 10) -> list[dict]: sorted_nodes = sorted(degree.items(), key=lambda x: x[1], reverse=True) result = [] for node_id, deg in sorted_nodes: - if _is_file_node(G, node_id) or _is_concept_node(G, node_id) or _is_json_key_node(G, node_id): + if ( + _is_file_node(G, node_id) + or _is_concept_node(G, node_id) + or _is_json_key_node(G, node_id) + ): continue - result.append({ - "id": node_id, - "label": G.nodes[node_id].get("label", node_id), - "degree": deg, - }) + result.append( + { + "id": node_id, + "label": G.nodes[node_id].get("label", node_id), + "degree": deg, + } + ) if len(result) >= top_n: break return result @@ -124,9 +150,7 @@ def surprising_connections( """ # Identify unique source files (ignore empty/null source_file) source_files = { - data.get("source_file", "") - for _, data in G.nodes(data=True) - if data.get("source_file", "") + data.get("source_file", "") for _, data in G.nodes(data=True) if data.get("source_file", "") } is_multi_source = len(source_files) > 1 @@ -155,9 +179,6 @@ def _is_concept_node(G: nx.Graph, node_id: str) -> bool: return False -from graphify.detect import CODE_EXTENSIONS, DOC_EXTENSIONS, PAPER_EXTENSIONS, IMAGE_EXTENSIONS - - def _file_category(path: str) -> str: ext = ("." + path.rsplit(".", 1)[-1].lower()) if "." in path else "" if ext in CODE_EXTENSIONS: @@ -288,18 +309,20 @@ def _cross_file_surprises(G: nx.Graph, communities: dict[int, list[str]], top_n: tgt_id = data.get("_tgt", v) if tgt_id not in G.nodes: tgt_id = v - candidates.append({ - "_score": score, - "source": G.nodes[src_id].get("label", src_id), - "target": G.nodes[tgt_id].get("label", tgt_id), - "source_files": [ - G.nodes[src_id].get("source_file", ""), - G.nodes[tgt_id].get("source_file", ""), - ], - "confidence": data.get("confidence", "EXTRACTED"), - "relation": relation, - "why": "; ".join(reasons) if reasons else "cross-file semantic connection", - }) + candidates.append( + { + "_score": score, + "source": G.nodes[src_id].get("label", src_id), + "target": G.nodes[tgt_id].get("label", tgt_id), + "source_files": [ + G.nodes[src_id].get("source_file", ""), + G.nodes[tgt_id].get("source_file", ""), + ], + "confidence": data.get("confidence", "EXTRACTED"), + "relation": relation, + "why": "; ".join(reasons) if reasons else "cross-file semantic connection", + } + ) candidates.sort(key=lambda x: x["_score"], reverse=True) for c in candidates: @@ -334,17 +357,19 @@ def _cross_community_surprises( result = [] for (u, v), score in top_edges: data = edge_data(G, u, v) - result.append({ - "source": G.nodes[u].get("label", u), - "target": G.nodes[v].get("label", v), - "source_files": [ - G.nodes[u].get("source_file", ""), - G.nodes[v].get("source_file", ""), - ], - "confidence": data.get("confidence", "EXTRACTED"), - "relation": data.get("relation", ""), - "note": f"Bridges graph structure (betweenness={score:.3f})", - }) + result.append( + { + "source": G.nodes[u].get("label", u), + "target": G.nodes[v].get("label", v), + "source_files": [ + G.nodes[u].get("source_file", ""), + G.nodes[v].get("source_file", ""), + ], + "confidence": data.get("confidence", "EXTRACTED"), + "relation": data.get("relation", ""), + "note": f"Bridges graph structure (betweenness={score:.3f})", + } + ) return result # Build node → community map @@ -370,18 +395,20 @@ def _cross_community_surprises( tgt_id = data.get("_tgt", v) if tgt_id not in G.nodes: tgt_id = v - surprises.append({ - "source": G.nodes[src_id].get("label", src_id), - "target": G.nodes[tgt_id].get("label", tgt_id), - "source_files": [ - G.nodes[src_id].get("source_file", ""), - G.nodes[tgt_id].get("source_file", ""), - ], - "confidence": confidence, - "relation": relation, - "note": f"Bridges community {cid_u} → community {cid_v}", - "_pair": tuple(sorted([cid_u, cid_v])), - }) + surprises.append( + { + "source": G.nodes[src_id].get("label", src_id), + "target": G.nodes[tgt_id].get("label", tgt_id), + "source_files": [ + G.nodes[src_id].get("source_file", ""), + G.nodes[tgt_id].get("source_file", ""), + ], + "confidence": confidence, + "relation": relation, + "note": f"Bridges community {cid_u} → community {cid_v}", + "_pair": tuple(sorted([cid_u, cid_v])), + } + ) # Sort: AMBIGUOUS first, then INFERRED, then EXTRACTED order = {"AMBIGUOUS": 0, "INFERRED": 1, "EXTRACTED": 2} @@ -411,7 +438,9 @@ def suggest_questions( Each question has a 'type', 'question', and 'why' field. """ if community_labels: - community_labels = {int(k) if isinstance(k, str) else k: v for k, v in community_labels.items()} + community_labels = { + int(k) if isinstance(k, str) else k: v for k, v in community_labels.items() + } questions = [] node_community = _node_community_map(communities) @@ -422,11 +451,13 @@ def suggest_questions( ul = G.nodes[u].get("label", u) vl = G.nodes[v].get("label", v) relation = data.get("relation", "related to") - questions.append({ - "type": "ambiguous_edge", - "question": f"What is the exact relationship between `{ul}` and `{vl}`?", - "why": f"Edge tagged AMBIGUOUS (relation: {relation}) - confidence is low.", - }) + questions.append( + { + "type": "ambiguous_edge", + "question": f"What is the exact relationship between `{ul}` and `{vl}`?", + "why": f"Edge tagged AMBIGUOUS (relation: {relation}) - confidence is low.", + } + ) # 2. Bridge nodes (high betweenness) → cross-cutting concern questions if G.number_of_edges() > 0: @@ -434,24 +465,35 @@ def suggest_questions( betweenness = nx.betweenness_centrality(G, k=k, seed=42) # Top bridge nodes that are NOT file-level hubs bridges = sorted( - [(n, s) for n, s in betweenness.items() - if not _is_file_node(G, n) and not _is_concept_node(G, n) and s > 0], + [ + (n, s) + for n, s in betweenness.items() + if not _is_file_node(G, n) and not _is_concept_node(G, n) and s > 0 + ], key=lambda x: x[1], reverse=True, )[:3] for node_id, score in bridges: label = G.nodes[node_id].get("label", node_id) cid = node_community.get(node_id) - comm_label = community_labels.get(cid, f"Community {cid}") if cid is not None else "unknown" + comm_label = ( + community_labels.get(cid, f"Community {cid}") if cid is not None else "unknown" + ) neighbors = list(G.neighbors(node_id)) - neighbor_comms = {node_community.get(n) for n in neighbors if node_community.get(n) != cid} + neighbor_comms = { + other_cid + for n in neighbors + if (other_cid := node_community.get(n)) is not None and other_cid != cid + } if neighbor_comms: other_labels = [community_labels.get(c, f"Community {c}") for c in neighbor_comms] - questions.append({ - "type": "bridge_node", - "question": f"Why does `{label}` connect `{comm_label}` to {', '.join(f'`{l}`' for l in other_labels)}?", - "why": f"High betweenness centrality ({score:.3f}) - this node is a cross-community bridge.", - }) + questions.append( + { + "type": "bridge_node", + "question": f"Why does `{label}` connect `{comm_label}` to {', '.join(f'`{label}`' for label in other_labels)}?", + "why": f"High betweenness centrality ({score:.3f}) - this node is a cross-community bridge.", + } + ) # 3. God nodes with many INFERRED edges → verification questions degree = dict(G.degree()) @@ -462,7 +504,8 @@ def suggest_questions( )[:5] for node_id, _ in top_nodes: inferred = [ - (u, v, d) for u, v, d in G.edges(node_id, data=True) + (u, v, d) + for u, v, d in G.edges(node_id, data=True) if d.get("confidence") == "INFERRED" ] if len(inferred) >= 2: @@ -478,48 +521,58 @@ def suggest_questions( tgt_id = v other_id = tgt_id if src_id == node_id else src_id others.append(G.nodes[other_id].get("label", other_id)) - questions.append({ - "type": "verify_inferred", - "question": f"Are the {len(inferred)} inferred relationships involving `{label}` (e.g. with `{others[0]}` and `{others[1]}`) actually correct?", - "why": f"`{label}` has {len(inferred)} INFERRED edges - model-reasoned connections that need verification.", - }) + questions.append( + { + "type": "verify_inferred", + "question": f"Are the {len(inferred)} inferred relationships involving `{label}` (e.g. with `{others[0]}` and `{others[1]}`) actually correct?", + "why": f"`{label}` has {len(inferred)} INFERRED edges - model-reasoned connections that need verification.", + } + ) # 4. Isolated or weakly-connected nodes → exploration questions isolated = [ - n for n in G.nodes() + n + for n in G.nodes() if G.degree(n) <= 1 and not _is_file_node(G, n) and not _is_concept_node(G, n) ] if isolated: labels = [G.nodes[n].get("label", n) for n in isolated[:3]] - questions.append({ - "type": "isolated_nodes", - "question": f"What connects {', '.join(f'`{l}`' for l in labels)} to the rest of the system?", - "why": f"{len(isolated)} weakly-connected nodes found - possible documentation gaps or missing edges.", - }) + questions.append( + { + "type": "isolated_nodes", + "question": f"What connects {', '.join(f'`{label}`' for label in labels)} to the rest of the system?", + "why": f"{len(isolated)} weakly-connected nodes found - possible documentation gaps or missing edges.", + } + ) # 5. Low-cohesion communities → structural questions from .cluster import cohesion_score + for cid, nodes in communities.items(): score = cohesion_score(G, nodes) if score < 0.15 and len(nodes) >= 5: label = community_labels.get(cid, f"Community {cid}") - questions.append({ - "type": "low_cohesion", - "question": f"Should `{label}` be split into smaller, more focused modules?", - "why": f"Cohesion score {score} - nodes in this community are weakly interconnected.", - }) + questions.append( + { + "type": "low_cohesion", + "question": f"Should `{label}` be split into smaller, more focused modules?", + "why": f"Cohesion score {score} - nodes in this community are weakly interconnected.", + } + ) if not questions: - return [{ - "type": "no_signal", - "question": None, - "why": ( - "Not enough signal to generate questions. " - "This usually means the corpus has no AMBIGUOUS edges, no bridge nodes, " - "no INFERRED relationships, and all communities are tightly cohesive. " - "Add more files or run with --mode deep to extract richer edges." - ), - }] + return [ + { + "type": "no_signal", + "question": None, + "why": ( + "Not enough signal to generate questions. " + "This usually means the corpus has no AMBIGUOUS edges, no bridge nodes, " + "no INFERRED relationships, and all communities are tightly cohesive. " + "Add more files or run with --mode deep to extract richer edges." + ), + } + ] return questions[:top_n] @@ -542,13 +595,9 @@ def graph_diff(G_old: nx.Graph, G_new: nx.Graph) -> dict: added_node_ids = new_nodes - old_nodes removed_node_ids = old_nodes - new_nodes - new_nodes_list = [ - {"id": n, "label": G_new.nodes[n].get("label", n)} - for n in added_node_ids - ] + new_nodes_list = [{"id": n, "label": G_new.nodes[n].get("label", n)} for n in added_node_ids] removed_nodes_list = [ - {"id": n, "label": G_old.nodes[n].get("label", n)} - for n in removed_node_ids + {"id": n, "label": G_old.nodes[n].get("label", n)} for n in removed_node_ids ] def edge_key(G: nx.Graph, u: str, v: str, data: dict) -> tuple: @@ -556,14 +605,8 @@ def edge_key(G: nx.Graph, u: str, v: str, data: dict) -> tuple: return (u, v, data.get("relation", "")) return (min(u, v), max(u, v), data.get("relation", "")) - old_edge_keys = { - edge_key(G_old, u, v, d) - for u, v, d in G_old.edges(data=True) - } - new_edge_keys = { - edge_key(G_new, u, v, d) - for u, v, d in G_new.edges(data=True) - } + old_edge_keys = {edge_key(G_old, u, v, d) for u, v, d in G_old.edges(data=True)} + new_edge_keys = {edge_key(G_new, u, v, d) for u, v, d in G_new.edges(data=True)} added_edge_keys = new_edge_keys - old_edge_keys removed_edge_keys = old_edge_keys - new_edge_keys @@ -571,22 +614,26 @@ def edge_key(G: nx.Graph, u: str, v: str, data: dict) -> tuple: new_edges_list = [] for u, v, d in G_new.edges(data=True): if edge_key(G_new, u, v, d) in added_edge_keys: - new_edges_list.append({ - "source": u, - "target": v, - "relation": d.get("relation", ""), - "confidence": d.get("confidence", ""), - }) + new_edges_list.append( + { + "source": u, + "target": v, + "relation": d.get("relation", ""), + "confidence": d.get("confidence", ""), + } + ) removed_edges_list = [] for u, v, d in G_old.edges(data=True): if edge_key(G_old, u, v, d) in removed_edge_keys: - removed_edges_list.append({ - "source": u, - "target": v, - "relation": d.get("relation", ""), - "confidence": d.get("confidence", ""), - }) + removed_edges_list.append( + { + "source": u, + "target": v, + "relation": d.get("relation", ""), + "confidence": d.get("confidence", ""), + } + ) parts = [] if new_nodes_list: @@ -594,9 +641,13 @@ def edge_key(G: nx.Graph, u: str, v: str, data: dict) -> tuple: if new_edges_list: parts.append(f"{len(new_edges_list)} new edge{'s' if len(new_edges_list) != 1 else ''}") if removed_nodes_list: - parts.append(f"{len(removed_nodes_list)} node{'s' if len(removed_nodes_list) != 1 else ''} removed") + parts.append( + f"{len(removed_nodes_list)} node{'s' if len(removed_nodes_list) != 1 else ''} removed" + ) if removed_edges_list: - parts.append(f"{len(removed_edges_list)} edge{'s' if len(removed_edges_list) != 1 else ''} removed") + parts.append( + f"{len(removed_edges_list)} edge{'s' if len(removed_edges_list) != 1 else ''} removed" + ) summary = ", ".join(parts) if parts else "no changes" return { diff --git a/graphify/benchmark.py b/graphify/benchmark.py index eabade292..cf2ab4666 100644 --- a/graphify/benchmark.py +++ b/graphify/benchmark.py @@ -1,4 +1,5 @@ """Token-reduction benchmark - measures how much context graphify saves vs naive full-corpus approach.""" + from __future__ import annotations import json import sys @@ -66,11 +67,15 @@ def _query_subgraph_tokens(G: nx.Graph, question: str, depth: int = 3) -> int: lines = [] for nid in visited: d = G.nodes[nid] - lines.append(f"NODE {d.get('label', nid)} src={d.get('source_file', '')} loc={d.get('source_location', '')}") + lines.append( + f"NODE {d.get('label', nid)} src={d.get('source_file', '')} loc={d.get('source_location', '')}" + ) for u, v in edges_seen: if u in visited and v in visited: d = edge_data(G, u, v) - lines.append(f"EDGE {G.nodes[u].get('label', u)} --{d.get('relation', '')}--> {G.nodes[v].get('label', v)}") + lines.append( + f"EDGE {G.nodes[u].get('label', u)} --{d.get('relation', '')}--> {G.nodes[v].get('label', v)}" + ) return _estimate_tokens("\n".join(lines)) @@ -99,6 +104,7 @@ def run_benchmark( Returns dict with: corpus_tokens, avg_query_tokens, reduction_ratio, per_question """ from graphify.security import check_graph_file_size_cap + check_graph_file_size_cap(Path(graph_path)) data = json.loads(Path(graph_path).read_text(encoding="utf-8")) try: @@ -108,16 +114,20 @@ def run_benchmark( if corpus_words is None: # Rough estimate: each node label is ~3 words, plus source context - corpus_words = G.number_of_nodes() * 50 + estimated_corpus_words = G.number_of_nodes() * 50 + else: + estimated_corpus_words = corpus_words - corpus_tokens = corpus_words * 100 // 75 # words → tokens (100 words ≈ 133 tokens) + corpus_tokens = estimated_corpus_words * 100 // 75 # words → tokens (100 words ≈ 133 tokens) qs = questions or _SAMPLE_QUESTIONS per_question = [] for q in qs: qt = _query_subgraph_tokens(G, q) if qt > 0: - per_question.append({"question": q, "query_tokens": qt, "reduction": round(corpus_tokens / qt, 1)}) + per_question.append( + {"question": q, "query_tokens": qt, "reduction": round(corpus_tokens / qt, 1)} + ) if not per_question: return {"error": "No matching nodes found for sample questions. Build the graph first."} @@ -127,7 +137,7 @@ def run_benchmark( return { "corpus_tokens": corpus_tokens, - "corpus_words": corpus_words, + "corpus_words": estimated_corpus_words, "nodes": G.number_of_nodes(), "edges": G.number_of_edges(), "avg_query_tokens": avg_query_tokens, @@ -142,14 +152,16 @@ def print_benchmark(result: dict) -> None: print(f"Benchmark error: {result['error']}") return - print(f"\ngraphify token reduction benchmark") + print("\ngraphify token reduction benchmark") print(_hr(50)) arrow = _safe("→", "->") - print(f" Corpus: {result['corpus_words']:,} words {arrow} ~{result['corpus_tokens']:,} tokens (naive)") + print( + f" Corpus: {result['corpus_words']:,} words {arrow} ~{result['corpus_tokens']:,} tokens (naive)" + ) print(f" Graph: {result['nodes']:,} nodes, {result['edges']:,} edges") print(f" Avg query cost: ~{result['avg_query_tokens']:,} tokens") print(f" Reduction: {result['reduction_ratio']}x fewer tokens per query") - print(f"\n Per question:") + print("\n Per question:") for p in result["per_question"]: print(f" [{p['reduction']}x] {p['question'][:55]}") print() diff --git a/graphify/cache.py b/graphify/cache.py index 2052cf7aa..73fff35e0 100644 --- a/graphify/cache.py +++ b/graphify/cache.py @@ -20,7 +20,7 @@ def _body_content(content: bytes) -> bytes: if text.startswith("---"): end = text.find("\n---", 3) if end != -1: - return text[end + 4:].encode() + return text[end + 4 :].encode() return content @@ -86,6 +86,7 @@ def _flush_stat_index() -> None: def _normalize_path(path: Path) -> Path: """Normalize path for consistent cache keys across Windows path spellings.""" import sys + if sys.platform != "win32": return path s = str(path) @@ -120,9 +121,7 @@ def file_hash(path: Path, root: Path = Path(".")) -> str: try: st = p.stat() entry = _stat_index.get(abs_key) - if (entry - and entry.get("size") == st.st_size - and entry.get("mtime_ns") == st.st_mtime_ns): + if entry and entry.get("size") == st.st_size and entry.get("mtime_ns") == st.st_mtime_ns: return entry["hash"] except OSError: pass @@ -216,6 +215,7 @@ def save_cached(path: Path, result: dict, root: Path = Path("."), kind: str = "a # Windows: os.replace can fail with WinError 5 if the target is # briefly locked. Fall back to copy-then-delete. import shutil + shutil.copy2(tmp_path, entry) os.unlink(tmp_path) except Exception: @@ -313,7 +313,7 @@ def save_semantic_cache( src = e.get("source_file", "") if src: by_file[src]["edges"].append(e) - for h in (hyperedges or []): + for h in hyperedges or []: src = h.get("source_file", "") if src: by_file[src]["hyperedges"].append(h) diff --git a/graphify/callflow_html.py b/graphify/callflow_html.py index 6195adb98..9f67a9c07 100644 --- a/graphify/callflow_html.py +++ b/graphify/callflow_html.py @@ -28,6 +28,7 @@ from pathlib import Path from collections import Counter, defaultdict from datetime import datetime, timezone +from typing import Any, cast from html import escape @@ -89,7 +90,8 @@ # 2. Data loading and normalization helpers # ────────────────────────────────────────────── -def read_json(path: str | Path, default=None): + +def read_json(path: str | Path | None, default: Any = None) -> Any: """Read JSON with a useful error message.""" if not path: return default @@ -204,7 +206,9 @@ def normalize_edge(raw: dict, index: int) -> dict | None: if not source or not target: return None - relation = first_present(edge, "relation", "type", "kind", "label", "predicate", default="relates") + relation = first_present( + edge, "relation", "type", "kind", "label", "predicate", default="relates" + ) confidence = first_present(edge, "confidence", "evidence", "provenance", default="EXTRACTED") score = first_present(edge, "confidence_score", "score", "weight", "probability", default=1.0) @@ -254,6 +258,7 @@ def load_graph(path: str | Path) -> tuple: """Load graph.json. Returns normalized (nodes, edges, hyperedges, metadata).""" if path: from graphify.security import check_graph_file_size_cap + try: check_graph_file_size_cap(Path(path)) except ValueError as exc: @@ -262,16 +267,32 @@ def load_graph(path: str | Path) -> tuple: if not isinstance(data, dict): raise SystemExit(f"ERROR: graph file must contain a JSON object: {path}") - graph_block = data.get("graph") if isinstance(data.get("graph"), dict) else {} - meta_block = data.get("metadata") if isinstance(data.get("metadata"), dict) else {} + graph_block: dict[str, Any] = ( + cast(dict[str, Any], data.get("graph")) if isinstance(data.get("graph"), dict) else {} + ) + meta_block: dict[str, Any] = ( + cast(dict[str, Any], data.get("metadata")) if isinstance(data.get("metadata"), dict) else {} + ) node_link = _node_link_payload(data) if node_link: raw_nodes, raw_edges = node_link else: - raw_nodes = first_list(data.get("nodes"), data.get("vertices"), graph_block.get("nodes"), graph_block.get("vertices")) - raw_edges = first_list(data.get("links"), data.get("edges"), graph_block.get("links"), graph_block.get("edges")) - hyperedges = first_list(data.get("hyperedges"), graph_block.get("hyperedges"), data.get("groups"), graph_block.get("groups")) + raw_nodes = first_list( + data.get("nodes"), + data.get("vertices"), + graph_block.get("nodes"), + graph_block.get("vertices"), + ) + raw_edges = first_list( + data.get("links"), data.get("edges"), graph_block.get("links"), graph_block.get("edges") + ) + hyperedges = first_list( + data.get("hyperedges"), + graph_block.get("hyperedges"), + data.get("groups"), + graph_block.get("groups"), + ) nodes = [normalize_node(n, i) for i, n in enumerate(raw_nodes) if isinstance(n, dict)] edges = [] @@ -282,9 +303,16 @@ def load_graph(path: str | Path) -> tuple: if edge: edges.append(edge) - meta = dict(graph_block) + meta: dict[str, Any] = dict(graph_block) meta.update(meta_block) - for key in ("built_at_commit", "commit", "project_name", "repo", "repository", "language_breakdown"): + for key in ( + "built_at_commit", + "commit", + "project_name", + "repo", + "repository", + "language_breakdown", + ): if data.get(key) and not meta.get(key): meta[key] = data.get(key) if meta.get("commit") and not meta.get("built_at_commit"): @@ -331,6 +359,7 @@ def load_report(path: str | Path | None) -> str: # 3. Mermaid-safe label helpers # ────────────────────────────────────────────── + def safe_mermaid_text(text: str) -> str: """Sanitize text for use inside a Mermaid node label. @@ -344,10 +373,10 @@ def safe_mermaid_text(text: str) -> str: """ text = str(text or "") text = text.replace('"', "'") - text = text.replace('`', '') - text = text.replace('#', '') - text = text.replace('|', ' ') - text = text.replace('{', '').replace('}', '') + text = text.replace("`", "") + text = text.replace("#", "") + text = text.replace("|", " ") + text = text.replace("{", "").replace("}", "") text = text.replace("->>", " to ").replace("-->", " to ").replace("->", " to ") text = " ".join(text.split()) return escape(text, quote=False) @@ -424,7 +453,9 @@ def resolve_graphify_paths(args) -> dict: project_root = graphify_out.parent if graphify_out.name == "graphify-out" else base graph = Path(args.graph).expanduser() if args.graph else graphify_out / "graph.json" report = Path(args.report).expanduser() if args.report else graphify_out / "GRAPH_REPORT.md" - labels = Path(args.labels).expanduser() if args.labels else graphify_out / ".graphify_labels.json" + labels = ( + Path(args.labels).expanduser() if args.labels else graphify_out / ".graphify_labels.json" + ) sections = Path(args.sections).expanduser() if args.sections else None return { "base": project_root, @@ -509,13 +540,39 @@ def node_kind(node: dict) -> str: if any(word in label for word in ("async", "await", "stream", "sse")): return "async" raw_label = str(node.get("label") or "") - hook_like = raw_label.startswith("use") and len(raw_label) > 3 and (raw_label[3].isupper() or raw_label[3] in "_-") - if any(word in label for word in ("component", "props", "hook", "store")) or hook_like or source_file.endswith((".tsx", ".jsx", ".vue", ".svelte")): + hook_like = ( + raw_label.startswith("use") + and len(raw_label) > 3 + and (raw_label[3].isupper() or raw_label[3] in "_-") + ) + if ( + any(word in label for word in ("component", "props", "hook", "store")) + or hook_like + or source_file.endswith((".tsx", ".jsx", ".vue", ".svelte")) + ): return "ui" raw = raw_label if raw[:1].isupper() and not raw.endswith("()"): return "klass" - if raw.endswith((".py", ".ts", ".tsx", ".js", ".jsx", ".go", ".rs", ".java", ".kt", ".rb", ".php", ".cs", ".swift", ".vue", ".svelte")): + if raw.endswith( + ( + ".py", + ".ts", + ".tsx", + ".js", + ".jsx", + ".go", + ".rs", + ".java", + ".kt", + ".rb", + ".php", + ".cs", + ".swift", + ".vue", + ".svelte", + ) + ): return "module" return "function" @@ -632,6 +689,7 @@ def mermaid_class_defs() -> list: # 4. Community and section indexing # ────────────────────────────────────────────── + def build_community_index(nodes: list) -> dict: """Map community_id (str) -> list of nodes.""" idx = defaultdict(list) @@ -652,7 +710,9 @@ def html_anchor_id(raw: str, fallback: str, used: set) -> str: base = base[:48].strip("-") or "section" candidate = base if candidate in used: - candidate = f"{base}-{hashlib.sha1(raw.encode('utf-8'), usedforsecurity=False).hexdigest()[:6]}" + candidate = ( + f"{base}-{hashlib.sha1(raw.encode('utf-8'), usedforsecurity=False).hexdigest()[:6]}" + ) suffix = 2 while candidate in used: candidate = f"{base}-{suffix}" @@ -688,11 +748,13 @@ def normalize_sections(sections: list, lang: str) -> list: continue sid = html_anchor_id(raw_id, f"section-{index}", used) - normalized.append({ - "id": sid, - "name": raw_name, - "communities": normalize_communities(raw.get("communities", raw.get("community"))), - }) + normalized.append( + { + "id": sid, + "name": raw_name, + "communities": normalize_communities(raw.get("communities", raw.get("community"))), + } + ) return normalized @@ -712,9 +774,22 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "提取管线", "Extraction Pipeline", { - "extract", "extractor", "tree", "sitter", "parser", "language", - "python", "javascript", "typescript", "rust", "java", "go", - "ast", "calls", "imports", "multilang", + "extract", + "extractor", + "tree", + "sitter", + "parser", + "language", + "python", + "javascript", + "typescript", + "rust", + "java", + "go", + "ast", + "calls", + "imports", + "multilang", }, ), ( @@ -722,8 +797,17 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "图谱构建", "Graph Build", { - "build", "graph", "merge", "dedup", "node", "edge", "hyperedge", - "json", "schema", "normalize", "confidence", + "build", + "graph", + "merge", + "dedup", + "node", + "edge", + "hyperedge", + "json", + "schema", + "normalize", + "confidence", }, ), ( @@ -731,8 +815,18 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "分析聚类", "Analysis & Clustering", { - "cluster", "community", "leiden", "cohesion", "analyze", "god", - "surprise", "question", "query", "path", "explain", "benchmark", + "cluster", + "community", + "leiden", + "cohesion", + "analyze", + "god", + "surprise", + "question", + "query", + "path", + "explain", + "benchmark", }, ), ( @@ -740,8 +834,18 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "输出文档", "Outputs & Docs", { - "export", "html", "wiki", "obsidian", "canvas", "svg", "graphml", - "report", "callflow", "mermaid", "tree", "documentation", + "export", + "html", + "wiki", + "obsidian", + "canvas", + "svg", + "graphml", + "report", + "callflow", + "mermaid", + "tree", + "documentation", }, ), ( @@ -749,9 +853,20 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "CLI 与技能安装", "CLI & Skill Installers", { - "main", "install", "uninstall", "skill", "agent", "claude", - "codex", "opencode", "aider", "copilot", "kiro", "vscode", - "hook", "command", + "main", + "install", + "uninstall", + "skill", + "agent", + "claude", + "codex", + "opencode", + "aider", + "copilot", + "kiro", + "vscode", + "hook", + "command", }, ), ( @@ -759,9 +874,21 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "摄取与增量更新", "Ingestion & Updates", { - "ingest", "fetch", "download", "url", "html", "markdown", - "cache", "manifest", "watch", "update", "incremental", - "transcribe", "video", "audio", "google", + "ingest", + "fetch", + "download", + "url", + "html", + "markdown", + "cache", + "manifest", + "watch", + "update", + "incremental", + "transcribe", + "video", + "audio", + "google", }, ), ( @@ -769,8 +896,17 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "服务 API", "Serving API", { - "serve", "api", "request", "response", "endpoint", "router", - "handle", "upload", "search", "delete", "enrich", + "serve", + "api", + "request", + "response", + "endpoint", + "router", + "handle", + "upload", + "search", + "delete", + "enrich", }, ), ( @@ -778,8 +914,17 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "安全与全局图", "Security & Global Graph", { - "security", "safe", "ssrf", "xss", "path", "traversal", - "global", "prefix", "prune", "repo", "clone", + "security", + "safe", + "ssrf", + "xss", + "path", + "traversal", + "global", + "prefix", + "prune", + "repo", + "clone", }, ), ( @@ -787,8 +932,14 @@ def label_for_community(cid: str, labels: dict, nodes: list, lang: str) -> str: "测试与样例", "Tests & Fixtures", { - "test", "tests", "fixture", "fixtures", "sample", "assert", - "pytest", "mock", + "test", + "tests", + "fixture", + "fixtures", + "sample", + "assert", + "pytest", + "mock", }, ), ] @@ -826,14 +977,24 @@ def _rank_grouped_sections(grouped: dict, max_sections: int) -> tuple[list, list return selected, overflow_communities -def derive_sections_from_communities(nodes: list, labels: dict, lang: str, max_sections: int) -> list: +def derive_sections_from_communities( + nodes: list, labels: dict, lang: str, max_sections: int +) -> list: """Derive architecture-oriented sections when no sections JSON is supplied.""" comm_idx = build_community_index(nodes) - sections = [{"id": "overview", "name": pick_text(lang, "架构总览", "Architecture Overview"), "communities": []}] + sections = [ + { + "id": "overview", + "name": pick_text(lang, "架构总览", "Architecture Overview"), + "communities": [], + } + ] grouped = {} unassigned = [] - for cid, community_nodes in sorted(comm_idx.items(), key=lambda item: (-len(item[1]), str(item[0]))): + for cid, community_nodes in sorted( + comm_idx.items(), key=lambda item: (-len(item[1]), str(item[0])) + ): label = label_for_community(cid, labels, community_nodes, lang) text = _community_text(community_nodes, label) best = None @@ -861,7 +1022,9 @@ def derive_sections_from_communities(nodes: list, labels: dict, lang: str, max_s else: unassigned.append((cid, community_nodes, label)) - selected, overflow_communities = _rank_grouped_sections(grouped, max(1, int(max_sections or 15)) - 1) + selected, overflow_communities = _rank_grouped_sections( + grouped, max(1, int(max_sections or 15)) - 1 + ) sections.extend( {"id": sec["id"], "name": sec["name"], "communities": sec["communities"]} for sec in selected @@ -869,15 +1032,19 @@ def derive_sections_from_communities(nodes: list, labels: dict, lang: str, max_s remaining_slots = max(0, int(max_sections or 15) - (len(sections) - 1) - 1) for cid, community_nodes, label in unassigned[:remaining_slots]: - sections.append({"id": str(label or f"community-{cid}"), "name": label, "communities": [cid]}) + sections.append( + {"id": str(label or f"community-{cid}"), "name": label, "communities": [cid]} + ) other_communities = overflow_communities + [cid for cid, _, _ in unassigned[remaining_slots:]] if other_communities: - sections.append({ - "id": "other", - "name": pick_text(lang, "其他", "Other"), - "communities": other_communities, - }) + sections.append( + { + "id": "other", + "name": pick_text(lang, "其他", "Other"), + "communities": other_communities, + } + ) return sections @@ -905,6 +1072,7 @@ def node_in_section(node_id: str, section_node_ids: set) -> bool: # 5. Edge analysis # ────────────────────────────────────────────── + def classify_edges(edges: list, section_nodes_map: dict) -> dict: """Classify edges as intra-section or inter-section. @@ -958,14 +1126,15 @@ def should_include_edge(edge: dict) -> bool: # 6. Mermaid diagram generators # ────────────────────────────────────────────── -def node_degree_scores(edges: list) -> Counter: + +def node_degree_scores(edges: list) -> dict[str, float]: """Score nodes by useful edge participation.""" - scores = Counter() + scores: defaultdict[str, float] = defaultdict(float) for edge in edges: score = edge_score(edge) - scores[edge.get("source", "")] += score - scores[edge.get("target", "")] += score - return scores + scores[str(edge.get("source", ""))] += score + scores[str(edge.get("target", ""))] += score + return dict(scores) def node_importance(node: dict) -> float: @@ -1037,7 +1206,10 @@ def fallback_key(node: dict) -> tuple: def node_label(node: dict) -> str: """Build a readable Mermaid node label.""" - label = humanize_label(node.get("label") or node.get("id"), node.get("source_file", "")) + label = humanize_label( + str(node.get("label") or node.get("id") or ""), + node.get("source_file", ""), + ) source_file = safe_file_path(node.get("source_file", "")) if source_file and not label.endswith(Path(source_file).name): return f"{safe_mermaid_text(label)}
{safe_mermaid_text(source_file)}" @@ -1056,7 +1228,9 @@ def group_nodes_by_file(nodes: list) -> dict: def section_edge_summary(classified_edges: dict) -> dict: """Aggregate inter-section edge counts and relation names.""" node_section = classified_edges.get("node_section", {}) - summary = defaultdict(lambda: {"count": 0, "relations": Counter()}) + summary: defaultdict[tuple[Any, Any], dict[str, Any]] = defaultdict( + lambda: {"count": 0, "relations": Counter()} + ) for edge in classified_edges.get("inter", []): if not should_include_edge(edge): continue @@ -1070,9 +1244,14 @@ def section_edge_summary(classified_edges: dict) -> dict: return summary -def generate_overview_graph(sections: list, section_nodes_map: dict, - classified_edges: dict, labels: dict, lang: str, - diagram_scale: float) -> str: +def generate_overview_graph( + sections: list, + section_nodes_map: dict, + classified_edges: dict, + labels: dict, + lang: str, + diagram_scale: float, +) -> str: """Generate a readable section-level architecture overview.""" lines = [mermaid_init(diagram_scale, "LR")] section_defs = [sec for sec in sections if sec["id"] != "overview"] @@ -1088,7 +1267,9 @@ def generate_overview_graph(sections: list, section_nodes_map: dict, lines.append(f" class {sid} module;") aggregated = section_edge_summary(classified_edges) - for (src, tgt), data in sorted(aggregated.items(), key=lambda item: item[1]["count"], reverse=True)[:12]: + for (src, tgt), data in sorted( + aggregated.items(), key=lambda item: item[1]["count"], reverse=True + )[:12]: src_id = mermaid_section_id(src) tgt_id = mermaid_section_id(tgt) relation, _ = data["relations"].most_common(1)[0] @@ -1099,19 +1280,29 @@ def generate_overview_graph(sections: list, section_nodes_map: dict, if not aggregated and len(section_defs) > 1: for prev, cur in zip(section_defs, section_defs[1:]): - lines.append(f" {mermaid_section_id(prev['id'])} -.-> {mermaid_section_id(cur['id'])}") + lines.append( + f" {mermaid_section_id(prev['id'])} -.-> {mermaid_section_id(cur['id'])}" + ) lines.extend(mermaid_class_defs()) return "\n".join(lines) -def generate_section_flowchart(section_id: str, section_name: str, - nodes: list, edges: list, lang: str, - diagram_scale: float, max_nodes: int, - max_edges: int) -> str: +def generate_section_flowchart( + section_id: str, + section_name: str, + nodes: list, + edges: list, + lang: str, + diagram_scale: float, + max_nodes: int, + max_edges: int, +) -> str: """Generate a compact, human-readable call-flow chart for a section.""" lines = [mermaid_init(diagram_scale, "LR")] - lines.append(f" %% Section: {safe_mermaid_text(section_name)} ({len(nodes)} nodes, {len(edges)} edges)") + lines.append( + f" %% Section: {safe_mermaid_text(section_name)} ({len(nodes)} nodes, {len(edges)} edges)" + ) if not nodes: empty_label = pick_text(lang, f"{section_name} - 无节点", f"{section_name} - no nodes") @@ -1122,12 +1313,14 @@ def generate_section_flowchart(section_id: str, section_name: str, selected_nodes = select_diagram_nodes(nodes, edges, max_nodes) selected_ids = {node.get("id") for node in selected_nodes} visible_edges = [ - edge for edge in preferred_edges(edges, allow_structure=False) + edge + for edge in preferred_edges(edges, allow_structure=False) if edge.get("source") in selected_ids and edge.get("target") in selected_ids ] if not visible_edges: visible_edges = [ - edge for edge in preferred_edges(edges, allow_structure=True) + edge + for edge in preferred_edges(edges, allow_structure=True) if edge.get("source") in selected_ids and edge.get("target") in selected_ids ] @@ -1160,7 +1353,9 @@ def generate_section_flowchart(section_id: str, section_name: str, omitted_nodes = max(0, len(nodes) - len(selected_nodes)) omitted_edges = max(0, len(visible_edges) - included) if omitted_nodes or omitted_edges: - lines.append(f" %% Omitted for readability: {omitted_nodes} nodes, {omitted_edges} edges") + lines.append( + f" %% Omitted for readability: {omitted_nodes} nodes, {omitted_edges} edges" + ) lines.extend(class_lines) lines.extend(mermaid_class_defs()) return "\n".join(lines) @@ -1170,6 +1365,7 @@ def generate_section_flowchart(section_id: str, section_name: str, # 7. HTML generators # ────────────────────────────────────────────── + def generate_nav(sections: list) -> str: """Generate the sticky navigation bar.""" links = [] @@ -1186,21 +1382,33 @@ def node_display_name(node: dict | None, fallback: str = "") -> str: return humanize_label(label, node.get("source_file", "")) -def format_node_refs(node_ids: set, node_by_id: dict, lang: str, empty_text: str, limit: int = 3) -> str: +def format_node_refs( + node_ids: set, node_by_id: dict, lang: str, empty_text: str, limit: int = 3 +) -> str: """Render node references as readable labels instead of internal IDs.""" if not node_ids: return escape(empty_text) parts = [] - for nid in sorted(node_ids, key=lambda item: node_display_name(node_by_id.get(item), item).lower())[:limit]: + for nid in sorted( + node_ids, key=lambda item: node_display_name(node_by_id.get(item), item).lower() + )[:limit]: node = node_by_id.get(nid) label = node_display_name(node, nid) source = safe_file_path((node or {}).get("source_file", "")) if source: - parts.append(f"{escape(label)}
{escape(source)}") + parts.append( + f'{escape(label)}
{escape(source)}' + ) else: parts.append(f"{escape(label)}") if len(node_ids) > limit: - parts.append(escape(pick_text(lang, f"+{len(node_ids) - limit} 个更多", f"+{len(node_ids) - limit} more"))) + parts.append( + escape( + pick_text( + lang, f"+{len(node_ids) - limit} 个更多", f"+{len(node_ids) - limit} more" + ) + ) + ) return "
".join(parts) @@ -1297,31 +1505,71 @@ def _describe_node(label: str, source_file: str, file_type: str, lang: str) -> s if file_type == "rationale": return pick_text(lang, f"设计说明:{label}", f"Design note for {label}.") if file_type == "document": - return pick_text(lang, f"文档入口,描述 {label} 相关能力。", f"Documentation node describing {label}.") + return pick_text( + lang, f"文档入口,描述 {label} 相关能力。", f"Documentation node describing {label}." + ) if label.endswith(".py") or label.endswith(".tsx") or label.endswith(".ts"): - return pick_text(lang, f"{source} 中的模块文件,承载该层主要实现。", f"Module file in {source}.") + return pick_text( + lang, f"{source} 中的模块文件,承载该层主要实现。", f"Module file in {source}." + ) if "config" in lower: - return pick_text(lang, "读取、解析或持久化项目配置。", "Reads, resolves, or persists project configuration.") + return pick_text( + lang, + "读取、解析或持久化项目配置。", + "Reads, resolves, or persists project configuration.", + ) if "scan" in lower: - return pick_text(lang, "触发项目扫描或处理扫描状态。", "Starts scanning or handles scan status.") + return pick_text( + lang, "触发项目扫描或处理扫描状态。", "Starts scanning or handles scan status." + ) if "ingest" in lower or "clone" in lower or "git" in lower: - return pick_text(lang, "把本地目录或远程仓库转换为分析上下文。", "Turns a local path or remote repository into analysis context.") + return pick_text( + lang, + "把本地目录或远程仓库转换为分析上下文。", + "Turns a local path or remote repository into analysis context.", + ) if "prompt" in lower: - return pick_text(lang, "构造发送给 LLM 的结构化提示。", "Builds structured prompts for model calls.") + return pick_text( + lang, "构造发送给 LLM 的结构化提示。", "Builds structured prompts for model calls." + ) if "analy" in lower: - return pick_text(lang, "编排分析流程并产出结构化文档数据。", "Orchestrates analysis and returns structured documentation data.") + return pick_text( + lang, + "编排分析流程并产出结构化文档数据。", + "Orchestrates analysis and returns structured documentation data.", + ) if "graph" in lower or "dependency" in lower: - return pick_text(lang, "构建依赖关系并提供排序或图形化数据。", "Builds dependency relationships and graph data.") + return pick_text( + lang, + "构建依赖关系并提供排序或图形化数据。", + "Builds dependency relationships and graph data.", + ) if "export" in lower or "markdown" in lower or "html" in lower: - return pick_text(lang, "将文档数据导出为目标格式。", "Exports documentation data to a target format.") + return pick_text( + lang, "将文档数据导出为目标格式。", "Exports documentation data to a target format." + ) if "chat" in lower or "rag" in lower or "retrieve" in lower: - return pick_text(lang, "支撑检索增强问答或流式聊天。", "Supports retrieval-augmented Q&A or streaming chat.") + return pick_text( + lang, + "支撑检索增强问答或流式聊天。", + "Supports retrieval-augmented Q&A or streaming chat.", + ) if "wiki" in lower or "page" in lower or "sidebar" in lower: - return pick_text(lang, "组织文档页面、侧边栏或内容读取。", "Organizes documentation pages, navigation, or content lookup.") + return pick_text( + lang, + "组织文档页面、侧边栏或内容读取。", + "Organizes documentation pages, navigation, or content lookup.", + ) if "cache" in lower or "hash" in lower: - return pick_text(lang, "缓存分析结果或生成缓存键。", "Caches analysis results or computes cache keys.") + return pick_text( + lang, "缓存分析结果或生成缓存键。", "Caches analysis results or computes cache keys." + ) if "test" in lower: - return pick_text(lang, "验证导入、入口点或版本等基础行为。", "Verifies imports, entry points, or version behavior.") + return pick_text( + lang, + "验证导入、入口点或版本等基础行为。", + "Verifies imports, entry points, or version behavior.", + ) return pick_text(lang, f"{source} 中的 {label} 节点。", f"{label} node in {source}.") @@ -1370,7 +1618,9 @@ def derive_flow_chain(sections: list, classified_edges: dict) -> str: seen = {start} current = start while len(chain) < min(7, len(order)): - candidates = [(count, tgt) for tgt, count in outgoing.get(current, {}).items() if tgt not in seen] + candidates = [ + (count, tgt) for tgt, count in outgoing.get(current, {}).items() if tgt not in seen + ] if candidates: _, nxt = max(candidates) else: @@ -1384,9 +1634,14 @@ def derive_flow_chain(sections: list, classified_edges: dict) -> str: return " -> ".join(section_names.get(sid, sid) for sid in chain) -def generate_overview_cards(meta: dict, report_text: str, sections: list, - section_nodes_map: dict, classified_edges: dict, - lang: str) -> str: +def generate_overview_cards( + meta: dict, + report_text: str, + sections: list, + section_nodes_map: dict, + classified_edges: dict, + lang: str, +) -> str: """Generate generic overview cards.""" rows = [] for sec in sections: @@ -1400,14 +1655,18 @@ def generate_overview_cards(meta: dict, report_text: str, sections: list, flow = derive_flow_chain(sections, classified_edges) layer_title = pick_text(lang, "架构层次", "Architecture Layers") - layer_cols = pick_text(lang, "层节点社区", "LayerNodesCommunities") + layer_cols = pick_text( + lang, + "层节点社区", + "LayerNodesCommunities", + ) flow_title = pick_text(lang, "核心数据流", "Core Flow") return f"""

{layer_title}

{layer_cols} - {''.join(rows)} + {"".join(rows)}
@@ -1421,12 +1680,40 @@ def section_keywords(nodes: list, limit: int = 5) -> list: """Pick representative words from labels and file names.""" counts = Counter() stopwords = { - "the", "and", "for", "with", "from", "this", "that", "class", "function", - "method", "file", "src", "lib", "core", "index", "main", "init", "py", - "ts", "tsx", "js", "jsx", "go", "rs", "java", "html", "css", + "the", + "and", + "for", + "with", + "from", + "this", + "that", + "class", + "function", + "method", + "file", + "src", + "lib", + "core", + "index", + "main", + "init", + "py", + "ts", + "tsx", + "js", + "jsx", + "go", + "rs", + "java", + "html", + "css", } for node in nodes: - text = f"{node.get('label', '')} {node.get('source_file', '')}".replace("/", " ").replace("_", " ").replace("-", " ") + text = ( + f"{node.get('label', '')} {node.get('source_file', '')}".replace("/", " ") + .replace("_", " ") + .replace("-", " ") + ) for raw in text.split(): word = "".join(ch for ch in raw.lower() if ch.isalnum()) if len(word) < 3 or word in stopwords: @@ -1475,10 +1762,16 @@ def generate_section_cards(sec: dict, nodes: list, section_edges: list, lang: st else: file_rows = f'{escape(pick_text(lang, "无源文件映射", "No source file mapping"))}' - relation_counts = Counter(edge.get("relation", "relates") for edge in section_edges if should_include_edge(edge)) - relation_text = ", ".join(f"{relation_label(rel, lang)} x{count}" for rel, count in relation_counts.most_common(4)) + relation_counts = Counter( + edge.get("relation", "relates") for edge in section_edges if should_include_edge(edge) + ) + relation_text = ", ".join( + f"{relation_label(rel, lang)} x{count}" for rel, count in relation_counts.most_common(4) + ) if not relation_text: - relation_text = pick_text(lang, "未检测到高置信调用边", "No high-confidence call edges detected") + relation_text = pick_text( + lang, "未检测到高置信调用边", "No high-confidence call edges detected" + ) note = pick_text( lang, f"本节由 graphify 社区聚类生成。关系概况:{relation_text}。图表优先展示高置信、跨节点调用或使用关系,完整节点清单位于表格中。", @@ -1506,6 +1799,7 @@ def generate_section_cards(sec: dict, nodes: list, section_edges: list, lang: st # 8. Main entry point # ────────────────────────────────────────────── + class CallflowOptions: """Options for call-flow architecture HTML generation.""" @@ -1615,27 +1909,39 @@ def write_callflow_html( # Load data nodes, edges, hyperedges, meta = load_graph(paths["graph"]) - labels = load_labels(paths["labels"]) - lang = detect_lang(args.lang, nodes, labels) + loaded_labels = load_labels(paths["labels"]) + lang = detect_lang(args.lang, nodes, loaded_labels) if paths["sections"]: - sections = load_sections(paths["sections"]) + flow_sections = load_sections(paths["sections"]) else: - sections = derive_sections_from_communities(nodes, labels, lang, args.max_sections) - sections = normalize_sections(sections, lang) + flow_sections = derive_sections_from_communities( + nodes, loaded_labels, lang, args.max_sections + ) + flow_sections = normalize_sections(flow_sections, lang) report_text = load_report(paths["report"]) if not nodes: raise ValueError("graph.json contains 0 nodes") - if len(sections) <= 1: + if len(flow_sections) <= 1: raise ValueError("no sections defined") if verbose and len(nodes) >= 5000: - print("WARNING: Large graph -- Mermaid rendering may be slow. Consider --max-sections 5.", file=sys.stderr) + print( + "WARNING: Large graph -- Mermaid rendering may be slow. Consider --max-sections 5.", + file=sys.stderr, + ) node_ids = {node.get("id") for node in nodes} - missing_endpoint_edges = [edge for edge in edges if edge.get("source") not in node_ids or edge.get("target") not in node_ids] + missing_endpoint_edges = [ + edge + for edge in edges + if edge.get("source") not in node_ids or edge.get("target") not in node_ids + ] if verbose and missing_endpoint_edges: - print(f"WARNING: {len(missing_endpoint_edges)} edges reference nodes not present in graph.json.", file=sys.stderr) + print( + f"WARNING: {len(missing_endpoint_edges)} edges reference nodes not present in graph.json.", + file=sys.stderr, + ) meta["project_name"] = infer_project_name(str(paths["graph"]), meta) meta["node_count"] = len(nodes) @@ -1650,13 +1956,13 @@ def write_callflow_html( output_path = paths["graphify_out"] / f"{safe_filename(meta['project_name'])}-callflow.html" if verbose: - print(f"Loaded: {len(nodes)} nodes, {len(edges)} edges, {len(sections)} sections") + print(f"Loaded: {len(nodes)} nodes, {len(edges)} edges, {len(flow_sections)} sections") print(f"Graph: {paths['graph']}") # Build index comm_idx = build_community_index(nodes) meta["community_count"] = len(comm_idx) - section_nodes_map = build_section_node_map(sections, comm_idx) + section_nodes_map = build_section_node_map(flow_sections, comm_idx) classified = classify_edges(edges, section_nodes_map) # Build HTML @@ -1684,19 +1990,31 @@ def write_callflow_html( """) # Header + nav - html.append(generate_header(sections, meta, lang)) + html.append(generate_header(flow_sections, meta, lang)) # ── Architecture Overview (Section "overview") ── - overview_name = sections[0].get("name", "Architecture Overview") if sections else "Architecture Overview" + overview_name = ( + flow_sections[0].get("name", "Architecture Overview") + if flow_sections + else "Architecture Overview" + ) html.append(f"""

1. {escape(str(overview_name))}

""") - html.append(generate_overview_graph(sections, section_nodes_map, classified, labels, lang, args.diagram_scale)) + html.append( + generate_overview_graph( + flow_sections, section_nodes_map, classified, loaded_labels, lang, args.diagram_scale + ) + ) html.append("""
""") - html.append(generate_overview_cards(meta, report_text, sections, section_nodes_map, classified, lang)) + html.append( + generate_overview_cards( + meta, report_text, flow_sections, section_nodes_map, classified, lang + ) + ) report_card = _report_highlights(report_text, lang) if report_card: html.append(f'
\n {report_card}\n
') @@ -1704,7 +2022,7 @@ def write_callflow_html( # ── Per-section content ── section_num = 1 # overview was #1 - for sec in sections: + for sec in flow_sections: if sec["id"] == "overview": continue section_num += 1 @@ -1769,7 +2087,7 @@ def write_callflow_html( html.append("
\n
") # ── Section: Statistics ── - total_sections = sum(1 for s in sections if s["id"] != "overview") + total_sections = sum(1 for s in flow_sections if s["id"] != "overview") html.append(f"""

Project Statistics

@@ -1786,9 +2104,9 @@ def write_callflow_html(

Edge Confidence

- - - + + +
EXTRACTED{sum(1 for e in edges if e.get('confidence') == 'EXTRACTED')}
INFERRED{sum(1 for e in edges if e.get('confidence') == 'INFERRED')}
AMBIGUOUS{sum(1 for e in edges if e.get('confidence') == 'AMBIGUOUS')}
EXTRACTED{sum(1 for e in edges if e.get("confidence") == "EXTRACTED")}
INFERRED{sum(1 for e in edges if e.get("confidence") == "INFERRED")}
AMBIGUOUS{sum(1 for e in edges if e.get("confidence") == "AMBIGUOUS")}
@@ -1796,8 +2114,8 @@ def write_callflow_html( # ── Footer ── html.append(f"""
-

{escape(str(meta.get('project_name', 'Project')))} — Architecture Documentation

-

Generated: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')} · graphify callflow-html

+

{escape(str(meta.get("project_name", "Project")))} — Architecture Documentation

+

Generated: {datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")} · graphify callflow-html

""") @@ -1967,11 +2285,13 @@ def write_callflow_html( # Summary mermaid_count = output.count('
') table_count = output.count('') - section_count = output.count('

bool: _SHEBANG_CODE_INTERPRETERS = { - "python", "python3", "python2", - "ruby", "perl", "node", "nodejs", - "bash", "sh", "dash", "zsh", "fish", "ksh", "tcsh", - "lua", "php", "julia", "Rscript", + "python", + "python3", + "python2", + "ruby", + "perl", + "node", + "nodejs", + "bash", + "sh", + "dash", + "zsh", + "fish", + "ksh", + "tcsh", + "lua", + "php", + "julia", + "Rscript", } @@ -152,7 +257,7 @@ def _env_command_args(args: list[str], *, allow_split: bool = True) -> list[str] arg = args[i] if arg == "--": - return args[i + 1:] + return args[i + 1 :] # Split-string forms: tokenize the packed payload, then re-parse it # as env args (so leading assignments/flags inside the payload are @@ -162,36 +267,36 @@ def _env_command_args(args: list[str], *, allow_split: bool = True) -> list[str] if i + 1 >= len(args): return [] return _env_command_args( - _split_env_s(" ".join(args[i + 1:]), []), + _split_env_s(" ".join(args[i + 1 :]), []), allow_split=False, ) if arg.startswith("-S") and len(arg) > 2: return _env_command_args( - _split_env_s(arg[2:], args[i + 1:]), + _split_env_s(arg[2:], args[i + 1 :]), allow_split=False, ) if arg == "-vS": if i + 1 >= len(args): return [] return _env_command_args( - _split_env_s(" ".join(args[i + 1:]), []), + _split_env_s(" ".join(args[i + 1 :]), []), allow_split=False, ) if arg.startswith("-vS") and len(arg) > 3: return _env_command_args( - _split_env_s(arg[3:], args[i + 1:]), + _split_env_s(arg[3:], args[i + 1 :]), allow_split=False, ) if arg.startswith("--split-string="): return _env_command_args( - _split_env_s(arg.split("=", 1)[1], args[i + 1:]), + _split_env_s(arg.split("=", 1)[1], args[i + 1 :]), allow_split=False, ) if arg == "--split-string": if i + 1 >= len(args): return [] return _env_command_args( - _split_env_s(args[i + 1], args[i + 2:]), + _split_env_s(args[i + 1], args[i + 2 :]), allow_split=False, ) @@ -203,11 +308,7 @@ def _env_command_args(args: list[str], *, allow_split: bool = True) -> list[str] continue # Clumped short option + operand - if ( - arg.startswith(("-u", "-C", "-P", "-a")) - and len(arg) > 2 - and not arg.startswith("--") - ): + if arg.startswith(("-u", "-C", "-P", "-a")) and len(arg) > 2 and not arg.startswith("--"): i += 1 continue @@ -217,8 +318,16 @@ def _env_command_args(args: list[str], *, allow_split: bool = True) -> list[str] continue # No-operand flags - if arg in {"-", "-i", "-0", "-v", "--ignore-environment", "--null", - "--debug", "--list-signal-handling"}: + if arg in { + "-", + "-i", + "-0", + "-v", + "--ignore-environment", + "--null", + "--debug", + "--list-signal-handling", + }: i += 1 continue @@ -320,6 +429,7 @@ def extract_pdf_text(path: Path) -> str: """Extract plain text from a PDF file using pypdf.""" try: from pypdf import PdfReader + reader = PdfReader(str(path)) pages = [] for page in reader.pages: @@ -335,11 +445,11 @@ def docx_to_markdown(path: Path) -> str: """Convert a .docx file to markdown text using python-docx.""" try: from docx import Document - from docx.oxml.ns import qn + doc = Document(str(path)) lines = [] for para in doc.paragraphs: - style = para.style.name if para.style else "" + style = str(para.style.name or "") if para.style else "" text = para.text.strip() if not text: lines.append("") @@ -375,6 +485,7 @@ def xlsx_to_markdown(path: Path) -> str: """Convert an .xlsx file to markdown text using openpyxl.""" try: import openpyxl + wb = openpyxl.load_workbook(str(path), read_only=True, data_only=True) sections = [] for sheet_name in wb.sheetnames: @@ -407,6 +518,7 @@ def xlsx_extract_structure(path: Path) -> dict: Returns a nodes/edges dict compatible with the graphify extract pipeline. Used in addition to xlsx_to_markdown so Claude sees both structure and content. """ + def _nid(*parts: str) -> str: return re.sub(r"[^a-z0-9_]", "_", "_".join(p.lower() for p in parts).strip("_")) @@ -428,21 +540,43 @@ def _nid(*parts: str) -> str: stem = re.sub(r"[^a-z0-9]", "_", path.stem.lower()) str_path = str(path) file_nid = _nid(str_path) - nodes: list[dict] = [{"id": file_nid, "label": path.name, "file_type": "document", - "source_file": str_path, "source_location": None}] + nodes: list[dict] = [ + { + "id": file_nid, + "label": path.name, + "file_type": "document", + "source_file": str_path, + "source_location": None, + } + ] edges: list[dict] = [] seen: set[str] = {file_nid} def _add(nid: str, label: str) -> None: if nid not in seen: seen.add(nid) - nodes.append({"id": nid, "label": label, "file_type": "document", - "source_file": str_path, "source_location": None}) + nodes.append( + { + "id": nid, + "label": label, + "file_type": "document", + "source_file": str_path, + "source_location": None, + } + ) def _edge(src: str, tgt: str, relation: str) -> None: - edges.append({"source": src, "target": tgt, "relation": relation, - "confidence": "EXTRACTED", "source_file": str_path, - "source_location": None, "weight": 1.0}) + edges.append( + { + "source": src, + "target": tgt, + "relation": relation, + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": None, + "weight": 1.0, + } + ) for sheet_name in wb.sheetnames: ws = wb[sheet_name] @@ -461,18 +595,28 @@ def _edge(src: str, tgt: str, relation: str) -> None: if ref: try: from openpyxl.utils import range_boundaries + min_col, min_row, max_col, _ = range_boundaries(ref) - header_row = list(ws.iter_rows(min_row=min_row, max_row=min_row, - min_col=min_col, max_col=max_col, - values_only=True)) + header_row = list( + ws.iter_rows( + min_row=min_row, + max_row=min_row, + min_col=min_col, + max_col=max_col, + values_only=True, + ) + ) if header_row: for col_name in header_row[0]: if col_name: col_nid = _nid(stem, tbl.name, str(col_name)) _add(col_nid, str(col_name)) _edge(tbl_nid, col_nid, "contains") - except Exception: - pass + except Exception as exc: + print( + f"[graphify] warning: could not read spreadsheet table columns: {exc}", + file=sys.stderr, + ) else: # Fallback: first non-empty row as column headers for row in ws.iter_rows(max_row=1, values_only=True): @@ -485,8 +629,8 @@ def _edge(src: str, tgt: str, relation: str) -> None: try: wb.close() - except Exception: - pass + except Exception as exc: + print(f"[graphify] warning: could not close workbook {path}: {exc}", file=sys.stderr) return {"nodes": nodes, "edges": edges} @@ -511,6 +655,7 @@ def convert_office_file(path: Path, out_dir: Path) -> Path | None: out_dir.mkdir(parents=True, exist_ok=True) # Use a stable name derived from the original path to avoid collisions import hashlib + name_hash = hashlib.sha256(str(path.resolve()).encode()).hexdigest()[:8] out_path = out_dir / f"{path.stem}_{name_hash}.md" out_path.write_text( @@ -536,33 +681,64 @@ def count_words(path: Path) -> int: # Directory names to always skip - venvs, caches, build artifacts, deps _SKIP_DIRS = { - "venv", ".venv", "env", ".env", - "node_modules", "__pycache__", ".git", - "dist", "build", "target", "out", - "site-packages", "lib64", - ".pytest_cache", ".mypy_cache", ".ruff_cache", - ".tox", ".eggs", "*.egg-info", + "venv", + ".venv", + "env", + ".env", + "node_modules", + "__pycache__", + ".git", + "dist", + "build", + "target", + "out", + "site-packages", + "lib64", + ".pytest_cache", + ".mypy_cache", + ".ruff_cache", + ".tox", + ".eggs", + "*.egg-info", "graphify-out", # never treat own output as source input (#524) # Coverage/test-artefact dirs — generated, never architecturally meaningful - "coverage", "lcov-report", # Vitest/Istanbul/nyc HTML reports (#870) - "visual-tests", "visual-test", # Playwright/visual-regression bundles (#869) - "__snapshots__", "snapshots", # Jest/Vitest snapshot dirs - "storybook-static", # Storybook production build output - "dist-protected", # Protected dist variants (same noise as dist) + "coverage", + "lcov-report", # Vitest/Istanbul/nyc HTML reports (#870) + "visual-tests", + "visual-test", # Playwright/visual-regression bundles (#869) + "__snapshots__", + "snapshots", # Jest/Vitest snapshot dirs + "storybook-static", # Storybook production build output + "dist-protected", # Protected dist variants (same noise as dist) # Framework cache/build dirs — generated, never architecturally meaningful (#873) - ".next", ".nuxt", ".turbo", ".angular", - ".idea", ".cache", ".parcel-cache", ".svelte-kit", ".terraform", ".serverless", + ".next", + ".nuxt", + ".turbo", + ".angular", + ".idea", + ".cache", + ".parcel-cache", + ".svelte-kit", + ".terraform", + ".serverless", ".graphify", # graphify's own extraction cache — never index self-generated data ".worktrees", # git worktree convention (#947) — sibling checkouts, always redundant } # Large generated files that are never useful to extract _SKIP_FILES = { - "package-lock.json", "yarn.lock", "pnpm-lock.yaml", - "Cargo.lock", "poetry.lock", "Gemfile.lock", - "composer.lock", "go.sum", "go.work.sum", + "package-lock.json", + "yarn.lock", + "pnpm-lock.yaml", + "Cargo.lock", + "poetry.lock", + "Gemfile.lock", + "composer.lock", + "go.sum", + "go.work.sum", } + def _is_noise_dir(part: str, parent: "Path | None" = None) -> bool: """Return True if this directory name looks like a venv, cache, or dep dir.""" if part in _SKIP_DIRS: @@ -671,6 +847,7 @@ def _is_ignored(path: Path, root: Path, patterns: list[tuple[Path, str]]) -> boo def _eval(target: Path) -> bool: """Apply last-match-wins to a single target path.""" + def _matches(rel: str, p: str, anchored: bool) -> bool: if anchored: return fnmatch.fnmatch(rel, p) @@ -682,7 +859,7 @@ def _matches(rel: str, p: str, anchored: bool) -> bool: for i, part in enumerate(parts): if fnmatch.fnmatch(part, p): return True - if fnmatch.fnmatch("/".join(parts[:i + 1]), p): + if fnmatch.fnmatch("/".join(parts[: i + 1]), p): return True return False @@ -782,7 +959,7 @@ def _matches(rel: str, p: str, anchored: bool) -> bool: for i, part in enumerate(parts): if fnmatch.fnmatch(part, p): return True - if fnmatch.fnmatch("/".join(parts[:i + 1]), p): + if fnmatch.fnmatch("/".join(parts[: i + 1]), p): return True return False @@ -866,7 +1043,13 @@ def _auto_follow_symlinks(root: Path) -> bool: return False -def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: bool | None = None, extra_excludes: list[str] | None = None) -> dict: +def detect( + root: Path, + *, + follow_symlinks: bool | None = None, + google_workspace: bool | None = None, + extra_excludes: list[str] | None = None, +) -> dict: root = root.resolve() if follow_symlinks is None: follow_symlinks = _auto_follow_symlinks(root) @@ -889,8 +1072,6 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: line = _parse_gitignore_line(pat) if line: ignore_patterns.append((root, line)) - include_patterns = _load_graphifyinclude(root) - # Always include graphify-out/memory/ - query results filed back into the graph memory_dir = root / "graphify-out" / "memory" scan_paths = [root] @@ -918,7 +1099,8 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: # pruning so negated files inside can still be reached. has_negation = any(p.startswith("!") for _, p in ignore_patterns) dirnames[:] = [ - d for d in dirnames + d + for d in dirnames if not _is_noise_dir(d, dp) and (has_negation or not _is_ignored(dp / d, root, ignore_patterns)) ] @@ -949,13 +1131,14 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: if p.suffix.lower() in GOOGLE_WORKSPACE_EXTENSIONS: if not google_workspace: skipped_sensitive.append( - str(p) - + " [Google Workspace shortcut skipped - pass --google-workspace " + str(p) + " [Google Workspace shortcut skipped - pass --google-workspace " "or set GRAPHIFY_GOOGLE_WORKSPACE=1]" ) continue try: - md_path = convert_google_workspace_file(p, converted_dir, xlsx_to_markdown=xlsx_to_markdown) + md_path = convert_google_workspace_file( + p, converted_dir, xlsx_to_markdown=xlsx_to_markdown + ) except Exception as exc: skipped_sensitive.append(str(p) + f" [Google Workspace export failed: {exc}]") continue @@ -965,7 +1148,9 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: files[ftype].append(str(md_path)) total_words += count_words(md_path) else: - skipped_sensitive.append(str(p) + " [Google Workspace export produced no readable text]") + skipped_sensitive.append( + str(p) + " [Google Workspace export produced no readable text]" + ) continue # Office files: convert to markdown sidecar so subagents can read them if p.suffix.lower() in OFFICE_EXTENSIONS: @@ -977,7 +1162,9 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: total_words += count_words(md_path) else: # Conversion failed (library not installed) - skip with note - skipped_sensitive.append(str(p) + " [office conversion failed - pip install graphifyy[office]]") + skipped_sensitive.append( + str(p) + " [office conversion failed - pip install graphifyy[office]]" + ) continue files[ftype].append(str(p)) if ftype != FileType.VIDEO: @@ -1015,6 +1202,7 @@ def detect(root: Path, *, follow_symlinks: bool | None = None, google_workspace: def _md5_file(path: Path) -> str: """MD5 of file contents streamed in 64KB chunks — for change detection only.""" import hashlib as _hl + h = _hl.md5(usedforsecurity=False) try: with path.open("rb") as f: @@ -1092,7 +1280,9 @@ def _normalise_entry(entry): entry["semantic_hash"] = h else: # Preserve semantic_hash only when content is unchanged - entry["semantic_hash"] = prev.get("semantic_hash", "") if h == prev.get("ast_hash", "") else "" + entry["semantic_hash"] = ( + prev.get("semantic_hash", "") if h == prev.get("ast_hash", "") else "" + ) manifest[f] = entry Path(manifest_path).parent.mkdir(parents=True, exist_ok=True) Path(manifest_path).write_text(json.dumps(manifest, indent=2), encoding="utf-8") @@ -1130,7 +1320,12 @@ def detect_incremental( incremental runs. ``None`` (default) means auto-detect: ``True`` when ``root`` contains at least one direct symlinked child, ``False`` otherwise. """ - full = detect(root, follow_symlinks=follow_symlinks, google_workspace=google_workspace, extra_excludes=extra_excludes) + full = detect( + root, + follow_symlinks=follow_symlinks, + google_workspace=google_workspace, + extra_excludes=extra_excludes, + ) manifest = load_manifest(manifest_path) if not manifest: @@ -1158,7 +1353,11 @@ def detect_incremental( elif isinstance(stored, dict): # Normalise legacy {mtime, hash} to new schema if "hash" in stored and "ast_hash" not in stored: - stored = {"mtime": stored.get("mtime", 0), "ast_hash": stored["hash"], "semantic_hash": ""} + stored = { + "mtime": stored.get("mtime", 0), + "ast_hash": stored["hash"], + "semantic_hash": "", + } hash_key = "semantic_hash" if kind == "semantic" else "ast_hash" stored_hash = stored.get(hash_key, "") # Missing semantic_hash means update ran but extract hasn't — always re-extract diff --git a/graphify/export.py b/graphify/export.py index ff127c0b2..b8726092e 100644 --- a/graphify/export.py +++ b/graphify/export.py @@ -7,9 +7,11 @@ import os import re import shutil +import sys from collections import Counter from datetime import date from pathlib import Path +from typing import Any, cast import networkx as nx from networkx.readwrite import json_graph from graphify.security import sanitize_label @@ -54,13 +56,18 @@ def backup_if_protected(out_dir: Path) -> "Path | None": try: labels = json.loads(labels_file.read_text(encoding="utf-8")) is_curated = any(v != f"Community {k}" for k, v in labels.items()) - except Exception: - pass + except Exception as exc: + print( + f"[graphify] warning: could not read community labels for backup check: {exc}", + file=sys.stderr, + ) if not is_semantic and not is_curated: return None - reason = "+".join(filter(None, ["semantic" if is_semantic else "", "curated" if is_curated else ""])) + reason = "+".join( + filter(None, ["semantic" if is_semantic else "", "curated" if is_curated else ""]) + ) today = date.today().isoformat() backup_dir = out / today graph_src = out / "graph.json" @@ -83,16 +90,19 @@ def backup_if_protected(out_dir: Path) -> "Path | None": try: shutil.copy2(src, backup_dir / name) copied += 1 - except Exception: - pass + except Exception as exc: + print(f"[graphify] warning: could not back up {src}: {exc}", file=sys.stderr) if copied: print(f"[graphify] backed up {reason} graph ({copied} files) -> {backup_dir.name}/") return backup_dir except Exception as exc: - import sys - print(f"[graphify] warning: backup failed ({exc}) - continuing with overwrite", file=sys.stderr) + print( + f"[graphify] warning: backup failed ({exc}) - continuing with overwrite", + file=sys.stderr, + ) return None + def _obsidian_tag(name: str) -> str: """Sanitize a community name for use as an Obsidian tag. @@ -104,6 +114,7 @@ def _obsidian_tag(name: str) -> str: def _strip_diacritics(text: str) -> str: import unicodedata + nfkd = unicodedata.normalize("NFKD", text) return "".join(c for c in nfkd if not unicodedata.combining(c)) @@ -147,8 +158,16 @@ def _yaml_str(s: str) -> str: COMMUNITY_COLORS = [ - "#4E79A7", "#F28E2B", "#E15759", "#76B7B2", "#59A14F", - "#EDC948", "#B07AA1", "#FF9DA7", "#9C755F", "#BAB0AC", + "#4E79A7", + "#F28E2B", + "#E15759", + "#76B7B2", + "#59A14F", + "#EDC948", + "#B07AA1", + "#FF9DA7", + "#9C755F", + "#BAB0AC", ] MAX_NODES_FOR_VIZ = 5_000 @@ -161,6 +180,7 @@ def _viz_node_limit() -> int: Set to 0 to disable HTML viz unconditionally (useful for CI runners). """ import os + raw = os.environ.get("GRAPHIFY_VIZ_NODE_LIMIT") if raw is None or not raw.strip(): return MAX_NODES_FOR_VIZ @@ -472,35 +492,46 @@ def attach_hyperedges(G: nx.Graph, hyperedges: list) -> None: def _git_head() -> str | None: """Return the current git HEAD commit hash, or None if not in a git repo.""" import subprocess as _sp + try: - r = _sp.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=3) + r = _sp.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=3) # nosec B603 B607 return r.stdout.strip() if r.returncode == 0 else None except Exception: return None -def to_json(G: nx.Graph, communities: dict[int, list[str]], output_path: str, *, force: bool = False, built_at_commit: str | None = None) -> bool: +def to_json( + G: nx.Graph, + communities: dict[int, list[str]], + output_path: str, + *, + force: bool = False, + built_at_commit: str | None = None, +) -> bool: # Safety check: refuse to silently shrink an existing graph (#479) existing_path = Path(output_path) if not force and existing_path.exists(): try: from graphify.security import check_graph_file_size_cap + check_graph_file_size_cap(existing_path) existing_data = json.loads(existing_path.read_text(encoding="utf-8")) existing_n = len(existing_data.get("nodes", [])) new_n = G.number_of_nodes() if new_n < existing_n: - import sys as _sys print( f"[graphify] WARNING: new graph has {new_n} nodes but existing " f"graph.json has {existing_n}. Refusing to overwrite — you may be " f"missing chunk files from a previous session. " f"Pass force=True to override.", - file=_sys.stderr, + file=sys.stderr, ) return False - except Exception: - pass # unreadable existing file — proceed with write + except Exception as exc: + print( + f"[graphify] warning: could not inspect existing graph before write: {exc}", + file=sys.stderr, + ) node_community = _node_community_map(communities) try: @@ -541,8 +572,7 @@ def prune_dangling_edges(graph_data: dict) -> tuple[dict, int]: links_key = "links" if "links" in graph_data else "edges" before = len(graph_data[links_key]) graph_data[links_key] = [ - e for e in graph_data[links_key] - if e["source"] in node_ids and e["target"] in node_ids + e for e in graph_data[links_key] if e["source"] in node_ids and e["target"] in node_ids ] return graph_data, before - len(graph_data[links_key]) @@ -566,12 +596,7 @@ def _cypher_escape(s: str) -> str: """ # First normalise: drop NUL and other C0 control chars except tab. s = "".join(ch for ch in s if ch >= " " or ch == "\t") - return ( - s.replace("\\", "\\\\") - .replace("'", "\\'") - .replace("\n", "\\n") - .replace("\r", "\\r") - ) + return s.replace("\\", "\\\\").replace("'", "\\'").replace("\n", "\\n").replace("\r", "\\r") # Restrict identifier-position values (labels and relationship types are NOT @@ -645,8 +670,13 @@ def to_html( # Build aggregated community meta-graph from collections import Counter as _Counter import networkx as _nx - print(f"Graph has {G.number_of_nodes()} nodes (above {limit} limit). Building aggregated community view...") - node_to_community = {nid: cid for cid, members in communities.items() for nid in members} + + print( + f"Graph has {G.number_of_nodes()} nodes (above {limit} limit). Building aggregated community view..." + ) + node_to_community = { + nid: cid for cid, members in communities.items() for nid in members + } meta = _nx.Graph() for cid, members in communities.items(): meta.add_node(str(cid), label=(community_labels or {}).get(cid, f"Community {cid}")) @@ -656,8 +686,13 @@ def to_html( if cu is not None and cv is not None and cu != cv: edge_counts[(min(cu, cv), max(cu, cv))] += 1 for (cu, cv), w in edge_counts.items(): - meta.add_edge(str(cu), str(cv), weight=w, - relation=f"{w} cross-community edges", confidence="AGGREGATED") + meta.add_edge( + str(cu), + str(cv), + weight=w, + relation=f"{w} cross-community edges", + confidence="AGGREGATED", + ) if meta.number_of_nodes() <= 1: print("Single community - aggregated view not useful. Skipping graph.html.") return @@ -666,10 +701,11 @@ def to_html( # Remap hyperedges from semantic node IDs to community IDs raw_hyperedges = G.graph.get("hyperedges", []) if raw_hyperedges: - remapped = [] + remapped: list[dict[str, Any]] = [] for he in raw_hyperedges: he_members = he.get("nodes") or he.get("members") or [] - comm_ids, seen = [], set() + comm_ids: list[str] = [] + seen: set[str] = set() for nid in he_members: c = node_to_community.get(nid) if c is None: @@ -681,15 +717,24 @@ def to_html( comm_ids.append(s) if len(comm_ids) < 2: continue - remapped.append({ - "id": he.get("id", ""), - "label": he.get("label") or he.get("relation", "").replace("_", " "), - "nodes": comm_ids, - }) + remapped.append( + { + "id": he.get("id", ""), + "label": he.get("label") or he.get("relation", "").replace("_", " "), + "nodes": comm_ids, + } + ) meta.graph["hyperedges"] = remapped - to_html(meta, meta_communities, output_path, - community_labels=community_labels, member_counts=mc) - print(f"graph.html written (aggregated: {meta.number_of_nodes()} community nodes, {meta.number_of_edges()} cross-community edges)") + to_html( + meta, + meta_communities, + output_path, + community_labels=community_labels, + member_counts=mc, + ) + print( + f"graph.html written (aggregated: {meta.number_of_nodes()} community nodes, {meta.number_of_edges()} cross-community edges)" + ) print("Tip: run with --obsidian for full node-level detail.") return raise ValueError( @@ -718,19 +763,27 @@ def to_html( size = 10 + 30 * (deg / max_deg) # Only show label for high-degree nodes by default; others show on hover font_size = 12 if deg >= max_deg * 0.15 else 0 - vis_nodes.append({ - "id": node_id, - "label": label, - "color": {"background": color, "border": color, "highlight": {"background": "#ffffff", "border": color}}, - "size": round(size, 1), - "font": {"size": font_size, "color": "#ffffff"}, - "title": _html.escape(label), - "community": cid, - "community_name": sanitize_label((community_labels or {}).get(cid, f"Community {cid}")), - "source_file": sanitize_label(str(data.get("source_file") or "")), - "file_type": data.get("file_type", ""), - "degree": deg, - }) + vis_nodes.append( + { + "id": node_id, + "label": label, + "color": { + "background": color, + "border": color, + "highlight": {"background": "#ffffff", "border": color}, + }, + "size": round(size, 1), + "font": {"size": font_size, "color": "#ffffff"}, + "title": _html.escape(label), + "community": cid, + "community_name": sanitize_label( + (community_labels or {}).get(cid, f"Community {cid}") + ), + "source_file": sanitize_label(str(data.get("source_file") or "")), + "file_type": data.get("file_type", ""), + "degree": deg, + } + ) # Build edges list. Restore original edge direction from _src/_tgt # (stashed by build.py for exactly this reason): undirected NetworkX @@ -742,23 +795,29 @@ def to_html( relation = data.get("relation", "") true_src = data.get("_src", u) true_tgt = data.get("_tgt", v) - vis_edges.append({ - "from": true_src, - "to": true_tgt, - "label": relation, - "title": _html.escape(f"{relation} [{confidence}]"), - "dashes": confidence != "EXTRACTED", - "width": 2 if confidence == "EXTRACTED" else 1, - "color": {"opacity": 0.7 if confidence == "EXTRACTED" else 0.35}, - "confidence": confidence, - }) + vis_edges.append( + { + "from": true_src, + "to": true_tgt, + "label": relation, + "title": _html.escape(f"{relation} [{confidence}]"), + "dashes": confidence != "EXTRACTED", + "width": 2 if confidence == "EXTRACTED" else 1, + "color": {"opacity": 0.7 if confidence == "EXTRACTED" else 0.35}, + "confidence": confidence, + } + ) # Build community legend data legend_data = [] for cid in sorted((community_labels or {}).keys()): color = COMMUNITY_COLORS[cid % len(COMMUNITY_COLORS)] lbl = _html.escape(sanitize_label((community_labels or {}).get(cid, f"Community {cid}"))) - n = member_counts.get(cid, len(communities.get(cid, []))) if member_counts else len(communities.get(cid, [])) + n = ( + member_counts.get(cid, len(communities.get(cid, []))) + if member_counts + else len(communities.get(cid, [])) + ) legend_data.append({"cid": cid, "color": color, "label": lbl, "count": n}) # Escape sequences so embedded JSON cannot break out of the script tag @@ -837,7 +896,11 @@ def to_obsidian( # Map node_id → safe filename so wikilinks stay consistent. # Deduplicate: if two nodes produce the same filename, append a numeric suffix. def safe_name(label: str) -> str: - cleaned = re.sub(r'[\\/*?:"<>|#^[\]]', "", label.replace("\r\n", " ").replace("\r", " ").replace("\n", " ")).strip() + cleaned = re.sub( + r'[\\/*?:"<>|#^[\]]', + "", + label.replace("\r\n", " ").replace("\r", " ").replace("\n", " "), + ).strip() # Strip trailing .md/.mdx/.markdown so "CLAUDE.md" doesn't become "CLAUDE.md.md" cleaned = re.sub(r"\.(md|mdx|qmd|markdown)$", "", cleaned, flags=re.IGNORECASE) return cleaned or "unnamed" @@ -975,8 +1038,10 @@ def _community_reach(node_id: str) -> int: # Cohesion + member count summary if coh_value is not None: cohesion_desc = ( - "tightly connected" if coh_value >= 0.7 - else "moderately connected" if coh_value >= 0.4 + "tightly connected" + if coh_value >= 0.7 + else "moderately connected" + if coh_value >= 0.4 else "loosely connected" ) lines.append(f"**Cohesion:** {coh_value:.2f} - {cohesion_desc}") @@ -1019,7 +1084,9 @@ def _community_reach(node_id: str) -> int: else f"Community {other_cid}" ) other_safe = safe_name(other_name) - lines.append(f"- {edge_count} edge{'s' if edge_count != 1 else ''} to [[_COMMUNITY_{other_safe}]]") + lines.append( + f"- {edge_count} edge{'s' if edge_count != 1 else ''} to [[_COMMUNITY_{other_safe}]]" + ) lines.append("") # Top bridge nodes - highest degree nodes that connect to other communities @@ -1051,7 +1118,10 @@ def _community_reach(node_id: str) -> int: "colorGroups": [ { "query": f"tag:#community/{label.replace(' ', '_')}", - "color": {"a": 1, "rgb": int(COMMUNITY_COLORS[cid % len(COMMUNITY_COLORS)].lstrip('#'), 16)} + "color": { + "a": 1, + "rgb": int(COMMUNITY_COLORS[cid % len(COMMUNITY_COLORS)].lstrip("#"), 16), + }, } for cid, label in sorted((community_labels or {}).items()) ] @@ -1078,7 +1148,11 @@ def to_canvas( CANVAS_COLORS = ["1", "2", "3", "4", "5", "6"] # red, orange, yellow, green, cyan, purple def safe_name(label: str) -> str: - cleaned = re.sub(r'[\\/*?:"<>|#^[\]]', "", label.replace("\r\n", " ").replace("\r", " ").replace("\n", " ")).strip() + cleaned = re.sub( + r'[\\/*?:"<>|#^[\]]', + "", + label.replace("\r\n", " ").replace("\r", " ").replace("\n", " "), + ).strip() cleaned = re.sub(r"\.(md|mdx|qmd|markdown)$", "", cleaned, flags=re.IGNORECASE) return cleaned or "unnamed" @@ -1104,8 +1178,6 @@ def safe_name(label: str) -> str: # Lay out communities in a grid gap = 80 - group_x_offsets: list[int] = [] - group_y_offsets: list[int] = [] # Precompute group sizes so we can calculate offsets sorted_cids = sorted(communities.keys()) @@ -1168,16 +1240,18 @@ def safe_name(label: str) -> str: canvas_color = CANVAS_COLORS[idx % len(CANVAS_COLORS)] # Group node - canvas_nodes.append({ - "id": f"g{cid}", - "type": "group", - "label": community_name, - "x": gx, - "y": gy, - "width": gw, - "height": gh, - "color": canvas_color, - }) + canvas_nodes.append( + { + "id": f"g{cid}", + "type": "group", + "label": community_name, + "x": gx, + "y": gy, + "width": gw, + "height": gh, + "color": canvas_color, + } + ) # Node cards inside the group - rows of 3 sorted_members = sorted(members, key=lambda n: G.nodes[n].get("label", n)) @@ -1187,15 +1261,17 @@ def safe_name(label: str) -> str: nx_x = gx + 20 + col * (180 + 20) nx_y = gy + 80 + row * (60 + 20) fname = node_filenames.get(node_id, safe_name(G.nodes[node_id].get("label", node_id))) - canvas_nodes.append({ - "id": f"n_{node_id}", - "type": "file", - "file": f"{fname}.md", - "x": nx_x, - "y": nx_y, - "width": 180, - "height": 60, - }) + canvas_nodes.append( + { + "id": f"n_{node_id}", + "type": "file", + "file": f"{fname}.md", + "x": nx_x, + "y": nx_y, + "width": 180, + "height": 60, + } + ) # Generate edges - only between nodes both in canvas, cap at 200 highest-weight all_edges_weighted: list[tuple[float, str, str, str]] = [] @@ -1209,12 +1285,14 @@ def safe_name(label: str) -> str: all_edges_weighted.sort(key=lambda x: -x[0]) for weight, u, v, label in all_edges_weighted[:200]: - canvas_edges.append({ - "id": f"e_{u}_{v}", - "fromNode": f"n_{u}", - "toNode": f"n_{v}", - "label": label, - }) + canvas_edges.append( + { + "id": f"e_{u}_{v}", + "fromNode": f"n_{u}", + "toNode": f"n_{v}", + "label": label, + } + ) canvas_data = {"nodes": canvas_nodes, "edges": canvas_edges} Path(output_path).write_text(json.dumps(canvas_data, indent=2), encoding="utf-8") # nosec @@ -1237,14 +1315,15 @@ def push_to_neo4j( try: from neo4j import GraphDatabase except ImportError as e: - raise ImportError( - "neo4j driver not installed. Run: pip install neo4j" - ) from e + raise ImportError("neo4j driver not installed. Run: pip install neo4j") from e node_community = _node_community_map(communities) if communities else {} def _safe_rel(relation: str) -> str: - return re.sub(r"[^A-Z0-9_]", "_", relation.upper().replace(" ", "_").replace("-", "_")) or "RELATED_TO" + return ( + re.sub(r"[^A-Z0-9_]", "_", relation.upper().replace(" ", "_").replace("-", "_")) + or "RELATED_TO" + ) def _safe_label(label: str) -> str: """Sanitize a Neo4j node label to prevent Cypher injection.""" @@ -1256,6 +1335,7 @@ def _safe_label(label: str) -> str: edges_pushed = 0 with driver.session() as session: + session_any = cast(Any, session) for node_id, data in G.nodes(data=True): props = {k: v for k, v in data.items() if isinstance(v, (str, int, float, bool))} props["id"] = node_id @@ -1263,7 +1343,7 @@ def _safe_label(label: str) -> str: if cid is not None: props["community"] = cid ftype = _safe_label(data.get("file_type", "Entity").capitalize()) - session.run( + session_any.run( f"MERGE (n:{ftype} {{id: $id}}) SET n += $props", id=node_id, props=props, @@ -1273,7 +1353,7 @@ def _safe_label(label: str) -> str: for u, v, data in G.edges(data=True): rel = _safe_rel(data.get("relation", "RELATED_TO")) props = {k: v for k, v in data.items() if isinstance(v, (str, int, float, bool))} - session.run( + session_any.run( f"MATCH (a {{id: $src}}), (b {{id: $tgt}}) " f"MERGE (a)-[r:{rel}]->(b) SET r += $props", src=u, @@ -1319,6 +1399,7 @@ def to_svg( """ try: import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches @@ -1336,7 +1417,9 @@ def to_svg( degree = dict(G.degree()) max_deg = max(degree.values(), default=1) or 1 - node_colors = [COMMUNITY_COLORS[node_community.get(n, 0) % len(COMMUNITY_COLORS)] for n in G.nodes()] + node_colors = [ + COMMUNITY_COLORS[node_community.get(n, 0) % len(COMMUNITY_COLORS)] for n in G.nodes() + ] node_sizes = [300 + 1200 * (degree.get(n, 1) / max_deg) for n in G.nodes()] # Draw edges - dashed for non-EXTRACTED @@ -1346,14 +1429,25 @@ def to_svg( alpha = 0.6 if conf == "EXTRACTED" else 0.3 x0, y0 = pos[u] x1, y1 = pos[v] - ax.plot([x0, x1], [y0, y1], color="#aaaaaa", linewidth=0.8, - linestyle=style, alpha=alpha, zorder=1) + ax.plot( + [x0, x1], + [y0, y1], + color="#aaaaaa", + linewidth=0.8, + linestyle=style, + alpha=alpha, + zorder=1, + ) - nx.draw_networkx_nodes(G, pos, ax=ax, node_color=node_colors, - node_size=node_sizes, alpha=0.9) - nx.draw_networkx_labels(G, pos, ax=ax, - labels={n: G.nodes[n].get("label", n) for n in G.nodes()}, - font_size=7, font_color="white") + nx.draw_networkx_nodes(G, pos, ax=ax, node_color=node_colors, node_size=node_sizes, alpha=0.9) + nx.draw_networkx_labels( + G, + pos, + ax=ax, + labels={n: G.nodes[n].get("label", n) for n in G.nodes()}, + font_size=7, + font_color="white", + ) # Legend if community_labels: @@ -1364,10 +1458,15 @@ def to_svg( ) for cid, label in sorted(community_labels.items()) ] - ax.legend(handles=patches, loc="upper left", framealpha=0.7, - facecolor="#2a2a4e", labelcolor="white", fontsize=8) + ax.legend( + handles=patches, + loc="upper left", + framealpha=0.7, + facecolor="#2a2a4e", + labelcolor="white", + fontsize=8, + ) plt.tight_layout() - plt.savefig(output_path, format="svg", bbox_inches="tight", - facecolor=fig.get_facecolor()) + plt.savefig(output_path, format="svg", bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close(fig) diff --git a/graphify/extract.py b/graphify/extract.py index 1d396a697..3a7222b66 100644 --- a/graphify/extract.py +++ b/graphify/extract.py @@ -1,44 +1,123 @@ """Deterministic structural extraction from source code using tree-sitter. Outputs nodes+edges dicts.""" + from __future__ import annotations import importlib import json +import logging import os import re import sys import unicodedata from dataclasses import dataclass, field from pathlib import Path -from typing import Callable, Any +from collections.abc import Iterator +from typing import Any, Callable from .cache import load_cached, save_cached from .mcp_ingest import extract_mcp_config, is_mcp_config_path _RECURSION_LIMIT = 10_000 +_LOG = logging.getLogger(__name__) # Language built-in globals that AST may classify as call targets when used as # constructors or coercion functions (e.g. String(x), Number(x), Boolean(x)). # Without this filter they become god-nodes accumulating spurious edges from # every call site. Filter applied at same-file and cross-file resolution. # See issue #726. -_LANGUAGE_BUILTIN_GLOBALS: frozenset[str] = frozenset({ - # JavaScript / TypeScript ECMAScript built-ins - "String", "Number", "Boolean", "Object", "Array", "Symbol", "BigInt", - "Date", "RegExp", "Error", "TypeError", "RangeError", "SyntaxError", - "ReferenceError", "EvalError", "URIError", - "Promise", "Map", "Set", "WeakMap", "WeakSet", "JSON", "Math", - "Reflect", "Proxy", "Intl", - "parseInt", "parseFloat", "isNaN", "isFinite", - "encodeURIComponent", "decodeURIComponent", "encodeURI", "decodeURI", - # Browser / Node common globals - "URL", "URLSearchParams", "FormData", "Blob", "File", - "Headers", "Request", "Response", "AbortController", "AbortSignal", - "TextEncoder", "TextDecoder", "console", - # Python built-in callables - "str", "int", "float", "bool", "list", "dict", "set", "tuple", "bytes", - "len", "range", "enumerate", "zip", "map", "filter", "sum", "min", "max", - "print", "open", "isinstance", "type", "super", "sorted", "reversed", - "any", "all", "abs", "round", "next", "iter", "hash", "id", "repr", - "callable", "getattr", "setattr", "hasattr", "delattr", "vars", "dir", -}) +_LANGUAGE_BUILTIN_GLOBALS: frozenset[str] = frozenset( + { + # JavaScript / TypeScript ECMAScript built-ins + "String", + "Number", + "Boolean", + "Object", + "Array", + "Symbol", + "BigInt", + "Date", + "RegExp", + "Error", + "TypeError", + "RangeError", + "SyntaxError", + "ReferenceError", + "EvalError", + "URIError", + "Promise", + "Map", + "Set", + "WeakMap", + "WeakSet", + "JSON", + "Math", + "Reflect", + "Proxy", + "Intl", + "parseInt", + "parseFloat", + "isNaN", + "isFinite", + "encodeURIComponent", + "decodeURIComponent", + "encodeURI", + "decodeURI", + # Browser / Node common globals + "URL", + "URLSearchParams", + "FormData", + "Blob", + "File", + "Headers", + "Request", + "Response", + "AbortController", + "AbortSignal", + "TextEncoder", + "TextDecoder", + "console", + # Python built-in callables + "str", + "int", + "float", + "bool", + "list", + "dict", + "set", + "tuple", + "bytes", + "len", + "range", + "enumerate", + "zip", + "map", + "filter", + "sum", + "min", + "max", + "print", + "open", + "isinstance", + "type", + "super", + "sorted", + "reversed", + "any", + "all", + "abs", + "round", + "next", + "iter", + "hash", + "id", + "repr", + "callable", + "getattr", + "setattr", + "hasattr", + "delattr", + "vars", + "dir", + } +) def _raise_recursion_limit() -> None: @@ -92,14 +171,33 @@ def _file_stem(path: Path) -> str: _JS_INDEX_FILES = ("index.ts", "index.tsx", "index.svelte", "index.js", "index.jsx", "index.mjs") -SEMANTIC_RELATIONS = frozenset({ - "inherits", "implements", "mixes_in", "embeds", "references", - "calls", "imports", "imports_from", "re_exports", "contains", "method", -}) +SEMANTIC_RELATIONS = frozenset( + { + "inherits", + "implements", + "mixes_in", + "embeds", + "references", + "calls", + "imports", + "imports_from", + "re_exports", + "contains", + "method", + } +) -REFERENCE_CONTEXTS = frozenset({ - "field", "parameter_type", "return_type", "generic_arg", "attribute", "value", "type", -}) +REFERENCE_CONTEXTS = frozenset( + { + "field", + "parameter_type", + "return_type", + "generic_arg", + "attribute", + "value", + "type", + } +) def _source_location(line: int | str | None) -> str | None: @@ -173,9 +271,9 @@ def _strip_jsonc(text: str) -> str: """ # Remove block and line comments while leaving string literals untouched. pattern = re.compile( - r'"(?:\\.|[^"\\])*"' # double-quoted string (with escapes) - r"|/\*.*?\*/" # /* block comment */ - r"|//[^\n]*", # // line comment + r'"(?:\\.|[^"\\])*"' # double-quoted string (with escapes) + r"|/\*.*?\*/" # /* block comment */ + r"|//[^\n]*", # // line comment re.DOTALL, ) @@ -205,7 +303,11 @@ def _read_tsconfig_aliases(tsconfig: Path, base_dir: Path, seen: set) -> dict[st try: raw = tsconfig.read_text(encoding="utf-8") except Exception as e: - print(f" warning: could not read {tsconfig} ({type(e).__name__}: {e})", file=sys.stderr, flush=True) + print( + f" warning: could not read {tsconfig} ({type(e).__name__}: {e})", + file=sys.stderr, + flush=True, + ) return {} try: data = json.loads(raw) @@ -213,10 +315,18 @@ def _read_tsconfig_aliases(tsconfig: Path, base_dir: Path, seen: set) -> dict[st try: data = json.loads(_strip_jsonc(raw)) except json.JSONDecodeError as e: - print(f" warning: failed to parse {tsconfig} as JSON/JSONC ({e.msg} at line {e.lineno} col {e.colno})", file=sys.stderr, flush=True) + print( + f" warning: failed to parse {tsconfig} as JSON/JSONC ({e.msg} at line {e.lineno} col {e.colno})", + file=sys.stderr, + flush=True, + ) return {} except Exception as e: - print(f" warning: failed to parse {tsconfig} ({type(e).__name__}: {e})", file=sys.stderr, flush=True) + print( + f" warning: failed to parse {tsconfig} ({type(e).__name__}: {e})", + file=sys.stderr, + flush=True, + ) return {} aliases: dict[str, str] = {} @@ -317,7 +427,8 @@ def _load_workspace_packages(start_dir: Path) -> dict[str, Path]: continue try: data = json.loads(manifest.read_text(encoding="utf-8")) - except Exception: + except Exception as exc: + _LOG.debug("could not read package manifest %s: %s", manifest, exc) continue name = data.get("name") if isinstance(name, str) and name: @@ -331,8 +442,8 @@ def _package_entry_candidates(package_dir: Path, subpath: str) -> list[Path]: manifest_data: dict[str, Any] = {} try: manifest_data = json.loads(manifest.read_text(encoding="utf-8")) - except Exception: - pass + except Exception as exc: + _LOG.debug("could not read package manifest %s: %s", manifest, exc) if subpath: return [package_dir / subpath] @@ -366,7 +477,7 @@ def _resolve_workspace_import(raw: str, start_dir: Path) -> Path | None: if raw == package_name: subpath = "" elif raw.startswith(package_name + "/"): - subpath = raw[len(package_name) + 1:] + subpath = raw[len(package_name) + 1 :] else: continue for candidate in _package_entry_candidates(package_dir, subpath): @@ -394,7 +505,7 @@ def _resolve_js_module_path(raw: str | Path, start_dir: Path | None = None) -> P aliases = _load_tsconfig_aliases(start_dir) for alias_prefix, alias_base in aliases.items(): if raw == alias_prefix or raw.startswith(alias_prefix + "/"): - rest = raw[len(alias_prefix):].lstrip("/") + rest = raw[len(alias_prefix) :].lstrip("/") return _resolve_js_import_path(Path(os.path.normpath(Path(alias_base) / rest))) return _resolve_workspace_import(raw, start_dir) @@ -402,10 +513,11 @@ def _resolve_js_module_path(raw: str | Path, start_dir: Path | None = None) -> P # ── LanguageConfig dataclass ───────────────────────────────────────────────── + @dataclass class LanguageConfig: - ts_module: str # e.g. "tree_sitter_python" - ts_language_fn: str = "language" # attr to call: e.g. tslang.language() + ts_module: str # e.g. "tree_sitter_python" + ts_language_fn: str = "language" # attr to call: e.g. tslang.language() class_types: frozenset = frozenset() function_types: frozenset = frozenset() @@ -422,12 +534,12 @@ class LanguageConfig: # Body detection body_field: str = "body" - body_fallback_child_types: tuple = () # e.g. ("declaration_list", "compound_statement") + body_fallback_child_types: tuple = () # e.g. ("declaration_list", "compound_statement") # Call name extraction - call_function_field: str = "function" # field on call node for callee + call_function_field: str = "function" # field on call node for callee call_accessor_node_types: frozenset = frozenset() # member/attribute nodes - call_accessor_field: str = "attribute" # field on accessor for method name + call_accessor_field: str = "attribute" # field on accessor for method name # Stop recursion at these types in walk_calls function_boundary_types: frozenset = frozenset() @@ -447,22 +559,57 @@ class LanguageConfig: # ── Generic helpers ─────────────────────────────────────────────────────────── -def _read_text(node, source: bytes) -> str: - return source[node.start_byte:node.end_byte].decode("utf-8", errors="replace") - -_PYTHON_TYPE_CONTAINERS = frozenset({ - "list", "dict", "set", "tuple", "frozenset", "type", - "List", "Dict", "Set", "Tuple", "FrozenSet", "Type", - "Optional", "Union", "Sequence", "Iterable", "Mapping", "MutableMapping", - "Iterator", "Callable", "Awaitable", "AsyncIterable", "AsyncIterator", "Coroutine", - "Generator", "AsyncGenerator", "ContextManager", "AsyncContextManager", - "Annotated", "ClassVar", "Final", "Literal", "Concatenate", "ParamSpec", "TypeVar", - "None", "Ellipsis", -}) +def _read_text(node, source: bytes) -> str: + return source[node.start_byte : node.end_byte].decode("utf-8", errors="replace") + + +_PYTHON_TYPE_CONTAINERS = frozenset( + { + "list", + "dict", + "set", + "tuple", + "frozenset", + "type", + "List", + "Dict", + "Set", + "Tuple", + "FrozenSet", + "Type", + "Optional", + "Union", + "Sequence", + "Iterable", + "Mapping", + "MutableMapping", + "Iterator", + "Callable", + "Awaitable", + "AsyncIterable", + "AsyncIterator", + "Coroutine", + "Generator", + "AsyncGenerator", + "ContextManager", + "AsyncContextManager", + "Annotated", + "ClassVar", + "Final", + "Literal", + "Concatenate", + "ParamSpec", + "TypeVar", + "None", + "Ellipsis", + } +) -def _python_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str, str]]) -> None: +def _python_collect_type_refs( + node, source: bytes, generic: bool, out: list[tuple[str, str]] +) -> None: """Walk a Python type annotation; append (name, role) where role is 'type' or 'generic_arg'. Builtin/typing containers (list, dict, Optional, Union, …) are not emitted as refs themselves, @@ -537,7 +684,9 @@ def _csharp_classify_base(name: str, interface_names: set[str]) -> str: return "inherits" -def _csharp_collect_type_refs(node, source: bytes, generic: bool, out: list[tuple[str, str]]) -> None: +def _csharp_collect_type_refs( + node, source: bytes, generic: bool, out: list[tuple[str, str]] +) -> None: """Walk a C# type expression; append (name, role) tuples (role is 'type' or 'generic_arg').""" if node is None: return @@ -1160,7 +1309,10 @@ def _find_body(node, config: LanguageConfig): # ── Import handlers ─────────────────────────────────────────────────────────── -def _import_python(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + +def _import_python( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: t = node.type if t == "import_statement": for child in node.children: @@ -1168,16 +1320,18 @@ def _import_python(node, source: bytes, file_nid: str, stem: str, edges: list, s raw = _read_text(child, source) module_name = raw.split(" as ")[0].strip().lstrip(".") tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) elif t == "import_from_statement": module_node = node.child_by_field_name("module_name") if module_node: @@ -1193,16 +1347,18 @@ def _import_python(node, source: bytes, file_nid: str, stem: str, edges: list, s tgt_nid = _make_id(str(base / rel)) else: tgt_nid = _make_id(raw) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports_from", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports_from", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) def _resolve_js_import_target(raw: str, str_path: str) -> "tuple[str, Path | None] | None": @@ -1228,7 +1384,11 @@ def _import_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_p # Only handle export_statement if it has a `from` clause (re-export). # Pure exports like `export const x = 1` or `export { localVar }` have no source module. if is_reexport: - has_from = any(child.type == "from" or (_read_text(child, source) == "from") for child in node.children if child.type in ("from", "identifier")) + has_from = any( + child.type == "from" or (_read_text(child, source) == "from") + for child in node.children + if child.type in ("from", "identifier") + ) if not has_from: # Check for string child (source path) as a more reliable indicator has_from = any(child.type == "string" for child in node.children) @@ -1243,16 +1403,18 @@ def _import_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_p if resolved is None: break tgt_nid, resolved_path = resolved - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports_from", - "context": "re-export" if is_reexport else "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports_from", + "context": "re-export" if is_reexport else "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break # Emit symbol-level edges for named imports/re-exports from local/aliased files. @@ -1277,16 +1439,18 @@ def _import_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_p sym = _read_text(name_node, source) if sym == "default": continue # skip default re-exports for ID matching - edges.append({ - "source": file_nid, - "target": _make_id(target_stem, sym), - "relation": "re_exports", - "context": "re-export", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": _make_id(target_stem, sym), + "relation": "re_exports", + "context": "re-export", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) else: # Handle: import { Foo, type Bar } from './bar' for child in node.children: @@ -1298,20 +1462,23 @@ def _import_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_p name_node = spec.child_by_field_name("name") if name_node: sym = _read_text(name_node, source) - edges.append({ - "source": file_nid, - "target": _make_id(target_stem, sym), - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) - - -def _dynamic_import_js(node, source: bytes, caller_nid: str, str_path: str, edges: list, - seen_dyn_pairs: set) -> bool: + edges.append( + { + "source": file_nid, + "target": _make_id(target_stem, sym), + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) + + +def _dynamic_import_js( + node, source: bytes, caller_nid: str, str_path: str, edges: list, seen_dyn_pairs: set +) -> bool: """Detect dynamic import() calls in JS/TS and emit imports_from edges. Handles patterns like: @@ -1358,16 +1525,18 @@ def _dynamic_import_js(node, source: bytes, caller_nid: str, str_path: str, edge pair = (caller_nid, tgt_nid) if pair not in seen_dyn_pairs: seen_dyn_pairs.add(pair) - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "imports_from", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": "imports_from", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break return True @@ -1398,16 +1567,18 @@ def _walk_scoped(n) -> str: ) if module_name: tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break @@ -1435,7 +1606,24 @@ def _import_c(node, source: bytes, file_nid: str, stem: str, edges: list, str_pa resolved = _resolve_c_include_path(raw, str_path) if resolved is not None: tgt_nid = _make_id(str(resolved)) - edges.append({ + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) + break + module_name = raw.split("/")[-1].split(".")[0] + if module_name: + tgt_nid = _make_id(module_name) + edges.append( + { "source": file_nid, "target": tgt_nid, "relation": "imports", @@ -1444,97 +1632,98 @@ def _import_c(node, source: bytes, file_nid: str, stem: str, edges: list, str_pa "source_file": str_path, "source_location": f"L{node.start_point[0] + 1}", "weight": 1.0, - }) - break - module_name = raw.split("/")[-1].split(".")[0] - if module_name: - tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + } + ) break -def _import_csharp(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: +def _import_csharp( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: for child in node.children: if child.type in ("qualified_name", "identifier", "name_equals"): raw = _read_text(child, source) module_name = raw.split(".")[-1].strip() if module_name: tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break -def _import_kotlin(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: +def _import_kotlin( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: path_node = node.child_by_field_name("path") if path_node: raw = _read_text(path_node, source) module_name = raw.split(".")[-1].strip() if module_name: tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) return # Fallback: find identifier child for child in node.children: if child.type == "identifier": raw = _read_text(child, source) tgt_nid = _make_id(raw) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break -def _import_scala(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: +def _import_scala( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: for child in node.children: if child.type in ("stable_id", "identifier"): raw = _read_text(child, source) module_name = raw.split(".")[-1].strip("{} ") if module_name and module_name != "_": tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break @@ -1545,21 +1734,24 @@ def _import_php(node, source: bytes, file_nid: str, stem: str, edges: list, str_ module_name = raw.split("\\")[-1].strip() if module_name: tgt_nid = _make_id(module_name) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break # ── C/C++ function name helpers ─────────────────────────────────────────────── + def _get_c_func_name(node, source: bytes) -> str | None: """Recursively unwrap declarator to find the innermost identifier (C).""" if node.type == "identifier": @@ -1594,6 +1786,7 @@ def _get_cpp_func_name(node, source: bytes) -> str | None: # ── JS/TS extra walk for arrow functions ────────────────────────────────────── + def _find_require_call(value_node): """Return the call_expression node if `value_node` is a `require(...)` call or `require(...).x` member access. Otherwise None.""" @@ -1609,7 +1802,9 @@ def _find_require_call(value_node): return None -def _require_imports_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> bool: +def _require_imports_js( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> bool: """Detect CommonJS require imports inside lexical_declaration / variable_declaration. Handles three patterns: @@ -1647,16 +1842,18 @@ def _require_imports_js(node, source: bytes, file_nid: str, stem: str, edges: li continue tgt_nid, resolved_path = resolved line = node.start_point[0] + 1 - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports_from", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports_from", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) found = True # Symbol-level edges for destructured / accessor binders. @@ -1679,22 +1876,35 @@ def _require_imports_js(node, source: bytes, file_nid: str, stem: str, edges: li sym_names.append(_read_text(prop, source)) if target_stem is not None: for sym in sym_names: - edges.append({ - "source": file_nid, - "target": _make_id(target_stem, sym), - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": _make_id(target_stem, sym), + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) return found -def _js_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, - nodes: list, edges: list, seen_ids: set, function_bodies: list, - parent_class_nid: str | None, add_node_fn, add_edge_fn) -> bool: +def _js_extra_walk( + node, + source: bytes, + file_nid: str, + stem: str, + str_path: str, + nodes: list, + edges: list, + seen_ids: set, + function_bodies: list, + parent_class_nid: str | None, + add_node_fn, + add_edge_fn, +) -> bool: """Handle lexical_declaration (arrow functions, CJS requires, module-level const literals) for JS/TS. Returns True if handled.""" if node.type in ("lexical_declaration", "variable_declaration"): # CJS require imports — emit edges, do not block other lexical_declaration handling @@ -1734,7 +1944,11 @@ def _js_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, function_bodies.append((func_nid, body)) arrow_found = True elif value and value.type in ( - "object", "array", "as_expression", "call_expression", "new_expression", + "object", + "array", + "as_expression", + "call_expression", + "new_expression", ): # Module-level const with literal/object/array/factory value name_node = child.child_by_field_name("name") @@ -1756,10 +1970,22 @@ def _js_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, # ── C# extra walk for namespace declarations ────────────────────────────────── -def _csharp_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, - nodes: list, edges: list, seen_ids: set, function_bodies: list, - parent_class_nid: str | None, add_node_fn, add_edge_fn, - walk_fn) -> bool: + +def _csharp_extra_walk( + node, + source: bytes, + file_nid: str, + stem: str, + str_path: str, + nodes: list, + edges: list, + seen_ids: set, + function_bodies: list, + parent_class_nid: str | None, + add_node_fn, + add_edge_fn, + walk_fn, +) -> bool: """Handle namespace_declaration for C#. Returns True if handled.""" if node.type == "namespace_declaration": name_node = node.child_by_field_name("name") @@ -1779,9 +2005,21 @@ def _csharp_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: # ── Swift extra walk for enum cases ────────────────────────────────────────── -def _swift_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, - nodes: list, edges: list, seen_ids: set, function_bodies: list, - parent_class_nid: str | None, add_node_fn, add_edge_fn) -> bool: + +def _swift_extra_walk( + node, + source: bytes, + file_nid: str, + stem: str, + str_path: str, + nodes: list, + edges: list, + seen_ids: set, + function_bodies: list, + parent_class_nid: str | None, + add_node_fn, + add_edge_fn, +) -> bool: """Handle enum_entry for Swift. Returns True if handled.""" if node.type == "enum_entry" and parent_class_nid: for child in node.children: @@ -1819,27 +2057,33 @@ def _swift_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: s call_function_field="function", call_accessor_node_types=frozenset({"member_expression"}), call_accessor_field="property", - function_boundary_types=frozenset({"function_declaration", "arrow_function", "method_definition"}), + function_boundary_types=frozenset( + {"function_declaration", "arrow_function", "method_definition"} + ), import_handler=_import_js, ) _TS_CONFIG = LanguageConfig( ts_module="tree_sitter_typescript", ts_language_fn="language_typescript", - class_types=frozenset({ - "class_declaration", - "abstract_class_declaration", # TS abstract class - "interface_declaration", # parity with Java/C# - "enum_declaration", # named enums - "type_alias_declaration", # named type aliases - }), + class_types=frozenset( + { + "class_declaration", + "abstract_class_declaration", # TS abstract class + "interface_declaration", # parity with Java/C# + "enum_declaration", # named enums + "type_alias_declaration", # named type aliases + } + ), function_types=frozenset({"function_declaration", "method_definition"}), import_types=frozenset({"import_statement", "export_statement"}), call_types=frozenset({"call_expression", "new_expression"}), call_function_field="function", call_accessor_node_types=frozenset({"member_expression"}), call_accessor_field="property", - function_boundary_types=frozenset({"function_declaration", "arrow_function", "method_definition"}), + function_boundary_types=frozenset( + {"function_declaration", "arrow_function", "method_definition"} + ), import_handler=_import_js, ) @@ -1980,7 +2224,14 @@ def _swift_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: s class_types=frozenset({"class_declaration"}), function_types=frozenset({"function_definition", "method_declaration"}), import_types=frozenset({"namespace_use_clause"}), - call_types=frozenset({"function_call_expression", "member_call_expression", "scoped_call_expression", "class_constant_access_expression"}), + call_types=frozenset( + { + "function_call_expression", + "member_call_expression", + "scoped_call_expression", + "class_constant_access_expression", + } + ), static_prop_types=frozenset({"scoped_property_access_expression"}), helper_fn_names=frozenset({"config"}), container_bind_methods=frozenset({"bind", "singleton", "scoped", "instance"}), @@ -2037,6 +2288,7 @@ def _import_lua(node, source: bytes, file_nid: str, stem: str, edges: list, str_ """Extract require('module') from Lua variable_declaration nodes.""" text = _read_text(node, source) import re + m = re.search(r"""require\s*[\('"]\s*['"]?([^'")\s]+)""", text) if m: raw_module = m.group(1) @@ -2073,21 +2325,25 @@ def _import_lua(node, source: bytes, file_nid: str, stem: str, edges: list, str_ ) -def _import_swift(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: +def _import_swift( + node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str +) -> None: for child in node.children: if child.type == "identifier": raw = _read_text(child, source) tgt_nid = _make_id(raw) - edges.append({ - "source": file_nid, - "target": tgt_nid, - "relation": "imports", - "context": "import", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - "weight": 1.0, - }) + edges.append( + { + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "context": "import", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + } + ) break @@ -2115,7 +2371,9 @@ def _read_csharp_type_name(node, source: bytes) -> str | None: _SWIFT_CONFIG = LanguageConfig( ts_module="tree_sitter_swift", class_types=frozenset({"class_declaration", "protocol_declaration"}), - function_types=frozenset({"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"}), + function_types=frozenset( + {"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"} + ), import_types=frozenset({"import_declaration"}), call_types=frozenset({"call_expression"}), call_function_field="", @@ -2123,23 +2381,31 @@ def _read_csharp_type_name(node, source: bytes) -> str | None: call_accessor_field="", name_fallback_child_types=("simple_identifier", "type_identifier", "user_type"), body_fallback_child_types=("class_body", "protocol_body", "function_body", "enum_class_body"), - function_boundary_types=frozenset({"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"}), + function_boundary_types=frozenset( + {"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"} + ), import_handler=_import_swift, ) # ── Generic extractor ───────────────────────────────────────────────────────── + def _extract_generic(path: Path, config: LanguageConfig) -> dict: """Generic AST extractor driven by LanguageConfig.""" try: mod = importlib.import_module(config.ts_module) from tree_sitter import Language, Parser + lang_fn = getattr(mod, config.ts_language_fn, None) if lang_fn is None: # Fallback for PHP: try "language_php" then "language" lang_fn = getattr(mod, "language", None) if lang_fn is None: - return {"nodes": [], "edges": [], "error": f"No language function in {config.ts_module}"} + return { + "nodes": [], + "edges": [], + "error": f"No language function in {config.ts_module}", + } language = Language(lang_fn()) except ImportError: return {"nodes": [], "edges": [], "error": f"{config.ts_module} not installed"} @@ -2188,17 +2454,25 @@ def _extract_generic(path: Path, config: LanguageConfig) -> dict: def add_node(nid: str, label: str, line: int) -> None: if nid not in seen_ids: seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) + nodes.append( + { + "id": nid, + "label": label, + "file_type": "code", + "source_file": str_path, + "source_location": f"L{line}", + } + ) - def add_edge(src: str, tgt: str, relation: str, line: int, - confidence: str = "EXTRACTED", weight: float = 1.0, - context: str | None = None) -> None: + def add_edge( + src: str, + tgt: str, + relation: str, + line: int, + confidence: str = "EXTRACTED", + weight: float = 1.0, + context: str | None = None, + ) -> None: edge = { "source": src, "target": tgt, @@ -2274,13 +2548,15 @@ def walk(node, parent_class_nid: str | None = None) -> None: if base_nid not in seen_ids: base_nid = _make_id(base) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) add_edge(class_nid, base_nid, "inherits", line) @@ -2448,7 +2724,8 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: if sub.type == "generic_name": name_child = sub.child_by_field_name("name") base = ( - _read_text(name_child, source) if name_child + _read_text(name_child, source) + if name_child else _read_text(sub.children[0], source) ) elif sub.type == "qualified_name": @@ -2461,13 +2738,15 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: if base_nid not in seen_ids: base_nid = _make_id(base) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) relation = _csharp_classify_base(base, csharp_interface_names) add_edge(class_nid, base_nid, relation, line) @@ -2482,11 +2761,17 @@ def _php_emit_base(base_name: str, rel: str, at_line: int) -> None: _csharp_collect_type_refs(arg, source, True, refs) for ref_name, _role in refs: target = ensure_named_node(ref_name, line) - add_edge(class_nid, target, "references", line, - context="generic_arg") + add_edge( + class_nid, + target, + "references", + line, + context="generic_arg", + ) # Java-specific: extends (superclass) / implements (interfaces) / interface-extends if config.ts_module == "tree_sitter_java": + def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if not base_name: return @@ -2494,13 +2779,15 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if base_nid not in seen_ids: base_nid = _make_id(base_name) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base_name, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base_name, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) add_edge(class_nid, base_nid, rel, at_line) @@ -2526,7 +2813,9 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if sub.type == "type_list": for tid in sub.children: if tid.type == "type_identifier": - _emit_java_parent(_read_text(tid, source), "inherits", line) + _emit_java_parent( + _read_text(tid, source), "inherits", line + ) # Scala: extends_clause carries `extends Base with Trait1 with Trait2`. # The first base after `extends` is `inherits`; each subsequent @@ -2612,13 +2901,15 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if base_nid not in seen_ids: base_nid = _make_id(base) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base, - "file_type": "code", - "source_file": "", - "source_location": "", - }) + nodes.append( + { + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + } + ) seen_ids.add(base_nid) add_edge(class_nid, base_nid, "inherits", line) @@ -2647,9 +2938,11 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: break elif c.type == "array_creation_expression": array_node = c - if (prop_name is None - or prop_name not in config.event_listener_properties - or array_node is None): + if ( + prop_name is None + or prop_name not in config.event_listener_properties + or array_node is None + ): continue handled_event_listener = True for entry in array_node.children: @@ -2683,9 +2976,11 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: if handled_event_listener: return - if (config.ts_module == "tree_sitter_c_sharp" - and t == "field_declaration" - and parent_class_nid): + if ( + config.ts_module == "tree_sitter_c_sharp" + and t == "field_declaration" + and parent_class_nid + ): type_node = node.child_by_field_name("type") if type_node is None: for child in node.children: @@ -2696,8 +2991,13 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: type_name = _read_csharp_type_name(type_node, source) if type_name: line = node.start_point[0] + 1 - add_edge(parent_class_nid, ensure_named_node(type_name, line), - "references", line, context="field") + add_edge( + parent_class_nid, + ensure_named_node(type_name, line), + "references", + line, + context="field", + ) return if (config.ts_module == "tree_sitter_php" @@ -3075,21 +3375,55 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: # JS/TS arrow functions and C# namespaces — language-specific extra handling if config.ts_module in ("tree_sitter_javascript", "tree_sitter_typescript"): - if _js_extra_walk(node, source, file_nid, stem, str_path, - nodes, edges, seen_ids, function_bodies, - parent_class_nid, add_node, add_edge): + if _js_extra_walk( + node, + source, + file_nid, + stem, + str_path, + nodes, + edges, + seen_ids, + function_bodies, + parent_class_nid, + add_node, + add_edge, + ): return if config.ts_module == "tree_sitter_c_sharp": - if _csharp_extra_walk(node, source, file_nid, stem, str_path, - nodes, edges, seen_ids, function_bodies, - parent_class_nid, add_node, add_edge, walk): + if _csharp_extra_walk( + node, + source, + file_nid, + stem, + str_path, + nodes, + edges, + seen_ids, + function_bodies, + parent_class_nid, + add_node, + add_edge, + walk, + ): return if config.ts_module == "tree_sitter_swift": - if _swift_extra_walk(node, source, file_nid, stem, str_path, - nodes, edges, seen_ids, function_bodies, - parent_class_nid, add_node, add_edge): + if _swift_extra_walk( + node, + source, + file_nid, + stem, + str_path, + nodes, + edges, + seen_ids, + function_bodies, + parent_class_nid, + add_node, + add_edge, + ): return # Python's `@property` / `@staticmethod` / `@classmethod` wrap the @@ -3113,7 +3447,7 @@ def _emit_java_parent(base_name: str, rel: str, at_line: int) -> None: walk(root) # ── Call-graph pass ─────────────────────────────────────────────────────── - label_to_nid: dict[str, str] = {} # case-sensitive (Ruby, C#, Java, Kotlin, etc.) + label_to_nid: dict[str, str] = {} # case-sensitive (Ruby, C#, Java, Kotlin, etc.) label_to_nid_ci: dict[str, str] = {} # case-insensitive (PHP functions/classes) for n in nodes: raw = n["label"] @@ -3146,8 +3480,9 @@ def walk_calls(node, caller_nid: str) -> None: if node.type in config.call_types: # JS/TS dynamic imports: await import('./foo.js') if config.ts_module in ("tree_sitter_javascript", "tree_sitter_typescript"): - if _dynamic_import_js(node, source, caller_nid, str_path, - edges, seen_dyn_import_pairs): + if _dynamic_import_js( + node, source, caller_nid, str_path, edges, seen_dyn_import_pairs + ): # Still recurse into children (import().then(...) may have calls) for child in node.children: walk_calls(child, caller_nid) @@ -3236,18 +3571,28 @@ def walk_calls(node, caller_nid: str) -> None: callee_name = _read_text(name_node, source) elif config.ts_module == "tree_sitter_cpp": # C++: function field, then field_expression/qualified_identifier - func_node = node.child_by_field_name(config.call_function_field) if config.call_function_field else None + func_node = ( + node.child_by_field_name(config.call_function_field) + if config.call_function_field + else None + ) if func_node: if func_node.type == "identifier": callee_name = _read_text(func_node, source) elif func_node.type in ("field_expression", "qualified_identifier"): is_member_call = True - name = func_node.child_by_field_name("field") or func_node.child_by_field_name("name") + name = func_node.child_by_field_name( + "field" + ) or func_node.child_by_field_name("name") if name: callee_name = _read_text(name, source) else: # Generic: get callee from call_function_field - func_node = node.child_by_field_name(config.call_function_field) if config.call_function_field else None + func_node = ( + node.child_by_field_name(config.call_function_field) + if config.call_function_field + else None + ) if func_node: if func_node.type == "identifier": callee_name = _read_text(func_node, source) @@ -3268,28 +3613,32 @@ def walk_calls(node, caller_nid: str) -> None: if pair not in seen_call_pairs: seen_call_pairs.add(pair) line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "context": "call", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": "calls", + "context": "call", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) elif callee_name and not tgt_nid: # Callee not in this file — save for cross-file resolution in extract() - raw_calls.append({ - "caller_nid": caller_nid, - "callee": callee_name, - "is_member_call": is_member_call, - "source_file": str_path, - "source_location": f"L{node.start_point[0] + 1}", - }) + raw_calls.append( + { + "caller_nid": caller_nid, + "callee": callee_name, + "is_member_call": is_member_call, + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + } + ) # Helper function calls: config('foo.bar') → uses_config edge to "foo" - if (callee_name and callee_name in config.helper_fn_names): + if callee_name and callee_name in config.helper_fn_names: args_node = node.child_by_field_name("arguments") first_key: str | None = None if args_node: @@ -3307,29 +3656,34 @@ def walk_calls(node, caller_nid: str) -> None: break if first_key: segment = first_key.split(".")[0] - tgt_nid = (label_to_nid_ci.get(segment.lower()) - or label_to_nid_ci.get(f"{segment}.php".lower())) + tgt_nid = label_to_nid_ci.get(segment.lower()) or label_to_nid_ci.get( + f"{segment}.php".lower() + ) if tgt_nid and tgt_nid != caller_nid: relation = f"uses_{callee_name}" pair3 = (caller_nid, tgt_nid, relation) if pair3 not in seen_helper_ref_pairs: seen_helper_ref_pairs.add(pair3) line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": relation, - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": relation, + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) # Service container bindings: $this->app->bind(Foo::class, Bar::class) - if (node.type == "member_call_expression" - and callee_name - and callee_name in config.container_bind_methods): + if ( + node.type == "member_call_expression" + and callee_name + and callee_name in config.container_bind_methods + ): args_node = node.child_by_field_name("arguments") class_args: list[str] = [] if args_node: @@ -3353,16 +3707,18 @@ def walk_calls(node, caller_nid: str) -> None: if pair3 not in seen_bind_pairs: seen_bind_pairs.add(pair3) line = node.start_point[0] + 1 - edges.append({ - "source": contract_nid, - "target": impl_nid, - "relation": "bound_to", - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": contract_nid, + "target": impl_nid, + "relation": "bound_to", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) # Static property access: Foo::$bar → uses_static_prop edge if node.type in config.static_prop_types: @@ -3380,19 +3736,24 @@ def walk_calls(node, caller_nid: str) -> None: if pair3 not in seen_static_ref_pairs: seen_static_ref_pairs.add(pair3) line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "uses_static_prop", - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": "uses_static_prop", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) # PHP class constant access: Foo::BAR → references_constant edge - if config.ts_module == "tree_sitter_php" and node.type == "class_constant_access_expression": + if ( + config.ts_module == "tree_sitter_php" + and node.type == "class_constant_access_expression" + ): class_name = _php_class_const_scope(node) if class_name: tgt_nid = label_to_nid_ci.get(class_name.lower()) @@ -3401,16 +3762,18 @@ def walk_calls(node, caller_nid: str) -> None: if pair3 not in seen_static_ref_pairs: seen_static_ref_pairs.add(pair3) line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "references_constant", - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": caller_nid, + "target": tgt_nid, + "relation": "references_constant", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) for child in node.children: walk_calls(child, caller_nid) @@ -3429,23 +3792,27 @@ def walk_calls(node, caller_nid: str) -> None: if pair2 in seen_listen_pairs: continue seen_listen_pairs.add(pair2) - edges.append({ - "source": event_nid, - "target": listener_nid, - "relation": "listened_by", - "confidence": "EXTRACTED", - "confidence_score": 1.0, - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + edges.append( + { + "source": event_nid, + "target": listener_nid, + "relation": "listened_by", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + } + ) # ── Clean edges ─────────────────────────────────────────────────────────── valid_ids = seen_ids clean_edges = [] for edge in edges: src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from", "re_exports")): + if src in valid_ids and ( + tgt in valid_ids or edge["relation"] in ("imports", "imports_from", "re_exports") + ): clean_edges.append(edge) result = {"nodes": nodes, "edges": clean_edges, "raw_calls": raw_calls} @@ -3456,7 +3823,15 @@ def walk_calls(node, caller_nid: str) -> None: # ── Python rationale extraction ─────────────────────────────────────────────── -_RATIONALE_PREFIXES = ("# NOTE:", "# IMPORTANT:", "# HACK:", "# WHY:", "# RATIONALE:", "# TODO:", "# FIXME:") +_RATIONALE_PREFIXES = ( + "# NOTE:", + "# IMPORTANT:", + "# HACK:", + "# WHY:", + "# RATIONALE:", + "# TODO:", + "# FIXME:", +) def _is_autogenerated_python(source: bytes) -> bool: @@ -3470,9 +3845,11 @@ def _is_autogenerated_python(source: bytes) -> bool: if any(m in head for m in ("DO NOT EDIT", "@generated", "Generated by the protocol buffer")): return True # Alembic / Flask-Migrate revision files - if (re.search(r"^revision\s*[:=]", head, re.MULTILINE) - and "def upgrade(" in head - and "down_revision" in head): + if ( + re.search(r"^revision\s*[:=]", head, re.MULTILINE) + and "def upgrade(" in head + and "down_revision" in head + ): return True # Django migrations if "class Migration(migrations.Migration)" in head and "operations" in head: @@ -3487,6 +3864,7 @@ def _extract_python_rationale(path: Path, result: dict) -> None: try: import tree_sitter_python as tspython from tree_sitter import Language, Parser + language = Language(tspython.language()) parser = Parser(language) source = path.read_bytes() @@ -3509,7 +3887,9 @@ def _get_docstring(body_node) -> tuple[str, int] | None: if child.type == "expression_statement": for sub in child.children: if sub.type in ("string", "concatenated_string"): - text = source[sub.start_byte:sub.end_byte].decode("utf-8", errors="replace") + text = source[sub.start_byte : sub.end_byte].decode( + "utf-8", errors="replace" + ) text = text.strip("\"'").strip('"""').strip("'''").strip() if len(text) > 20: return text, child.start_point[0] + 1 @@ -3521,22 +3901,26 @@ def _add_rationale(text: str, line: int, parent_nid: str) -> None: rid = _make_id(stem, "rationale", str(line)) if rid not in seen_ids: seen_ids.add(rid) - nodes.append({ - "id": rid, - "label": label, - "file_type": "rationale", + nodes.append( + { + "id": rid, + "label": label, + "file_type": "rationale", + "source_file": str_path, + "source_location": f"L{line}", + } + ) + edges.append( + { + "source": rid, + "target": parent_nid, + "relation": "rationale_for", + "confidence": "EXTRACTED", "source_file": str_path, "source_location": f"L{line}", - }) - edges.append({ - "source": rid, - "target": parent_nid, - "relation": "rationale_for", - "confidence": "EXTRACTED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 1.0, - }) + "weight": 1.0, + } + ) # Module-level docstring — skip for auto-generated files (Alembic, Django # migrations, protobuf stubs, etc.) whose module docstrings are revision @@ -3553,7 +3937,9 @@ def walk_docstrings(node, parent_nid: str) -> None: name_node = node.child_by_field_name("name") body = node.child_by_field_name("body") if name_node and body: - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + class_name = source[name_node.start_byte : name_node.end_byte].decode( + "utf-8", errors="replace" + ) nid = _make_id(stem, class_name) ds = _get_docstring(body) if ds: @@ -3565,8 +3951,14 @@ def walk_docstrings(node, parent_nid: str) -> None: name_node = node.child_by_field_name("name") body = node.child_by_field_name("body") if name_node and body: - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - nid = _make_id(parent_nid, func_name) if parent_nid != file_nid else _make_id(stem, func_name) + func_name = source[name_node.start_byte : name_node.end_byte].decode( + "utf-8", errors="replace" + ) + nid = ( + _make_id(parent_nid, func_name) + if parent_nid != file_nid + else _make_id(stem, func_name) + ) ds = _get_docstring(body) if ds: _add_rationale(ds[0], ds[1], nid) @@ -3586,6 +3978,7 @@ def walk_docstrings(node, parent_nid: str) -> None: # ── Public API ──────────────────────────────────────────────────────────────── + def extract_python(path: Path) -> dict: """Extract classes, functions, and imports from a .py file via tree-sitter AST.""" result = _extract_generic(path, _PYTHON_CONFIG) @@ -3615,6 +4008,7 @@ def extract_svelte(path: Path) -> dict: result = _extract_generic(path, _JS_CONFIG) try: import re as _re + src = path.read_text(encoding="utf-8", errors="replace") existing_ids = {n["id"] for n in result.get("nodes", [])} # Source file node ID must match the one _extract_generic creates: @@ -3642,7 +4036,7 @@ def extract_svelte(path: Path) -> dict: resolved_alias = None for alias_prefix, alias_base in aliases.items(): if raw == alias_prefix or raw.startswith(alias_prefix + "/"): - rest = raw[len(alias_prefix):].lstrip("/") + rest = raw[len(alias_prefix) :].lstrip("/") resolved_alias = Path(os.path.normpath(Path(alias_base) / rest)) break if resolved_alias is not None: @@ -3659,34 +4053,42 @@ def extract_svelte(path: Path) -> dict: stub_source_file = raw if node_id in existing_ids: # Edge target already a real node - just add the edge, don't add a node. - result.setdefault("edges", []).append({ - "source": file_node_id, "target": node_id, - "relation": "dynamic_import", "confidence": "EXTRACTED", - "source_file": str(path), - }) + result.setdefault("edges", []).append( + { + "source": file_node_id, + "target": node_id, + "relation": "dynamic_import", + "confidence": "EXTRACTED", + "source_file": str(path), + } + ) continue - result.setdefault("nodes", []).append({ - "id": node_id, "label": raw, - "file_type": "code", "source_file": stub_source_file, - "confidence": "EXTRACTED", - }) - result.setdefault("edges", []).append({ - "source": file_node_id, "target": node_id, - "relation": "dynamic_import", "confidence": "EXTRACTED", - "source_file": str(path), - }) + result.setdefault("nodes", []).append( + { + "id": node_id, + "label": raw, + "file_type": "code", + "source_file": stub_source_file, + "confidence": "EXTRACTED", + } + ) + result.setdefault("edges", []).append( + { + "source": file_node_id, + "target": node_id, + "relation": "dynamic_import", + "confidence": "EXTRACTED", + "source_file": str(path), + } + ) existing_ids.add(node_id) # Static imports inside ", "source_file": "src/evil.py", "file_type": "code", "community": 1}, + { + "id": "api", + "label": "ApiClient", + "source_file": "src/api.py", + "file_type": "code", + "community": 0, + }, + { + "id": "run", + "label": "run()", + "source_file": "src/main.py", + "file_type": "code", + "community": 0, + }, + { + "id": "export", + "label": "write_html()", + "source_file": "src/export.py", + "file_type": "code", + "community": 1, + }, + { + "id": "evil", + "label": "", + "source_file": "src/evil.py", + "file_type": "code", + "community": 1, + }, ], "links": [ - {"source": "run", "target": "api", "relation": "calls", "confidence": "EXTRACTED", "confidence_score": 1.0}, - {"source": "api", "target": "export", "relation": "uses", "confidence": "EXTRACTED", "confidence_score": 1.0}, - {"source": "export", "target": "evil", "relation": "calls", "confidence": "EXTRACTED", "confidence_score": 1.0}, + { + "source": "run", + "target": "api", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, + { + "source": "api", + "target": "export", + "relation": "uses", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, + { + "source": "export", + "target": "evil", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, ], "hyperedges": [], "built_at_commit": "abcdef123456", @@ -103,16 +145,36 @@ def test_export_callflow_html_cli_accepts_positional_graph_path(tmp_path): "multigraph": False, "graph": {}, "nodes": [ - {"id": "external", "label": "ExternalOnly", "source_file": "src/external.py", "file_type": "code", "community": 0}, - {"id": "writer", "label": "write_external()", "source_file": "src/writer.py", "file_type": "code", "community": 1}, + { + "id": "external", + "label": "ExternalOnly", + "source_file": "src/external.py", + "file_type": "code", + "community": 0, + }, + { + "id": "writer", + "label": "write_external()", + "source_file": "src/writer.py", + "file_type": "code", + "community": 1, + }, ], "links": [ - {"source": "external", "target": "writer", "relation": "calls", "confidence": "EXTRACTED", "confidence_score": 1.0}, + { + "source": "external", + "target": "writer", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, ], "hyperedges": [], } (external_out / "graph.json").write_text(json.dumps(graph), encoding="utf-8") - (external_out / ".graphify_labels.json").write_text(json.dumps({"0": "External Runtime", "1": "External Export"}), encoding="utf-8") + (external_out / ".graphify_labels.json").write_text( + json.dumps({"0": "External Runtime", "1": "External Export"}), encoding="utf-8" + ) (external_out / "GRAPH_REPORT.md").write_text( "\n".join( [ @@ -156,10 +218,25 @@ def test_export_callflow_html_cli_accepts_positional_graph_path(tmp_path): def test_derive_sections_groups_by_architecture_keywords(): nodes = [ - {"id": "extract_py", "label": "extract_python", "source_file": "graphify/extract.py", "community": 0}, - {"id": "extract_js", "label": "extract_js", "source_file": "graphify/extract.py", "community": 0}, + { + "id": "extract_py", + "label": "extract_python", + "source_file": "graphify/extract.py", + "community": 0, + }, + { + "id": "extract_js", + "label": "extract_js", + "source_file": "graphify/extract.py", + "community": 0, + }, {"id": "to_html", "label": "to_html", "source_file": "graphify/export.py", "community": 1}, - {"id": "test_html", "label": "test_export_html", "source_file": "tests/test_export.py", "community": 2}, + { + "id": "test_html", + "label": "test_export_html", + "source_file": "tests/test_export.py", + "community": 2, + }, ] sections = derive_sections_from_communities(nodes, {}, "en", 6) diff --git a/tests/test_charmap_encoding.py b/tests/test_charmap_encoding.py index f255dbbb6..f9d24c1a8 100644 --- a/tests/test_charmap_encoding.py +++ b/tests/test_charmap_encoding.py @@ -11,15 +11,12 @@ b) Assert that extract_corpus_parallel reports loud failure (non-zero exit or summary block) when ≥1 chunk fails. """ + from __future__ import annotations import json -import sys -from io import StringIO -from pathlib import Path -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, patch -import pytest from graphify import llm @@ -31,14 +28,15 @@ "type": "result", "subtype": "success", "is_error": False, - "result": json.dumps({ - "nodes": [{"id": "n1", "label": "N1", "file_type": "document", - "source_file": "u.md"}], - "edges": [], - "hyperedges": [], - "input_tokens": 0, - "output_tokens": 0, - }), + "result": json.dumps( + { + "nodes": [{"id": "n1", "label": "N1", "file_type": "document", "source_file": "u.md"}], + "edges": [], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } + ), "stop_reason": "end_turn", "usage": { "input_tokens": 1, @@ -54,6 +52,7 @@ # ── Test A: subprocess encoding ─────────────────────────────────────────────── + class TestSubprocessEncoding: """_call_claude_cli must pass encoding="utf-8" to subprocess.run so that non-ASCII content in chunk messages does not raise UnicodeEncodeError on @@ -68,8 +67,10 @@ def test_subprocess_called_with_utf8_encoding(self, monkeypatch): """subprocess.run must be invoked with encoding='utf-8'.""" completed = self._make_completed() monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as mock_run: + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as mock_run, + ): llm._call_claude_cli(_UNICODE_CONTENT, max_tokens=8192) _args, kwargs = mock_run.call_args assert kwargs.get("encoding") == "utf-8", ( @@ -85,8 +86,10 @@ def test_subprocess_does_not_use_text_true_without_encoding(self, monkeypatch): """ completed = self._make_completed() monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as mock_run: + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as mock_run, + ): llm._call_claude_cli(_UNICODE_CONTENT, max_tokens=8192) _args, kwargs = mock_run.call_args # If text=True is present, encoding must also be set to 'utf-8'. @@ -111,12 +114,12 @@ def test_unicode_chars_survive_subprocess_roundtrip(self, monkeypatch, tmp_path) completed = self._make_completed() monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed): + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed), + ): # Should not raise - result = llm.extract_files_direct( - files=[f], backend="claude-cli", root=tmp_path - ) + result = llm.extract_files_direct(files=[f], backend="claude-cli", root=tmp_path) assert len(result["nodes"]) >= 1 def test_call_llm_claude_cli_subprocess_encoding(self, monkeypatch): @@ -126,8 +129,10 @@ def test_call_llm_claude_cli_subprocess_encoding(self, monkeypatch): stdout=json.dumps({"result": "ok", "stop_reason": "end_turn"}), stderr="", ) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as mock_run: + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as mock_run, + ): llm._call_llm(_UNICODE_CONTENT, backend="claude-cli", max_tokens=200) _args, kwargs = mock_run.call_args assert kwargs.get("encoding") == "utf-8", ( @@ -138,6 +143,7 @@ def test_call_llm_claude_cli_subprocess_encoding(self, monkeypatch): # ── Test B: loud failure on chunk error ──────────────────────────────────────── + class TestLoudChunkFailure: """extract_corpus_parallel must surface chunk failures loudly — either via non-zero exit (exception raised from the function) or a printed summary @@ -196,10 +202,11 @@ def test_no_false_alarm_when_all_chunks_succeed(self, monkeypatch, tmp_path, cap f.write_text("z = 1\n", encoding="utf-8") good_result = { - "nodes": [{"id": "n1", "label": "N1", "file_type": "code", - "source_file": str(f)}], - "edges": [], "hyperedges": [], - "input_tokens": 1, "output_tokens": 1, + "nodes": [{"id": "n1", "label": "N1", "file_type": "code", "source_file": str(f)}], + "edges": [], + "hyperedges": [], + "input_tokens": 1, + "output_tokens": 1, "elapsed_seconds": 0.1, } monkeypatch.setattr( @@ -217,6 +224,7 @@ def test_no_false_alarm_when_all_chunks_succeed(self, monkeypatch, tmp_path, cap # ── Substitution validation (rsl-siege-manager path via Python) ──────────────── + class TestSubstitutionValidation: """Exercises the same code path as the rsl-siege-manager reproduction without requiring the `claude` CLI or its auth. @@ -262,9 +270,7 @@ def test_cp1252_would_fail_but_utf8_succeeds(self, tmp_path): try: prompt.encode("utf-8") except UnicodeEncodeError as e: - raise AssertionError( - f"UTF-8 encode must succeed but failed: {e}" - ) from e + raise AssertionError(f"UTF-8 encode must succeed but failed: {e}") from e # cp1252 must fail (confirming these chars are the failing surface) try: @@ -279,9 +285,7 @@ def test_cp1252_would_fail_but_utf8_succeeds(self, tmp_path): except UnicodeEncodeError: pass # Expected — confirms these chars hit the pre-fix failure surface - def test_subprocess_encoding_kwarg_in_extract_files_direct( - self, monkeypatch, tmp_path - ): + def test_subprocess_encoding_kwarg_in_extract_files_direct(self, monkeypatch, tmp_path): """End-to-end path: write unicode file → extract_files_direct → subprocess. Subprocess must receive encoding='utf-8', not the locale default. @@ -290,39 +294,49 @@ def test_subprocess_encoding_kwarg_in_extract_files_direct( f.write_text(self._UNICODE_CHARS, encoding="utf-8") _ENVELOPE_SIMPLE = { - "type": "result", "subtype": "success", "is_error": False, - "result": json.dumps({ - "nodes": [{"id": "u_chunk", "label": "Unicode Chunk", - "file_type": "document", - "source_file": "unicode_chunk.md"}], - "edges": [], "hyperedges": [], - "input_tokens": 1, "output_tokens": 1, - }), + "type": "result", + "subtype": "success", + "is_error": False, + "result": json.dumps( + { + "nodes": [ + { + "id": "u_chunk", + "label": "Unicode Chunk", + "file_type": "document", + "source_file": "unicode_chunk.md", + } + ], + "edges": [], + "hyperedges": [], + "input_tokens": 1, + "output_tokens": 1, + } + ), "stop_reason": "end_turn", "usage": { - "input_tokens": 1, "output_tokens": 1, - "cache_read_input_tokens": 0, "cache_creation_input_tokens": 0, + "input_tokens": 1, + "output_tokens": 1, + "cache_read_input_tokens": 0, + "cache_creation_input_tokens": 0, }, "modelUsage": { "claude-opus-4-7": {"inputTokens": 1, "outputTokens": 1}, }, } - completed = MagicMock( - returncode=0, stdout=json.dumps(_ENVELOPE_SIMPLE), stderr="" - ) + completed = MagicMock(returncode=0, stdout=json.dumps(_ENVELOPE_SIMPLE), stderr="") monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as mock_run: - result = llm.extract_files_direct( - files=[f], backend="claude-cli", root=tmp_path - ) + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as mock_run, + ): + result = llm.extract_files_direct(files=[f], backend="claude-cli", root=tmp_path) assert mock_run.called _args, kwargs = mock_run.call_args assert kwargs.get("encoding") == "utf-8", ( - "subprocess.run must be called with encoding='utf-8'; " - f"got {kwargs.get('encoding')!r}" + f"subprocess.run must be called with encoding='utf-8'; got {kwargs.get('encoding')!r}" ) # Confirm the unicode content was in the input (not truncated/replaced) inp = kwargs.get("input", "") diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 087464ab8..f037349cf 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -1,6 +1,6 @@ """Tests for token-aware chunking and parallel chunk execution in graphify.llm.""" + import time -from pathlib import Path from unittest.mock import patch import pytest @@ -13,12 +13,14 @@ def no_tokenizer(): compresses repeated/synthetic content heavily, which would make pack-size assertions tied to specific input sizes flaky.""" from graphify import llm + with patch.object(llm, "_TOKENIZER", None): yield # ---- Token-aware packing ----------------------------------------------------- + def test_pack_chunks_packs_small_files_together(tmp_path): """Many small files should land in a single chunk, not one chunk per file.""" from graphify.llm import _pack_chunks_by_tokens @@ -64,10 +66,14 @@ def test_pack_chunks_groups_by_directory(tmp_path): dir_a.mkdir() dir_b.mkdir() - a1 = dir_a / "x.py"; a1.write_text("a") - a2 = dir_a / "y.py"; a2.write_text("a") - b1 = dir_b / "x.py"; b1.write_text("b") - b2 = dir_b / "y.py"; b2.write_text("b") + a1 = dir_a / "x.py" + a1.write_text("a") + a2 = dir_a / "y.py" + a2.write_text("a") + b1 = dir_b / "x.py" + b1.write_text("b") + b2 = dir_b / "y.py" + b2.write_text("b") # Big budget — everything fits in one chunk in principle, but the order # within the chunk should keep dir_a's files contiguous and dir_b's @@ -87,8 +93,10 @@ def test_pack_chunks_oversized_file_gets_its_own_chunk(tmp_path, no_tokenizer): """A file larger than the budget can't be split — it goes alone in a chunk.""" from graphify.llm import _pack_chunks_by_tokens - big = tmp_path / "big.py"; big.write_text("x" * 200_000) # ~50k tokens (cap-bound) - small = tmp_path / "small.py"; small.write_text("x") + big = tmp_path / "big.py" + big.write_text("x" * 200_000) # ~50k tokens (cap-bound) + small = tmp_path / "small.py" + small.write_text("x") chunks = _pack_chunks_by_tokens([big, small], token_budget=1_000) sizes = [len(c) for c in chunks] @@ -100,13 +108,15 @@ def test_pack_chunks_oversized_file_gets_its_own_chunk(tmp_path, no_tokenizer): def test_pack_chunks_rejects_non_positive_budget(tmp_path): from graphify.llm import _pack_chunks_by_tokens - f = tmp_path / "x.py"; f.write_text("a") + f = tmp_path / "x.py" + f.write_text("a") with pytest.raises(ValueError): _pack_chunks_by_tokens([f], token_budget=0) # ---- Tokenizer fallback ------------------------------------------------------ + def test_estimate_file_tokens_uses_tiktoken_when_available(tmp_path): """When tiktoken is installed, the estimator should call into it for accurate counts rather than the chars/4 heuristic.""" @@ -139,6 +149,7 @@ def test_estimate_file_tokens_falls_back_to_chars_when_no_tokenizer(tmp_path): # ---- Parallel execution ------------------------------------------------------ + def _stub_chunk_result(file_count: int, idx: int) -> dict: """Build a deterministic fake extraction result for a chunk.""" return { @@ -157,7 +168,8 @@ def test_corpus_parallel_runs_chunks_concurrently(tmp_path): files = [] for i in range(8): - f = tmp_path / f"f{i}.py"; f.write_text("x") + f = tmp_path / f"f{i}.py" + f.write_text("x") files.append(f) def slow_extract(chunk, **kwargs): @@ -183,7 +195,8 @@ def test_corpus_parallel_sequential_when_max_concurrency_is_one(tmp_path): files = [] for i in range(3): - f = tmp_path / f"f{i}.py"; f.write_text("x") + f = tmp_path / f"f{i}.py" + f.write_text("x") files.append(f) call_order = [] @@ -208,7 +221,8 @@ def test_corpus_parallel_continues_after_chunk_failure(tmp_path, capsys): files = [] for i in range(4): - f = tmp_path / f"f{i}.py"; f.write_text("x") + f = tmp_path / f"f{i}.py" + f.write_text("x") files.append(f) call_count = {"n": 0} @@ -236,7 +250,8 @@ def test_corpus_parallel_legacy_mode_when_token_budget_is_none(tmp_path): files = [] for i in range(45): - f = tmp_path / f"f{i}.py"; f.write_text("x") + f = tmp_path / f"f{i}.py" + f.write_text("x") files.append(f) chunks_seen = [] @@ -260,7 +275,8 @@ def test_corpus_parallel_token_budget_default_packs_files(tmp_path): files = [] for i in range(50): - f = tmp_path / f"f{i}.py"; f.write_text("x = 1\n") + f = tmp_path / f"f{i}.py" + f.write_text("x = 1\n") files.append(f) chunks_seen = [] @@ -279,6 +295,7 @@ def record(chunk, **kwargs): # ---- Adaptive retry on truncation ------------------------------------------- + def _stub_with_finish(file_count: int, finish_reason: str = "stop") -> dict: """Build a stub extraction result with a controllable finish_reason.""" return { @@ -398,7 +415,8 @@ def test_adaptive_retry_single_file_truncation_does_not_recurse(tmp_path, capsys warning and return what we got. No infinite loop.""" from graphify.llm import _extract_with_adaptive_retry - f = tmp_path / "huge.py"; f.write_text("x") + f = tmp_path / "huge.py" + f.write_text("x") calls = [] diff --git a/tests/test_claude_cli_backend.py b/tests/test_claude_cli_backend.py index eeb6fd27b..1b2c1fa28 100644 --- a/tests/test_claude_cli_backend.py +++ b/tests/test_claude_cli_backend.py @@ -3,6 +3,7 @@ Mocks subprocess.run + shutil.which so the suite runs on CI without the `claude` binary or a live network call. """ + from __future__ import annotations import json @@ -16,19 +17,31 @@ "type": "result", "subtype": "success", "is_error": False, - "result": json.dumps({ - "nodes": [ - {"id": "foo_module", "label": "Foo", "file_type": "document", "source_file": "foo.md"}, - {"id": "foo_greet", "label": "greet", "file_type": "code", "source_file": "foo.md"}, - ], - "edges": [ - {"source": "foo_module", "target": "foo_greet", - "relation": "references", "confidence": "EXTRACTED", "confidence_score": 1.0}, - ], - "hyperedges": [], - "input_tokens": 0, - "output_tokens": 0, - }), + "result": json.dumps( + { + "nodes": [ + { + "id": "foo_module", + "label": "Foo", + "file_type": "document", + "source_file": "foo.md", + }, + {"id": "foo_greet", "label": "greet", "file_type": "code", "source_file": "foo.md"}, + ], + "edges": [ + { + "source": "foo_module", + "target": "foo_greet", + "relation": "references", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + }, + ], + "hyperedges": [], + "input_tokens": 0, + "output_tokens": 0, + } + ), "stop_reason": "end_turn", "usage": { "input_tokens": 6, @@ -44,8 +57,10 @@ def fake_claude(monkeypatch): completed = MagicMock(returncode=0, stdout=json.dumps(_ENVELOPE), stderr="") monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed) as run: + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed) as run, + ): yield run @@ -67,8 +82,10 @@ def test_finish_reason_length_on_max_tokens(monkeypatch): envelope = dict(_ENVELOPE, stop_reason="max_tokens") completed = MagicMock(returncode=0, stdout=json.dumps(envelope), stderr="") monkeypatch.setattr(llm, "_response_is_hollow", lambda raw, parsed: False) - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed): + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed), + ): result = llm._call_claude_cli("dummy", max_tokens=8192) assert result["finish_reason"] == "length" @@ -81,16 +98,20 @@ def test_raises_when_cli_missing(): def test_raises_on_nonzero_exit(): completed = MagicMock(returncode=2, stdout="", stderr="auth failed") - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed): + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed), + ): with pytest.raises(RuntimeError, match="exited 2"): llm._call_claude_cli("dummy", max_tokens=8192) def test_raises_on_garbage_envelope(): completed = MagicMock(returncode=0, stdout="not json", stderr="") - with patch("shutil.which", return_value="/fake/bin/claude"), \ - patch("subprocess.run", return_value=completed): + with ( + patch("shutil.which", return_value="/fake/bin/claude"), + patch("subprocess.run", return_value=completed), + ): with pytest.raises(RuntimeError, match="unparseable JSON envelope"): llm._call_claude_cli("dummy", max_tokens=8192) diff --git a/tests/test_claude_md.py b/tests/test_claude_md.py index f81f10dd3..3198b797e 100644 --- a/tests/test_claude_md.py +++ b/tests/test_claude_md.py @@ -1,13 +1,13 @@ """Tests for graphify claude install / uninstall commands.""" -from pathlib import Path -import pytest -from graphify.__main__ import claude_install, claude_uninstall, _CLAUDE_MD_MARKER, _CLAUDE_MD_SECTION + +from graphify.__main__ import claude_install, claude_uninstall, _CLAUDE_MD_MARKER # --------------------------------------------------------------------------- # install # --------------------------------------------------------------------------- + def test_install_creates_claude_md(tmp_path): """Creates CLAUDE.md when none exists.""" claude_install(tmp_path) @@ -58,6 +58,7 @@ def test_install_idempotent_message(tmp_path, capsys): # uninstall # --------------------------------------------------------------------------- + def test_uninstall_removes_section(tmp_path): """Removes the graphify section after it was installed.""" claude_install(tmp_path) @@ -101,9 +102,11 @@ def test_uninstall_no_op_when_no_file(tmp_path, capsys): # settings.json PreToolUse hook # --------------------------------------------------------------------------- + def test_install_creates_settings_json(tmp_path): """claude_install also writes .claude/settings.json with PreToolUse hook.""" import json + claude_install(tmp_path) settings_path = tmp_path / ".claude" / "settings.json" assert settings_path.exists() @@ -115,6 +118,7 @@ def test_install_creates_settings_json(tmp_path): def test_install_settings_json_idempotent(tmp_path): """Running claude_install twice does not duplicate the PreToolUse hook.""" import json + claude_install(tmp_path) claude_install(tmp_path) settings_path = tmp_path / ".claude" / "settings.json" @@ -127,6 +131,7 @@ def test_install_settings_json_idempotent(tmp_path): def test_uninstall_removes_settings_hook(tmp_path): """claude_uninstall removes the PreToolUse hook from settings.json.""" import json + claude_install(tmp_path) claude_uninstall(tmp_path) settings_path = tmp_path / ".claude" / "settings.json" diff --git a/tests/test_cli_export.py b/tests/test_cli_export.py index 942dbcf26..8909b8ef0 100644 --- a/tests/test_cli_export.py +++ b/tests/test_cli_export.py @@ -3,6 +3,7 @@ Each test builds a minimal graph in a temp dir, runs the CLI command as a subprocess, and asserts the expected output file exists and is non-empty / valid. """ + from __future__ import annotations import json import os @@ -10,13 +11,14 @@ import sys from pathlib import Path -import pytest PYTHON = sys.executable FIXTURES = Path(__file__).parent / "fixtures" -def _run(args: list[str], cwd: Path, env: dict[str, str] | None = None) -> subprocess.CompletedProcess: +def _run( + args: list[str], cwd: Path, env: dict[str, str] | None = None +) -> subprocess.CompletedProcess: return subprocess.run( [PYTHON, "-m", "graphify"] + args, cwd=cwd, @@ -53,14 +55,13 @@ def _make_graph(tmp_path: Path) -> Path: "surprises": surprises, } (out / ".graphify_analysis.json").write_text(json.dumps(analysis)) - (out / ".graphify_labels.json").write_text( - json.dumps({str(k): v for k, v in labels.items()}) - ) + (out / ".graphify_labels.json").write_text(json.dumps({str(k): v for k, v in labels.items()})) return out # ── graphify export html ───────────────────────────────────────────────────── + def test_export_html_creates_file(tmp_path): _make_graph(tmp_path) r = _run(["export", "html"], tmp_path) @@ -83,8 +84,24 @@ def test_export_html_error_without_graph(tmp_path): assert r.returncode != 0 +def test_update_accepts_no_viz_and_removes_stale_html(tmp_path): + (tmp_path / "app.py").write_text("def alpha():\n return 1\n", encoding="utf-8") + out = tmp_path / "graphify-out" + out.mkdir() + stale_html = out / "graph.html" + stale_html.write_text("", encoding="utf-8") + + env = os.environ | {"GRAPHIFY_NO_TIPS": "1"} + r = _run(["update", ".", "--force", "--no-viz"], tmp_path, env=env) + + assert r.returncode == 0, r.stderr + assert not stale_html.exists() + assert "Skipped graph.html" not in r.stdout + + # ── graphify export obsidian ───────────────────────────────────────────────── + def test_export_obsidian_creates_vault(tmp_path): _make_graph(tmp_path) r = _run(["export", "obsidian"], tmp_path) @@ -106,6 +123,7 @@ def test_export_obsidian_custom_dir(tmp_path): # ── graphify export wiki ───────────────────────────────────────────────────── + def test_export_wiki_creates_articles(tmp_path): _make_graph(tmp_path) r = _run(["export", "wiki"], tmp_path) @@ -130,6 +148,7 @@ def test_export_wiki_accepts_edges_only_graph_json(tmp_path): # ── graphify export graphml ────────────────────────────────────────────────── + def test_export_graphml_creates_file(tmp_path): _make_graph(tmp_path) r = _run(["export", "graphml"], tmp_path) @@ -143,6 +162,7 @@ def test_export_graphml_creates_file(tmp_path): # ── graphify export neo4j (cypher) ─────────────────────────────────────────── + def test_export_neo4j_creates_cypher(tmp_path): _make_graph(tmp_path) r = _run(["export", "neo4j"], tmp_path) @@ -156,6 +176,7 @@ def test_export_neo4j_creates_cypher(tmp_path): # ── graphify query ─────────────────────────────────────────────────────────── + def test_query_returns_output(tmp_path): _make_graph(tmp_path) r = _run(["query", "test"], tmp_path) @@ -195,6 +216,7 @@ def test_query_uses_graphify_out_env(tmp_path): # ── graphify path ──────────────────────────────────────────────────────────── + def test_path_runs_without_error(tmp_path): _make_graph(tmp_path) r = _run(["path", "Transformer", "LayerNorm"], tmp_path) @@ -221,6 +243,7 @@ def test_path_uses_graphify_out_env(tmp_path): # ── graphify explain ───────────────────────────────────────────────────────── + def test_explain_runs_without_error(tmp_path): _make_graph(tmp_path) r = _run(["explain", "test"], tmp_path) @@ -246,6 +269,7 @@ def test_explain_uses_graphify_out_env(tmp_path): # ── graphify export unknown format ─────────────────────────────────────────── + def test_export_unknown_format_fails(tmp_path): r = _run(["export", "pdf"], tmp_path) assert r.returncode != 0 @@ -267,6 +291,7 @@ def test_update_no_cluster_writes_raw_graph(tmp_path): # Regression test for #934 - cluster-only crashes when graphify-out/ doesn't exist + def test_cluster_only_creates_output_dir_when_missing(tmp_path): """cluster-only must not crash with FileNotFoundError when graphify-out/ is absent (#934).""" # Build graph.json somewhere other than the default graphify-out/ location @@ -278,6 +303,7 @@ def test_cluster_only_creates_output_dir_when_missing(tmp_path): graph_json = out_dir / "graph.json" # Simulate user archiving the output dir before re-clustering import shutil + shutil.copy(graph_json, graph_src) shutil.rmtree(out_dir) @@ -290,6 +316,7 @@ def test_cluster_only_creates_output_dir_when_missing(tmp_path): # Regression test for #1027 - cluster-only must remap labels via node overlap + def test_cluster_only_remaps_labels_to_previous_cids(tmp_path): """cluster-only must invoke remap_communities_to_previous so the existing .graphify_labels.json keeps tracking the same conceptual communities after @@ -350,6 +377,7 @@ def test_cluster_only_remaps_labels_to_previous_cids(tmp_path): # silently bails or generates a degraded artifact whenever the sidecar is # missing, even though the data is right there. + def test_export_html_falls_back_to_node_community_attribute(tmp_path): """When .graphify_analysis.json is absent, export html should reconstruct communities from the per-node attribute in graph.json rather than bailing @@ -385,8 +413,7 @@ def test_export_html_fallback_recovers_multiple_communities(tmp_path): # And the count we'd reconstruct from graph.json's node attributes graph = json.loads((out / "graph.json").read_text(encoding="utf-8")) reconstructed_cids = { - n["community"] for n in graph.get("nodes", []) - if n.get("community") is not None + n["community"] for n in graph.get("nodes", []) if n.get("community") is not None } assert len(reconstructed_cids) == expected_count, ( f"reconstruction would lose communities: sidecar={expected_count} vs " diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 21fd2ca3a..514e4aea0 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,5 +1,4 @@ import json -import sys import networkx as nx from pathlib import Path from graphify.build import build_from_json @@ -7,38 +6,45 @@ FIXTURES = Path(__file__).parent / "fixtures" + def make_graph(): return build_from_json(json.loads((FIXTURES / "extraction.json").read_text())) + def test_cluster_returns_dict(): G = make_graph() communities = cluster(G) assert isinstance(communities, dict) + def test_cluster_covers_all_nodes(): G = make_graph() communities = cluster(G) all_nodes = {n for nodes in communities.values() for n in nodes} assert all_nodes == set(G.nodes) + def test_cohesion_score_complete_graph(): G = nx.complete_graph(4) G = nx.relabel_nodes(G, {i: str(i) for i in G.nodes}) score = cohesion_score(G, list(G.nodes)) assert score == 1.0 + def test_cohesion_score_single_node(): G = nx.Graph() G.add_node("a") score = cohesion_score(G, ["a"]) assert score == 1.0 + def test_cohesion_score_disconnected(): G = nx.Graph() G.add_nodes_from(["a", "b", "c"]) score = cohesion_score(G, ["a", "b", "c"]) assert score == 0.0 + def test_cohesion_score_range(): G = make_graph() communities = cluster(G) @@ -46,6 +52,7 @@ def test_cohesion_score_range(): score = cohesion_score(G, nodes) assert 0.0 <= score <= 1.0 + def test_score_all_keys_match_communities(): G = make_graph() communities = cluster(G) diff --git a/tests/test_confidence.py b/tests/test_confidence.py index 299548aca..4e6af31f1 100644 --- a/tests/test_confidence.py +++ b/tests/test_confidence.py @@ -1,9 +1,9 @@ """Tests for confidence_score on edges.""" + import json import tempfile from pathlib import Path -import networkx as nx from graphify.build import build_from_json from graphify.cluster import cluster, score_all @@ -24,12 +24,33 @@ def _make_extraction(**edge_overrides): {"id": "n_d", "label": "D", "file_type": "document", "source_file": "d.md"}, ], "edges": [ - {"source": "n_a", "target": "n_b", "relation": "calls", "confidence": "EXTRACTED", - "confidence_score": 1.0, "source_file": "a.py", "weight": 1.0}, - {"source": "n_b", "target": "n_c", "relation": "implements", "confidence": "INFERRED", - "confidence_score": 0.75, "source_file": "b.py", "weight": 0.8}, - {"source": "n_c", "target": "n_d", "relation": "references", "confidence": "AMBIGUOUS", - "confidence_score": 0.2, "source_file": "c.md", "weight": 0.5}, + { + "source": "n_a", + "target": "n_b", + "relation": "calls", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "a.py", + "weight": 1.0, + }, + { + "source": "n_b", + "target": "n_c", + "relation": "implements", + "confidence": "INFERRED", + "confidence_score": 0.75, + "source_file": "b.py", + "weight": 0.8, + }, + { + "source": "n_c", + "target": "n_d", + "relation": "references", + "confidence": "AMBIGUOUS", + "confidence_score": 0.2, + "source_file": "c.md", + "weight": 0.5, + }, ], "input_tokens": 100, "output_tokens": 50, @@ -108,10 +129,22 @@ def test_to_json_defaults_missing_confidence_score(): ], "edges": [ # No confidence_score field on any of these - {"source": "n_x", "target": "n_y", "relation": "calls", - "confidence": "EXTRACTED", "source_file": "x.py", "weight": 1.0}, - {"source": "n_y", "target": "n_z", "relation": "depends_on", - "confidence": "INFERRED", "source_file": "y.py", "weight": 1.0}, + { + "source": "n_x", + "target": "n_y", + "relation": "calls", + "confidence": "EXTRACTED", + "source_file": "x.py", + "weight": 1.0, + }, + { + "source": "n_y", + "target": "n_z", + "relation": "depends_on", + "confidence": "INFERRED", + "source_file": "y.py", + "weight": 1.0, + }, ], "input_tokens": 0, "output_tokens": 0, @@ -148,7 +181,7 @@ def test_report_shows_avg_confidence_for_inferred(): report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, ".") assert "avg confidence" in report, "Report should show avg confidence for INFERRED edges" # The fixture has one INFERRED edge with score 0.75, so avg should be 0.75 - assert "0.75" in report, f"Expected avg confidence 0.75 in report" + assert "0.75" in report, "Expected avg confidence 0.75 in report" def test_report_inferred_tag_with_score(): @@ -160,9 +193,15 @@ def test_report_inferred_tag_with_score(): {"id": "n_q", "label": "Renderer", "file_type": "code", "source_file": "renderer.py"}, ], "edges": [ - {"source": "n_p", "target": "n_q", "relation": "feeds", - "confidence": "INFERRED", "confidence_score": 0.82, - "source_file": "parser.py", "weight": 1.0}, + { + "source": "n_p", + "target": "n_q", + "relation": "feeds", + "confidence": "INFERRED", + "confidence_score": 0.82, + "source_file": "parser.py", + "weight": 1.0, + }, ], "input_tokens": 0, "output_tokens": 0, diff --git a/tests/test_dedup.py b/tests/test_dedup.py index 293d2a8fe..b5f90b307 100644 --- a/tests/test_dedup.py +++ b/tests/test_dedup.py @@ -1,29 +1,34 @@ """Tests for graphify/dedup.py entity deduplication pipeline.""" + from __future__ import annotations -import pytest from graphify.dedup import deduplicate_entities, _entropy, _shingles # ── entropy gate ───────────────────────────────────────────────────────────── + def test_entropy_short_label_low(): assert _entropy("AI") < 2.5 + def test_entropy_normal_label_high(): assert _entropy("AuthenticationManager") >= 2.5 + def test_entropy_empty_string(): assert _entropy("") == 0.0 # ── shingles ───────────────────────────────────────────────────────────────── + def test_shingles_produces_trigrams(): s = _shingles("hello") assert "hel" in s assert "ell" in s assert "llo" in s + def test_shingles_short_string(): # strings shorter than 3 chars return single shingle of the string itself assert _shingles("ab") == {"ab"} @@ -31,8 +36,13 @@ def test_shingles_short_string(): # ── full pipeline ───────────────────────────────────────────────────────────── + def _make_nodes(*labels): - return [{"id": label.lower().replace(" ", "_"), "label": label, "source_file": "test.md"} for label in labels] + return [ + {"id": label.lower().replace(" ", "_"), "label": label, "source_file": "test.md"} + for label in labels + ] + def _make_edges(src, tgt, relation="relates_to"): return [{"source": src, "target": tgt, "relation": relation}] @@ -122,9 +132,11 @@ def test_dedup_llm_flag_accepted(): # ── build integration ───────────────────────────────────────────────────────── + def test_build_calls_dedup(): """build() should deduplicate near-identical nodes across extractions.""" from graphify.build import build + chunk1 = { "nodes": [{"id": "graphextractor", "label": "GraphExtractor", "source_file": "a.py"}], "edges": [], @@ -139,6 +151,7 @@ def test_build_calls_dedup(): # --- #878: fuzzy dedup false merges on short/variant labels --- + def test_dedup_does_not_merge_numeric_variants(tmp_path): """Chip SKU variants (ASR1603 vs ASR1605) must not be merged (#878).""" nodes = _make_nodes("ASR1603", "ASR1605") @@ -164,6 +177,7 @@ def test_dedup_still_merges_real_typos(): """Genuine same-length single-char typos should still merge (#878 non-regression).""" from graphify.dedup import _is_variant_pair, _short_label_blocked from rapidfuzz.distance import JaroWinkler + a, b = "graphextractor", "graphextractar" score = JaroWinkler.normalized_similarity(a, b) * 100 assert not _is_variant_pair(a, b), "not a variant pair" @@ -173,6 +187,7 @@ def test_dedup_still_merges_real_typos(): def test_variant_pair_helper(): """_is_variant_pair correctly identifies chip-model variant pairs (#878).""" from graphify.dedup import _is_variant_pair + assert _is_variant_pair("asr1603", "asr1605") assert _is_variant_pair("cortex a55", "cortex a55x") assert not _is_variant_pair("graphextractor", "graphextracter") diff --git a/tests/test_detect.py b/tests/test_detect.py index 900851802..0357bab57 100644 --- a/tests/test_detect.py +++ b/tests/test_detect.py @@ -1,52 +1,74 @@ from pathlib import Path -from graphify.detect import classify_file, count_words, detect, detect_incremental, save_manifest, FileType, _looks_like_paper, _is_ignored, _load_graphifyignore, _is_sensitive +from graphify.detect import ( + classify_file, + count_words, + detect, + detect_incremental, + save_manifest, + FileType, + _is_ignored, + _load_graphifyignore, + _is_sensitive, +) FIXTURES = Path(__file__).parent / "fixtures" + def test_classify_python(): assert classify_file(Path("foo.py")) == FileType.CODE + def test_classify_typescript(): assert classify_file(Path("bar.ts")) == FileType.CODE + def test_classify_markdown(): assert classify_file(Path("README.md")) == FileType.DOCUMENT + def test_classify_pdf(): assert classify_file(Path("paper.pdf")) == FileType.PAPER + def test_classify_pdf_in_xcassets_skipped(): # PDFs inside Xcode asset catalogs are vector icons, not papers asset_pdf = Path("MyApp/Images.xcassets/icon.imageset/icon.pdf") assert classify_file(asset_pdf) is None + def test_classify_pdf_in_xcassets_root_skipped(): asset_pdf = Path("Pods/HXPHPicker/Assets.xcassets/photo.pdf") assert classify_file(asset_pdf) is None + def test_classify_unknown_returns_none(): assert classify_file(Path("archive.zip")) is None + def test_classify_image(): assert classify_file(Path("screenshot.png")) == FileType.IMAGE assert classify_file(Path("design.jpg")) == FileType.IMAGE assert classify_file(Path("diagram.webp")) == FileType.IMAGE + def test_count_words_sample_md(): words = count_words(FIXTURES / "sample.md") assert words > 5 + def test_detect_finds_fixtures(): result = detect(FIXTURES) assert result["total_files"] >= 2 assert "code" in result["files"] assert "document" in result["files"] + def test_detect_warns_small_corpus(): result = detect(FIXTURES) assert result["needs_graph"] is False assert result["warning"] is not None + def test_detect_skips_noise_dot_dirs(): """Noise dot dirs (.next, .nuxt, .graphify cache, …) are skipped (#873). Non-noise dot dirs (.github, .claude, …) are now allowed through.""" @@ -301,6 +323,7 @@ def test_detect_incremental_propagates_follow_symlinks(tmp_path, monkeypatch): def test_classify_video_extensions(): """Video and audio file extensions should classify as VIDEO.""" from graphify.detect import FileType + assert classify_file(Path("lecture.mp4")) == FileType.VIDEO assert classify_file(Path("podcast.mp3")) == FileType.VIDEO assert classify_file(Path("talk.mov")) == FileType.VIDEO @@ -399,7 +422,9 @@ def test_detect_skips_visual_tests_dir(tmp_path): def test_detect_skips_snapshots_dir(tmp_path): """__snapshots__/ and snapshots/ are jest/vitest artefacts — must be excluded.""" (tmp_path / "__snapshots__").mkdir() - (tmp_path / "__snapshots__" / "app.test.ts.snap").write_text("// Jest Snapshot\nexports[`test 1`] = `
`") + (tmp_path / "__snapshots__" / "app.test.ts.snap").write_text( + "// Jest Snapshot\nexports[`test 1`] = `
`" + ) (tmp_path / "app.ts").write_text("export function greet() { return 'hi'; }") result = detect(tmp_path) all_files = [f for files in result["files"].values() for f in files] @@ -422,6 +447,7 @@ def test_detect_skips_storybook_static_dir(tmp_path): # --- #873: dot dirs allowed, framework caches blocked --- + def test_detect_allows_github_dir(tmp_path): """Files inside .github/ (workflows etc.) are now indexed (#873).""" gh = tmp_path / ".github" / "workflows" @@ -430,7 +456,9 @@ def test_detect_allows_github_dir(tmp_path): (tmp_path / "main.py").write_text("def run(): pass") result = detect(tmp_path) all_files = [f for files in result["files"].values() for f in files] - assert any(".github" in f for f in all_files), "expected .github/workflows/ci.yml to be detected" + assert any(".github" in f for f in all_files), ( + "expected .github/workflows/ci.yml to be detected" + ) def test_detect_skips_next_cache(tmp_path): @@ -461,9 +489,9 @@ def test_detect_skips_graphify_own_cache(tmp_path): # --- #882: gitignore parent-exclusion rule for ! re-includes --- + def test_negation_cannot_rescue_file_under_excluded_dir(tmp_path): """A ! re-include cannot un-ignore a file whose parent dir is excluded (#882).""" - from graphify.detect import _is_ignored, _load_graphifyignore android = tmp_path / "android" / "app" / "src" android.mkdir(parents=True) victim = android / "Main.kt" @@ -478,7 +506,6 @@ def test_negation_cannot_rescue_file_under_excluded_dir(tmp_path): def test_negation_works_when_no_ancestor_excluded(tmp_path): """A ! re-include must still un-ignore a file when no ancestor is excluded (#882).""" - from graphify.detect import _is_ignored, _load_graphifyignore src = tmp_path / "src" src.mkdir() keep = src / "keep.py" @@ -492,7 +519,6 @@ def test_negation_works_when_no_ancestor_excluded(tmp_path): def test_negation_ancestor_itself_reincluded(tmp_path): """If the ancestor dir itself is re-included, its children should not be blocked (#882).""" - from graphify.detect import _is_ignored, _load_graphifyignore vendor = tmp_path / "vendor" / "lib" vendor.mkdir(parents=True) f = vendor / "utils.py" @@ -588,34 +614,44 @@ def test_anchored_multi_segment_pattern(tmp_path): def test_sensitive_flags_api_token_txt(): assert _is_sensitive(Path("api_token.txt")) + def test_sensitive_flags_oauth_token_json(): assert _is_sensitive(Path("oauth_token.json")) + def test_sensitive_flags_underscore_secret(): assert _is_sensitive(Path("app_secret.yaml")) + def test_sensitive_does_not_flag_tokenizer_py(): assert not _is_sensitive(Path("tokenizer.py")) + def test_sensitive_does_not_flag_tokenize_py(): assert not _is_sensitive(Path("tokenize.py")) + def test_sensitive_flags_passwords_py(): # passwords.py is just as likely a secret store as passwords.txt — code ext is no excuse assert _is_sensitive(Path("passwords.py")) + def test_sensitive_flags_ssh_dir(): assert _is_sensitive(Path("/home/user/.ssh/id_rsa")) + def test_sensitive_flags_secrets_dir(): assert _is_sensitive(Path("config/secrets/db.json")) + def test_sensitive_flags_token_txt(): assert _is_sensitive(Path("token.txt")) + def test_sensitive_flags_credentials_json(): assert _is_sensitive(Path("credentials.json")) + def test_sensitive_does_not_flag_root_file_named_credentials(): # A root-level file called "credentials" (no parent dir named credentials) # must NOT be flagged by Stage 1; Stage 2 name-pattern check catches it instead. @@ -628,11 +664,13 @@ def test_sensitive_does_not_flag_root_file_named_credentials(): # Verify the whole function still returns True (via name pattern, not dir check). assert _is_sensitive(p) + def test_sensitive_secret_handler_txt(): # Both patterns now use (?![a-zA-Z]) so underscore after keyword is allowed. # "secret_handler.txt": "secret" followed by "_" (not alpha) → flagged. assert _is_sensitive(Path("secret_handler.txt")) + def test_sensitive_token_config_yaml(): # "token_config.yaml": "token" followed by "_" (not alpha) → flagged. assert _is_sensitive(Path("token_config.yaml")) @@ -640,6 +678,7 @@ def test_sensitive_token_config_yaml(): # ── Issue #933: failed-chunk files must not be frozen in manifest ───────────── + def test_save_manifest_skips_semantic_hash_for_files_without_cache(tmp_path): """Files in failed chunks have no semantic cache entry; save_manifest must leave their semantic_hash empty so detect_incremental re-queues them (#933).""" @@ -653,7 +692,12 @@ def test_save_manifest_skips_semantic_hash_for_files_without_cache(tmp_path): doc2.write_text("# B\n\ncontent b") # Simulate: doc1's chunk succeeded (has a cache entry), doc2's chunk failed (no entry). - save_cached(doc1, {"nodes": [{"id": "a", "source_file": str(doc1)}], "edges": [], "hyperedges": []}, root=tmp_path, kind="semantic") + save_cached( + doc1, + {"nodes": [{"id": "a", "source_file": str(doc1)}], "edges": [], "hyperedges": []}, + root=tmp_path, + kind="semantic", + ) # doc2: no cache entry written files = {"document": [str(doc1), str(doc2)]} @@ -674,7 +718,6 @@ def test_save_manifest_skips_semantic_hash_for_files_without_cache(tmp_path): assert str(doc2) not in manifest, "failed-chunk file must be absent from manifest" - def test_save_manifest_without_filter_unchanged_for_code(tmp_path): """Code files must be stamped in the manifest regardless of semantic cache.""" import json @@ -689,8 +732,11 @@ def test_save_manifest_without_filter_unchanged_for_code(tmp_path): manifest = json.loads(Path(manifest_path).read_text()) assert str(py) in manifest assert manifest[str(py)]["ast_hash"] != "" + + # Regression tests for #945 - .gitignore fallback when no .graphifyignore exists + def test_gitignore_fallback_when_no_graphifyignore(tmp_path): """When no .graphifyignore exists, .gitignore patterns are honored (#945).""" (tmp_path / ".git").mkdir() @@ -719,12 +765,13 @@ def test_graphifyignore_takes_precedence_over_gitignore(tmp_path): result = detect(tmp_path) code = result["files"]["code"] - assert any("main.py" in f for f in code) # gitignore NOT applied + assert any("main.py" in f for f in code) # gitignore NOT applied assert not any("other.py" in f for f in code) # graphifyignore IS applied # Regression tests for #947 - .worktrees/ skipped and --exclude flag + def test_detect_skips_worktrees_dir(tmp_path): """Files inside .worktrees/ are never indexed (#947).""" wt = tmp_path / ".worktrees" / "feature-branch" @@ -770,9 +817,11 @@ def test_detect_extra_excludes_pattern(tmp_path): # Shebang interpreter parsing # --------------------------------------------------------------------------- + def test_shebang_interpreter_plain(tmp_path): """Plain shebang returns the interpreter basename.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "plain" script.write_bytes(b"#!/usr/bin/python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -781,6 +830,7 @@ def test_shebang_interpreter_plain(tmp_path): def test_shebang_interpreter_env_single_arg(tmp_path): """`#!/usr/bin/env python3` returns the interpreter, not 'env'.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_single" script.write_bytes(b"#!/usr/bin/env python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -789,6 +839,7 @@ def test_shebang_interpreter_env_single_arg(tmp_path): def test_shebang_interpreter_env_dash_s(tmp_path): """`#!/usr/bin/env -S python3 -u` (-S split-args form) recovers the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_dashs" script.write_bytes(b"#!/usr/bin/env -S python3 -u\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -797,6 +848,7 @@ def test_shebang_interpreter_env_dash_s(tmp_path): def test_shebang_interpreter_env_with_flags(tmp_path): """`#!/usr/bin/env -i bash` skips env flags and resolves to the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_flags" script.write_bytes(b"#!/usr/bin/env -i bash\necho hi\n") assert _shebang_interpreter(script) == "bash" @@ -805,6 +857,7 @@ def test_shebang_interpreter_env_with_flags(tmp_path): def test_shebang_interpreter_env_with_assignment(tmp_path): """`#!/usr/bin/env DEBUG=1 python3` skips var=value assignments.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_assign" script.write_bytes(b"#!/usr/bin/env DEBUG=1 python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -813,6 +866,7 @@ def test_shebang_interpreter_env_with_assignment(tmp_path): def test_shebang_interpreter_no_shebang(tmp_path): """File without shebang returns None.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "no_shebang" script.write_bytes(b"print('x')\n") assert _shebang_interpreter(script) is None @@ -821,6 +875,7 @@ def test_shebang_interpreter_no_shebang(tmp_path): def test_shebang_interpreter_quoted_path(tmp_path): """Quoted interpreter path with spaces parses correctly via shlex.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "quoted" # Note: actual `#!` on disk wouldn't permit a quoted path on most kernels, # but shlex must not crash and should produce a reasonable answer @@ -839,6 +894,7 @@ def test_shebang_file_type_classifies_via_interpreter(tmp_path): def test_shebang_interpreter_unreadable_returns_none(tmp_path): """Unreadable / nonexistent files return None, never raise.""" from graphify.detect import _shebang_interpreter + missing = tmp_path / "does_not_exist" assert _shebang_interpreter(missing) is None @@ -846,6 +902,7 @@ def test_shebang_interpreter_unreadable_returns_none(tmp_path): def test_shebang_interpreter_env_unset_with_operand(tmp_path): """`env -u VAR python3` skips both -u and its required operand.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_unset" script.write_bytes(b"#!/usr/bin/env -u PYTHONPATH python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -855,6 +912,7 @@ def test_shebang_interpreter_env_unset_with_operand(tmp_path): def test_shebang_interpreter_env_chdir_with_operand(tmp_path): """`env -C /tmp python3` skips both -C and its workdir operand.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_chdir" script.write_bytes(b"#!/usr/bin/env -C /tmp python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -864,6 +922,7 @@ def test_shebang_interpreter_env_chdir_with_operand(tmp_path): def test_shebang_interpreter_env_path_with_operand(tmp_path): """`env -P /bin python3` skips both -P and its utilpath operand.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_path" script.write_bytes(b"#!/usr/bin/env -P /bin python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -873,6 +932,7 @@ def test_shebang_interpreter_env_path_with_operand(tmp_path): def test_shebang_interpreter_env_dash_s_after_flag(tmp_path): """`env -i -S "python3 -u"` handles -S after another env flag.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_flag_dash_s" script.write_bytes(b'#!/usr/bin/env -i -S "python3 -u"\nprint("x")\n') assert _shebang_interpreter(script) == "python3" @@ -882,6 +942,7 @@ def test_shebang_interpreter_env_dash_s_after_flag(tmp_path): def test_shebang_interpreter_env_clumped_u_operand(tmp_path): """Clumped `-uPYTHONPATH` form (no space between flag and operand) is one arg.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_clumped" script.write_bytes(b"#!/usr/bin/env -uPYTHONPATH python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -891,6 +952,7 @@ def test_shebang_interpreter_env_clumped_u_operand(tmp_path): def test_shebang_interpreter_env_missing_operand_returns_none(tmp_path): """`env -u` with no operand → not a valid command, return None.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_missing_op" script.write_bytes(b"#!/usr/bin/env -u\n") assert _shebang_interpreter(script) is None @@ -899,6 +961,7 @@ def test_shebang_interpreter_env_missing_operand_returns_none(tmp_path): def test_shebang_interpreter_env_gnu_split_string_equals(tmp_path): """GNU `--split-string='python3 -u'` (with `=` operand) → python3.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_split_eq" script.write_bytes(b"#!/usr/bin/env --split-string='python3 -u'\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -908,6 +971,7 @@ def test_shebang_interpreter_env_gnu_split_string_equals(tmp_path): def test_shebang_interpreter_env_gnu_split_string_separate(tmp_path): """GNU `--split-string "python3 -u"` (separate operand) → python3.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_split_sep" script.write_bytes(b'#!/usr/bin/env --split-string "python3 -u"\nprint("x")\n') assert _shebang_interpreter(script) == "python3" @@ -917,6 +981,7 @@ def test_shebang_interpreter_env_gnu_split_string_separate(tmp_path): def test_shebang_interpreter_env_gnu_argv0_operand(tmp_path): """GNU `-a alias python3` skips both -a and its argv0 operand.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_argv0" script.write_bytes(b"#!/usr/bin/env -a alias python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -926,6 +991,7 @@ def test_shebang_interpreter_env_gnu_argv0_operand(tmp_path): def test_shebang_interpreter_env_compact_dash_s(tmp_path): """Compact `-Spython3 -u` form (no space between -S and packed string).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_compact_dash_s" script.write_bytes(b"#!/usr/bin/env -Spython3 -u\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -935,6 +1001,7 @@ def test_shebang_interpreter_env_compact_dash_s(tmp_path): def test_shebang_interpreter_env_compact_v_then_s(tmp_path): """Compact `-vSpython3` (-v plus compact -S).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_compact_vs" script.write_bytes(b"#!/usr/bin/env -vSpython3 -u\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -944,6 +1011,7 @@ def test_shebang_interpreter_env_compact_v_then_s(tmp_path): def test_shebang_interpreter_env_long_unset_separate_operand(tmp_path): """GNU `--unset PYTHONPATH python3` (separate operand).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_unset" script.write_bytes(b"#!/usr/bin/env --unset PYTHONPATH python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -953,6 +1021,7 @@ def test_shebang_interpreter_env_long_unset_separate_operand(tmp_path): def test_shebang_interpreter_env_long_unset_equals(tmp_path): """GNU `--unset=PYTHONPATH python3` (`=` operand form).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_unset_eq" script.write_bytes(b"#!/usr/bin/env --unset=PYTHONPATH python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -962,6 +1031,7 @@ def test_shebang_interpreter_env_long_unset_equals(tmp_path): def test_shebang_interpreter_env_long_chdir_separate_operand(tmp_path): """GNU `--chdir /tmp python3` (separate operand).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_chdir" script.write_bytes(b"#!/usr/bin/env --chdir /tmp python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -971,6 +1041,7 @@ def test_shebang_interpreter_env_long_chdir_separate_operand(tmp_path): def test_shebang_interpreter_env_long_chdir_equals(tmp_path): """GNU `--chdir=/tmp python3` (`=` operand form).""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_chdir_eq" script.write_bytes(b"#!/usr/bin/env --chdir=/tmp python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -980,6 +1051,7 @@ def test_shebang_interpreter_env_long_chdir_equals(tmp_path): def test_shebang_interpreter_env_signal_flags(tmp_path): """GNU signal-handling flags skip transparently.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_signal" script.write_bytes(b"#!/usr/bin/env --default-signal=TERM --ignore-signal=PIPE python3\n") assert _shebang_interpreter(script) == "python3" @@ -989,6 +1061,7 @@ def test_shebang_interpreter_env_signal_flags(tmp_path): def test_shebang_interpreter_env_unknown_option_returns_none(tmp_path): """Unknown hyphen-prefixed env option → return None rather than guessing.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_unknown" script.write_bytes(b"#!/usr/bin/env --no-such-flag python3\n") # Must refuse to guess: if we can't classify the option, we can't trust @@ -999,10 +1072,10 @@ def test_shebang_interpreter_env_unknown_option_returns_none(tmp_path): def test_shebang_interpreter_env_dash_s_assignment_before_interpreter(tmp_path): """`-S` payload may carry NAME=value assignments before the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_s_assignment" script.write_bytes( - b"#!/usr/bin/env -S PYTHONPATH=/opt/custom:${PYTHONPATH} python3\n" - b"print('x')\n" + b"#!/usr/bin/env -S PYTHONPATH=/opt/custom:${PYTHONPATH} python3\nprint('x')\n" ) assert _shebang_interpreter(script) == "python3" assert classify_file(script) == FileType.CODE @@ -1011,6 +1084,7 @@ def test_shebang_interpreter_env_dash_s_assignment_before_interpreter(tmp_path): def test_shebang_interpreter_env_dash_s_flag_before_interpreter(tmp_path): """`-S` payload may carry env flags (e.g. -i) before the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_s_flag" script.write_bytes(b"#!/usr/bin/env -S -i OLDUSER=${USER} python3\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -1020,6 +1094,7 @@ def test_shebang_interpreter_env_dash_s_flag_before_interpreter(tmp_path): def test_shebang_interpreter_env_long_split_assignment_before_interpreter(tmp_path): """`--split-string=` payload may carry assignments before the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_split_assignment" script.write_bytes( b"#!/usr/bin/env --split-string='PYTHONPATH=/opt/custom:${PYTHONPATH} python3 -u'\n" @@ -1032,6 +1107,7 @@ def test_shebang_interpreter_env_long_split_assignment_before_interpreter(tmp_pa def test_shebang_interpreter_env_long_split_flag_before_interpreter(tmp_path): """`--split-string=` payload may carry env flags before the interpreter.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_long_split_flag" script.write_bytes(b"#!/usr/bin/env --split-string='-i python3 -u'\nprint('x')\n") assert _shebang_interpreter(script) == "python3" @@ -1043,6 +1119,7 @@ def test_shebang_interpreter_env_nested_split_string_rejected(tmp_path): on the recursive call bounds the recursion depth at one). Without this guard, a malicious or strange shebang could spin the parser indefinitely.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_nested_split" # Outer -S splits into ["-S", "python3", "-u"]; inner -S is treated as an # unknown option in the recursed pass, so we get None (refuse to guess). @@ -1053,6 +1130,7 @@ def test_shebang_interpreter_env_nested_split_string_rejected(tmp_path): def test_shebang_interpreter_env_vs_assignment_before_interpreter(tmp_path): """`-vS` packed payload also re-parses for leading assignments.""" from graphify.detect import _shebang_interpreter + script = tmp_path / "env_vs_assignment" script.write_bytes(b"#!/usr/bin/env -vS DEBUG=1 python3 -u\nprint('x')\n") assert _shebang_interpreter(script) == "python3" diff --git a/tests/test_devin.py b/tests/test_devin.py index a3bca5d05..1f94bb58d 100644 --- a/tests/test_devin.py +++ b/tests/test_devin.py @@ -1,24 +1,28 @@ """Tests for graphify devin install / uninstall commands.""" + from pathlib import Path import sys from unittest.mock import patch -import pytest # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _devin_install_user(tmp_path): from graphify.__main__ import install + old_cwd = Path.cwd() try: import os + os.chdir(tmp_path) with patch("graphify.__main__.Path.home", return_value=tmp_path): install(platform="devin") finally: import os + os.chdir(old_cwd) @@ -38,6 +42,7 @@ def _rules_path(project_dir): # User-scope install (graphify install --platform devin / graphify devin install) # --------------------------------------------------------------------------- + def test_devin_install_user_creates_skill_file(tmp_path): """User-scope install copies skill to ~/.config/devin/skills/graphify/SKILL.md.""" _devin_install_user(tmp_path) @@ -71,9 +76,11 @@ def test_devin_install_user_does_not_write_rules(tmp_path): # Project-scope install (graphify devin install --project) # --------------------------------------------------------------------------- + def test_devin_install_project_creates_skill_file(tmp_path, monkeypatch): """Project-scope install copies skill to .devin/skills/graphify/SKILL.md.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -88,6 +95,7 @@ def test_devin_install_project_creates_skill_file(tmp_path, monkeypatch): def test_devin_install_project_creates_rules_file(tmp_path, monkeypatch): """Project-scope install writes .windsurf/rules/graphify.md.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -104,6 +112,7 @@ def test_devin_install_project_creates_rules_file(tmp_path, monkeypatch): def test_devin_rules_content_recommends_graphify_query(tmp_path): """The rules file installed by devin must use query-first policy.""" from graphify.__main__ import _devin_rules_install + _devin_rules_install(tmp_path) content = _rules_path(tmp_path).read_text() assert "graphify query" in content @@ -112,6 +121,7 @@ def test_devin_rules_content_recommends_graphify_query(tmp_path): def test_devin_rules_install_idempotent(tmp_path, capsys): """Installing rules twice does not change content and prints 'no change'.""" from graphify.__main__ import _devin_rules_install + _devin_rules_install(tmp_path) content_first = _rules_path(tmp_path).read_text() _devin_rules_install(tmp_path) @@ -123,6 +133,7 @@ def test_devin_rules_install_idempotent(tmp_path, capsys): def test_devin_install_project_hints_git_add(tmp_path, monkeypatch, capsys): """Project-scope install prints a git add hint covering .devin/ and .windsurf/.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -138,6 +149,7 @@ def test_devin_install_project_hints_git_add(tmp_path, monkeypatch, capsys): # Uninstall — user scope # --------------------------------------------------------------------------- + def test_devin_uninstall_user_removes_skill_file(tmp_path): """User-scope uninstall removes the skill file.""" _devin_install_user(tmp_path) @@ -145,6 +157,7 @@ def test_devin_uninstall_user_removes_skill_file(tmp_path): assert skill.exists() from graphify.__main__ import _remove_skill_file + with patch("graphify.__main__.Path.home", return_value=tmp_path): _remove_skill_file("devin") assert not skill.exists() @@ -154,6 +167,7 @@ def test_devin_uninstall_user_noop_when_not_installed(tmp_path, capsys): """User-scope uninstall prints an appropriate message when nothing is installed.""" from graphify.__main__ import main import os + old_cwd = Path.cwd() try: os.chdir(tmp_path) @@ -170,9 +184,11 @@ def test_devin_uninstall_user_noop_when_not_installed(tmp_path, capsys): # Uninstall — project scope # --------------------------------------------------------------------------- + def test_devin_uninstall_project_removes_skill_file(tmp_path, monkeypatch): """Project-scope uninstall removes .devin/skills/graphify/SKILL.md.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -188,6 +204,7 @@ def test_devin_uninstall_project_removes_skill_file(tmp_path, monkeypatch): def test_devin_uninstall_project_removes_rules_file(tmp_path, monkeypatch): """Project-scope uninstall removes .windsurf/rules/graphify.md.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -203,6 +220,7 @@ def test_devin_uninstall_project_removes_rules_file(tmp_path, monkeypatch): def test_devin_uninstall_project_does_not_touch_user_scope(tmp_path, monkeypatch): """Project-scope uninstall must not remove the user-scope skill file.""" from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -222,6 +240,7 @@ def test_devin_uninstall_project_does_not_touch_user_scope(tmp_path, monkeypatch def test_devin_rules_uninstall_noop_when_not_installed(tmp_path): """_devin_rules_uninstall does nothing if the rules file was never written.""" from graphify.__main__ import _devin_rules_uninstall + _devin_rules_uninstall(tmp_path) # should not raise @@ -229,9 +248,11 @@ def test_devin_rules_uninstall_noop_when_not_installed(tmp_path): # Skill file content # --------------------------------------------------------------------------- + def test_devin_skill_file_exists_in_package(): """skill-devin.md must be present in the installed package.""" import graphify + skill = Path(graphify.__file__).parent / "skill-devin.md" assert skill.exists(), "skill-devin.md missing from package" @@ -244,6 +265,7 @@ def test_devin_skill_file_uses_python_c_syntax(): ``python -c "..."`` so they work in pipx / venv environments. """ import graphify + skill = (Path(graphify.__file__).parent / "skill-devin.md").read_text() assert '.graphify_python) -c "' in skill, ( "skill-devin.md must use the interpreter-detection pattern " @@ -255,6 +277,7 @@ def test_devin_skill_file_uses_python_c_syntax(): def test_devin_skill_file_frontmatter_has_triggers(): """Devin skill frontmatter must list triggers for model-invocable activation.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-devin.md").read_text() assert "triggers:" in skill assert "model" in skill @@ -264,9 +287,11 @@ def test_devin_skill_file_frontmatter_has_triggers(): # Platform config sanity # --------------------------------------------------------------------------- + def test_devin_in_platform_config(): """devin must be registered in _PLATFORM_CONFIG.""" from graphify.__main__ import _PLATFORM_CONFIG + assert "devin" in _PLATFORM_CONFIG assert _PLATFORM_CONFIG["devin"]["skill_file"] == "skill-devin.md" assert _PLATFORM_CONFIG["devin"]["claude_md"] is False @@ -275,6 +300,7 @@ def test_devin_in_platform_config(): def test_devin_platform_skill_destination_user_scope(tmp_path): """User-scope destination must be ~/.config/devin/skills/graphify/SKILL.md.""" from graphify.__main__ import _platform_skill_destination + with patch("graphify.__main__.Path.home", return_value=tmp_path): dst = _platform_skill_destination("devin", project=False) assert dst == tmp_path / ".config" / "devin" / "skills" / "graphify" / "SKILL.md" @@ -283,6 +309,7 @@ def test_devin_platform_skill_destination_user_scope(tmp_path): def test_devin_in_main_help_text(capsys, monkeypatch): """`graphify --help` must list devin in the platform list and in the per-platform section.""" from graphify.__main__ import main + monkeypatch.setattr(sys, "argv", ["graphify", "--help"]) main() captured = capsys.readouterr().out @@ -305,5 +332,6 @@ def test_devin_in_main_help_text(capsys, monkeypatch): def test_devin_platform_skill_destination_project_scope(tmp_path): """Project-scope destination must be /.devin/skills/graphify/SKILL.md.""" from graphify.__main__ import _platform_skill_destination + dst = _platform_skill_destination("devin", project=True, project_dir=tmp_path) assert dst == tmp_path / ".devin" / "skills" / "graphify" / "SKILL.md" diff --git a/tests/test_dotnet.py b/tests/test_dotnet.py index 17a146073..4bc6faf30 100644 --- a/tests/test_dotnet.py +++ b/tests/test_dotnet.py @@ -1,7 +1,7 @@ """Tests for .NET project file extraction (.sln, .csproj, .razor).""" + from pathlib import Path import tempfile -import pytest from graphify.extract import extract_sln, extract_csproj, extract_razor FIXTURES = Path(__file__).parent / "fixtures" @@ -17,6 +17,7 @@ def _relations(r): # ── .sln ───────────────────────────────────────────────────────────────────── + def test_sln_extracts_projects(): r = extract_sln(FIXTURES / "sample.sln") assert "error" not in r @@ -39,13 +40,14 @@ def test_sln_project_dependency(): # ── .csproj ────────────────────────────────────────────────────────────────── + def test_csproj_packages(): r = extract_csproj(FIXTURES / "sample.csproj") assert "error" not in r labels = _labels(r) - assert any("MediatR" in l for l in labels) - assert any("FluentValidation" in l for l in labels) - assert any("Swashbuckle" in l for l in labels) + assert any("MediatR" in label for label in labels) + assert any("FluentValidation" in label for label in labels) + assert any("Swashbuckle" in label for label in labels) def test_csproj_project_references(): @@ -74,6 +76,7 @@ def test_csproj_invalid_xml(): # ── .razor ─────────────────────────────────────────────────────────────────── + def test_razor_using_and_inject(): r = extract_razor(FIXTURES / "sample.razor") assert "error" not in r @@ -91,7 +94,7 @@ def test_razor_components(): def test_razor_page_route(): r = extract_razor(FIXTURES / "sample.razor") - assert any("/counter" in l for l in _labels(r)) + assert any("/counter" in label for label in _labels(r)) def test_razor_inherits(): @@ -113,13 +116,16 @@ def test_razor_missing_file(): # ── dispatch & detect integration ──────────────────────────────────────────── + def test_dispatch_table(): from graphify.extract import _get_extractor + for ext in (".sln", ".csproj", ".fsproj", ".vbproj", ".razor", ".cshtml"): assert _get_extractor(Path(f"foo{ext}")) is not None, f"{ext} not in dispatch" def test_code_extensions(): from graphify.detect import CODE_EXTENSIONS + for ext in (".sln", ".csproj", ".fsproj", ".vbproj", ".razor", ".cshtml"): assert ext in CODE_EXTENSIONS, f"{ext} missing" diff --git a/tests/test_explain_cli.py b/tests/test_explain_cli.py index 1d00955f0..f96b896cd 100644 --- a/tests/test_explain_cli.py +++ b/tests/test_explain_cli.py @@ -1,4 +1,5 @@ """Regression tests for `graphify explain` arrow direction (#853).""" + from __future__ import annotations import json import graphify.__main__ as mainmod @@ -6,24 +7,54 @@ def _write_graph(tmp_path): graph_data = { - "directed": False, "multigraph": False, "graph": {}, + "directed": False, + "multigraph": False, + "graph": {}, "nodes": [ - {"id": "validate", "label": "validateSanitySession()", - "source_file": "server/sanity-validate-session.ts", "community": 0}, - {"id": "create_patch", "label": "createPatchHandler()", - "source_file": "server/create-patch-handler.ts", "community": 0}, - {"id": "create_edit", "label": "createEditHandler()", - "source_file": "server/create-edit-handler.ts", "community": 0}, - {"id": "stable_stringify", "label": "stableStringify()", - "source_file": "shared/stringify.ts", "community": 0}, + { + "id": "validate", + "label": "validateSanitySession()", + "source_file": "server/sanity-validate-session.ts", + "community": 0, + }, + { + "id": "create_patch", + "label": "createPatchHandler()", + "source_file": "server/create-patch-handler.ts", + "community": 0, + }, + { + "id": "create_edit", + "label": "createEditHandler()", + "source_file": "server/create-edit-handler.ts", + "community": 0, + }, + { + "id": "stable_stringify", + "label": "stableStringify()", + "source_file": "shared/stringify.ts", + "community": 0, + }, ], "links": [ - {"source": "create_patch", "target": "validate", - "relation": "calls", "confidence": "EXTRACTED"}, - {"source": "create_edit", "target": "validate", - "relation": "calls", "confidence": "EXTRACTED"}, - {"source": "validate", "target": "stable_stringify", - "relation": "calls", "confidence": "EXTRACTED"}, + { + "source": "create_patch", + "target": "validate", + "relation": "calls", + "confidence": "EXTRACTED", + }, + { + "source": "create_edit", + "target": "validate", + "relation": "calls", + "confidence": "EXTRACTED", + }, + { + "source": "validate", + "target": "stable_stringify", + "relation": "calls", + "confidence": "EXTRACTED", + }, ], } p = tmp_path / "graph.json" @@ -33,8 +64,9 @@ def _write_graph(tmp_path): def _run(monkeypatch, graph_path, label, capsys): monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) - monkeypatch.setattr(mainmod.sys, "argv", - ["graphify", "explain", label, "--graph", str(graph_path)]) + monkeypatch.setattr( + mainmod.sys, "argv", ["graphify", "explain", label, "--graph", str(graph_path)] + ) mainmod.main() return capsys.readouterr().out diff --git a/tests/test_export.py b/tests/test_export.py index 65964d24e..78a1a36c3 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -7,9 +7,11 @@ FIXTURES = Path(__file__).parent / "fixtures" + def make_graph(): return build_from_json(json.loads((FIXTURES / "extraction.json").read_text())) + def test_to_json_creates_file(): G = make_graph() communities = cluster(G) @@ -18,6 +20,7 @@ def test_to_json_creates_file(): to_json(G, communities, str(out)) assert out.exists() + def test_to_json_valid_json(): G = make_graph() communities = cluster(G) @@ -28,6 +31,7 @@ def test_to_json_valid_json(): assert "nodes" in data assert "links" in data + def test_to_json_nodes_have_community(): G = make_graph() communities = cluster(G) @@ -38,6 +42,7 @@ def test_to_json_nodes_have_community(): for node in data["nodes"]: assert "community" in node + def test_to_cypher_creates_file(): G = make_graph() with tempfile.TemporaryDirectory() as tmp: @@ -45,6 +50,7 @@ def test_to_cypher_creates_file(): to_cypher(G, str(out)) assert out.exists() + def test_to_cypher_contains_merge_statements(): G = make_graph() with tempfile.TemporaryDirectory() as tmp: @@ -53,6 +59,7 @@ def test_to_cypher_contains_merge_statements(): content = out.read_text() assert "MERGE" in content + def test_to_graphml_creates_file(): G = make_graph() communities = cluster(G) @@ -61,6 +68,7 @@ def test_to_graphml_creates_file(): to_graphml(G, communities, str(out)) assert out.exists() + def test_to_graphml_valid_xml(): G = make_graph() communities = cluster(G) @@ -71,6 +79,7 @@ def test_to_graphml_valid_xml(): assert " console.log('hi');\n") try: @@ -432,18 +440,13 @@ def test_cross_file_call_promoted_to_extracted_with_import_evidence(tmp_path): an `imports` or `imports_from` edge linking it to the callee.""" caller = tmp_path / "caller.js" callee = tmp_path / "lib.js" - caller.write_text( - "const { doWork } = require('./lib');\n" - "function run() { doWork(); }\n" - ) - callee.write_text( - "function doWork() { return 1; }\n" - "module.exports = { doWork };\n" - ) + caller.write_text("const { doWork } = require('./lib');\nfunction run() { doWork(); }\n") + callee.write_text("function doWork() { return 1; }\nmodule.exports = { doWork };\n") result = extract([caller, callee], cache_root=tmp_path) nodes = {n["id"]: n for n in result["nodes"]} call_edges = [ - e for e in result["edges"] + e + for e in result["edges"] if e["relation"] == "calls" and nodes[e["source"]]["label"] == "run()" and nodes[e["target"]]["label"] == "doWork()" @@ -460,14 +463,12 @@ def test_cross_file_call_remains_inferred_without_import_evidence(tmp_path): callee = tmp_path / "lib.js" # Caller does NOT require lib — same-name function happens to exist elsewhere caller.write_text("function run() { doUnique(); }\n") - callee.write_text( - "function doUnique() { return 1; }\n" - "module.exports = { doUnique };\n" - ) + callee.write_text("function doUnique() { return 1; }\nmodule.exports = { doUnique };\n") result = extract([caller, callee], cache_root=tmp_path) nodes = {n["id"]: n for n in result["nodes"]} call_edges = [ - e for e in result["edges"] + e + for e in result["edges"] if e["relation"] == "calls" and nodes[e["source"]]["label"] == "run()" and nodes[e["target"]]["label"] == "doUnique()" @@ -481,14 +482,16 @@ def test_cross_file_call_remains_inferred_without_import_evidence(tmp_path): # `language_typescript` grammar. Parsing JSX with the wrong grammar produces # silent ERROR nodes and drops every function/call inside JSX trees. + def test_extract_tsx_finds_helpers_and_component(): """Functions defined alongside a JSX-returning component must be captured.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "sample.tsx") labels = [n["label"] for n in result["nodes"]] - assert any("fmtDate" in l for l in labels), f"fmtDate missing from {labels}" - assert any("fmtCount" in l for l in labels), f"fmtCount missing from {labels}" - assert any("App" in l for l in labels), f"App missing from {labels}" + assert any("fmtDate" in label for label in labels), f"fmtDate missing from {labels}" + assert any("fmtCount" in label for label in labels), f"fmtCount missing from {labels}" + assert any("App" in label for label in labels), f"App missing from {labels}" def test_extract_tsx_jsx_expression_calls_resolve(): @@ -498,6 +501,7 @@ def test_extract_tsx_jsx_expression_calls_resolve(): JSX is parsed as ERROR nodes and these call_expressions disappear. """ from graphify.extract import extract_js + result = extract_js(FIXTURES / "sample.tsx") nodes_by_id = {n["id"]: n for n in result["nodes"]} call_targets = { @@ -516,6 +520,7 @@ def test_extract_tsx_jsx_expression_calls_resolve(): def test_extract_tsx_uses_tsx_grammar(): """Wiring check: the .tsx config must use tree-sitter's `language_tsx`.""" from graphify.extract import _TSX_CONFIG, _TS_CONFIG + assert _TSX_CONFIG.ts_language_fn == "language_tsx" assert _TS_CONFIG.ts_language_fn == "language_typescript" @@ -526,6 +531,7 @@ def test_extract_tsx_uses_tsx_grammar(): # detect this, warn, and fall back to sequential extraction rather than # propagating a 290-line traceback. + def test_extract_falls_back_to_sequential_when_parallel_returns_false(tmp_path, monkeypatch): """extract() must run sequential when _extract_parallel signals failure (returns False).""" from graphify import extract as extract_mod @@ -561,15 +567,19 @@ def test_extract_parallel_returns_false_on_broken_pool(tmp_path, monkeypatch, ca from graphify import extract as extract_mod class FakePool: - def __init__(self, *a, **kw): pass - def __enter__(self): return self - def __exit__(self, *a): return False + def __init__(self, *a, **kw): + pass + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + def submit(self, *a, **kw): raise BrokenProcessPool("simulated spawn failure") - monkeypatch.setattr( - concurrent.futures, "ProcessPoolExecutor", lambda *a, **kw: FakePool() - ) + monkeypatch.setattr(concurrent.futures, "ProcessPoolExecutor", lambda *a, **kw: FakePool()) uncached = [(0, FIXTURES / "sample.py")] per_file: list = [None] @@ -580,10 +590,46 @@ def submit(self, *a, **kw): assert "__main__" in out, "warning must hint at the Windows __main__ guard idiom" +def test_extract_parallel_worker_warning_handles_sparse_file_indexes(tmp_path, monkeypatch, capsys): + """Worker-failure warnings must not index work_items by original file index.""" + import concurrent.futures + from graphify import extract as extract_mod + + class FakePool: + def __init__(self, *a, **kw): + pass + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + def submit(self, fn, item): + future: concurrent.futures.Future = concurrent.futures.Future() + future.set_exception(RuntimeError("simulated worker failure")) + return future + + monkeypatch.setattr(concurrent.futures, "ProcessPoolExecutor", lambda *a, **kw: FakePool()) + + source = tmp_path / "late.py" + source.write_text("x = 1\n", encoding="utf-8") + uncached = [(3, source)] + per_file: list = [None, None, None, None] + + ok = extract_mod._extract_parallel(uncached, per_file, tmp_path, 2, 4) + + assert ok is True + err = capsys.readouterr().err + assert "late.py" in err + assert "simulated worker failure" in err + + # --------------------------------------------------------------------------- # Bash extractor tests (#866) # --------------------------------------------------------------------------- + def test_dispatch_includes_sh_and_json(): assert ".sh" in _DISPATCH assert ".bash" in _DISPATCH @@ -626,7 +672,7 @@ def test_extract_bash_emits_source_imports_from(tmp_path): helpers = tmp_path / "helpers.sh" helpers.write_text("# helper\n") script = tmp_path / "deploy.sh" - script.write_text(f"#!/bin/bash\nsource ./helpers.sh\nfoo() {{ echo hi; }}\n") + script.write_text("#!/bin/bash\nsource ./helpers.sh\nfoo() { echo hi; }\n") result = extract_bash(script) import_edges = [e for e in result["edges"] if e["relation"] == "imports_from"] assert len(import_edges) >= 1 @@ -661,6 +707,7 @@ def test_extract_bash_missing_grammar_returns_error(): """extract_bash returns error dict when tree-sitter-bash not installed (mocked).""" import unittest.mock as mock import builtins + real_import = builtins.__import__ def patched(name, *args, **kwargs): @@ -677,11 +724,7 @@ def patched(name, *args, **kwargs): def test_extract_bash_rejects_command_substitution_as_call(tmp_path): """`$(build)` must not be recorded as a call edge to build().""" script = tmp_path / "command_substitution.sh" - script.write_text( - "#!/usr/bin/env bash\n" - "build() { echo build; }\n" - "$(build)\n" - ) + script.write_text("#!/usr/bin/env bash\nbuild() { echo build; }\n$(build)\n") result = extract_bash(script) labels = {n["id"]: n["label"] for n in result["nodes"]} call_pairs = [ @@ -695,11 +738,7 @@ def test_extract_bash_rejects_command_substitution_as_call(tmp_path): def test_extract_bash_process_substitution_not_recorded(tmp_path): """`<(helper)` (process substitution) must not be recorded as a call edge.""" script = tmp_path / "process_substitution.sh" - script.write_text( - "#!/usr/bin/env bash\n" - "helper() { echo h; }\n" - "diff <(helper) <(helper)\n" - ) + script.write_text("#!/usr/bin/env bash\nhelper() { echo h; }\ndiff <(helper) <(helper)\n") result = extract_bash(script) labels = {n["id"]: n["label"] for n in result["nodes"]} call_pairs = [ @@ -713,11 +752,7 @@ def test_extract_bash_process_substitution_not_recorded(tmp_path): def test_extract_bash_shadowing_function_is_recorded(tmp_path): """User-defined function shadowing an external command (install/find/etc.) must still produce a call edge.""" script = tmp_path / "shadowing.sh" - script.write_text( - "#!/usr/bin/env bash\n" - "install() { echo install; }\n" - "deploy() { install; }\n" - ) + script.write_text("#!/usr/bin/env bash\ninstall() { echo install; }\ndeploy() { install; }\n") result = extract_bash(script) labels = {n["id"]: n["label"] for n in result["nodes"]} call_pairs = [ @@ -739,10 +774,15 @@ def test_extract_bash_creates_entrypoint_node(tmp_path): assert "bash_entrypoint" in kinds, f"No bash_entrypoint node; kinds={kinds}" assert "file" in kinds, f"No file node; kinds={kinds}" file_node = next(n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "file") - entry_node = next(n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint") + entry_node = next( + n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint" + ) contains_edges = [ - e for e in result["edges"] - if e["relation"] == "contains" and e["source"] == file_node["id"] and e["target"] == entry_node["id"] + e + for e in result["edges"] + if e["relation"] == "contains" + and e["source"] == file_node["id"] + and e["target"] == entry_node["id"] ] assert contains_edges, "Missing contains edge from file → bash_entrypoint" @@ -750,23 +790,19 @@ def test_extract_bash_creates_entrypoint_node(tmp_path): def test_extract_bash_top_level_call_attributes_to_entrypoint(tmp_path): """Top-level function call attaches to the entrypoint node, not orphaned.""" script = tmp_path / "top_level_call.sh" - script.write_text( - "#!/usr/bin/env bash\n" - "build() { echo build; }\n" - "build\n" - ) + script.write_text("#!/usr/bin/env bash\nbuild() { echo build; }\nbuild\n") result = extract_bash(script) entry_node = next( (n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint"), None, ) assert entry_node is not None, "No entrypoint node created" - call_pairs = [ - (e["source"], e["target"]) - for e in result["edges"] - if e["relation"] == "calls" - ] - target_ids = {tgt for _, tgt in call_pairs if any(n["id"] == tgt and n["label"] == "build()" for n in result["nodes"])} + call_pairs = [(e["source"], e["target"]) for e in result["edges"] if e["relation"] == "calls"] + target_ids = { + tgt + for _, tgt in call_pairs + if any(n["id"] == tgt and n["label"] == "build()" for n in result["nodes"]) + } source_ids_to_build = {src for src, tgt in call_pairs if tgt in target_ids} assert entry_node["id"] in source_ids_to_build, ( f"Top-level call to build not attributed to entrypoint; calls={call_pairs}" @@ -788,8 +824,12 @@ def test_extract_bash_entrypoint_no_collision_with_function_named_script(tmp_pat script = tmp_path / "deploy.sh" script.write_text("#!/usr/bin/env bash\nfunction script() { echo hi; }\n") result = extract_bash(script) - entry_nodes = [n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint"] - func_nodes = [n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_function"] + entry_nodes = [ + n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_entrypoint" + ] + func_nodes = [ + n for n in result["nodes"] if n.get("metadata", {}).get("kind") == "bash_function" + ] assert entry_nodes, "Must have a bash_entrypoint node" assert func_nodes, "Must have a bash_function node for 'script'" entry_id = entry_nodes[0]["id"] @@ -814,8 +854,12 @@ def test_extract_bash_nested_function_calls_recorded(tmp_path): ) result = extract_bash(script) node_id_by_label = {n["label"].rstrip("()"): n["id"] for n in result["nodes"]} - assert "inner" in node_id_by_label, f"inner function must be discovered; labels={list(node_id_by_label)}" - assert "do_work" in node_id_by_label, f"do_work function must be discovered; labels={list(node_id_by_label)}" + assert "inner" in node_id_by_label, ( + f"inner function must be discovered; labels={list(node_id_by_label)}" + ) + assert "do_work" in node_id_by_label, ( + f"do_work function must be discovered; labels={list(node_id_by_label)}" + ) calls = {(e["source"], e["target"]) for e in result["edges"] if e.get("relation") == "calls"} inner_id = node_id_by_label["inner"] do_work_id = node_id_by_label["do_work"] @@ -832,9 +876,7 @@ def test_extract_bash_source_user_defined_emits_calls_not_imports_from(tmp_path) helpers.write_text("#!/bin/bash\n") script = tmp_path / "run.sh" script.write_text( - "#!/usr/bin/env bash\n" - "function source() { echo 'custom source'; }\n" - "source ./helpers.sh\n" + "#!/usr/bin/env bash\nfunction source() { echo 'custom source'; }\nsource ./helpers.sh\n" ) result = extract_bash(script) import_edges = [e for e in result["edges"] if e.get("relation") == "imports_from"] @@ -847,6 +889,7 @@ def test_extract_bash_source_user_defined_emits_calls_not_imports_from(tmp_path) # JSON extractor tests (#866) # --------------------------------------------------------------------------- + def test_extract_json_top_level_keys(): result = extract_json(FIXTURES / "sample.json") assert "error" not in result @@ -907,12 +950,14 @@ def test_extract_json_no_self_loops(): def test_extract_bash_via_dispatch(): from graphify.extract import _get_extractor + assert _get_extractor(Path("foo.sh")) is extract_bash assert _get_extractor(Path("foo.bash")) is extract_bash def test_extract_json_via_dispatch(): from graphify.extract import _get_extractor + assert _get_extractor(Path("foo.json")) is extract_json @@ -938,6 +983,7 @@ def test_extract_bash_node_metadata_is_sanitized(): def test_barrel_reexport_emits_re_exports_edges(): """export { X } from './mod' must emit re_exports edges for each named specifier.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") reexports = [e for e in result["edges"] if e["relation"] == "re_exports"] targets = [e["target"] for e in reexports] @@ -952,6 +998,7 @@ def test_barrel_reexport_emits_re_exports_edges(): def test_barrel_reexport_emits_imports_from(): """Barrel file must emit file-level imports_from edges to source modules.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") imports_from = [e for e in result["edges"] if e["relation"] == "imports_from"] targets = [e["target"] for e in imports_from] @@ -963,6 +1010,7 @@ def test_barrel_reexport_emits_imports_from(): def test_barrel_reexport_context_tagged(): """re_exports edges should have context='re-export'.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") reexports = [e for e in result["edges"] if e["relation"] == "re_exports"] for e in reexports: @@ -972,6 +1020,7 @@ def test_barrel_reexport_context_tagged(): def test_barrel_local_exports_still_extracted(): """export function/const in a barrel file must still create nodes.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") labels = [n["label"] for n in result["nodes"]] assert "localHelper()" in labels or "localHelper" in labels @@ -982,6 +1031,7 @@ def test_barrel_local_exports_still_extracted(): def test_barrel_reexport_confidence_extracted(): """All re_exports edges should have confidence=EXTRACTED.""" from graphify.extract import extract_js + result = extract_js(FIXTURES / "barrel_reexport.ts") reexports = [e for e in result["edges"] if e["relation"] == "re_exports"] for e in reexports: @@ -1015,6 +1065,7 @@ def test_pure_export_no_from_not_treated_as_reexport(): """export { localVar } without 'from' should NOT create re_exports edges.""" from graphify.extract import extract_js import tempfile + code = b"const x = 1;\nexport { x };\n" with tempfile.NamedTemporaryFile(suffix=".ts", delete=False) as f: f.write(code) @@ -1035,8 +1086,8 @@ def test_dart_child_node_ids_are_stem_based(tmp_path): result = extract_dart(src_file) stem = _file_stem(src_file) # -> "mydir.sample" - expected_class_nid = _make_id(stem, "MyClass") # -> "mydir_sample_myclass" - expected_func_nid = _make_id(stem, "myFunc") # -> "mydir_sample_myfunc" + expected_class_nid = _make_id(stem, "MyClass") # -> "mydir_sample_myclass" + expected_func_nid = _make_id(stem, "myFunc") # -> "mydir_sample_myfunc" node_ids = {n["id"] for n in result["nodes"]} @@ -1054,9 +1105,9 @@ def test_dart_child_node_ids_are_stem_based(tmp_path): for node in result["nodes"]: if node["id"] == file_nid: continue - assert "_" + stem.replace(".", "_") in node["id"] or node["id"].startswith(stem.replace(".", "_")), ( + assert "_" + stem.replace(".", "_") in node["id"] or node["id"].startswith( + stem.replace(".", "_") + ), ( f"Child node ID '{node['id']}' does not start with the expected stem prefix '{stem}'. " "This suggests an absolute path is still leaking into the ID." ) - - diff --git a/tests/test_extract_cli.py b/tests/test_extract_cli.py index 6998bcce9..a3ca1e4ed 100644 --- a/tests/test_extract_cli.py +++ b/tests/test_extract_cli.py @@ -1,4 +1,5 @@ """Tests for `graphify extract` CLI dispatch path in graphify.__main__.""" + from __future__ import annotations import pytest @@ -17,9 +18,7 @@ def _make_corpus(tmp_path): return tmp_path -def test_extract_exits_nonzero_when_all_semantic_chunks_fail( - monkeypatch, tmp_path, capsys -): +def test_extract_exits_nonzero_when_all_semantic_chunks_fail(monkeypatch, tmp_path, capsys): """When every semantic chunk errors (e.g. backend SDK not installed), the CLI must exit non-zero instead of silently writing an AST-only graph. @@ -48,23 +47,19 @@ def _all_chunks_failed(paths, **kwargs): "output_tokens": 0, } - monkeypatch.setattr( - "graphify.llm.extract_corpus_parallel", _all_chunks_failed - ) + monkeypatch.setattr("graphify.llm.extract_corpus_parallel", _all_chunks_failed) monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) monkeypatch.setattr( mainmod.sys, "argv", - ["graphify", "extract", str(corpus), "--backend", "claude", - "--out", str(out_dir)], + ["graphify", "extract", str(corpus), "--backend", "claude", "--out", str(out_dir)], ) with pytest.raises(SystemExit) as exc_info: mainmod.main() assert exc_info.value.code == 1, ( - f"expected exit code 1 when all semantic chunks fail, " - f"got {exc_info.value.code}" + f"expected exit code 1 when all semantic chunks fail, got {exc_info.value.code}" ) stderr = capsys.readouterr().err @@ -78,9 +73,7 @@ def _all_chunks_failed(paths, **kwargs): ) -def test_extract_succeeds_when_at_least_one_chunk_completes( - monkeypatch, tmp_path -): +def test_extract_succeeds_when_at_least_one_chunk_completes(monkeypatch, tmp_path): """Sanity counter-test: a successful chunk run keeps exit 0. Confirms the new guard only fires on the all-failed path, not on every extract.""" corpus = _make_corpus(tmp_path) @@ -99,15 +92,12 @@ def _one_chunk_succeeded(paths, **kwargs): "output_tokens": 50, } - monkeypatch.setattr( - "graphify.llm.extract_corpus_parallel", _one_chunk_succeeded - ) + monkeypatch.setattr("graphify.llm.extract_corpus_parallel", _one_chunk_succeeded) monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) monkeypatch.setattr( mainmod.sys, "argv", - ["graphify", "extract", str(corpus), "--backend", "claude", - "--out", str(out_dir)], + ["graphify", "extract", str(corpus), "--backend", "claude", "--out", str(out_dir)], ) # extract may still raise SystemExit at the end (clean exit code 0) diff --git a/tests/test_global_graph.py b/tests/test_global_graph.py index f40d9c6d5..c09b457f1 100644 --- a/tests/test_global_graph.py +++ b/tests/test_global_graph.py @@ -1,6 +1,7 @@ """Tests for the global graph infrastructure (graphify/global_graph.py), prefix/prune helpers in graphify/build.py, and the cross-repo guard in graphify/dedup.py.""" + from __future__ import annotations import json @@ -11,13 +12,14 @@ # ── helpers ────────────────────────────────────────────────────────────────── + def _make_graph(nodes, edges=None): """Build a simple nx.Graph from node dicts.""" G = nx.Graph() for n in nodes: nid = n["id"] G.add_node(nid, **{k: v for k, v in n.items() if k != "id"}) - for e in (edges or []): + for e in edges or []: G.add_edge( e["source"], e["target"], @@ -28,6 +30,7 @@ def _make_graph(nodes, edges=None): def _graph_to_json(G, path): from networkx.readwrite import json_graph as jg + try: data = jg.node_link_data(G, edges="links") except TypeError: @@ -37,8 +40,10 @@ def _graph_to_json(G, path): # ── build.py helpers ────────────────────────────────────────────────────────── + def test_prefix_graph_preserves_label(): from graphify.build import prefix_graph_for_global + G = _make_graph([{"id": "userservice", "label": "UserService", "source_file": "src/user.py"}]) H = prefix_graph_for_global(G, "repoA") assert "repoA::userservice" in H.nodes @@ -48,6 +53,7 @@ def test_prefix_graph_preserves_label(): def test_prefix_graph_sets_repo_and_local_id(): from graphify.build import prefix_graph_for_global + G = _make_graph([{"id": "userservice", "label": "UserService"}]) H = prefix_graph_for_global(G, "repoA") data = H.nodes["repoA::userservice"] @@ -57,6 +63,7 @@ def test_prefix_graph_sets_repo_and_local_id(): def test_prefix_graph_rewrites_edges(): from graphify.build import prefix_graph_for_global + G = _make_graph( [{"id": "a", "label": "A"}, {"id": "b", "label": "B"}], [{"source": "a", "target": "b"}], @@ -68,6 +75,7 @@ def test_prefix_graph_rewrites_edges(): def test_prune_repo_removes_correct_nodes(): from graphify.build import prune_repo_from_graph + G = nx.Graph() G.add_node("repoA::userservice", repo="repoA", label="UserService") G.add_node("repoB::userservice", repo="repoB", label="UserService") @@ -81,6 +89,7 @@ def test_prune_repo_removes_correct_nodes(): def test_prune_repo_returns_zero_if_not_present(): from graphify.build import prune_repo_from_graph + G = nx.Graph() G.add_node("repoA::x", repo="repoA") removed = prune_repo_from_graph(G, "repoB") @@ -90,16 +99,20 @@ def test_prune_repo_returns_zero_if_not_present(): # ── global_graph.py ─────────────────────────────────────────────────────────── + def test_global_add_creates_global_graph(tmp_path): src_graph = tmp_path / "graph.json" G = _make_graph([{"id": "userservice", "label": "UserService", "source_file": "src/user.py"}]) _graph_to_json(G, src_graph) global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add + result = global_add(src_graph, "repoA") assert result["skipped"] is False @@ -116,10 +129,13 @@ def test_global_add_skip_on_unchanged_hash(tmp_path): _graph_to_json(G, src_graph) global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add + global_add(src_graph, "repoA") result2 = global_add(src_graph, "repoA") @@ -137,10 +153,13 @@ def test_global_add_two_repos_no_collision(tmp_path): global_dir = tmp_path / ".graphify" global_graph_path = global_dir / "global-graph.json" global_manifest_path = global_dir / "global-manifest.json" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_graph_path), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_manifest_path): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_graph_path), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_manifest_path), + ): from graphify.global_graph import global_add, _load_global_graph + global_add(g1, "repoA") global_add(g2, "repoB") G = _load_global_graph() @@ -156,30 +175,39 @@ def test_global_remove(tmp_path): _graph_to_json(G, src_graph) global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add, global_remove + global_add(src_graph, "repoA") removed = global_remove("repoA") assert removed > 0 # manifest should no longer list repoA - need to re-patch for list call global_dir2 = global_dir # same dir - with patch("graphify.global_graph._GLOBAL_DIR", global_dir2), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir2 / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir2 / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir2), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir2 / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir2 / "global-manifest.json"), + ): from graphify.global_graph import global_list + repos = global_list() assert "repoA" not in repos def test_global_remove_unknown_tag_raises(tmp_path): global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_remove + with pytest.raises(KeyError): global_remove("nonexistent") @@ -192,10 +220,13 @@ def test_global_add_collision_warning(tmp_path, capsys): _graph_to_json(G, g2) global_dir = tmp_path / ".graphify" - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add + global_add(g1, "myrepo") global_add(g2, "myrepo") # different source path, same tag @@ -205,8 +236,10 @@ def test_global_add_collision_warning(tmp_path, capsys): # ── dedup guard ─────────────────────────────────────────────────────────────── + def test_dedup_raises_on_cross_repo_nodes(): from graphify.dedup import deduplicate_entities + nodes = [ {"id": "repoA::userservice", "label": "UserService", "repo": "repoA"}, {"id": "repoB::userservice", "label": "UserService", "repo": "repoB"}, @@ -217,6 +250,7 @@ def test_dedup_raises_on_cross_repo_nodes(): def test_dedup_ok_with_single_repo(): from graphify.dedup import deduplicate_entities + nodes = [ {"id": "repoA::userservice", "label": "UserService", "repo": "repoA"}, {"id": "repoA::auth", "label": "Auth", "repo": "repoA"}, @@ -227,6 +261,7 @@ def test_dedup_ok_with_single_repo(): def test_dedup_ok_with_no_repo_attr(): from graphify.dedup import deduplicate_entities + nodes = [ {"id": "userservice", "label": "UserService"}, {"id": "auth", "label": "Auth"}, @@ -237,6 +272,7 @@ def test_dedup_ok_with_no_repo_attr(): # ── merge-graphs prefix ─────────────────────────────────────────────────────── + def test_merge_graphs_prefixes_ids(tmp_path): """merge-graphs should prefix node IDs with repo name to avoid silent collision.""" from graphify.build import prefix_graph_for_global @@ -290,9 +326,12 @@ def test_global_add_rejects_oversized_source_graph(monkeypatch, tmp_path): global_dir = tmp_path / ".graphify" monkeypatch.setattr("graphify.security._MAX_GRAPH_FILE_BYTES", 8) - with patch("graphify.global_graph._GLOBAL_DIR", global_dir), \ - patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), \ - patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"): + with ( + patch("graphify.global_graph._GLOBAL_DIR", global_dir), + patch("graphify.global_graph._GLOBAL_GRAPH", global_dir / "global-graph.json"), + patch("graphify.global_graph._GLOBAL_MANIFEST", global_dir / "global-manifest.json"), + ): from graphify.global_graph import global_add + with pytest.raises(ValueError, match="exceeds"): global_add(src_graph, "repoA") diff --git a/tests/test_google_workspace.py b/tests/test_google_workspace.py index 9d8cbfa4b..c23913304 100644 --- a/tests/test_google_workspace.py +++ b/tests/test_google_workspace.py @@ -1,4 +1,3 @@ -from pathlib import Path import json import graphify.google_workspace as gw diff --git a/tests/test_hooks.py b/tests/test_hooks.py index 873b2028c..7f43b2ab1 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -1,4 +1,5 @@ """Tests for hooks.py - git hook install/uninstall.""" + import os import subprocess from types import SimpleNamespace @@ -120,7 +121,6 @@ def test_status_shows_both_hooks(tmp_path): assert result.count("installed") >= 2 - def test_hooks_dir_resolves_relative_git_hooks_path(tmp_path, monkeypatch): repo = _make_git_repo(tmp_path) @@ -155,15 +155,18 @@ def fake_run(*args, **kwargs): assert _hooks_dir(repo) == hooks.resolve() + def test_hook_skips_head_on_exe(): """Hook script must skip shebang extraction for .exe binaries (Windows).""" from graphify.hooks import _PYTHON_DETECT - assert "*.exe) _SHEBANG=" in _PYTHON_DETECT or '*.exe)' in _PYTHON_DETECT + + assert "*.exe) _SHEBANG=" in _PYTHON_DETECT or "*.exe)" in _PYTHON_DETECT def test_hook_check_no_additionalContext(tmp_path): """graphify hook-check must not emit additionalContext — Codex Desktop rejects it.""" import sys + out = tmp_path / "graphify-out" out.mkdir() (out / "graph.json").write_text("{}", encoding="utf-8") diff --git a/tests/test_hypergraph.py b/tests/test_hypergraph.py index dda8ac793..f82d36816 100644 --- a/tests/test_hypergraph.py +++ b/tests/test_hypergraph.py @@ -1,11 +1,11 @@ """Tests for hyperedge support in graphify.""" + from __future__ import annotations import json import tempfile from pathlib import Path import networkx as nx -import pytest from graphify.build import build_from_json from graphify.export import attach_hyperedges, to_json @@ -22,10 +22,22 @@ {"id": "DigestAuth", "label": "DigestAuth", "file_type": "code", "source_file": "auth.py"}, {"id": "Request", "label": "Request", "file_type": "code", "source_file": "http.py"}, {"id": "Response", "label": "Response", "file_type": "code", "source_file": "http.py"}, - {"id": "BaseClient", "label": "BaseClient", "file_type": "code", "source_file": "client.py"}, + { + "id": "BaseClient", + "label": "BaseClient", + "file_type": "code", + "source_file": "client.py", + }, ], "edges": [ - {"source": "BasicAuth", "target": "Request", "relation": "uses", "confidence": "EXTRACTED", "confidence_score": 1.0, "source_file": "auth.py"}, + { + "source": "BasicAuth", + "target": "Request", + "relation": "uses", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": "auth.py", + }, ], "hyperedges": [ { @@ -55,6 +67,7 @@ # 1. Hyperedges survive build_from_json round-trip # --------------------------------------------------------------------------- + def test_build_from_json_stores_hyperedges(): G = build_from_json(SAMPLE_EXTRACTION) assert "hyperedges" in G.graph @@ -78,6 +91,7 @@ def test_build_from_json_missing_hyperedges_key(): # 2. attach_hyperedges deduplicates by id # --------------------------------------------------------------------------- + def test_attach_hyperedges_adds_new(): G = nx.Graph() attach_hyperedges(G, [{"id": "auth_flow", "label": "Auth Flow", "nodes": ["A", "B", "C"]}]) @@ -94,10 +108,13 @@ def test_attach_hyperedges_deduplicates(): def test_attach_hyperedges_multiple_different_ids(): G = nx.Graph() - attach_hyperedges(G, [ - {"id": "flow_a", "label": "Flow A", "nodes": ["A", "B", "C"]}, - {"id": "flow_b", "label": "Flow B", "nodes": ["D", "E", "F"]}, - ]) + attach_hyperedges( + G, + [ + {"id": "flow_a", "label": "Flow A", "nodes": ["A", "B", "C"]}, + {"id": "flow_b", "label": "Flow B", "nodes": ["D", "E", "F"]}, + ], + ) assert len(G.graph["hyperedges"]) == 2 @@ -111,6 +128,7 @@ def test_attach_hyperedges_skips_entry_without_id(): # 3. to_json includes hyperedges key # --------------------------------------------------------------------------- + def test_to_json_includes_hyperedges(): G = build_from_json(SAMPLE_EXTRACTION) communities = {0: list(G.nodes())} @@ -139,6 +157,7 @@ def test_to_json_hyperedges_empty_when_none(): # 4. Hyperedges loaded from graph.json via build_from_json # --------------------------------------------------------------------------- + def test_hyperedges_roundtrip_via_json_file(): """Write graph.json then reload it - hyperedges must survive.""" G = build_from_json(SAMPLE_EXTRACTION) @@ -149,11 +168,22 @@ def test_hyperedges_roundtrip_via_json_file(): # Reload the JSON as if build_from_json were called on it data = json.loads(Path(path).read_text()) - G2 = build_from_json({ - "nodes": [{"id": n["id"], **{k: v for k, v in n.items() if k != "id"}} for n in data["nodes"]], - "edges": [{"source": e["source"], "target": e["target"], **{k: v for k, v in e.items() if k not in ("source", "target")}} for e in data.get("links", [])], - "hyperedges": data.get("hyperedges", []), - }) + G2 = build_from_json( + { + "nodes": [ + {"id": n["id"], **{k: v for k, v in n.items() if k != "id"}} for n in data["nodes"] + ], + "edges": [ + { + "source": e["source"], + "target": e["target"], + **{k: v for k, v in e.items() if k not in ("source", "target")}, + } + for e in data.get("links", []) + ], + "hyperedges": data.get("hyperedges", []), + } + ) assert G2.graph.get("hyperedges", []) != [] assert G2.graph["hyperedges"][0]["id"] == "auth_flow" @@ -162,13 +192,24 @@ def test_hyperedges_roundtrip_via_json_file(): # 5. Report includes hyperedges section when hyperedges present # --------------------------------------------------------------------------- + def _make_report(G): communities = {0: list(G.nodes())} cohesion = {0: 1.0} labels = {0: "All"} gods = [{"label": "BasicAuth", "degree": 2}] surprises = [] - return generate(G, communities, cohesion, labels, gods, surprises, SAMPLE_DETECTION, {"input": 10, "output": 5}, ".") + return generate( + G, + communities, + cohesion, + labels, + gods, + surprises, + SAMPLE_DETECTION, + {"input": 10, "output": 5}, + ".", + ) def test_report_includes_hyperedges_section(): @@ -191,6 +232,7 @@ def test_report_includes_hyperedge_node_list(): # 6. Report skips hyperedges section when none present # --------------------------------------------------------------------------- + def test_report_skips_hyperedges_section_when_empty(): extraction = {**SAMPLE_EXTRACTION, "hyperedges": []} G = build_from_json(extraction) diff --git a/tests/test_import_extension_resolution.py b/tests/test_import_extension_resolution.py index 0d1222c0a..930c2fcaa 100644 --- a/tests/test_import_extension_resolution.py +++ b/tests/test_import_extension_resolution.py @@ -25,8 +25,11 @@ def _write(path: Path, body: str) -> Path: def _import_targets(result: dict) -> set[str]: - return {str(e.get("target") or "") for e in result["edges"] - if e.get("relation") in ("imports", "imports_from")} + return { + str(e.get("target") or "") + for e in result["edges"] + if e.get("relation") in ("imports", "imports_from") + } # ── _resolve_js_module_path unit tests ────────────────────────────────────── @@ -94,13 +97,10 @@ def test_resolve_svelte_to_svelte_ts_for_rune_files(tmp_path): """Svelte 5: `from './foo.svelte'` may actually point at `foo.svelte.ts` (a rune-only TypeScript file with no .svelte file). The resolver must APPEND .ts to the full filename, not swap suffixes.""" - target = _write(tmp_path / "is-mobile.svelte.ts", - "export const isMobile = () => true") + target = _write(tmp_path / "is-mobile.svelte.ts", "export const isMobile = () => true") written_as = tmp_path / "is-mobile.svelte" resolved = _resolve_js_module_path(written_as) - assert resolved == target, ( - f"Expected resolution to is-mobile.svelte.ts; got {resolved}" - ) + assert resolved == target, f"Expected resolution to is-mobile.svelte.ts; got {resolved}" def test_resolve_svelte_to_svelte_js_for_javascript_rune_files(tmp_path): @@ -111,8 +111,7 @@ def test_resolve_svelte_to_svelte_js_for_javascript_rune_files(tmp_path): Same code path as the .svelte.ts case — the generalized resolver tries every extension in priority order, so JS-only and TS-only projects both work without special-casing.""" - target = _write(tmp_path / "store.svelte.js", - "export const count = $state(0)") + target = _write(tmp_path / "store.svelte.js", "export const count = $state(0)") written_as = tmp_path / "store.svelte" resolved = _resolve_js_module_path(written_as) assert resolved == target @@ -128,10 +127,8 @@ def test_resolve_svelte_prefers_svelte_ts_over_svelte_js(tmp_path): expect tooling to read the `.svelte.ts` source. graphify is a source- code tool, not a runtime resolver, so source-first ordering is correct for our use case.""" - ts_target = _write(tmp_path / "store.svelte.ts", - "export const count = $state(0)") - _write(tmp_path / "store.svelte.js", - "export const count = 0 // build artifact") + ts_target = _write(tmp_path / "store.svelte.ts", "export const count = $state(0)") + _write(tmp_path / "store.svelte.js", "export const count = 0 // build artifact") written_as = tmp_path / "store.svelte" resolved = _resolve_js_module_path(written_as) assert resolved == ts_target @@ -142,8 +139,10 @@ def test_resolve_real_svelte_file_wins_over_svelte_ts_sibling(tmp_path): must resolve to that — not get hijacked to a sibling `foo.svelte.ts` rune file. The existence-check short-circuits before any append.""" real = _write(tmp_path / "Card.svelte", "
card markup
") - _write(tmp_path / "Card.svelte.ts", - "export const helpers = {} // rune sibling, not the import target") + _write( + tmp_path / "Card.svelte.ts", + "export const helpers = {} // rune sibling, not the import target", + ) resolved = _resolve_js_module_path(real) assert resolved == real @@ -180,10 +179,8 @@ def test_resolve_real_js_stays_js_when_ts_does_not_exist(tmp_path): def test_bare_path_import_resolves_in_ts_file(tmp_path): """The #716 reproducer: TS file imports a sibling without an extension.""" - target = _write(tmp_path / "type-helpers.ts", - "export type GetNestedType = T") - importer = _write(tmp_path / "page.ts", - "import type { GetNestedType } from './type-helpers'\n") + target = _write(tmp_path / "type-helpers.ts", "export type GetNestedType = T") + importer = _write(tmp_path / "page.ts", "import type { GetNestedType } from './type-helpers'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -194,10 +191,8 @@ def test_bare_path_import_resolves_in_ts_file(tmp_path): def test_directory_import_resolves_to_index_ts(tmp_path): """`from './queue'` must resolve to `./queue/index.ts`.""" - target = _write(tmp_path / "queue" / "index.ts", - "export const enqueue = () => {}") - importer = _write(tmp_path / "page.ts", - "import { enqueue } from './queue'\n") + target = _write(tmp_path / "queue" / "index.ts", "export const enqueue = () => {}") + importer = _write(tmp_path / "page.ts", "import { enqueue } from './queue'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -211,10 +206,8 @@ def test_directory_import_resolves_to_index_ts(tmp_path): def test_dot_svelte_import_resolves_to_dot_svelte_ts(tmp_path): """Svelte 5 rune file: import written as .svelte, real file is .svelte.ts.""" - target = _write(tmp_path / "is-mobile.svelte.ts", - "export const isMobile = () => true") - importer = _write(tmp_path / "page.ts", - "import { isMobile } from './is-mobile.svelte'\n") + target = _write(tmp_path / "is-mobile.svelte.ts", "export const isMobile = () => true") + importer = _write(tmp_path / "page.ts", "import { isMobile } from './is-mobile.svelte'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -230,8 +223,7 @@ def test_explicit_ts_import_still_works(tmp_path): """The most common case — import with explicit .ts extension — must continue to work after the resolver change.""" target = _write(tmp_path / "foo.ts", "export const x = 1") - importer = _write(tmp_path / "page.ts", - "import { x } from './foo.ts'\n") + importer = _write(tmp_path / "page.ts", "import { x } from './foo.ts'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -244,8 +236,7 @@ def test_explicit_svelte_import_still_works(tmp_path): """Real .svelte file imports must still resolve when the .svelte file exists (i.e. don't accidentally redirect to a non-existent .svelte.ts).""" target = _write(tmp_path / "Card.svelte", "
") - importer = _write(tmp_path / "page.ts", - "import Card from './Card.svelte'\n") + importer = _write(tmp_path / "page.ts", "import Card from './Card.svelte'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -259,14 +250,12 @@ def test_external_module_unchanged(tmp_path): """Bare module specifiers (no leading dot, no alias match) must still fall through to the external/last-segment path — don't accidentally treat 'lodash' as a relative path.""" - importer = _write(tmp_path / "page.ts", - "import _ from 'lodash-es'\n") + importer = _write(tmp_path / "page.ts", "import _ from 'lodash-es'\n") result = extract_js(importer) targets = _import_targets(result) # The target should be the bare module name, not a resolved file path assert "lodash_es" in targets or any("lodash" in t for t in targets), ( - f"External module specifier should still produce an external " - f"reference; got {targets}" + f"External module specifier should still produce an external reference; got {targets}" ) @@ -276,19 +265,17 @@ def test_external_module_unchanged(tmp_path): def test_alias_import_with_bare_path_resolves(tmp_path): """`$lib/foo` (alias + bare path) — both layers must work together.""" src = tmp_path / "src" - target = _write(src / "lib" / "type-helpers.ts", - "export type X = string") - _write(tmp_path / "tsconfig.json", - '{"compilerOptions":{"paths":{"$lib":["./src/lib"],' - '"$lib/*":["./src/lib/*"]}}}') + target = _write(src / "lib" / "type-helpers.ts", "export type X = string") + _write( + tmp_path / "tsconfig.json", + '{"compilerOptions":{"paths":{"$lib":["./src/lib"],"$lib/*":["./src/lib/*"]}}}', + ) importer_dir = src / "routes" - importer = _write(importer_dir / "page.ts", - "import type { X } from '$lib/type-helpers'\n") + importer = _write(importer_dir / "page.ts", "import type { X } from '$lib/type-helpers'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( - f"Alias + bare-path resolution failed; " - f"expected {expected}; got {_import_targets(result)}" + f"Alias + bare-path resolution failed; expected {expected}; got {_import_targets(result)}" ) @@ -299,10 +286,8 @@ def test_type_only_import_with_bare_path_resolves(tmp_path): """`import type { X } from './foo'` — type-only imports must go through the same resolution path as regular imports. Common in TS codebases that separate types into their own module.""" - target = _write(tmp_path / "type-helpers.ts", - "export type GetNestedType = T") - importer = _write(tmp_path / "page.ts", - "import type { GetNestedType } from './type-helpers'\n") + target = _write(tmp_path / "type-helpers.ts", "export type GetNestedType = T") + importer = _write(tmp_path / "page.ts", "import type { GetNestedType } from './type-helpers'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -317,8 +302,7 @@ def test_named_imports_emit_symbol_edges_after_resolution(tmp_path): `imports_from`. The symbol-edge target_stem comes from _file_stem(resolved), which depends on resolution succeeding first.""" _write(tmp_path / "utils.ts", "export const foo = 1\nexport const bar = 2") - importer = _write(tmp_path / "page.ts", - "import { foo, bar } from './utils'\n") + importer = _write(tmp_path / "page.ts", "import { foo, bar } from './utils'\n") result = extract_js(importer) sym_edges = [e for e in result["edges"] if e.get("relation") == "imports"] targets = {str(e.get("target") or "") for e in sym_edges} @@ -334,18 +318,16 @@ def test_named_imports_emit_symbol_edges_after_resolution(tmp_path): def test_alias_directory_import_resolves_to_index_ts(tmp_path): """`from '$lib/queue'` where queue/ is a directory under src/lib/.""" src = tmp_path / "src" - target = _write(src / "lib" / "queue" / "index.ts", - "export const enqueue = () => {}") - _write(tmp_path / "tsconfig.json", - '{"compilerOptions":{"paths":{"$lib":["./src/lib"],' - '"$lib/*":["./src/lib/*"]}}}') - importer = _write(src / "routes" / "page.ts", - "import { enqueue } from '$lib/queue'\n") + target = _write(src / "lib" / "queue" / "index.ts", "export const enqueue = () => {}") + _write( + tmp_path / "tsconfig.json", + '{"compilerOptions":{"paths":{"$lib":["./src/lib"],"$lib/*":["./src/lib/*"]}}}', + ) + importer = _write(src / "routes" / "page.ts", "import { enqueue } from '$lib/queue'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( - f"Alias + directory resolution failed; " - f"expected {expected}; got {_import_targets(result)}" + f"Alias + directory resolution failed; expected {expected}; got {_import_targets(result)}" ) @@ -358,9 +340,7 @@ def test_resolve_does_not_match_partial_directory_name(tmp_path): bare = tmp_path / "foo" resolved = _resolve_js_module_path(bare) # Not a real file → nothing matches → returns input unchanged - assert resolved == bare, ( - f"Partial-name match must not happen; got {resolved}" - ) + assert resolved == bare, f"Partial-name match must not happen; got {resolved}" def test_resolve_directory_without_index_returns_unchanged(tmp_path): @@ -369,16 +349,13 @@ def test_resolve_directory_without_index_returns_unchanged(tmp_path): pkg = tmp_path / "pkg" _write(pkg / "not-index.ts", "export const x = 1") resolved = _resolve_js_module_path(pkg) - assert resolved == pkg, ( - f"Directory without index.* must return unchanged; got {resolved}" - ) + assert resolved == pkg, f"Directory without index.* must return unchanged; got {resolved}" def test_resolve_handles_subpath_into_directory_with_index(tmp_path): """`./foo/sub` where ./foo/sub/index.ts exists — nested subpath. Common pattern for sub-modules inside a package.""" - target = _write(tmp_path / "foo" / "sub" / "index.ts", - "export const x = 1") + target = _write(tmp_path / "foo" / "sub" / "index.ts", "export const x = 1") sub = tmp_path / "foo" / "sub" assert _resolve_js_module_path(sub) == target @@ -387,8 +364,7 @@ def test_resolve_does_not_treat_dotfile_as_extension(tmp_path): """Edge case: `.eslintrc` and similar dotfiles. Path('.eslintrc').suffix returns '' on Python 3.x for files starting with `.`. Make sure we don't accidentally treat a real file as bare and try to append .ts.""" - target = _write(tmp_path / ".env-types.ts", - "export const x = 1") + target = _write(tmp_path / ".env-types.ts", "export const x = 1") # Path('.env-types.ts').suffix is '.ts' — not a problem assert _resolve_js_module_path(target) == target @@ -401,8 +377,7 @@ def test_resolve_multi_dot_helper_file(tmp_path): Before this rule, .suffix was '.shared' so neither the bare-path branch nor the .js/.jsx branches matched, and the import dropped to a phantom.""" - target = _write(tmp_path / "tag-action.shared.ts", - "export const apply = () => {}") + target = _write(tmp_path / "tag-action.shared.ts", "export const apply = () => {}") written_as = tmp_path / "tag-action.shared" assert _resolve_js_module_path(written_as) == target @@ -423,15 +398,12 @@ def test_resolve_ambient_d_ts_via_bare_path(tmp_path): def test_end_to_end_multi_dot_import_resolves(tmp_path): """End-to-end sanity for the multi-dot pattern via the import handler.""" - target = _write(tmp_path / "tag-action.shared.ts", - "export const apply = () => {}") - importer = _write(tmp_path / "page.ts", - "import { apply } from './tag-action.shared'\n") + target = _write(tmp_path / "tag-action.shared.ts", "export const apply = () => {}") + importer = _write(tmp_path / "page.ts", "import { apply } from './tag-action.shared'\n") result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( - f"Multi-dot import failed end-to-end; " - f"expected {expected}; got {_import_targets(result)}" + f"Multi-dot import failed end-to-end; expected {expected}; got {_import_targets(result)}" ) @@ -440,13 +412,16 @@ def test_resolve_chain_alias_and_extension_compose(tmp_path): compose correctly: tsconfig alias maps `$lib/...` to a real dir, then extension resolution finds the actual file.""" src = tmp_path / "src" - target = _write(src / "lib" / "hooks" / "is-mobile.svelte.ts", - "export const isMobile = () => true") - _write(tmp_path / "tsconfig.json", - '{"compilerOptions":{"paths":{"$lib":["./src/lib"],' - '"$lib/*":["./src/lib/*"]}}}') - importer = _write(src / "routes" / "page.ts", - "import { isMobile } from '$lib/hooks/is-mobile.svelte'\n") + target = _write( + src / "lib" / "hooks" / "is-mobile.svelte.ts", "export const isMobile = () => true" + ) + _write( + tmp_path / "tsconfig.json", + '{"compilerOptions":{"paths":{"$lib":["./src/lib"],"$lib/*":["./src/lib/*"]}}}', + ) + importer = _write( + src / "routes" / "page.ts", "import { isMobile } from '$lib/hooks/is-mobile.svelte'\n" + ) result = extract_js(importer) expected = _make_id(str(target)) assert expected in _import_targets(result), ( @@ -464,21 +439,25 @@ def test_ts_dynamic_import_bare_path_resolves(tmp_path): has its own copy of the resolution logic — distinct from the static-import handler and from the Svelte regex pass — and was missing the bare-path extension append, silently dropping every such edge.""" - target = _write(tmp_path / "profanity.ts", - "export const hasProfanity = (s: string) => false") - importer = _write(tmp_path / "auth-validators.ts", """\ + target = _write(tmp_path / "profanity.ts", "export const hasProfanity = (s: string) => false") + importer = _write( + tmp_path / "auth-validators.ts", + """\ export async function validate(name: string) { const { hasProfanity } = await import('./profanity') return hasProfanity(name) } -""") +""", + ) result = extract_js(importer) expected = _make_id(str(target)) - targets = {str(e.get("target") or "") for e in result["edges"] - if e.get("relation") in ("imports", "imports_from")} + targets = { + str(e.get("target") or "") + for e in result["edges"] + if e.get("relation") in ("imports", "imports_from") + } assert expected in targets, ( - f"Bare-path TS dynamic import failed to resolve; " - f"expected {expected}; got {targets}" + f"Bare-path TS dynamic import failed to resolve; expected {expected}; got {targets}" ) @@ -488,22 +467,28 @@ def test_ts_dynamic_import_alias_with_bare_path_resolves(tmp_path): `$lib/foo.ts` after both alias substitution and extension append.""" src = tmp_path / "src" target = _write(src / "lib" / "lazy-module.ts", "export const x = 1") - _write(tmp_path / "tsconfig.json", - '{"compilerOptions":{"paths":{"$lib":["./src/lib"],' - '"$lib/*":["./src/lib/*"]}}}') - importer = _write(src / "routes" / "page.ts", """\ + _write( + tmp_path / "tsconfig.json", + '{"compilerOptions":{"paths":{"$lib":["./src/lib"],"$lib/*":["./src/lib/*"]}}}', + ) + importer = _write( + src / "routes" / "page.ts", + """\ export async function load() { const m = await import('$lib/lazy-module') return m.x } -""") +""", + ) result = extract_js(importer) expected = _make_id(str(target)) - targets = {str(e.get("target") or "") for e in result["edges"] - if e.get("relation") in ("imports", "imports_from")} + targets = { + str(e.get("target") or "") + for e in result["edges"] + if e.get("relation") in ("imports", "imports_from") + } assert expected in targets, ( - f"Alias + bare-path dynamic import failed to resolve; " - f"expected {expected}; got {targets}" + f"Alias + bare-path dynamic import failed to resolve; expected {expected}; got {targets}" ) @@ -511,16 +496,19 @@ def test_dynamic_import_bare_path_resolves(tmp_path): """The regex pass for `import('...')` in .svelte files must also use the new resolver — otherwise dynamic imports of bare paths still produce phantom edges.""" - target = _write(tmp_path / "Heavy.svelte.ts", - "export const heavy = () => 1") - importer = _write(tmp_path / "page.svelte", """\ + target = _write(tmp_path / "Heavy.svelte.ts", "export const heavy = () => 1") + importer = _write( + tmp_path / "page.svelte", + """\ -""") +""", + ) result = extract_svelte(importer) - dyn_targets = {str(e.get("target") or "") for e in result["edges"] - if e.get("relation") == "dynamic_import"} + dyn_targets = { + str(e.get("target") or "") for e in result["edges"] if e.get("relation") == "dynamic_import" + } expected = _make_id(str(target)) assert expected in dyn_targets, ( f"dynamic_import of .svelte that's actually .svelte.ts must " diff --git a/tests/test_incremental.py b/tests/test_incremental.py index 2e2e902e0..47006c695 100644 --- a/tests/test_incremental.py +++ b/tests/test_incremental.py @@ -1,11 +1,11 @@ """Integration tests for incremental graphify extract behavior.""" + from __future__ import annotations import json import subprocess import sys from pathlib import Path -import pytest PYTHON = sys.executable diff --git a/tests/test_ingest.py b/tests/test_ingest.py index 41128eee2..2d6774258 100644 --- a/tests/test_ingest.py +++ b/tests/test_ingest.py @@ -1,8 +1,6 @@ """Tests for graphify.ingest.save_query_result""" + from __future__ import annotations -import re -from pathlib import Path -import pytest from graphify.ingest import save_query_result @@ -49,7 +47,7 @@ def test_source_nodes_capped_at_10(tmp_path): out = save_query_result("q", "a", mem, source_nodes=nodes) content = out.read_text() # Only first 10 should appear in frontmatter source_nodes line - fm_line = [l for l in content.splitlines() if l.startswith("source_nodes:")][0] + fm_line = [label for label in content.splitlines() if label.startswith("source_nodes:")][0] assert fm_line.count('"Node') == 10 diff --git a/tests/test_install.py b/tests/test_install.py index 5b464e8d9..23a4309e5 100644 --- a/tests/test_install.py +++ b/tests/test_install.py @@ -1,4 +1,5 @@ """Tests for graphify install --platform routing.""" + import os from pathlib import Path import sys @@ -20,6 +21,7 @@ def _install(tmp_path, platform): from graphify.__main__ import install + old_cwd = Path.cwd() try: os.chdir(tmp_path) @@ -46,6 +48,7 @@ def test_install_opencode(tmp_path): def test_install_positional_platform_opencode(tmp_path, monkeypatch): from graphify.__main__ import main + monkeypatch.chdir(tmp_path) monkeypatch.setattr(sys, "argv", ["graphify", "install", "opencode"]) with patch("graphify.__main__.Path.home", return_value=tmp_path): @@ -56,6 +59,7 @@ def test_install_positional_platform_opencode(tmp_path, monkeypatch): def test_install_project_claude_writes_project_scope(tmp_path, monkeypatch, capsys): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -67,12 +71,15 @@ def test_install_project_claude_writes_project_scope(tmp_path, monkeypatch, caps assert (project / ".claude" / "CLAUDE.md").exists() assert not (home / ".claude" / "skills" / "graphify" / "SKILL.md").exists() assert ".claude/skills/graphify/SKILL.md" in (project / ".claude" / "CLAUDE.md").read_text() - assert "~/.claude/skills/graphify/SKILL.md" not in (project / ".claude" / "CLAUDE.md").read_text() + assert ( + "~/.claude/skills/graphify/SKILL.md" not in (project / ".claude" / "CLAUDE.md").read_text() + ) assert "git add .claude/" in capsys.readouterr().out def test_install_project_codex_writes_skill_and_agents(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -88,6 +95,7 @@ def test_install_project_codex_writes_skill_and_agents(tmp_path, monkeypatch): def test_claude_subcommand_project_install_and_uninstall_are_project_scoped(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -114,6 +122,7 @@ def test_claude_subcommand_project_install_and_uninstall_are_project_scoped(tmp_ def test_codex_subcommand_project_install_and_uninstall_are_project_scoped(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -142,6 +151,7 @@ def test_codex_subcommand_project_install_and_uninstall_are_project_scoped(tmp_p def test_antigravity_install_project_writes_project_skill(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -155,6 +165,7 @@ def test_antigravity_install_project_writes_project_skill(tmp_path, monkeypatch) def test_install_help_does_not_install_default(tmp_path, monkeypatch, capsys): from graphify.__main__ import main + monkeypatch.chdir(tmp_path) monkeypatch.setattr(sys, "argv", ["graphify", "install", "opencode", "--help"]) with patch("graphify.__main__.Path.home", return_value=tmp_path): @@ -199,6 +210,7 @@ def test_install_unknown_platform_exits(tmp_path): def test_codex_skill_contains_spawn_agent(): """Codex skill file must reference spawn_agent.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-codex.md").read_text() assert "spawn_agent" in skill @@ -206,6 +218,7 @@ def test_codex_skill_contains_spawn_agent(): def test_codex_skill_uses_graphify_with_dirty_graph_output(): """Codex skill must keep graph-first orientation even when graph output is dirty.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-codex.md").read_text() assert "Dirty `graphify-out/` artifacts are expected" in skill assert "not a reason to skip Graphify" in skill @@ -224,6 +237,7 @@ def test_codex_agents_install_mentions_dirty_graph_output(tmp_path): def test_opencode_skill_contains_mention(): """OpenCode skill file must reference @mention.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-opencode.md").read_text() assert "@mention" in skill @@ -231,6 +245,7 @@ def test_opencode_skill_contains_mention(): def test_opencode_skill_uses_opencode_agent_guidance(): """OpenCode skill must not reference Codex/Claude agent type names.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-opencode.md").read_text() assert "general-purpose" not in skill assert 'subagent_type="general-purpose"' not in skill @@ -245,6 +260,7 @@ def test_opencode_skill_uses_opencode_agent_guidance(): def test_claw_skill_is_sequential(): """OpenClaw skill file must describe sequential extraction.""" import graphify + skill = (Path(graphify.__file__).parent / "skill-claw.md").read_text() assert "sequential" in skill.lower() assert "spawn_agent" not in skill @@ -254,8 +270,17 @@ def test_claw_skill_is_sequential(): def test_all_skill_files_exist_in_package(): """All installable platform skill files must be present in the installed package.""" import graphify + pkg = Path(graphify.__file__).parent - for name in ("skill.md", "skill-codex.md", "skill-opencode.md", "skill-claw.md", "skill-windows.md", "skill-droid.md", "skill-trae.md"): + for name in ( + "skill.md", + "skill-codex.md", + "skill-opencode.md", + "skill-claw.md", + "skill-windows.md", + "skill-droid.md", + "skill-trae.md", + ): assert (pkg / name).exists(), f"Missing: {name}" @@ -272,6 +297,7 @@ def test_codex_install_does_not_write_claude_md(tmp_path): def test_uninstall_project_removes_project_skill_only(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -280,9 +306,13 @@ def test_uninstall_project_removes_project_skill_only(tmp_path, monkeypatch): user_skill.write_text("user skill") monkeypatch.chdir(project) with patch("graphify.__main__.Path.home", return_value=home): - monkeypatch.setattr(sys, "argv", ["graphify", "install", "--project", "--platform", "codex"]) + monkeypatch.setattr( + sys, "argv", ["graphify", "install", "--project", "--platform", "codex"] + ) main() - monkeypatch.setattr(sys, "argv", ["graphify", "uninstall", "--project", "--platform", "codex"]) + monkeypatch.setattr( + sys, "argv", ["graphify", "uninstall", "--project", "--platform", "codex"] + ) main() assert user_skill.exists() assert not (project / ".agents" / "skills" / "graphify" / "SKILL.md").exists() @@ -291,6 +321,7 @@ def test_uninstall_project_removes_project_skill_only(tmp_path, monkeypatch): def test_uninstall_project_without_platform_removes_project_installs(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -310,6 +341,7 @@ def test_uninstall_project_without_platform_removes_project_installs(tmp_path, m def test_antigravity_uninstall_project_removes_project_skill_only(tmp_path, monkeypatch): from graphify.__main__ import main + home = tmp_path / "home" project = tmp_path / "project" project.mkdir() @@ -368,13 +400,16 @@ def test_antigravity_global_uninstall_removes_gemini_config_skill(tmp_path, monk # --- always-on AGENTS.md install/uninstall tests --- + def _agents_install(tmp_path, platform): from graphify.__main__ import _agents_install as _install_fn + _install_fn(tmp_path, platform) def _agents_uninstall(tmp_path, platform=""): from graphify.__main__ import _agents_uninstall as _uninstall_fn + _uninstall_fn(tmp_path, platform=platform) @@ -442,6 +477,7 @@ def test_agents_uninstall_no_op_when_not_installed(tmp_path, capsys): # --- OpenCode plugin tests --- + def test_opencode_agents_install_writes_plugin(tmp_path): """opencode install writes .opencode/plugins/graphify.js.""" _agents_install(tmp_path, "opencode") @@ -456,6 +492,7 @@ def test_opencode_agents_install_registers_plugin_in_config(tmp_path): config_file = tmp_path / ".opencode" / "opencode.json" assert config_file.exists() import json as _json + config = _json.loads(config_file.read_text()) assert any("graphify.js" in p for p in config.get("plugin", [])) @@ -463,6 +500,7 @@ def test_opencode_agents_install_registers_plugin_in_config(tmp_path): def test_opencode_agents_install_merges_existing_config(tmp_path): """opencode install preserves existing .opencode/opencode.json keys.""" import json as _json + config_file = tmp_path / ".opencode" / "opencode.json" config_file.parent.mkdir(parents=True, exist_ok=True) config_file.write_text(_json.dumps({"model": "claude-opus-4-5", "plugin": []})) @@ -475,6 +513,7 @@ def test_opencode_agents_install_merges_existing_config(tmp_path): def test_opencode_agents_uninstall_removes_plugin(tmp_path): """opencode uninstall removes the plugin file and deregisters from opencode.json.""" import json as _json + _agents_install(tmp_path, "opencode") _agents_uninstall(tmp_path, platform="opencode") plugin = tmp_path / ".opencode" / "plugins" / "graphify.js" @@ -487,9 +526,11 @@ def test_opencode_agents_uninstall_removes_plugin(tmp_path): # ── Cursor ──────────────────────────────────────────────────────────────────── + def test_cursor_install_writes_rule(tmp_path): """cursor install writes .cursor/rules/graphify.mdc.""" from graphify.__main__ import _cursor_install + _cursor_install(tmp_path) rule = tmp_path / ".cursor" / "rules" / "graphify.mdc" assert rule.exists() @@ -501,6 +542,7 @@ def test_cursor_install_writes_rule(tmp_path): def test_cursor_install_idempotent(tmp_path): """cursor install does not overwrite an existing rule file.""" from graphify.__main__ import _cursor_install + _cursor_install(tmp_path) rule = tmp_path / ".cursor" / "rules" / "graphify.mdc" original = rule.read_text() @@ -511,6 +553,7 @@ def test_cursor_install_idempotent(tmp_path): def test_cursor_uninstall_removes_rule(tmp_path): """cursor uninstall removes the rule file.""" from graphify.__main__ import _cursor_install, _cursor_uninstall + _cursor_install(tmp_path) _cursor_uninstall(tmp_path) rule = tmp_path / ".cursor" / "rules" / "graphify.mdc" @@ -520,51 +563,64 @@ def test_cursor_uninstall_removes_rule(tmp_path): def test_cursor_uninstall_noop_if_not_installed(tmp_path): """cursor uninstall does nothing if rule was never written.""" from graphify.__main__ import _cursor_uninstall + _cursor_uninstall(tmp_path) # should not raise # ── Gemini CLI ──────────────────────────────────────────────────────────────── + def test_gemini_install_writes_gemini_md(tmp_path): from graphify.__main__ import gemini_install + gemini_install(tmp_path) md = tmp_path / "GEMINI.md" assert md.exists() assert "graphify-out/GRAPH_REPORT.md" in md.read_text() + def test_gemini_install_writes_hook(tmp_path): import json as _json from graphify.__main__ import gemini_install + gemini_install(tmp_path) settings = _json.loads((tmp_path / ".gemini" / "settings.json").read_text()) hooks = settings["hooks"]["BeforeTool"] assert any("graphify" in str(h) for h in hooks) + def test_gemini_install_idempotent(tmp_path): from graphify.__main__ import gemini_install + gemini_install(tmp_path) gemini_install(tmp_path) md = tmp_path / "GEMINI.md" assert md.read_text().count("## graphify") == 1 + def test_gemini_install_merges_existing_gemini_md(tmp_path): from graphify.__main__ import gemini_install + (tmp_path / "GEMINI.md").write_text("# My project rules\n") gemini_install(tmp_path) content = (tmp_path / "GEMINI.md").read_text() assert "# My project rules" in content assert "graphify-out/GRAPH_REPORT.md" in content + def test_gemini_uninstall_removes_section(tmp_path): from graphify.__main__ import gemini_install, gemini_uninstall + gemini_install(tmp_path) gemini_uninstall(tmp_path) md = tmp_path / "GEMINI.md" assert not md.exists() + def test_gemini_uninstall_removes_hook(tmp_path): import json as _json from graphify.__main__ import gemini_install, gemini_uninstall + gemini_install(tmp_path) gemini_uninstall(tmp_path) settings_path = tmp_path / ".gemini" / "settings.json" @@ -573,6 +629,8 @@ def test_gemini_uninstall_removes_hook(tmp_path): hooks = settings.get("hooks", {}).get("BeforeTool", []) assert not any("graphify" in str(h) for h in hooks) + def test_gemini_uninstall_noop_if_not_installed(tmp_path): from graphify.__main__ import gemini_uninstall + gemini_uninstall(tmp_path) # should not raise diff --git a/tests/test_install_strings.py b/tests/test_install_strings.py index 5e0037089..93cb2a7d8 100644 --- a/tests/test_install_strings.py +++ b/tests/test_install_strings.py @@ -8,6 +8,7 @@ (issue #580). This file locks in the query-first policy so a future revert or partial change is caught by CI. """ + from __future__ import annotations import json @@ -73,6 +74,7 @@ def test_no_install_surface_demands_reading_the_full_report_first(): are legitimate platform metadata, not the bug. """ import re + banned = [ # "read ... GRAPH_REPORT.md ... before" re.compile(r"read[^.\n]{0,80}GRAPH_REPORT\.md[^.\n]{0,80}before", re.IGNORECASE), @@ -87,10 +89,7 @@ def test_no_install_surface_demands_reading_the_full_report_first(): m = pattern.search(text) if m: hits.append((name, m.group(0))) - assert not hits, ( - f"banned report-first phrasing reappeared: {hits}. " - f"This regresses issue #580." - ) + assert not hits, f"banned report-first phrasing reappeared: {hits}. This regresses issue #580." def test_report_is_still_referenced_as_fallback(): @@ -127,6 +126,7 @@ def test_agents_section_does_not_skip_dirty_graph_output(): def test_how_it_works_clarifies_code_only_semantic_extraction(): from pathlib import Path + doc = (Path(__file__).parent.parent / "docs" / "how-it-works.md").read_text(encoding="utf-8") assert "Code files are not sent to the LLM semantic extractor" in doc assert "code files, Pass 3 is skipped entirely" in doc diff --git a/tests/test_install_upgrade.py b/tests/test_install_upgrade.py index 09ee3d81e..fa085dbdc 100644 --- a/tests/test_install_upgrade.py +++ b/tests/test_install_upgrade.py @@ -9,6 +9,7 @@ section, run the installer, and assert that the on-disk file now contains the new query-first wording and does not contain the old report-first text. """ + from __future__ import annotations import json from pathlib import Path @@ -85,9 +86,7 @@ def _assert_no_report_first(text: str, ctx: str) -> None: def _assert_query_first(text: str, ctx: str) -> None: - assert "graphify query" in text, ( - f"{ctx}: new 'graphify query' guidance missing after upgrade" - ) + assert "graphify query" in text, f"{ctx}: new 'graphify query' guidance missing after upgrade" def test_claude_install_upgrades_stale_section(tmp_path, monkeypatch): @@ -95,7 +94,9 @@ def test_claude_install_upgrades_stale_section(tmp_path, monkeypatch): `graphify claude install` again after upgrading to a fixed package.""" monkeypatch.chdir(tmp_path) claude_md = tmp_path / "CLAUDE.md" - claude_md.write_text("# My Project\n\nSome description.\n\n" + _OLD_CLAUDE_SECTION, encoding="utf-8") + claude_md.write_text( + "# My Project\n\nSome description.\n\n" + _OLD_CLAUDE_SECTION, encoding="utf-8" + ) monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) mainmod.claude_install(tmp_path) @@ -125,11 +126,7 @@ def test_claude_install_upgrades_stale_hook_payload(tmp_path, monkeypatch): "hooks": [ { "type": "command", - "command": ( - "case x in *) " - + _OLD_HOOK_PAYLOAD_SNIPPET - + " esac" - ), + "command": ("case x in *) " + _OLD_HOOK_PAYLOAD_SNIPPET + " esac"), } ], } @@ -142,9 +139,7 @@ def test_claude_install_upgrades_stale_hook_payload(tmp_path, monkeypatch): mainmod.claude_install(tmp_path) new_settings_text = settings.read_text(encoding="utf-8") - assert _OLD_HOOK_PAYLOAD_SNIPPET not in new_settings_text, ( - "stale hook payload survived upgrade" - ) + assert _OLD_HOOK_PAYLOAD_SNIPPET not in new_settings_text, "stale hook payload survived upgrade" assert "graphify query" in new_settings_text, ( "new hook payload should route to `graphify query`" ) diff --git a/tests/test_js_import_resolution.py b/tests/test_js_import_resolution.py index 414cf8d87..a1ac2ff7e 100644 --- a/tests/test_js_import_resolution.py +++ b/tests/test_js_import_resolution.py @@ -18,10 +18,7 @@ def _extract_for(paths: list[Path], root: Path): def _has_edge(result: dict, source: str, target: str, relation: str = "imports_from") -> bool: expected = (_make_id(source), _make_id(target), relation) - actual = { - (edge["source"], edge["target"], edge["relation"]) - for edge in result["edges"] - } + actual = {(edge["source"], edge["target"], edge["relation"]) for edge in result["edges"]} return expected in actual @@ -33,10 +30,7 @@ def _has_symbol_edge( relation: str = "imports", ) -> bool: expected = (_make_id(source), _make_id(_file_stem(Path(target_file)), symbol), relation) - actual = { - (edge["source"], edge["target"], edge["relation"]) - for edge in result["edges"] - } + actual = {(edge["source"], edge["target"], edge["relation"]) for edge in result["edges"]} return expected in actual @@ -53,10 +47,7 @@ def _has_symbol_to_symbol_edge( _make_id(_file_stem(Path(target_file)), target_symbol), relation, ) - actual = { - (edge["source"], edge["target"], edge["relation"]) - for edge in result["edges"] - } + actual = {(edge["source"], edge["target"], edge["relation"]) for edge in result["edges"]} return expected in actual @@ -198,7 +189,9 @@ def test_ts_reexported_type_alias_resolves_imported_symbol_to_origin(tmp_path: P def test_ts_reexported_abstract_class_resolves_imported_symbol_to_origin(tmp_path: Path): - target = _write(tmp_path / "src/lib/foo.ts", "export abstract class Foo { abstract run(): void }\n") + target = _write( + tmp_path / "src/lib/foo.ts", "export abstract class Foo { abstract run(): void }\n" + ) barrel = _write(tmp_path / "src/lib/index.ts", "export { Foo } from './foo'\n") consumer = _write( tmp_path / "src/routes/page.ts", @@ -228,7 +221,9 @@ def test_ts_const_alias_reexport_resolves_imported_symbol_to_origin(tmp_path: Pa assert _has_symbol_edge(result, "src/routes/page.ts", "src/lib/foo.ts", "Foo") -def test_ts_local_const_alias_then_named_reexport_resolves_imported_symbol_to_origin(tmp_path: Path): +def test_ts_local_const_alias_then_named_reexport_resolves_imported_symbol_to_origin( + tmp_path: Path, +): target = _write(tmp_path / "src/lib/foo.ts", "export function makeFoo() { return {} }\n") barrel = _write( tmp_path / "src/lib/index.ts", @@ -307,7 +302,9 @@ def test_ts_import_alias_call_from_same_named_local_symbol_targets_origin(tmp_pa def test_svelte_rune_import_resolves_svelte_ts_file(tmp_path: Path): - target = _write(tmp_path / "src/lib/hooks/is-mobile.svelte.ts", "export const isMobile = true\n") + target = _write( + tmp_path / "src/lib/hooks/is-mobile.svelte.ts", "export const isMobile = true\n" + ) importer = _write( tmp_path / "src/routes/page.ts", "import { isMobile } from '../lib/hooks/is-mobile.svelte'\nconsole.log(isMobile)\n", @@ -482,8 +479,12 @@ def _norm(label: str) -> str: if edge.get("relation") == "references" } - assert _has_symbol_to_symbol_edge(result, "src/lib/impl.ts", "DataProcessor", "src/lib/base.ts", "BaseProcessor", "inherits") - assert _has_symbol_to_symbol_edge(result, "src/lib/impl.ts", "DataProcessor", "src/lib/base.ts", "IProcessor", "implements") + assert _has_symbol_to_symbol_edge( + result, "src/lib/impl.ts", "DataProcessor", "src/lib/base.ts", "BaseProcessor", "inherits" + ) + assert _has_symbol_to_symbol_edge( + result, "src/lib/impl.ts", "DataProcessor", "src/lib/base.ts", "IProcessor", "implements" + ) assert ("run", "Payload", "parameter_type") in reference_contexts assert ("run", "Result", "return_type") in reference_contexts assert ("run", "Payload", "generic_arg") in reference_contexts diff --git a/tests/test_languages.py b/tests/test_languages.py index fefc13807..1c38a9845 100644 --- a/tests/test_languages.py +++ b/tests/test_languages.py @@ -1,14 +1,30 @@ """Tests for language extractors: Java, C, C++, Ruby, C#, Kotlin, Scala, PHP, Swift, Go, Julia, Fortran, JS/TS, .NET project files.""" + from __future__ import annotations from pathlib import Path -import pytest from graphify.extract import ( - extract_java, extract_c, extract_cpp, extract_ruby, - extract_csharp, extract_kotlin, extract_scala, extract_php, - extract_swift, extract_go, extract_julia, extract_js, extract_fortran, - extract_groovy, extract_sln, extract_csproj, extract_razor, - extract_dm, extract_dmi, extract_dmm, extract_dmf, + extract_java, + extract_c, + extract_cpp, + extract_csproj, + extract_dm, + extract_dmf, + extract_dmi, + extract_dmm, + extract_fortran, + extract_go, + extract_groovy, + extract_js, + extract_julia, + extract_kotlin, + extract_php, extract_powershell, + extract_razor, + extract_ruby, + extract_csharp, + extract_scala, + extract_sln, + extract_swift, ) FIXTURES = Path(__file__).parent / "fixtures" @@ -17,14 +33,17 @@ def _labels(r): return [n["label"] for n in r["nodes"]] + def _relations(r): return {e["relation"] for e in r["edges"]} + def _calls(r): node_by_id = {n["id"]: n["label"] for n in r["nodes"]} return { (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "calls" + for e in r["edges"] + if e["relation"] == "calls" } @@ -36,7 +55,8 @@ def _references(r): node_by_id.get(e["target"], e["target"]), e, ) - for e in r["edges"] if e["relation"] == "references" + for e in r["edges"] + if e["relation"] == "references" ] @@ -63,29 +83,36 @@ def _edge_labels(result: dict, relation: str, context: str | None = None) -> set continue if context is not None and edge.get("context") != context: continue - pairs.add((labels.get(edge["source"], edge["source"]), labels.get(edge["target"], edge["target"]))) + pairs.add( + (labels.get(edge["source"], edge["source"]), labels.get(edge["target"], edge["target"])) + ) return pairs # ── Java ────────────────────────────────────────────────────────────────────── + def test_java_no_error(): r = extract_java(FIXTURES / "sample.java") assert "error" not in r + def test_java_finds_class(): r = extract_java(FIXTURES / "sample.java") - assert any("DataProcessor" in l for l in _labels(r)) + assert any("DataProcessor" in label for label in _labels(r)) + def test_java_finds_interface(): r = extract_java(FIXTURES / "sample.java") - assert any("Processor" in l for l in _labels(r)) + assert any("Processor" in label for label in _labels(r)) + def test_java_finds_methods(): r = extract_java(FIXTURES / "sample.java") labels = _labels(r) - assert any("addItem" in l for l in labels) - assert any("process" in l for l in labels) + assert any("addItem" in label for label in labels) + assert any("process" in label for label in labels) + def test_java_finds_imports(): r = extract_java(FIXTURES / "sample.java") @@ -98,6 +125,7 @@ def test_java_import_edges_have_import_context(): assert import_edges assert all(e.get("context") == "import" for e in import_edges) + def test_java_no_dangling_edges(): r = extract_java(FIXTURES / "sample.java") node_ids = {n["id"] for n in r["nodes"]} @@ -107,24 +135,29 @@ def test_java_no_dangling_edges(): # ── C ──────────────────────────────────────────────────────────────────────── + def test_c_no_error(): r = extract_c(FIXTURES / "sample.c") assert "error" not in r + def test_c_finds_functions(): r = extract_c(FIXTURES / "sample.c") labels = _labels(r) - assert any("process" in l for l in labels) - assert any("main" in l for l in labels) + assert any("process" in label for label in labels) + assert any("main" in label for label in labels) + def test_c_finds_includes(): r = extract_c(FIXTURES / "sample.c") assert "imports" in _relations(r) + def test_c_emits_calls(): r = extract_c(FIXTURES / "sample.c") assert any(e["relation"] == "calls" for e in r["edges"]) + def test_c_calls_are_extracted(): r = extract_c(FIXTURES / "sample.c") for e in r["edges"]: @@ -154,19 +187,23 @@ def test_c_call_edges_have_call_context(): # ── C++ ─────────────────────────────────────────────────────────────────────── + def test_cpp_no_error(): r = extract_cpp(FIXTURES / "sample.cpp") assert "error" not in r + def test_cpp_finds_class(): r = extract_cpp(FIXTURES / "sample.cpp") - assert any("HttpClient" in l for l in _labels(r)) + assert any("HttpClient" in label for label in _labels(r)) + def test_cpp_finds_methods(): r = extract_cpp(FIXTURES / "sample.cpp") labels = _labels(r) # C++ extractor captures the constructor and public-visible methods - assert any("HttpClient" in l for l in labels) + assert any("HttpClient" in label for label in labels) + def test_cpp_finds_includes(): r = extract_cpp(FIXTURES / "sample.cpp") @@ -200,7 +237,8 @@ def test_cpp_class_inherits_edge(): found = any( "AuthedHttpClient" in node_by_id.get(e["source"], "") and "HttpClient" in node_by_id.get(e["target"], "") - for e in r["edges"] if e["relation"] == "inherits" + for e in r["edges"] + if e["relation"] == "inherits" ) assert found, "AuthedHttpClient should have inherits edge to HttpClient" @@ -212,67 +250,80 @@ def test_cpp_struct_inherits_edge(): found = any( "RetryingHttpClient" in node_by_id.get(e["source"], "") and "HttpClient" in node_by_id.get(e["target"], "") - for e in r["edges"] if e["relation"] == "inherits" + for e in r["edges"] + if e["relation"] == "inherits" ) assert found, "RetryingHttpClient (struct) should have inherits edge to HttpClient" # ── Ruby ───────────────────────────────────────────────────────────────────── + def test_ruby_no_error(): r = extract_ruby(FIXTURES / "sample.rb") assert "error" not in r + def test_ruby_finds_class(): r = extract_ruby(FIXTURES / "sample.rb") - assert any("ApiClient" in l for l in _labels(r)) + assert any("ApiClient" in label for label in _labels(r)) + def test_ruby_finds_methods(): r = extract_ruby(FIXTURES / "sample.rb") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) + def test_ruby_finds_function(): r = extract_ruby(FIXTURES / "sample.rb") - assert any("parse_response" in l for l in _labels(r)) + assert any("parse_response" in label for label in _labels(r)) # ── C# ─────────────────────────────────────────────────────────────────────── + def test_csharp_no_error(): r = extract_csharp(FIXTURES / "sample.cs") assert "error" not in r + def test_csharp_finds_class(): r = extract_csharp(FIXTURES / "sample.cs") - assert any("DataProcessor" in l for l in _labels(r)) + assert any("DataProcessor" in label for label in _labels(r)) + def test_csharp_finds_interface(): r = extract_csharp(FIXTURES / "sample.cs") - assert any("IProcessor" in l for l in _labels(r)) + assert any("IProcessor" in label for label in _labels(r)) + def test_csharp_finds_methods(): r = extract_csharp(FIXTURES / "sample.cs") labels = _labels(r) - assert any("Process" in l for l in labels) + assert any("Process" in label for label in labels) + def test_csharp_finds_usings(): r = extract_csharp(FIXTURES / "sample.cs") assert "imports" in _relations(r) + def test_csharp_inherits_edge(): r = extract_csharp(FIXTURES / "sample.cs") inherits = [e for e in r["edges"] if e["relation"] == "inherits"] assert len(inherits) >= 1 + def test_csharp_implements_iprocessor(): r = extract_csharp(FIXTURES / "sample.cs") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} found = any( - "DataProcessor" in node_by_id.get(e["source"], "") and - "IProcessor" in node_by_id.get(e["target"], "") - for e in r["edges"] if e["relation"] == "implements" + "DataProcessor" in node_by_id.get(e["source"], "") + and "IProcessor" in node_by_id.get(e["target"], "") + for e in r["edges"] + if e["relation"] == "implements" ) assert found, "DataProcessor should have implements edge to IProcessor" @@ -320,7 +371,8 @@ def test_csharp_call_edges_have_call_context(): "Process" in node_by_id.get(e["source"], "") and "Validate" in node_by_id.get(e["target"], "") and e.get("context") == "call" - for e in r["edges"] if e["relation"] == "calls" + for e in r["edges"] + if e["relation"] == "calls" ), "C# call edges should retain call context" @@ -333,27 +385,33 @@ def test_csharp_import_edges_have_import_context(): # ── Kotlin ─────────────────────────────────────────────────────────────────── + def test_kotlin_no_error(): r = extract_kotlin(FIXTURES / "sample.kt") assert "error" not in r + def test_kotlin_finds_class(): r = extract_kotlin(FIXTURES / "sample.kt") - assert any("HttpClient" in l for l in _labels(r)) + assert any("HttpClient" in label for label in _labels(r)) + def test_kotlin_finds_data_class(): r = extract_kotlin(FIXTURES / "sample.kt") - assert any("Config" in l for l in _labels(r)) + assert any("Config" in label for label in _labels(r)) + def test_kotlin_finds_methods(): r = extract_kotlin(FIXTURES / "sample.kt") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) + def test_kotlin_finds_function(): r = extract_kotlin(FIXTURES / "sample.kt") - assert any("createClient" in l for l in _labels(r)) + assert any("createClient" in label for label in _labels(r)) + def test_kotlin_emits_in_file_calls(): """Regression test for the call-walker `simple_identifier` / @@ -384,23 +442,27 @@ def test_kotlin_parameter_return_generic_and_field_contexts(): # ── Scala ───────────────────────────────────────────────────────────────────── + def test_scala_no_error(): r = extract_scala(FIXTURES / "sample.scala") assert "error" not in r + def test_scala_finds_class(): r = extract_scala(FIXTURES / "sample.scala") - assert any("HttpClient" in l for l in _labels(r)) + assert any("HttpClient" in label for label in _labels(r)) + def test_scala_finds_object(): r = extract_scala(FIXTURES / "sample.scala") - assert any("HttpClientFactory" in l for l in _labels(r)) + assert any("HttpClientFactory" in label for label in _labels(r)) + def test_scala_finds_methods(): r = extract_scala(FIXTURES / "sample.scala") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) def test_scala_import_edges_have_import_context(): @@ -440,23 +502,28 @@ def test_scala_call_edges_have_call_context(): # ── PHP ─────────────────────────────────────────────────────────────────────── + def test_php_no_error(): r = extract_php(FIXTURES / "sample.php") assert "error" not in r + def test_php_finds_class(): r = extract_php(FIXTURES / "sample.php") - assert any("ApiClient" in l for l in _labels(r)) + assert any("ApiClient" in label for label in _labels(r)) + def test_php_finds_methods(): r = extract_php(FIXTURES / "sample.php") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) + def test_php_finds_function(): r = extract_php(FIXTURES / "sample.php") - assert any("parseResponse" in l for l in _labels(r)) + assert any("parseResponse" in label for label in _labels(r)) + def test_php_finds_imports(): r = extract_php(FIXTURES / "sample.php") @@ -476,55 +543,67 @@ def test_php_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_php_finds_static_property_access(): r = extract_php(FIXTURES / "sample_php_static_prop.php") assert "uses_static_prop" in _relations(r) + def test_php_static_prop_target_is_holding_class(): r = extract_php(FIXTURES / "sample_php_static_prop.php") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} uses_prop = [ (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "uses_static_prop" + for e in r["edges"] + if e["relation"] == "uses_static_prop" ] assert any("DefaultPalette" in tgt for _, tgt in uses_prop) + def test_php_finds_config_helper_call(): r = extract_php(FIXTURES / "sample_php_config.php") assert "uses_config" in _relations(r) + def test_php_config_helper_target_matches_first_segment(): r = extract_php(FIXTURES / "sample_php_config.php") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} uses_cfg = [ (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "uses_config" + for e in r["edges"] + if e["relation"] == "uses_config" ] assert any("Throttle" in tgt for _, tgt in uses_cfg) + def test_php_finds_container_bind(): r = extract_php(FIXTURES / "sample_php_container.php") assert "bound_to" in _relations(r) + def test_php_container_bind_links_contract_to_implementation(): r = extract_php(FIXTURES / "sample_php_container.php") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} bound = [ (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "bound_to" + for e in r["edges"] + if e["relation"] == "bound_to" ] assert any("PaymentGateway" in src and "StripeGateway" in tgt for src, tgt in bound) + def test_php_finds_event_listeners(): r = extract_php(FIXTURES / "sample_php_listen.php") assert "listened_by" in _relations(r) + def test_php_event_listener_links_event_to_listener(): r = extract_php(FIXTURES / "sample_php_listen.php") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} listened = [ (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in r["edges"] if e["relation"] == "listened_by" + for e in r["edges"] + if e["relation"] == "listened_by" ] assert any("UserRegistered" in src and "SendWelcomeEmail" in tgt for src, tgt in listened) @@ -545,31 +624,38 @@ def test_php_property_parameter_and_return_contexts(): # ── Swift ──────────────────────────────────────────────────────────────────── + def test_swift_no_error(): r = extract_swift(FIXTURES / "sample.swift") assert "error" not in r + def test_swift_finds_class(): r = extract_swift(FIXTURES / "sample.swift") - assert any("DataProcessor" in l for l in _labels(r)) + assert any("DataProcessor" in label for label in _labels(r)) + def test_swift_finds_protocol(): r = extract_swift(FIXTURES / "sample.swift") - assert any("Processor" in l for l in _labels(r)) + assert any("Processor" in label for label in _labels(r)) + def test_swift_finds_struct(): r = extract_swift(FIXTURES / "sample.swift") - assert any("Config" in l for l in _labels(r)) + assert any("Config" in label for label in _labels(r)) + def test_swift_finds_methods(): r = extract_swift(FIXTURES / "sample.swift") labels = _labels(r) - assert any("addItem" in l for l in labels) - assert any("process" in l for l in labels) + assert any("addItem" in label for label in labels) + assert any("process" in label for label in labels) + def test_swift_finds_function(): r = extract_swift(FIXTURES / "sample.swift") - assert any("createProcessor" in l for l in _labels(r)) + assert any("createProcessor" in label for label in _labels(r)) + def test_swift_finds_imports(): r = extract_swift(FIXTURES / "sample.swift") @@ -582,42 +668,51 @@ def test_swift_import_edges_have_import_context(): assert import_edges assert all(e.get("context") == "import" for e in import_edges) + def test_swift_no_dangling_edges(): r = extract_swift(FIXTURES / "sample.swift") node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: assert e["source"] in node_ids + def test_swift_finds_actor(): r = extract_swift(FIXTURES / "sample.swift") - assert any("CacheManager" in l for l in _labels(r)) + assert any("CacheManager" in label for label in _labels(r)) + def test_swift_finds_enum(): r = extract_swift(FIXTURES / "sample.swift") - assert any("NetworkError" in l for l in _labels(r)) + assert any("NetworkError" in label for label in _labels(r)) + def test_swift_finds_enum_methods(): r = extract_swift(FIXTURES / "sample.swift") - assert any("describe" in l for l in _labels(r)) + assert any("describe" in label for label in _labels(r)) + def test_swift_finds_enum_cases(): r = extract_swift(FIXTURES / "sample.swift") labels = _labels(r) - assert any("timeout" in l for l in labels) - assert any("connectionFailed" in l for l in labels) + assert any("timeout" in label for label in labels) + assert any("connectionFailed" in label for label in labels) + def test_swift_enum_cases_have_case_of_edge(): r = extract_swift(FIXTURES / "sample.swift") case_edges = [e for e in r["edges"] if e["relation"] == "case_of"] assert len(case_edges) >= 2 + def test_swift_finds_deinit(): r = extract_swift(FIXTURES / "sample.swift") - assert any("deinit" in l for l in _labels(r)) + assert any("deinit" in label for label in _labels(r)) + def test_swift_finds_subscript(): r = extract_swift(FIXTURES / "sample.swift") - assert any("subscript" in l for l in _labels(r)) + assert any("subscript" in label for label in _labels(r)) + def test_swift_extension_methods_attach_to_type(): r = extract_swift(FIXTURES / "sample.swift") @@ -632,6 +727,7 @@ def test_swift_extension_methods_attach_to_type(): break assert found, "extension method isValid should attach to Config" + def test_swift_extension_does_not_duplicate_type_node(): r = extract_swift(FIXTURES / "sample.swift") config_nodes = [n for n in r["nodes"] if n["label"] == "Config"] @@ -660,11 +756,13 @@ def test_swift_parameter_return_generic_and_field_contexts(): assert ("run", "DataProcessor") in _edge_labels(r, "references", "generic_arg") assert ("DataProcessor", "Result") in _edge_labels(r, "references", "field") + def test_swift_emits_calls(): r = extract_swift(FIXTURES / "sample.swift") calls = _calls(r) assert any("process" in src and "validate" in tgt for src, tgt in calls) + def test_swift_call_edges_have_call_context(): r = extract_swift(FIXTURES / "sample.swift") call_edges = _edges_with_relation(r, "calls") @@ -678,36 +776,45 @@ def test_swift_extension_across_files_merges_into_canonical_type(): node ids carry the file stem, so without a corpus-level merge each file would emit its own Foo.""" from graphify.extract import extract + paths = sorted((FIXTURES / "swift_cross_file").glob("*.swift")) r = extract(paths, cache_root=Path("/tmp/graphify-test-no-cache")) foo_nodes = [n for n in r["nodes"] if n["label"] == "Foo"] - assert len(foo_nodes) == 1, f"Foo should appear once, got {len(foo_nodes)}: {[n['id'] for n in foo_nodes]}" + assert len(foo_nodes) == 1, ( + f"Foo should appear once, got {len(foo_nodes)}: {[n['id'] for n in foo_nodes]}" + ) foo_id = foo_nodes[0]["id"] method_targets = { - e["target"] for e in r["edges"] - if e["relation"] == "method" and e["source"] == foo_id + e["target"] for e in r["edges"] if e["relation"] == "method" and e["source"] == foo_id } method_labels = {n["label"] for n in r["nodes"] if n["id"] in method_targets} - assert any("one" in l for l in method_labels), f"one() should attach to Foo, got {method_labels}" - assert any("two" in l for l in method_labels), f"extension method two() should attach to Foo, got {method_labels}" + assert any("one" in label for label in method_labels), ( + f"one() should attach to Foo, got {method_labels}" + ) + assert any("two" in label for label in method_labels), ( + f"extension method two() should attach to Foo, got {method_labels}" + ) # ── Elixir ──────────────────────────────────────────────────────────────────── -from graphify.extract import extract_elixir +from graphify.extract import extract_elixir # noqa: E402 + def test_elixir_finds_module(): r = extract_elixir(FIXTURES / "sample.ex") assert "error" not in r labels = [n["label"] for n in r["nodes"]] - assert any("MyApp.Accounts.User" in l for l in labels) + assert any("MyApp.Accounts.User" in label for label in labels) + def test_elixir_finds_functions(): r = extract_elixir(FIXTURES / "sample.ex") labels = [n["label"] for n in r["nodes"]] - assert any("create" in l for l in labels) - assert any("find" in l for l in labels) - assert any("validate" in l for l in labels) + assert any("create" in label for label in labels) + assert any("find" in label for label in labels) + assert any("validate" in label for label in labels) + def test_elixir_finds_imports(): r = extract_elixir(FIXTURES / "sample.ex") @@ -721,11 +828,14 @@ def test_elixir_import_edges_have_import_context(): assert import_edges assert all(e.get("context") == "import" for e in import_edges) + def test_elixir_finds_calls(): r = extract_elixir(FIXTURES / "sample.ex") calls = {(e["source"], e["target"]) for e in r["edges"] if e["relation"] == "calls"} labels = {n["id"]: n["label"] for n in r["nodes"]} - assert any("create" in labels.get(src, "") and "validate" in labels.get(tgt, "") for src, tgt in calls) + assert any( + "create" in labels.get(src, "") and "validate" in labels.get(tgt, "") for src, tgt in calls + ) def test_elixir_call_edges_have_call_context(): @@ -734,6 +844,7 @@ def test_elixir_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_elixir_method_edges(): r = extract_elixir(FIXTURES / "sample.ex") methods = [e for e in r["edges"] if e["relation"] == "method"] @@ -741,7 +852,7 @@ def test_elixir_method_edges(): # ── Objective-C ────────────────────────────────────────────────────────────── -from graphify.extract import extract_objc +from graphify.extract import extract_objc # noqa: E402 def test_objc_finds_interface(): @@ -759,7 +870,7 @@ def test_objc_finds_subclass(): def test_objc_finds_methods(): r = extract_objc(FIXTURES / "sample.m") labels = [n["label"] for n in r["nodes"]] - assert any("speak" in l or "fetch" in l or "initWithName" in l for l in labels) + assert any("speak" in label or "fetch" in label or "initWithName" in label for label in labels) def test_objc_finds_imports(): @@ -804,6 +915,7 @@ def test_objc_no_dangling_edges(): # Go # --------------------------------------------------------------------------- + def test_go_receiver_methods_share_type_node(): """Methods on the same receiver type must share one canonical type node.""" r = extract_go(FIXTURES / "sample.go") @@ -811,6 +923,7 @@ def test_go_receiver_methods_share_type_node(): # Both Start() and Stop() are on *Server — should produce exactly one Server node assert len(server_nodes) == 1 + def test_go_receiver_uses_pkg_scope(): """Type node id should be scoped to directory, not file stem.""" r = extract_go(FIXTURES / "sample.go") @@ -824,6 +937,7 @@ def test_go_receiver_uses_pkg_scope(): # Julia # --------------------------------------------------------------------------- + def test_julia_finds_module(): r = extract_julia(FIXTURES / "sample.jl") labels = [n["label"] for n in r["nodes"]] @@ -846,14 +960,14 @@ def test_julia_finds_abstract_type(): def test_julia_finds_functions(): r = extract_julia(FIXTURES / "sample.jl") labels = [n["label"] for n in r["nodes"]] - assert any("area" in l for l in labels) - assert any("distance" in l for l in labels) + assert any("area" in label for label in labels) + assert any("distance" in label for label in labels) def test_julia_finds_short_function(): r = extract_julia(FIXTURES / "sample.jl") labels = [n["label"] for n in r["nodes"]] - assert any("perimeter" in l for l in labels) + assert any("perimeter" in label for label in labels) def test_julia_finds_imports(): @@ -910,6 +1024,7 @@ def test_julia_no_dangling_edges(): # ── Fortran extractor ──────────────────────────────────────────────────────── + def test_fortran_finds_module(): r = extract_fortran(FIXTURES / "sample.f90") assert "error" not in r @@ -920,14 +1035,14 @@ def test_fortran_finds_module(): def test_fortran_finds_subroutines(): r = extract_fortran(FIXTURES / "sample.f90") labels = [n["label"] for n in r["nodes"]] - assert any("circle_area" in l for l in labels) - assert any("print_area" in l for l in labels) + assert any("circle_area" in label for label in labels) + assert any("print_area" in label for label in labels) def test_fortran_finds_function(): r = extract_fortran(FIXTURES / "sample.f90") labels = [n["label"] for n in r["nodes"]] - assert any("distance" in l for l in labels) + assert any("distance" in label for label in labels) def test_fortran_finds_program(): @@ -957,7 +1072,11 @@ def test_fortran_finds_calls(): def test_fortran_case_insensitive_names(): r = extract_fortran(FIXTURES / "sample.f90") labels = [n["label"] for n in r["nodes"]] - assert all(l == l.lower() or "(" in l for l in labels if l.endswith(("()", "")) and not "." in l) + assert all( + label == label.lower() or "(" in label + for label in labels + if label.endswith(("()", "")) and "." not in label + ) assert "geometry" in labels assert "main" in labels @@ -986,7 +1105,7 @@ def test_fortran_capital_F_parses_preprocessed(): assert "error" not in r labels = [n["label"] for n in r["nodes"]] assert "shapes" in labels - assert any("compute_volume" in l for l in labels) + assert any("compute_volume" in label for label in labels) # ── PowerShell ─────────────────────────────────────────────────────────────── @@ -1000,7 +1119,7 @@ def test_powershell_finds_class_and_method(): r = extract_powershell(FIXTURES / "sample.ps1") labels = [n["label"] for n in r["nodes"]] assert "DataProcessor" in labels - assert any("Transform" in l for l in labels) + assert any("Transform" in label for label in labels) def test_powershell_property_field_type_context(): @@ -1017,10 +1136,12 @@ def test_powershell_method_parameter_and_return_type_contexts(): # ── TypeScript dynamic imports ─────────────────────────────────────────────── + def test_ts_dynamic_import_no_error(): r = extract_js(FIXTURES / "dynamic_import.ts") assert "error" not in r + def test_ts_dynamic_import_extracts_edges(): """Dynamic import() calls inside functions should produce imports_from edges.""" r = extract_js(FIXTURES / "dynamic_import.ts") @@ -1028,56 +1149,71 @@ def test_ts_dynamic_import_extracts_edges(): targets = {e["target"] for e in dyn_edges} # Should find: static ./logger, dynamic ./mayaEngine.js, dynamic ./queue.js assert any("logger" in t for t in targets), f"Missing static import of logger: {targets}" - assert any("mayaengine" in t.lower() for t in targets), f"Missing dynamic import of mayaEngine: {targets}" + assert any("mayaengine" in t.lower() for t in targets), ( + f"Missing dynamic import of mayaEngine: {targets}" + ) assert any("queue" in t.lower() for t in targets), f"Missing dynamic import of queue: {targets}" + def test_ts_dynamic_import_confidence(): """Dynamic imports should have EXTRACTED confidence (they are deterministic string literals).""" r = extract_js(FIXTURES / "dynamic_import.ts") - dyn_edges = [e for e in r["edges"] - if e["relation"] == "imports_from" - and "mayaengine" in e["target"].lower()] + dyn_edges = [ + e + for e in r["edges"] + if e["relation"] == "imports_from" and "mayaengine" in e["target"].lower() + ] assert len(dyn_edges) >= 1 assert dyn_edges[0]["confidence"] == "EXTRACTED" + def test_ts_dynamic_import_source_is_function(): """Dynamic import edge source should be the enclosing function, not the file.""" r = extract_js(FIXTURES / "dynamic_import.ts") node_labels = {n["id"]: n["label"] for n in r["nodes"]} - dyn_edges = [e for e in r["edges"] - if e["relation"] == "imports_from" - and "mayaengine" in e["target"].lower()] + dyn_edges = [ + e + for e in r["edges"] + if e["relation"] == "imports_from" and "mayaengine" in e["target"].lower() + ] assert len(dyn_edges) >= 1 src_label = node_labels.get(dyn_edges[0]["source"], "") assert "processInbound" in src_label, f"Expected processInbound as source, got {src_label}" + def test_ts_no_dynamic_import_in_sync_fn(): """Functions without dynamic imports should not get spurious imports_from edges.""" r = extract_js(FIXTURES / "dynamic_import.ts") node_ids = {n["label"]: n["id"] for n in r["nodes"]} sync_nid = node_ids.get("syncOnly()") if sync_nid: - sync_imports = [e for e in r["edges"] - if e["source"] == sync_nid and e["relation"] == "imports_from"] + sync_imports = [ + e for e in r["edges"] if e["source"] == sync_nid and e["relation"] == "imports_from" + ] assert len(sync_imports) == 0 + def test_ts_dynamic_template_literal_skipped(): """Dynamic template literals (with ${}) must not produce an imports_from edge.""" r = extract_js(FIXTURES / "dynamic_import.ts") targets = {e["target"] for e in r["edges"] if e["relation"] == "imports_from"} # loadHandler uses `./handlers/${handlerName}` — no static path, must be absent - assert not any("handler" in t.lower() and "$" in t for t in targets), \ + assert not any("handler" in t.lower() and "$" in t for t in targets), ( f"Garbage edge from dynamic template literal found: {targets}" + ) # More robust: no target should contain a brace character - assert not any("{" in t or "}" in t for t in targets), \ + assert not any("{" in t or "}" in t for t in targets), ( f"Target contains unresolved template expression: {targets}" + ) + def test_ts_static_template_literal_resolved(): """Static template literals (no ${}) should resolve the same as a plain string.""" r = extract_js(FIXTURES / "dynamic_import.ts") targets = {e["target"] for e in r["edges"] if e["relation"] == "imports_from"} - assert any("statichelper" in t.lower() for t in targets), \ + assert any("statichelper" in t.lower() for t in targets), ( f"Static template literal import not resolved: {targets}" + ) def test_js_local_const_does_not_emit_phantom_node(tmp_path): @@ -1151,25 +1287,29 @@ def test_ts_local_const_does_not_emit_phantom_node(tmp_path): # ── Markdown ───────────────────────────────────────────────────────────────── -from graphify.extract import extract_markdown +from graphify.extract import extract_markdown # noqa: E402 + def test_markdown_no_error(): r = extract_markdown(FIXTURES / "deploy_guide.md") assert "error" not in r + def test_markdown_finds_headings(): r = extract_markdown(FIXTURES / "deploy_guide.md") labels = _labels(r) - assert any("Deploy Guide" in l for l in labels) - assert any("Prerequisites" in l for l in labels) - assert any("Full Deploy" in l for l in labels) - assert any("Rollback" in l for l in labels) + assert any("Deploy Guide" in label for label in labels) + assert any("Prerequisites" in label for label in labels) + assert any("Full Deploy" in label for label in labels) + assert any("Rollback" in label for label in labels) + def test_markdown_finds_nested_heading(): """### Database Migration is nested under ## Full Deploy.""" r = extract_markdown(FIXTURES / "deploy_guide.md") labels = _labels(r) - assert any("Database Migration" in l for l in labels) + assert any("Database Migration" in label for label in labels) + def test_markdown_skips_fenced_code_blocks(): """Fenced code blocks should NOT emit nodes (#1077). @@ -1225,6 +1365,7 @@ def test_markdown_fenced_heading_not_parsed(): assert not any("Not A Heading" in l for l in labels), \ f"fenced '## Not A Heading' was incorrectly parsed as a node: {labels}" + def test_markdown_no_dangling_edges(): r = extract_markdown(FIXTURES / "deploy_guide.md") node_ids = {n["id"] for n in r["nodes"]} @@ -1242,14 +1383,14 @@ def test_groovy_no_error(): def test_groovy_finds_class(): r = extract_groovy(FIXTURES / "sample.groovy") - assert any("SampleService" in l for l in _labels(r)) + assert any("SampleService" in label for label in _labels(r)) def test_groovy_finds_methods(): r = extract_groovy(FIXTURES / "sample.groovy") labels = _labels(r) - assert any("process" in l for l in labels) - assert any("reset" in l for l in labels) + assert any("process" in label for label in labels) + assert any("reset" in label for label in labels) def test_groovy_finds_imports(): @@ -1273,18 +1414,18 @@ def test_groovy_no_dangling_edges(): def test_groovy_spock_finds_class(): r = extract_groovy(FIXTURES / "sample_spock.groovy") - assert any("SampleSpec" in l for l in _labels(r)) + assert any("SampleSpec" in label for label in _labels(r)) def test_groovy_spock_finds_feature_methods(): r = extract_groovy(FIXTURES / "sample_spock.groovy") - feature_labels = [l for l in _labels(r) if l.startswith('"')] + feature_labels = [label for label in _labels(r) if label.startswith('"')] assert len(feature_labels) >= 2 def test_groovy_spock_finds_method_with_apostrophe(): r = extract_groovy(FIXTURES / "sample_spock.groovy") - assert any("it's" in l for l in _labels(r)) + assert any("it's" in label for label in _labels(r)) def test_groovy_spock_preserves_import_edges(): @@ -1308,8 +1449,8 @@ def test_dm_no_error(): def test_dm_finds_global_proc(): r = extract_dm(FIXTURES / "sample.dm") labels = _labels(r) - assert any(l == "log_event()" for l in labels) - assert any(l == "RunTest()" for l in labels) + assert any(label == "log_event()" for label in labels) + assert any(label == "RunTest()" for label in labels) def test_dm_finds_type_definition(): r = extract_dm(FIXTURES / "sample.dm") @@ -1389,7 +1530,7 @@ def test_dmi_no_error(): def test_dmi_emits_state_nodes(): r = extract_dmi(FIXTURES / "sample.dmi") labels = _labels(r) - assert any(l == '"mob"' for l in labels) + assert any(label == '"mob"' for label in labels) def test_dmi_state_contained_by_file(): r = extract_dmi(FIXTURES / "sample.dmi") @@ -1463,68 +1604,83 @@ def test_dmf_no_dangling_edges(): # -- .NET project files (.sln, .csproj, .razor) ------------------------------- + def test_sln_no_error(): r = extract_sln(FIXTURES / "sample.sln") assert "error" not in r + def test_sln_finds_projects(): r = extract_sln(FIXTURES / "sample.sln") labels = _labels(r) - assert any("WebApi" in l for l in labels) - assert any("Domain" in l for l in labels) + assert any("WebApi" in label for label in labels) + assert any("Domain" in label for label in labels) + def test_sln_contains_edges(): r = extract_sln(FIXTURES / "sample.sln") assert "contains" in _relations(r) + def test_sln_project_dependency_edges(): r = extract_sln(FIXTURES / "sample.sln") assert "imports" in _relations(r) + def test_csproj_no_error(): r = extract_csproj(FIXTURES / "sample.csproj") assert "error" not in r + def test_csproj_finds_packages(): r = extract_csproj(FIXTURES / "sample.csproj") labels = _labels(r) - assert any("MediatR" in l for l in labels) - assert any("FluentValidation" in l for l in labels) + assert any("MediatR" in label for label in labels) + assert any("FluentValidation" in label for label in labels) + def test_csproj_finds_project_references(): r = extract_csproj(FIXTURES / "sample.csproj") labels = _labels(r) - assert any("Domain.csproj" in l for l in labels) + assert any("Domain.csproj" in label for label in labels) + def test_csproj_finds_target_framework(): r = extract_csproj(FIXTURES / "sample.csproj") - assert any("net8.0" in l for l in _labels(r)) + assert any("net8.0" in label for label in _labels(r)) + def test_csproj_finds_sdk(): r = extract_csproj(FIXTURES / "sample.csproj") - assert any("Microsoft.NET.Sdk.Web" in l for l in _labels(r)) + assert any("Microsoft.NET.Sdk.Web" in label for label in _labels(r)) + def test_razor_no_error(): r = extract_razor(FIXTURES / "sample.razor") assert "error" not in r + def test_razor_finds_using_directives(): r = extract_razor(FIXTURES / "sample.razor") assert "imports" in _relations(r) + def test_razor_finds_component_references(): r = extract_razor(FIXTURES / "sample.razor") assert "calls" in _relations(r) + def test_razor_finds_inherits(): r = extract_razor(FIXTURES / "sample.razor") assert "inherits" in _relations(r) + def test_razor_finds_code_block_methods(): r = extract_razor(FIXTURES / "sample.razor") labels = _labels(r) - assert any("IncrementCount" in l for l in labels) - assert any("LoadData" in l for l in labels) + assert any("IncrementCount" in label for label in labels) + assert any("LoadData" in label for label in labels) + def test_razor_no_dangling_edges(): r = extract_razor(FIXTURES / "sample.razor") diff --git a/tests/test_llm_backends.py b/tests/test_llm_backends.py index d2ee058c4..7cd670e83 100644 --- a/tests/test_llm_backends.py +++ b/tests/test_llm_backends.py @@ -224,6 +224,7 @@ def test_response_is_hollow_accepts_real_extraction(): def _fake_openai_response(content, *, finish_reason="stop", prompt_tokens=100, completion_tokens=0): """Build a minimal stand-in for an `openai` SDK ChatCompletion response.""" + class _Usage: def __init__(self): self.prompt_tokens = prompt_tokens @@ -257,11 +258,12 @@ class _FakeOpenAI: def __init__(self, *_, **__): self.chat = self self.completions = self + def create(self, **__): return fake_resp fake_module = types.ModuleType("openai") - fake_module.OpenAI = _FakeOpenAI + setattr(fake_module, "OpenAI", _FakeOpenAI) monkeypatch.setitem(sys.modules, "openai", fake_module) @@ -274,8 +276,13 @@ def test_call_openai_compat_relabels_empty_content_as_length(monkeypatch): _install_fake_openai(monkeypatch, fake_resp) result = llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "user msg", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "user msg", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert result["finish_reason"] == "length", ( "empty content from a 'successful' call must be re-labelled so the " @@ -288,8 +295,13 @@ def test_call_openai_compat_relabels_none_content_as_length(monkeypatch): _install_fake_openai(monkeypatch, fake_resp) result = llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "u", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "u", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert result["finish_reason"] == "length" @@ -298,12 +310,19 @@ def test_call_openai_compat_relabels_unparseable_json_as_length(monkeypatch): # A half-generated response: `{"nodes": [{"id":` parses to {} (empty # fragment) via _parse_llm_json's JSONDecodeError fallback. That is also # hollow and must trigger bisection. - fake_resp = _fake_openai_response('{"nodes": [{"id":', finish_reason="stop", completion_tokens=20) + fake_resp = _fake_openai_response( + '{"nodes": [{"id":', finish_reason="stop", completion_tokens=20 + ) _install_fake_openai(monkeypatch, fake_resp) result = llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "u", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "u", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert result["finish_reason"] == "length" @@ -318,8 +337,13 @@ def test_call_openai_compat_preserves_real_finish_reason(monkeypatch): _install_fake_openai(monkeypatch, fake_resp) result = llm._call_openai_compat( - "http://localhost:11434/v1", "k", "m", - "u", temperature=0, max_completion_tokens=8192, backend="kimi", + "http://localhost:11434/v1", + "k", + "m", + "u", + temperature=0, + max_completion_tokens=8192, + backend="kimi", ) assert result["finish_reason"] == "stop" assert result["nodes"] == [{"id": "a"}] @@ -352,7 +376,7 @@ def create(self, **kwargs): ) fake_module = types.ModuleType("openai") - fake_module.OpenAI = _FakeOpenAI + setattr(fake_module, "OpenAI", _FakeOpenAI) monkeypatch.setitem(sys.modules, "openai", fake_module) return captured @@ -363,8 +387,13 @@ def test_ollama_extra_body_sets_num_ctx_and_keep_alive(monkeypatch): monkeypatch.delenv("GRAPHIFY_OLLAMA_KEEP_ALIVE", raising=False) llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "user msg", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "user msg", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert "extra_body" in captured, "extra_body must be sent to Ollama" @@ -387,8 +416,13 @@ def test_ollama_num_ctx_scales_with_small_token_budget(monkeypatch): small_chunk_msg = "x" * 32_000 llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - small_chunk_msg, temperature=0, max_completion_tokens=16384, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + small_chunk_msg, + temperature=0, + max_completion_tokens=16384, + backend="ollama", ) num_ctx = captured["extra_body"]["options"]["num_ctx"] @@ -407,8 +441,13 @@ def test_ollama_num_ctx_env_override(monkeypatch): monkeypatch.delenv("GRAPHIFY_OLLAMA_KEEP_ALIVE", raising=False) llm._call_openai_compat( - "http://localhost:11434/v1", "ollama", "qwen2.5-coder:7b", - "u", temperature=0, max_completion_tokens=8192, backend="ollama", + "http://localhost:11434/v1", + "ollama", + "qwen2.5-coder:7b", + "u", + temperature=0, + max_completion_tokens=8192, + backend="ollama", ) assert captured["extra_body"]["options"]["num_ctx"] == 65536 @@ -418,8 +457,13 @@ def test_non_ollama_backend_gets_no_num_ctx_extra_body(monkeypatch): captured = _install_capturing_openai(monkeypatch) llm._call_openai_compat( - "https://api.openai.com/v1", "sk-test", "gpt-4.1-mini", - "u", temperature=0, max_completion_tokens=8192, backend="openai", + "https://api.openai.com/v1", + "sk-test", + "gpt-4.1-mini", + "u", + temperature=0, + max_completion_tokens=8192, + backend="openai", ) eb = captured.get("extra_body") @@ -445,8 +489,14 @@ def fake_extract(chunk, *_, **__): with patch("graphify.llm.extract_files_direct", side_effect=fake_extract): with patch("graphify.llm.ThreadPoolExecutor") as mock_pool: result = llm.extract_corpus_parallel( - files, backend="ollama", api_key="ollama", model="qwen2.5-coder:7b", - root=tmp_path, token_budget=None, chunk_size=2, max_concurrency=4, + files, + backend="ollama", + api_key="ollama", + model="qwen2.5-coder:7b", + root=tmp_path, + token_budget=None, + chunk_size=2, + max_concurrency=4, ) mock_pool.assert_not_called() @@ -469,8 +519,14 @@ def test_extract_corpus_parallel_ollama_parallel_env_restores_concurrency(tmp_pa )() try: llm.extract_corpus_parallel( - files, backend="ollama", api_key="ollama", model="m", - root=tmp_path, token_budget=None, chunk_size=2, max_concurrency=4, + files, + backend="ollama", + api_key="ollama", + model="m", + root=tmp_path, + token_budget=None, + chunk_size=2, + max_concurrency=4, ) except Exception: pass # mock scaffolding may not be complete; we only care about the call @@ -496,16 +552,24 @@ def fake_extract(chunk, *_, **__): # Hollow response: looks successful, finish_reason already # rewritten to "length" by _call_openai_compat. return { - "nodes": [], "edges": [], "hyperedges": [], - "input_tokens": 100, "output_tokens": 0, - "model": "m", "finish_reason": "length", + "nodes": [], + "edges": [], + "hyperedges": [], + "input_tokens": 100, + "output_tokens": 0, + "model": "m", + "finish_reason": "length", } return _ok(nodes=[{"id": f.stem} for f in chunk]) with patch("graphify.llm.extract_files_direct", side_effect=fake_extract): result = llm._extract_with_adaptive_retry( - files, backend="ollama", api_key="ollama", model="qwen2.5-coder:7b", - root=tmp_path, max_depth=3, + files, + backend="ollama", + api_key="ollama", + model="qwen2.5-coder:7b", + root=tmp_path, + max_depth=3, ) assert len(result["nodes"]) == 4, ( diff --git a/tests/test_llm_parser.py b/tests/test_llm_parser.py index c643807b9..ad8bd07dd 100644 --- a/tests/test_llm_parser.py +++ b/tests/test_llm_parser.py @@ -11,8 +11,6 @@ import json from unittest.mock import patch -import pytest - from graphify import llm diff --git a/tests/test_mcp_ingest.py b/tests/test_mcp_ingest.py index 8e46b170d..2f447ac54 100644 --- a/tests/test_mcp_ingest.py +++ b/tests/test_mcp_ingest.py @@ -1,10 +1,10 @@ """Tests for graphify.mcp_ingest — MCP config file extraction.""" + from __future__ import annotations import json from pathlib import Path -import pytest from graphify.mcp_ingest import ( MCP_CONFIG_FILENAMES, @@ -29,11 +29,7 @@ def _relations(result): def _label_by_kind(result, kind): - return [ - n["label"] - for n in result["nodes"] - if n.get("metadata", {}).get("mcp_kind") == kind - ] + return [n["label"] for n in result["nodes"] if n.get("metadata", {}).get("mcp_kind") == kind] def _write(tmp_path: Path, name: str, payload) -> Path: @@ -167,13 +163,21 @@ def test_every_edge_has_confidence_score(): def test_same_command_collapses_to_one_node_across_configs(tmp_path): # Two configs both use "npx". The mcp_command node should be shared. - config_a = _write(tmp_path, ".mcp.json", { - "mcpServers": {"a": {"command": "npx", "args": ["@scope/server-a"]}}, - }) + config_a = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"a": {"command": "npx", "args": ["@scope/server-a"]}}, + }, + ) (tmp_path / "subdir").mkdir() - config_b = _write(tmp_path / "subdir", "claude_desktop_config.json", { - "mcpServers": {"b": {"command": "npx", "args": ["@scope/server-b"]}}, - }) + config_b = _write( + tmp_path / "subdir", + "claude_desktop_config.json", + { + "mcpServers": {"b": {"command": "npx", "args": ["@scope/server-b"]}}, + }, + ) r_a = extract_mcp_config(config_a) r_b = extract_mcp_config(config_b) cmd_id_a = next(n["id"] for n in r_a["nodes"] if n["metadata"]["mcp_kind"] == "mcp_command") @@ -183,17 +187,25 @@ def test_same_command_collapses_to_one_node_across_configs(tmp_path): def test_same_env_var_collapses_to_one_node_across_configs(tmp_path): # Two configs both require OPENAI_API_KEY. The env_var node ID must be identical. - a = _write(tmp_path, ".mcp.json", { - "mcpServers": { - "x": {"command": "npx", "args": ["@scope/x"], "env": {"OPENAI_API_KEY": "v1"}}, + a = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": { + "x": {"command": "npx", "args": ["@scope/x"], "env": {"OPENAI_API_KEY": "v1"}}, + }, }, - }) + ) (tmp_path / "sub").mkdir() - b = _write(tmp_path / "sub", "claude_desktop_config.json", { - "mcpServers": { - "y": {"command": "uvx", "args": ["mcp-server-y"], "env": {"OPENAI_API_KEY": "v2"}}, + b = _write( + tmp_path / "sub", + "claude_desktop_config.json", + { + "mcpServers": { + "y": {"command": "uvx", "args": ["mcp-server-y"], "env": {"OPENAI_API_KEY": "v2"}}, + }, }, - }) + ) r_a = extract_mcp_config(a) r_b = extract_mcp_config(b) env_id_a = next(n["id"] for n in r_a["nodes"] if n["metadata"]["mcp_kind"] == "env_var") @@ -206,12 +218,20 @@ def test_same_server_name_in_different_dirs_does_not_collide(tmp_path): # The server nodes should NOT collide (stem-scoped via parent dir). (tmp_path / "proj_a").mkdir() (tmp_path / "proj_b").mkdir() - a = _write(tmp_path / "proj_a", ".mcp.json", { - "mcpServers": {"filesystem": {"command": "npx", "args": ["@scope/a"]}}, - }) - b = _write(tmp_path / "proj_b", ".mcp.json", { - "mcpServers": {"filesystem": {"command": "npx", "args": ["@scope/b"]}}, - }) + a = _write( + tmp_path / "proj_a", + ".mcp.json", + { + "mcpServers": {"filesystem": {"command": "npx", "args": ["@scope/a"]}}, + }, + ) + b = _write( + tmp_path / "proj_b", + ".mcp.json", + { + "mcpServers": {"filesystem": {"command": "npx", "args": ["@scope/b"]}}, + }, + ) r_a = extract_mcp_config(a) r_b = extract_mcp_config(b) srv_a = next(n["id"] for n in r_a["nodes"] if n["metadata"]["mcp_kind"] == "mcp_server") @@ -232,9 +252,13 @@ def test_missing_mcp_servers_key(tmp_path): def test_nested_mcp_servers_shape(tmp_path): # Some tools wrap the map: {"mcp": {"servers": {...}}} - p = _write(tmp_path, ".mcp.json", { - "mcp": {"servers": {"x": {"command": "node", "args": ["dist/index.js"]}}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcp": {"servers": {"x": {"command": "node", "args": ["dist/index.js"]}}}, + }, + ) r = extract_mcp_config(p) assert "error" not in r assert "x" in _label_by_kind(r, "mcp_server") @@ -266,12 +290,16 @@ def test_root_not_an_object(tmp_path): def test_non_dict_server_entry_skipped(tmp_path): - p = _write(tmp_path, ".mcp.json", { - "mcpServers": { - "valid": {"command": "npx", "args": ["@scope/pkg"]}, - "broken": ["this", "is", "not", "an", "object"], + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": { + "valid": {"command": "npx", "args": ["@scope/pkg"]}, + "broken": ["this", "is", "not", "an", "object"], + }, }, - }) + ) r = extract_mcp_config(p) server_labels = _label_by_kind(r, "mcp_server") assert "valid" in server_labels @@ -283,26 +311,38 @@ def test_non_dict_server_entry_skipped(tmp_path): def test_package_detection_skips_flags(tmp_path): # First arg is -y (flag); second is the package. Detection should skip the flag. - p = _write(tmp_path, ".mcp.json", { - "mcpServers": {"x": {"command": "npx", "args": ["-y", "@scope/server-x"]}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"x": {"command": "npx", "args": ["-y", "@scope/server-x"]}}, + }, + ) r = extract_mcp_config(p) assert "@scope/server-x" in _label_by_kind(r, "mcp_package") def test_no_package_detected_for_unknown_arg_shape(tmp_path): # Args don't look like any known package pattern => no package node. - p = _write(tmp_path, ".mcp.json", { - "mcpServers": {"x": {"command": "node", "args": ["./local-script.js", "--verbose"]}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"x": {"command": "node", "args": ["./local-script.js", "--verbose"]}}, + }, + ) r = extract_mcp_config(p) assert _label_by_kind(r, "mcp_package") == [] def test_server_without_command_still_emits_server_node(tmp_path): - p = _write(tmp_path, ".mcp.json", { - "mcpServers": {"x": {"args": ["@scope/server-x"]}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"x": {"args": ["@scope/server-x"]}}, + }, + ) r = extract_mcp_config(p) assert "x" in _label_by_kind(r, "mcp_server") assert _label_by_kind(r, "mcp_command") == [] @@ -316,9 +356,13 @@ def test_dispatch_routes_mcp_filename_to_mcp_extractor(tmp_path): # extract_mcp_config, NOT extract_json. from graphify.extract import _get_extractor - p = _write(tmp_path, ".mcp.json", { - "mcpServers": {"x": {"command": "npx", "args": ["@scope/server-x"]}}, - }) + p = _write( + tmp_path, + ".mcp.json", + { + "mcpServers": {"x": {"command": "npx", "args": ["@scope/server-x"]}}, + }, + ) extractor = _get_extractor(p) assert extractor is extract_mcp_config diff --git a/tests/test_multilang.py b/tests/test_multilang.py index c30b9e10c..9496a9664 100644 --- a/tests/test_multilang.py +++ b/tests/test_multilang.py @@ -1,6 +1,6 @@ """Tests for multi-language AST extraction: JS/TS, Go, Rust, SQL.""" + from __future__ import annotations -import shutil from pathlib import Path import pytest from graphify.extract import extract_js, extract_go, extract_rust, extract, extract_sql @@ -10,16 +10,20 @@ # ── helpers ────────────────────────────────────────────────────────────────── + def _labels(result): return [n["label"] for n in result["nodes"]] + def _call_pairs(result): node_by_id = {n["id"]: n["label"] for n in result["nodes"]} return { (node_by_id.get(e["source"], e["source"]), node_by_id.get(e["target"], e["target"])) - for e in result["edges"] if e["relation"] == "calls" + for e in result["edges"] + if e["relation"] == "calls" } + def _confidences(result): return {e["confidence"] for e in result["edges"]} @@ -46,20 +50,24 @@ def _edge_labels(result, relation, context=None): # ── TypeScript ──────────────────────────────────────────────────────────────── + def test_ts_finds_class(): r = extract_js(FIXTURES / "sample.ts") assert "error" not in r assert "HttpClient" in _labels(r) + def test_ts_finds_methods(): r = extract_js(FIXTURES / "sample.ts") labels = _labels(r) - assert any("get" in l for l in labels) - assert any("post" in l for l in labels) + assert any("get" in label for label in labels) + assert any("post" in label for label in labels) + def test_ts_finds_function(): r = extract_js(FIXTURES / "sample.ts") - assert any("buildHeaders" in l for l in _labels(r)) + assert any("buildHeaders" in label for label in _labels(r)) + def test_ts_emits_calls(): r = extract_js(FIXTURES / "sample.ts") @@ -67,6 +75,7 @@ def test_ts_emits_calls(): # .post() calls .get() assert any("post" in src and "get" in tgt for src, tgt in calls) + def test_ts_calls_are_extracted(): r = extract_js(FIXTURES / "sample.ts") for e in r["edges"]: @@ -87,6 +96,7 @@ def test_ts_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_ts_no_dangling_edges(): r = extract_js(FIXTURES / "sample.ts") node_ids = {n["id"] for n in r["nodes"]} @@ -97,26 +107,31 @@ def test_ts_no_dangling_edges(): # ── Go ──────────────────────────────────────────────────────────────────────── + def test_go_finds_struct(): r = extract_go(FIXTURES / "sample.go") assert "error" not in r assert "Server" in _labels(r) + def test_go_finds_methods(): r = extract_go(FIXTURES / "sample.go") labels = _labels(r) - assert any("Start" in l for l in labels) - assert any("Stop" in l for l in labels) + assert any("Start" in label for label in labels) + assert any("Stop" in label for label in labels) + def test_go_finds_constructor(): r = extract_go(FIXTURES / "sample.go") - assert any("NewServer" in l for l in _labels(r)) + assert any("NewServer" in label for label in _labels(r)) + def test_go_emits_calls(): r = extract_go(FIXTURES / "sample.go") # main() calls NewServer and Start assert len(_call_pairs(r)) > 0 + def test_go_has_extracted_calls(): r = extract_go(FIXTURES / "sample.go") assert "EXTRACTED" in _confidences(r) @@ -135,6 +150,7 @@ def test_go_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_go_no_dangling_edges(): r = extract_go(FIXTURES / "sample.go") node_ids = {n["id"] for n in r["nodes"]} @@ -288,26 +304,31 @@ def _is_guarded(use: ast.AST) -> bool: # ── Rust ────────────────────────────────────────────────────────────────────── + def test_rust_finds_struct(): r = extract_rust(FIXTURES / "sample.rs") assert "error" not in r assert "Graph" in _labels(r) + def test_rust_finds_impl_methods(): r = extract_rust(FIXTURES / "sample.rs") labels = _labels(r) - assert any("add_node" in l for l in labels) - assert any("add_edge" in l for l in labels) + assert any("add_node" in label for label in labels) + assert any("add_edge" in label for label in labels) + def test_rust_finds_function(): r = extract_rust(FIXTURES / "sample.rs") - assert any("build_graph" in l for l in _labels(r)) + assert any("build_graph" in label for label in _labels(r)) + def test_rust_emits_calls(): r = extract_rust(FIXTURES / "sample.rs") calls = _call_pairs(r) assert any("build_graph" in src for src, _ in calls) + def test_rust_calls_are_extracted(): r = extract_rust(FIXTURES / "sample.rs") for e in r["edges"]: @@ -328,6 +349,7 @@ def test_rust_call_edges_have_call_context(): assert call_edges assert all(e.get("context") == "call" for e in call_edges) + def test_rust_no_dangling_edges(): r = extract_rust(FIXTURES / "sample.rs") node_ids = {n["id"] for n in r["nodes"]} @@ -363,6 +385,7 @@ def test_rust_no_cross_crate_spurious_edges(): """Scoped calls (Type::method) and blocklisted names must not produce INFERRED cross-crate calls edges (#908).""" from graphify.extract import extract + crate_a = FIXTURES / "crate_a" / "src" / "lib.rs" crate_b = FIXTURES / "crate_b" / "src" / "lib.rs" r = extract([crate_a, crate_b]) @@ -370,18 +393,16 @@ def test_rust_no_cross_crate_spurious_edges(): node_ids_b = {n["id"] for n in r["nodes"] if "crate_b" in (n.get("source_file") or "")} # No calls edge should cross from crate_b into crate_a cross_crate_calls = [ - e for e in r["edges"] - if e["relation"] == "calls" - and e["source"] in node_ids_b - and e["target"] in node_ids_a + e + for e in r["edges"] + if e["relation"] == "calls" and e["source"] in node_ids_b and e["target"] in node_ids_a ] - assert cross_crate_calls == [], ( - f"Spurious cross-crate edges: {cross_crate_calls}" - ) + assert cross_crate_calls == [], f"Spurious cross-crate edges: {cross_crate_calls}" # ── extract() dispatch ──────────────────────────────────────────────────────── + def test_extract_dispatches_all_languages(): files = [ FIXTURES / "sample.py", @@ -400,6 +421,7 @@ def test_extract_dispatches_all_languages(): # ── Cache ───────────────────────────────────────────────────────────────────── + def test_cache_hit_returns_same_result(tmp_path): src = FIXTURES / "sample.py" dst = tmp_path / "sample.py" @@ -410,20 +432,22 @@ def test_cache_hit_returns_same_result(tmp_path): assert len(r1["nodes"]) == len(r2["nodes"]) assert len(r1["edges"]) == len(r2["edges"]) + def test_cache_miss_after_file_change(tmp_path): dst = tmp_path / "a.py" dst.write_text("def foo(): pass\n") - r1 = extract([dst]) + extract([dst]) dst.write_text("def foo(): pass\ndef bar(): pass\n") r2 = extract([dst]) # bar() should appear in the second result labels2 = [n["label"] for n in r2["nodes"]] - assert any("bar" in l for l in labels2) + assert any("bar" in label for label in labels2) # ── SQL ─────────────────────────────────────────────────────────────────────── + def _extract_sql_or_skip(fixture: str = "sample.sql"): pytest.importorskip("tree_sitter_sql") return extract_sql(FIXTURES / fixture) @@ -432,35 +456,41 @@ def _extract_sql_or_skip(fixture: str = "sample.sql"): def test_sql_finds_tables(): r = _extract_sql_or_skip() labels = [n["label"] for n in r["nodes"]] - assert any("users" in l for l in labels) - assert any("organizations" in l for l in labels) + assert any("users" in label for label in labels) + assert any("organizations" in label for label in labels) + def test_sql_finds_view(): r = _extract_sql_or_skip() labels = [n["label"] for n in r["nodes"]] - assert any("active_users" in l for l in labels) + assert any("active_users" in label for label in labels) + def test_sql_finds_function(): r = _extract_sql_or_skip() labels = [n["label"] for n in r["nodes"]] - assert any("get_user" in l for l in labels) + assert any("get_user" in label for label in labels) + def test_sql_emits_foreign_key_edge(): r = _extract_sql_or_skip() relations = {e["relation"] for e in r["edges"]} assert "references" in relations + def test_sql_emits_reads_from_edge(): r = _extract_sql_or_skip() relations = {e["relation"] for e in r["edges"]} assert "reads_from" in relations + def test_sql_no_dangling_edges(): r = _extract_sql_or_skip() node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: assert e["source"] in node_ids, f"dangling source: {e['source']}" + def test_sql_alter_table_fk_edge(): """ALTER TABLE ... FOREIGN KEY ... REFERENCES produces a references edge.""" r = _extract_sql_or_skip("sample_alter_fk.sql") @@ -471,12 +501,14 @@ def test_sql_alter_table_fk_edge(): assert e["source"] in node_ids, f"dangling source: {e['source']}" assert e["target"] in node_ids, f"dangling target: {e['target']}" + def test_sql_schema_qualified_names(): """Schema-qualified table names (Schema.Table) are preserved.""" r = _extract_sql_or_skip("sample_schema_qualified.sql") labels = [n["label"] for n in r["nodes"]] - assert any("Sales.Customer" in l for l in labels) - assert any("Sales.SalesOrder" in l for l in labels) + assert any("Sales.Customer" in label for label in labels) + assert any("Sales.SalesOrder" in label for label in labels) + def test_sql_schema_qualified_alter_fk(): """ALTER TABLE with schema-qualified names produces correct edges.""" diff --git a/tests/test_ollama.py b/tests/test_ollama.py index a7af29a64..7336dd5fe 100644 --- a/tests/test_ollama.py +++ b/tests/test_ollama.py @@ -1,4 +1,5 @@ """Tests for the Ollama backend additions in graphify/llm.py.""" + from __future__ import annotations from graphify.llm import detect_backend, BACKENDS @@ -60,6 +61,7 @@ def test_ollama_api_key_sentinel(monkeypatch): } with patch("graphify.llm._call_openai_compat", return_value=fake_result) as mock_call: from graphify.llm import extract_files_direct + with tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False) as f: f.write("x = 1\n") tmp = Path(f.name) @@ -68,7 +70,9 @@ def test_ollama_api_key_sentinel(monkeypatch): # Should have called _call_openai_compat with api_key="ollama" assert mock_call.called call_kwargs = mock_call.call_args - api_key_used = call_kwargs.args[1] if call_kwargs.args else call_kwargs.kwargs.get("api_key", "") + api_key_used = ( + call_kwargs.args[1] if call_kwargs.args else call_kwargs.kwargs.get("api_key", "") + ) assert api_key_used == "ollama" finally: tmp.unlink(missing_ok=True) diff --git a/tests/test_pascal.py b/tests/test_pascal.py index 36c1b8747..50b0c7412 100644 --- a/tests/test_pascal.py +++ b/tests/test_pascal.py @@ -1,4 +1,5 @@ """Tests for the Pascal/Delphi extractor.""" + from __future__ import annotations from pathlib import Path @@ -19,48 +20,55 @@ def _edges_with_relation(r, *relations): def test_pascal_no_error(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") assert "error" not in r def test_pascal_finds_unit(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") - assert any("SampleUnit" in l for l in _labels(r)) + assert any("SampleUnit" in label for label in _labels(r)) def test_pascal_finds_classes(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") labels = _labels(r) - assert any("TBaseProcessor" in l for l in labels) - assert any("TDataProcessor" in l for l in labels) + assert any("TBaseProcessor" in label for label in labels) + assert any("TDataProcessor" in label for label in labels) def test_pascal_finds_interface(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") - assert any("IProcessor" in l for l in _labels(r)) + assert any("IProcessor" in label for label in _labels(r)) def test_pascal_finds_methods(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") labels = _labels(r) - assert any("Process" in l for l in labels) - assert any("Initialize" in l for l in labels) - assert any("GetCount" in l for l in labels) - assert any("Reset" in l for l in labels) + assert any("Process" in label for label in labels) + assert any("Initialize" in label for label in labels) + assert any("GetCount" in label for label in labels) + assert any("Reset" in label for label in labels) def test_pascal_finds_imports(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") assert "imports" in _relations(r) def test_pascal_import_edges_have_import_context(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") import_edges = _edges_with_relation(r, "imports") assert import_edges @@ -69,30 +77,31 @@ def test_pascal_import_edges_have_import_context(): def test_pascal_finds_inherits(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") assert "inherits" in _relations(r) def test_pascal_inherits_from_base(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") node_by_id = {n["id"]: n["label"] for n in r["nodes"]} inherits = [e for e in r["edges"] if e["relation"] == "inherits"] - found = any( - "TDataProcessor" in node_by_id.get(e["source"], "") - for e in inherits - ) + found = any("TDataProcessor" in node_by_id.get(e["source"], "") for e in inherits) assert found, "TDataProcessor should have at least one inherits edge" def test_pascal_finds_calls(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") assert "calls" in _relations(r) def test_pascal_call_edges_have_call_context(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") call_edges = _edges_with_relation(r, "calls") assert call_edges @@ -101,6 +110,7 @@ def test_pascal_call_edges_have_call_context(): def test_pascal_all_edges_extracted(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") structural = {"contains", "method", "inherits", "imports"} for e in r["edges"]: @@ -110,6 +120,7 @@ def test_pascal_all_edges_extracted(): def test_pascal_no_dangling_edges(): from graphify.extract import extract_pascal + r = extract_pascal(FIXTURES / "sample.pas") node_ids = {n["id"] for n in r["nodes"]} # imports edges are cross-file by design; only check within-file edge targets @@ -122,6 +133,7 @@ def test_pascal_no_dangling_edges(): def test_pascal_dispatch_registered(): from graphify.extract import _DISPATCH + assert ".pas" in _DISPATCH assert ".pp" in _DISPATCH assert ".dpr" in _DISPATCH @@ -134,6 +146,7 @@ def test_pascal_dispatch_registered(): def test_pascal_detect_extensions_registered(): from graphify.detect import CODE_EXTENSIONS + assert ".pas" in CODE_EXTENSIONS assert ".pp" in CODE_EXTENSIONS assert ".dpr" in CODE_EXTENSIONS @@ -144,38 +157,44 @@ def test_pascal_detect_extensions_registered(): # ── Lazarus Form (.lfm) ─────────────────────────────────────────────────────── + def test_lfm_no_error(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") assert "error" not in r def test_lfm_finds_root_form_class(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") - assert any("TSampleForm" in l for l in _labels(r)) + assert any("TSampleForm" in label for label in _labels(r)) def test_lfm_finds_component_classes(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") labels = _labels(r) - assert any("TPanel" in l for l in labels) - assert any("TButton" in l for l in labels) - assert any("TLabel" in l for l in labels) - assert any("TTimer" in l for l in labels) + assert any("TPanel" in label for label in labels) + assert any("TButton" in label for label in labels) + assert any("TLabel" in label for label in labels) + assert any("TTimer" in label for label in labels) def test_lfm_finds_event_handlers(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") labels = _labels(r) - assert any("ButtonOKClick" in l for l in labels) - assert any("TimerRefreshTimer" in l for l in labels) + assert any("ButtonOKClick" in label for label in labels) + assert any("TimerRefreshTimer" in label for label in labels) def test_lfm_event_edges_have_event_context(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") ref_edges = [e for e in r["edges"] if e["relation"] == "references"] assert ref_edges @@ -184,12 +203,14 @@ def test_lfm_event_edges_have_event_context(): def test_lfm_contains_edges_form_hierarchy(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") assert "contains" in _relations(r) def test_lfm_no_dangling_edges(): from graphify.extract import extract_lazarus_form + r = extract_lazarus_form(FIXTURES / "sample.lfm") node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: @@ -198,28 +219,33 @@ def test_lfm_no_dangling_edges(): # ── Lazarus Package (.lpk) ─────────────────────────────────────────────────── + def test_lpk_no_error(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") assert "error" not in r def test_lpk_finds_package_name(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") - assert any("SamplePackage" in l for l in _labels(r)) + assert any("SamplePackage" in label for label in _labels(r)) def test_lpk_finds_required_packages(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") labels = _labels(r) - assert any("FCL" in l for l in labels) - assert any("LCL" in l for l in labels) + assert any("FCL" in label for label in labels) + assert any("LCL" in label for label in labels) def test_lpk_imports_edges_have_import_context(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") import_edges = _edges_with_relation(r, "imports") assert import_edges @@ -228,14 +254,16 @@ def test_lpk_imports_edges_have_import_context(): def test_lpk_contains_listed_units(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") labels = _labels(r) - assert any("sample" in l.lower() for l in labels) - assert any("sampleutils" in l.lower() for l in labels) + assert any("sample" in label.lower() for label in labels) + assert any("sampleutils" in label.lower() for label in labels) def test_lpk_no_dangling_edges(): from graphify.extract import extract_lazarus_package + r = extract_lazarus_package(FIXTURES / "sample.lpk") node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: @@ -244,38 +272,44 @@ def test_lpk_no_dangling_edges(): # ── Delphi Form (.dfm) ─────────────────────────────────────────────────────── + def test_dfm_no_error(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") assert "error" not in r def test_dfm_finds_root_form_class(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") - assert any("TMainForm" in l for l in _labels(r)) + assert any("TMainForm" in label for label in _labels(r)) def test_dfm_finds_component_classes(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") labels = _labels(r) - assert any("TPanel" in l for l in labels) - assert any("TButton" in l for l in labels) - assert any("TMemo" in l for l in labels) - assert any("TStatusBar" in l for l in labels) + assert any("TPanel" in label for label in labels) + assert any("TButton" in label for label in labels) + assert any("TMemo" in label for label in labels) + assert any("TStatusBar" in label for label in labels) def test_dfm_finds_event_handlers(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") labels = _labels(r) - assert any("FormCreate" in l for l in labels) - assert any("ButtonOKClick" in l for l in labels) + assert any("FormCreate" in label for label in labels) + assert any("ButtonOKClick" in label for label in labels) def test_dfm_event_edges_have_event_context(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") ref_edges = [e for e in r["edges"] if e["relation"] == "references"] assert ref_edges @@ -284,12 +318,14 @@ def test_dfm_event_edges_have_event_context(): def test_dfm_contains_edges_form_hierarchy(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") assert "contains" in _relations(r) def test_dfm_no_dangling_edges(): from graphify.extract import extract_delphi_form + r = extract_delphi_form(FIXTURES / "sample.dfm") node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: @@ -298,7 +334,9 @@ def test_dfm_no_dangling_edges(): def test_dfm_binary_returns_empty_not_crash(): from graphify.extract import extract_delphi_form - import tempfile, pathlib + import tempfile + import pathlib + # Write a fake binary DFM (FF 0A magic header) with tempfile.NamedTemporaryFile(suffix=".dfm", delete=False) as f: f.write(b"\xff\x0a\x00\x00some binary data") @@ -314,9 +352,11 @@ def test_dfm_binary_returns_empty_not_crash(): def test_dfm_dispatch_registered(): from graphify.extract import _DISPATCH + assert ".dfm" in _DISPATCH def test_dfm_detect_extension_registered(): from graphify.detect import CODE_EXTENSIONS + assert ".dfm" in CODE_EXTENSIONS diff --git a/tests/test_path_cli.py b/tests/test_path_cli.py index de7e8837f..57ae3ebdd 100644 --- a/tests/test_path_cli.py +++ b/tests/test_path_cli.py @@ -1,23 +1,36 @@ """Regression tests for `graphify path` arrow direction (#849).""" + from __future__ import annotations import json -import networkx as nx -from networkx.readwrite import json_graph import graphify.__main__ as mainmod def _write_graph(tmp_path): graph_data = { - "directed": False, "multigraph": False, "graph": {}, + "directed": False, + "multigraph": False, + "graph": {}, "nodes": [ - {"id": "create_patch", "label": "createPatchHandler()", - "source_file": "server/create-patch-handler.ts", "community": 0}, - {"id": "validate", "label": "validateSanitySession()", - "source_file": "server/sanity-validate-session.ts", "community": 0}, + { + "id": "create_patch", + "label": "createPatchHandler()", + "source_file": "server/create-patch-handler.ts", + "community": 0, + }, + { + "id": "validate", + "label": "validateSanitySession()", + "source_file": "server/sanity-validate-session.ts", + "community": 0, + }, ], "links": [ - {"source": "create_patch", "target": "validate", - "relation": "calls", "confidence": "EXTRACTED"}, + { + "source": "create_patch", + "target": "validate", + "relation": "calls", + "confidence": "EXTRACTED", + }, ], } p = tmp_path / "graph.json" @@ -27,8 +40,9 @@ def _write_graph(tmp_path): def _run(monkeypatch, graph_path, src, tgt, capsys): monkeypatch.setattr(mainmod, "_check_skill_version", lambda _: None) - monkeypatch.setattr(mainmod.sys, "argv", - ["graphify", "path", src, tgt, "--graph", str(graph_path)]) + monkeypatch.setattr( + mainmod.sys, "argv", ["graphify", "path", src, tgt, "--graph", str(graph_path)] + ) mainmod.main() return capsys.readouterr().out diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index ce6055d8b..de977ad91 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -3,14 +3,13 @@ Uses the existing test fixtures (code + markdown). No LLM calls - AST extraction only. Catches regressions in how modules connect, not just individual module behaviour. """ + import json -import tempfile from pathlib import Path -import pytest from graphify.detect import detect -from graphify.extract import collect_files, extract +from graphify.extract import extract from graphify.build import build_from_json from graphify.cluster import cluster, score_all from graphify.analyze import god_nodes, surprising_connections, suggest_questions @@ -62,7 +61,18 @@ def run_pipeline(tmp_path: Path) -> dict: # Step 6: report tokens = {"input": 0, "output": 0} - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, str(FIXTURES), suggested_questions=questions) + report = generate( + G, + communities, + cohesion, + labels, + gods, + surprises, + detection, + tokens, + str(FIXTURES), + suggested_questions=questions, + ) assert "God Nodes" in report assert "Communities" in report assert len(report) > 100 @@ -85,7 +95,9 @@ def run_pipeline(tmp_path: Path) -> dict: # Step 9: export - Obsidian vault vault_path = tmp_path / "obsidian" - n_notes = to_obsidian(G, communities, str(vault_path), community_labels=labels, cohesion=cohesion) + n_notes = to_obsidian( + G, communities, str(vault_path), community_labels=labels, cohesion=cohesion + ) assert n_notes > 0 assert (vault_path / ".obsidian" / "graph.json").exists() md_files = list(vault_path.glob("*.md")) diff --git a/tests/test_prs.py b/tests/test_prs.py index 61ebc140e..1f7ef1fba 100644 --- a/tests/test_prs.py +++ b/tests/test_prs.py @@ -1,4 +1,5 @@ """Tests for graphify/prs.py.""" + from __future__ import annotations import subprocess @@ -6,7 +7,6 @@ from unittest.mock import patch, MagicMock import networkx as nx -import pytest from graphify.prs import ( PRInfo, @@ -23,6 +23,7 @@ # ── Helpers ─────────────────────────────────────────────────────────────────── + def make_pr( number: int = 1, title: str = "Test PR", @@ -54,6 +55,7 @@ def make_pr( # ── _classify ───────────────────────────────────────────────────────────────── + class TestClassify: def test_ready(self): pr = make_pr(ci_status="SUCCESS", review_decision="", is_draft=False) @@ -94,6 +96,7 @@ def test_wrong_base(self): # ── _parse_ci ───────────────────────────────────────────────────────────────── + class TestParseCi: def test_empty_rollup_returns_none(self): assert _parse_ci([]) == "NONE" @@ -128,6 +131,7 @@ def test_mixed_success_and_failure_is_failure(self): # ── _path_match ─────────────────────────────────────────────────────────────── + class TestPathMatch: def test_exact_match(self): assert _path_match("src/auth/api.py", "src/auth/api.py") is True @@ -150,6 +154,7 @@ def test_both_directions_work(self): # ── compute_pr_impact ───────────────────────────────────────────────────────── + class TestComputePrImpact: def _make_graph(self) -> nx.Graph: """3 nodes across 2 communities, 2 distinct source files.""" @@ -167,9 +172,7 @@ def test_matching_files_returns_correct_communities_and_count(self): def test_matching_both_files(self): G = self._make_graph() - comms, nodes = compute_pr_impact( - ["src/auth/api.py", "src/utils/helpers.py"], G - ) + comms, nodes = compute_pr_impact(["src/auth/api.py", "src/utils/helpers.py"], G) assert comms == [0, 1] assert nodes == 3 @@ -208,6 +211,7 @@ def test_no_double_counting_same_graph_file_matched_by_two_pr_files(self): # ── fetch_worktrees ─────────────────────────────────────────────────────────── + class TestFetchWorktrees: def test_normal_case_maps_branch_to_path(self): porcelain = ( @@ -279,6 +283,7 @@ def test_subprocess_failure_returns_empty_dict(self): # ── format_prs_text ─────────────────────────────────────────────────────────── + class TestFormatPrsText: def test_contains_pr_metadata_and_count_header(self): prs = [ @@ -330,6 +335,7 @@ def test_empty_pr_list(self): # ── _detect_default_branch ──────────────────────────────────────────────────── + class TestDetectDefaultBranch: def test_gh_returns_main(self): with patch( @@ -342,8 +348,9 @@ def test_falls_back_to_git_symbolic_ref(self): mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = "refs/remotes/origin/develop\n" - with patch("graphify.prs._gh", return_value=None), patch( - "graphify.prs.subprocess.run", return_value=mock_result + with ( + patch("graphify.prs._gh", return_value=None), + patch("graphify.prs.subprocess.run", return_value=mock_result), ): assert _detect_default_branch() == "develop" @@ -351,8 +358,9 @@ def test_both_fail_returns_main(self): mock_result = MagicMock() mock_result.returncode = 1 mock_result.stdout = "" - with patch("graphify.prs._gh", return_value=None), patch( - "graphify.prs.subprocess.run", return_value=mock_result + with ( + patch("graphify.prs._gh", return_value=None), + patch("graphify.prs.subprocess.run", return_value=mock_result), ): assert _detect_default_branch() == "main" @@ -361,27 +369,32 @@ def test_gh_returns_empty_dict_falls_back(self): mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = "refs/remotes/origin/trunk\n" - with patch("graphify.prs._gh", return_value={}), patch( - "graphify.prs.subprocess.run", return_value=mock_result + with ( + patch("graphify.prs._gh", return_value={}), + patch("graphify.prs.subprocess.run", return_value=mock_result), ): assert _detect_default_branch() == "trunk" def test_git_timeout_returns_main(self): - with patch("graphify.prs._gh", return_value=None), patch( - "graphify.prs.subprocess.run", - side_effect=subprocess.TimeoutExpired("git", 5), + with ( + patch("graphify.prs._gh", return_value=None), + patch( + "graphify.prs.subprocess.run", + side_effect=subprocess.TimeoutExpired("git", 5), + ), ): assert _detect_default_branch() == "main" # ── build_community_labels ───────────────────────────────────────────────────── + class TestBuildCommunityLabels: def test_basic_grouping(self): data = { "nodes": [ {"id": "a", "label": "Alpha", "community": 0}, - {"id": "b", "label": "Beta", "community": 0}, + {"id": "b", "label": "Beta", "community": 0}, {"id": "c", "label": "Gamma", "community": 1}, ] } diff --git a/tests/test_python_import_resolution.py b/tests/test_python_import_resolution.py index 2a517aaea..fb333eac4 100644 --- a/tests/test_python_import_resolution.py +++ b/tests/test_python_import_resolution.py @@ -23,9 +23,7 @@ def _node_id(result: dict, label: str, source_file: str) -> str: def _has_edge(result: dict, source: str, target: str, relation: str) -> bool: return any( - edge["source"] == source - and edge["target"] == target - and edge["relation"] == relation + edge["source"] == source and edge["target"] == target and edge["relation"] == relation for edge in result["edges"] ) @@ -35,9 +33,7 @@ def test_python_package_reexport_resolves_import_and_call_to_origin_symbol(tmp_p barrel = _write(tmp_path / "pkg/__init__.py", "from .foo import Foo as PublicFoo\n") consumer = _write( tmp_path / "app.py", - "from pkg import PublicFoo\n\n" - "def X():\n" - " return PublicFoo()\n", + "from pkg import PublicFoo\n\ndef X():\n return PublicFoo()\n", ) result = extract([origin, barrel, consumer], cache_root=tmp_path) @@ -57,10 +53,7 @@ def test_python_parameter_return_and_generic_contexts(tmp_path: Path): model = tmp_path / "pkg" / "model.py" model.parent.mkdir(parents=True) model.write_text( - "class Payload:\n" - " pass\n\n" - "class Result:\n" - " pass\n", + "class Payload:\n pass\n\nclass Result:\n pass\n", encoding="utf-8", ) service = tmp_path / "pkg" / "service.py" @@ -77,7 +70,11 @@ def test_python_parameter_return_and_generic_contexts(tmp_path: Path): labels = {node["id"]: node["label"] for node in result["nodes"]} edges = [edge for edge in result["edges"] if edge.get("relation") == "references"] pairs = { - (labels.get(e["source"], e["source"]), labels.get(e["target"], e["target"]), e.get("context")) + ( + labels.get(e["source"], e["source"]), + labels.get(e["target"], e["target"]), + e.get("context"), + ) for e in edges } diff --git a/tests/test_query_cli.py b/tests/test_query_cli.py index cf8eb6e56..ef3ebbe46 100644 --- a/tests/test_query_cli.py +++ b/tests/test_query_cli.py @@ -1,4 +1,5 @@ """Tests for graphify query CLI context filtering.""" + from __future__ import annotations import json diff --git a/tests/test_rationale.py b/tests/test_rationale.py index b52aa3909..8ab29d157 100644 --- a/tests/test_rationale.py +++ b/tests/test_rationale.py @@ -1,7 +1,7 @@ """Tests for rationale/docstring extraction in extract.py.""" + import textwrap from pathlib import Path -import pytest from graphify.extract import extract_python from graphify.build import build_from_json @@ -13,10 +13,13 @@ def _write_py(tmp_path: Path, code: str) -> Path: def test_module_docstring_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """This module handles authentication because legacy sessions were insecure.""" def login(): pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert len(rationale) >= 1 @@ -24,45 +27,57 @@ def login(): pass def test_function_docstring_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' def process(): """We use chunked processing here because the full dataset exceeds RAM.""" pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert any("chunked" in n["label"] for n in rationale) def test_class_docstring_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' class Cache: """Chosen over Redis because we need zero external dependencies in the test env.""" pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert any("Redis" in n["label"] for n in rationale) def test_rationale_comment_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + """ def build(): # NOTE: must run before compile() or linker will fail pass - ''') + """, + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert any("NOTE" in n["label"] for n in rationale) def test_rationale_for_edges_present(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """Module docstring explaining the why.""" def foo(): """Function docstring with rationale.""" pass - ''') + ''', + ) result = extract_python(path) rationale_edges = [e for e in result["edges"] if e.get("relation") == "rationale_for"] assert len(rationale_edges) >= 1 @@ -70,28 +85,36 @@ def foo(): def test_short_docstring_ignored(tmp_path): """Trivial docstrings under 20 chars should not become rationale nodes.""" - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' def foo(): """Constructor.""" pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert len(rationale) == 0 def test_rationale_confidence_is_extracted(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """This module exists because we needed a standalone parser.""" def parse(): pass - ''') + ''', + ) result = extract_python(path) rationale_edges = [e for e in result["edges"] if e.get("relation") == "rationale_for"] assert all(e.get("confidence") == "EXTRACTED" for e in rationale_edges) def test_alembic_module_docstring_suppressed(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """initial schema Revision ID: 0001abcd @@ -107,7 +130,8 @@ def upgrade(): def downgrade(): pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert not any("Revision ID" in n["label"] for n in rationale) @@ -115,7 +139,9 @@ def downgrade(): def test_alembic_function_docstrings_still_extracted(tmp_path): """Function docstrings inside upgrade/downgrade should still be captured.""" - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """Revision ID: 0002 Revises: 0001""" revision = "0002" down_revision = "0001" @@ -126,7 +152,8 @@ def upgrade(): def downgrade(): pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] # module docstring suppressed @@ -137,39 +164,48 @@ def downgrade(): def test_non_migration_revision_var_not_suppressed(tmp_path): """A file with a `revision` variable but no Alembic markers keeps its docstring.""" - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """This module tracks document revisions because we need audit history.""" revision = 42 def get_revision(): pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert any("audit history" in n["label"] for n in rationale) def test_django_migration_module_docstring_suppressed(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """Add post_priority_config table.""" from django.db import migrations class Migration(migrations.Migration): dependencies = [("myapp", "0001_initial")] operations = [] - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert not any("post_priority" in n["label"] for n in rationale) def test_generated_file_module_docstring_suppressed(tmp_path): - path = _write_py(tmp_path, ''' + path = _write_py( + tmp_path, + ''' """Generated by the protocol buffer compiler. DO NOT EDIT!""" from google.protobuf import descriptor as _descriptor class UserMessage: pass - ''') + ''', + ) result = extract_python(path) rationale = [n for n in result["nodes"] if n.get("file_type") == "rationale"] assert not any("protocol buffer" in n["label"].lower() for n in rationale) diff --git a/tests/test_report.py b/tests/test_report.py index a5b3916a1..d9b9253d6 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -7,6 +7,7 @@ FIXTURES = Path(__file__).parent / "fixtures" + def make_inputs(): extraction = json.loads((FIXTURES / "extraction.json").read_text()) G = build_from_json(extraction) @@ -19,45 +20,78 @@ def make_inputs(): tokens = {"input": extraction["input_tokens"], "output": extraction["output_tokens"]} return G, communities, cohesion, labels, gods, surprises, detection, tokens + def test_report_contains_header(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "# Graph Report" in report + def test_report_contains_corpus_check(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## Corpus Check" in report + def test_report_contains_god_nodes(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## God Nodes" in report + def test_report_contains_surprising_connections(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## Surprising Connections" in report + def test_report_contains_communities(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## Communities" in report + def test_report_contains_ambiguous_section(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "## Ambiguous Edges" in report + def test_report_shows_token_cost(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project") + report = generate( + G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project" + ) assert "Token cost" in report assert "1,200" in report + def test_report_shows_raw_cohesion_scores(): G, communities, cohesion, labels, gods, surprises, detection, tokens = make_inputs() - report = generate(G, communities, cohesion, labels, gods, surprises, detection, tokens, "./project", min_community_size=1) + report = generate( + G, + communities, + cohesion, + labels, + gods, + surprises, + detection, + tokens, + "./project", + min_community_size=1, + ) assert "Cohesion:" in report assert "✓" not in report assert "⚠" not in report diff --git a/tests/test_security.py b/tests/test_security.py index c547ab842..32e160865 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,7 +1,7 @@ """Tests for graphify/security.py - URL validation, safe fetch, path guards, label sanitisation.""" + from __future__ import annotations -import json import urllib.error from pathlib import Path from typing import Any @@ -17,9 +17,7 @@ safe_fetch_text, validate_graph_path, validate_url, - _MAX_FETCH_BYTES, _MAX_GRAPH_FILE_BYTES, - _MAX_TEXT_BYTES, _METADATA_MAX_LIST_ITEMS, _METADATA_MAX_VALUE_LEN, _sanitize_metadata_string, @@ -31,24 +29,30 @@ # validate_url # --------------------------------------------------------------------------- + def test_validate_url_accepts_http(): assert validate_url("http://example.com/page") == "http://example.com/page" + def test_validate_url_accepts_https(): assert validate_url("https://arxiv.org/abs/1706.03762") == "https://arxiv.org/abs/1706.03762" + def test_validate_url_rejects_file(): with pytest.raises(ValueError, match="file"): validate_url("file:///etc/passwd") + def test_validate_url_rejects_ftp(): with pytest.raises(ValueError, match="ftp"): validate_url("ftp://files.example.com/data.zip") + def test_validate_url_rejects_data(): with pytest.raises(ValueError, match="data"): validate_url("data:text/html,") + def test_validate_url_rejects_empty_scheme(): with pytest.raises(ValueError): validate_url("//no-scheme.example.com") @@ -58,13 +62,14 @@ def test_validate_url_rejects_empty_scheme(): # safe_fetch - scheme and redirect guards (mocked network) # --------------------------------------------------------------------------- + def _make_mock_response(content: bytes, status: int = 200): mock = MagicMock() mock.__enter__ = lambda s: s mock.__exit__ = MagicMock(return_value=False) mock.status = status mock.code = status - chunks = [content[i:i+65536] for i in range(0, len(content), 65536)] + [b""] + chunks = [content[i : i + 65536] for i in range(0, len(content), 65536)] + [b""] mock.read.side_effect = chunks return mock @@ -73,10 +78,12 @@ def test_safe_fetch_rejects_file_url(): with pytest.raises(ValueError, match="file"): safe_fetch("file:///etc/passwd") + def test_safe_fetch_rejects_ftp_url(): with pytest.raises(ValueError, match="ftp"): safe_fetch("ftp://example.com/file.zip") + def test_safe_fetch_returns_bytes(tmp_path): mock_resp = _make_mock_response(b"hello world") with patch("graphify.security._build_opener") as mock_opener_fn: @@ -86,6 +93,7 @@ def test_safe_fetch_returns_bytes(tmp_path): result = safe_fetch("https://example.com/") assert result == b"hello world" + def test_safe_fetch_raises_on_non_2xx(): mock_resp = _make_mock_response(b"Not Found", status=404) with patch("graphify.security._build_opener") as mock_opener_fn: @@ -95,6 +103,7 @@ def test_safe_fetch_raises_on_non_2xx(): with pytest.raises(urllib.error.HTTPError): safe_fetch("https://example.com/missing") + def test_safe_fetch_raises_on_size_exceeded(): # Build a response larger than max_bytes big_chunk = b"x" * 65_537 @@ -118,6 +127,7 @@ def test_safe_fetch_raises_on_size_exceeded(): # safe_fetch_text # --------------------------------------------------------------------------- + def test_safe_fetch_text_decodes_utf8(): content = "héllo wörld".encode("utf-8") mock_resp = _make_mock_response(content) @@ -128,6 +138,7 @@ def test_safe_fetch_text_decodes_utf8(): result = safe_fetch_text("https://example.com/") assert result == "héllo wörld" + def test_safe_fetch_text_replaces_bad_bytes(): bad = b"hello \xff world" mock_resp = _make_mock_response(bad) @@ -145,6 +156,7 @@ def test_safe_fetch_text_replaces_bad_bytes(): # validate_graph_path # --------------------------------------------------------------------------- + def test_validate_graph_path_allows_inside_base(tmp_path): base = tmp_path / "graphify-out" base.mkdir() @@ -153,6 +165,7 @@ def test_validate_graph_path_allows_inside_base(tmp_path): result = validate_graph_path(str(graph), base=base) assert result == graph.resolve() + def test_validate_graph_path_blocks_traversal(tmp_path): base = tmp_path / "graphify-out" base.mkdir() @@ -160,11 +173,13 @@ def test_validate_graph_path_blocks_traversal(tmp_path): with pytest.raises(ValueError, match="escapes"): validate_graph_path(str(evil), base=base) + def test_validate_graph_path_requires_base_exists(tmp_path): base = tmp_path / "graphify-out" # not created with pytest.raises(ValueError, match="does not exist"): validate_graph_path(str(base / "graph.json"), base=base) + def test_validate_graph_path_raises_if_file_missing(tmp_path): base = tmp_path / "graphify-out" base.mkdir() @@ -176,22 +191,26 @@ def test_validate_graph_path_raises_if_file_missing(tmp_path): # sanitize_label # --------------------------------------------------------------------------- + def test_sanitize_label_passthrough_html_chars(): # sanitize_label does NOT HTML-escape — callers that inject into HTML must # wrap with html.escape() themselves (e.g. the title in to_html()) assert sanitize_label("