diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 5672bf0df..a33ec7cea 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -26,6 +26,7 @@ Each stage is a single function in its own module. They communicate through plai | `cache.py` | `check_semantic_cache / save_semantic_cache` | files → (cached, uncached) split | | `security.py` | validation helpers | URL / path / label → validated or raises | | `validate.py` | `validate_extraction(data)` | extraction dict → raises on schema errors | +| `storage.py` | `init_db / ingest_extraction / ingest_communities` | extraction dict → NeuG `graph.db` (optional, requires `neug`) | | `serve.py` | `start_server(graph_path)` | graph file path → MCP stdio server | | `watch.py` | `watch(root, flag_path)` | directory → writes flag file on change | | `benchmark.py` | `run_benchmark(graph_path)` | graph file → corpus vs subgraph token comparison | diff --git a/README.md b/README.md index 1e4467e74..2280a1144 100644 --- a/README.md +++ b/README.md @@ -159,6 +159,7 @@ Install only what you need: | `video` | Video/audio transcription (faster-whisper + yt-dlp) | `uv tool install "graphifyy[video]"` | | `mcp` | MCP stdio server | `uv tool install "graphifyy[mcp]"` | | `neo4j` | Neo4j push support | `uv tool install "graphifyy[neo4j]"` | +| `neug` | [NeuG](https://github.com/alibaba/neug) embedded graph database — Cypher queries on your graph | `uv tool install "graphifyy[neug]"` | | `svg` | SVG graph export | `uv tool install "graphifyy[svg]"` | | `leiden` | Leiden community detection (Python < 3.13 only) | `uv tool install "graphifyy[leiden]"` | | `ollama` | Ollama local inference | `uv tool install "graphifyy[ollama]"` | @@ -449,6 +450,9 @@ graphify install # overwrites the skill file /graphify ./raw --graphml # export for Gephi / yEd /graphify ./raw --neo4j # generate cypher.txt for Neo4j /graphify ./raw --neo4j-push bolt://localhost:7687 + +graphify cypher "MATCH (n) RETURN n LIMIT 10" # query graph.db with Cypher (requires neug) +graphify cypher "MATCH (n:code)-[e]->(m) RETURN n.id, e, m.id LIMIT 10" --db path/to/graph.db # default: graphify-out/graph.db /graphify ./raw --watch # auto-sync as files change /graphify ./raw --mcp # start MCP stdio server diff --git a/graphify/__main__.py b/graphify/__main__.py index 8ac1f9f85..ad39a2ddc 100644 --- a/graphify/__main__.py +++ b/graphify/__main__.py @@ -1761,6 +1761,8 @@ def main() -> None: print(" --backend= backend to use for community naming (default: auto-detect)") print(" label (re)name communities with the configured LLM backend, regenerate report") print(" --backend= backend to use (default: auto-detect from API keys)") + print(" cypher \"MATCH ...\" execute a Cypher query against graph.db (requires neug)") + print(" --db path to graph.db (default graphify-out/graph.db)") print(" query \"\" BFS traversal of graph.json for a question") print(" --dfs use depth-first instead of breadth-first") print(" --context C explicit edge-context filter (repeatable)") @@ -2240,6 +2242,31 @@ def main() -> None: else: print("Usage: graphify hook [install|uninstall|status]", file=sys.stderr) sys.exit(1) + elif cmd == "cypher": + if len(sys.argv) < 3: + print('Usage: graphify cypher "MATCH ..." [--db path]', file=sys.stderr) + sys.exit(1) + query_str = sys.argv[2] + db_path = str(Path(_GRAPHIFY_OUT) / "graph.db") + args = sys.argv[3:] + for i, a in enumerate(args): + if a == "--db" and i + 1 < len(args): + db_path = args[i + 1] + try: + from graphify.storage import init_db, execute_cypher, close_db + except ImportError: + print("error: neug is not installed. Run: pip install neug", file=sys.stderr) + sys.exit(1) + if not Path(db_path).exists(): + print(f"error: database not found: {db_path}", file=sys.stderr) + sys.exit(1) + db, conn = init_db(db_path) + try: + results = execute_cypher(conn, query_str) + for row in results: + print("\t".join(str(v) for v in row)) + finally: + close_db(db, conn) elif cmd == "query": if len(sys.argv) < 3: print("Usage: graphify query \"\" [--dfs] [--context C] [--budget N] [--graph path]", file=sys.stderr) @@ -3829,6 +3856,21 @@ def _progress(idx: int, total: int, _result: dict) -> None: graph_json_path.write_text( json.dumps(merged, indent=2), encoding="utf-8" ) + try: + from graphify.storage import init_db as _init_db, ensure_schema as _ensure_schema, ingest_extraction as _ingest, close_db as _close_db + _db_path = str(graphify_out / "graph.db") + _is_inc = Path(_db_path).exists() + _db, _conn = _init_db(_db_path) + _known = _ensure_schema(_conn, create_tables=not _is_inc) + _ingest(_conn, merged, incremental=_is_inc, + prune_sources=deleted_files or None, root=target, + known_tables=_known) + _close_db(_db, _conn) + print("[graphify extract] graph.db written (powered by NeuG)") + except ImportError: + pass + except Exception as _exc: + print(f"[graphify extract] warning: NeuG write failed: {_exc}", file=sys.stderr) cost = _estimate_cost( backend, merged["input_tokens"], merged["output_tokens"] ) @@ -3906,6 +3948,22 @@ def _progress(idx: int, total: int, _result: dict) -> None: from graphify.export import backup_if_protected as _backup _backup(graphify_out) _to_json(G, communities, str(graph_json_path), force=True) + try: + from graphify.storage import init_db as _init_db, ensure_schema as _ensure_schema, ingest_extraction as _ingest, ingest_communities as _ingest_comm, close_db as _close_db + _db_path = str(graphify_out / "graph.db") + _is_inc = Path(_db_path).exists() + _db, _conn = _init_db(_db_path) + _known = _ensure_schema(_conn, create_tables=not _is_inc) + _ntypes = _ingest(_conn, merged, incremental=_is_inc, + prune_sources=deleted_files or None, root=target, + known_tables=_known) + _ingest_comm(_conn, communities, node_types=_ntypes) + _close_db(_db, _conn) + print("[graphify extract] graph.db written (powered by NeuG)") + except ImportError: + pass + except Exception as _exc: + print(f"[graphify extract] warning: NeuG write failed: {_exc}", file=sys.stderr) if merged.get("output_tokens", 0) > 0: (graphify_out / ".graphify_semantic_marker").write_text( json.dumps({"output_tokens": merged["output_tokens"]}), encoding="utf-8" diff --git a/graphify/serve.py b/graphify/serve.py index 6e5d4a1f6..fae867081 100644 --- a/graphify/serve.py +++ b/graphify/serve.py @@ -487,6 +487,20 @@ def serve(graph_path: str = "graphify-out/graph.json") -> None: G = _load_graph(graph_path) communities = _communities_from_graph(G) + _neug_conn = None + _neug_db = None + _neug_execute = None + try: + from graphify.storage import init_db as _neug_init, execute_cypher as _neug_exec, close_db as _neug_close + _neug_db_path = str(Path(graph_path).parent / "graph.db") + if Path(_neug_db_path).exists(): + _neug_db, _neug_conn = _neug_init(_neug_db_path) + _neug_execute = _neug_exec + except ImportError: + pass + except Exception: + pass + # Hot-reload state: mtime+size key lets us detect graph.json changes without # polling. Initialised from the file stat at startup so the first tool call # never triggers a redundant reload. @@ -646,6 +660,20 @@ async def list_tools() -> list[types.Tool]: }, }, ), + types.Tool( + name="cypher_query", + description=( + "Execute a Cypher query against the NeuG graph database. " + "Returns tabular results. Requires neug to be installed and graph.db to exist." + ), + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "Cypher query string"}, + }, + "required": ["query"], + }, + ), ] def _tool_query_graph(arguments: dict) -> str: @@ -882,6 +910,22 @@ def _tool_triage_prs(arguments: dict) -> str: ) return "\n\n".join(lines) + def _tool_cypher_query(arguments: dict) -> str: + if _neug_conn is None: + return "NeuG not available (not installed or graph.db not found)." + query = arguments["query"] + from graphify.storage import execute_cypher as _exec_cypher + try: + results = _exec_cypher(_neug_conn, query) + except RuntimeError as exc: + return f"Cypher error: {exc}" + if not results: + return "No results." + lines = [] + for row in results: + lines.append("\t".join(str(v) for v in row)) + return "\n".join(lines) + _handlers = { "query_graph": _tool_query_graph, "get_node": _tool_get_node, @@ -893,6 +937,7 @@ def _tool_triage_prs(arguments: dict) -> str: "list_prs": _tool_list_prs, "get_pr_impact": _tool_get_pr_impact, "triage_prs": _tool_triage_prs, + "cypher_query": _tool_cypher_query, } def _load_community_labels() -> dict[int, str]: diff --git a/graphify/storage.py b/graphify/storage.py new file mode 100644 index 000000000..ce55a8539 --- /dev/null +++ b/graphify/storage.py @@ -0,0 +1,320 @@ +"""NeuG graph database adapter for graphify. + +Provides an optional parallel storage engine alongside NetworkX. +NeuG is lazily imported — when not installed, callers should catch +ImportError at the call site and skip silently. + +All property values interpolated into Cypher statements use NeuG's native +parameterised queries ($param syntax) to prevent injection. Table/label +names (which come from a fixed internal set, not user input) are still +interpolated as identifiers. +""" +from __future__ import annotations + +import re +from pathlib import Path + +from .build import _FILE_TYPE_SYNONYMS, _normalize_id, _norm_source_file +from .validate import VALID_FILE_TYPES + +# --------------------------------------------------------------------------- +# Node tables (one per file_type) +# --------------------------------------------------------------------------- + +_NODE_TABLES = { + "code": """CREATE NODE TABLE IF NOT EXISTS code ( + id STRING PRIMARY KEY, label STRING, + source_file STRING, source_location STRING, community INT64)""", + "document": """CREATE NODE TABLE IF NOT EXISTS document ( + id STRING PRIMARY KEY, label STRING, + source_file STRING, community INT64)""", + "paper": """CREATE NODE TABLE IF NOT EXISTS paper ( + id STRING PRIMARY KEY, label STRING, + source_file STRING, community INT64)""", + "image": """CREATE NODE TABLE IF NOT EXISTS image ( + id STRING PRIMARY KEY, label STRING, + source_file STRING, community INT64)""", + "concept": """CREATE NODE TABLE IF NOT EXISTS concept ( + id STRING PRIMARY KEY, label STRING, + source_file STRING, community INT64)""", + "rationale": """CREATE NODE TABLE IF NOT EXISTS rationale ( + id STRING PRIMARY KEY, label STRING, + source_file STRING, community INT64)""", +} + +# --------------------------------------------------------------------------- +# Edge tables — split by (src_type, tgt_type, relation). +# --------------------------------------------------------------------------- + +_EDGE_DDL_TEMPLATE = """CREATE REL TABLE IF NOT EXISTS {tbl}( + FROM {src} TO {tgt}, + relation STRING, confidence STRING, + confidence_score DOUBLE, source_file STRING, weight DOUBLE)""" + +# Known relation types per (src, tgt) pair — pre-built at init time. +_KNOWN_RELATIONS: dict[tuple[str, str], list[str]] = { + ("code", "code"): [ + "calls", "contains", "method", "uses", "inherits", "defines", + "references", "imports", "imports_from", "listened_by", "case_of", + "references_constant", "bound_to", "uses_static_prop", "uses_config", + ], + ("rationale", "code"): ["rationale_for"], +} + + +def _sanitize_rel_name(relation: str) -> str: + """Normalize a relation string into a safe table-name suffix.""" + r = relation.lower().strip() + r = re.sub(r"[^a-z0-9_]", "_", r) + r = re.sub(r"_+", "_", r).strip("_") + return r or "rel" + + +def _edge_table_name(src_type: str, tgt_type: str, relation: str) -> str: + return f"edge_{src_type}_{tgt_type}_{_sanitize_rel_name(relation)}" + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def init_db(db_path: str) -> tuple: + """Open (or create) a NeuG database and connect. + + Returns (db, conn). Raises ImportError if neug is not installed. + """ + import neug + db = neug.Database(db_path) + conn = db.connect() + return db, conn + + +def ensure_schema(conn: object, *, create_tables: bool = True) -> set[str]: + """Populate known table registry; optionally execute DDL. + + create_tables=True (first build): run CREATE TABLE statements. + create_tables=False (incremental): only build the registry set + so _ensure_rel_table() knows what exists. + + Returns the set of known rel table names (per-connection registry). + """ + created: set[str] = set() + + if create_tables: + for ddl in _NODE_TABLES.values(): + conn.execute(ddl) + + for (src, tgt), rels in _KNOWN_RELATIONS.items(): + for rel in rels: + tbl = _edge_table_name(src, tgt, rel) + if create_tables: + conn.execute(_EDGE_DDL_TEMPLATE.format(tbl=tbl, src=src, tgt=tgt)) + created.add(tbl) + + return created + + +def _ensure_rel_table( + conn: object, src_type: str, tgt_type: str, relation: str, + known: set[str], +) -> str: + """Resolve edge table name, creating on-the-fly if needed. Returns table name.""" + tbl = _edge_table_name(src_type, tgt_type, relation) + if tbl in known: + return tbl + conn.execute(_EDGE_DDL_TEMPLATE.format(tbl=tbl, src=src_type, tgt=tgt_type)) + known.add(tbl) + return tbl + + +def _fix_file_type(ft: str | None) -> str: + """Canonicalize file_type, matching build.py:138-146 logic.""" + if not ft or ft not in VALID_FILE_TYPES: + return _FILE_TYPE_SYNONYMS.get(ft, "concept") if ft else "concept" + return ft + + +def ingest_extraction( + conn: object, + extraction: dict, + *, + incremental: bool = False, + prune_sources: list[str] | None = None, + root: str | Path | None = None, + known_tables: set[str] | None = None, +) -> dict[str, str]: + """Write an extraction dict into NeuG. + + incremental=False: first build — uses CREATE (faster). + incremental=True: update — uses MERGE (upsert). + + Returns node_types dict (id -> file_type) for use by ingest_communities. + """ + _root = str(Path(root).resolve()) if root else None + _known = known_tables if known_tables is not None else set() + + # --- prune deleted/changed files first --- + if prune_sources: + for sf in prune_sources: + sf_norm = _norm_source_file(sf, _root) or sf + for tbl in _NODE_TABLES: + conn.execute( + f"MATCH (n:{tbl}) WHERE n.source_file = $sf DETACH DELETE n", + parameters={"sf": sf_norm}, + ) + + # --- build node lookup: id -> file_type --- + node_types: dict[str, str] = {} + nodes = extraction.get("nodes") or [] + edges = extraction.get("edges") or [] + + # --- write nodes --- + _written_ids: set[str] = set() + _n_errors = 0 + for node in nodes: + nid = _normalize_id(node.get("id", "")) + if not nid: + continue + ft = _fix_file_type(node.get("file_type")) + label = node.get("label", "") + sf = _norm_source_file(node.get("source_file"), _root) or "" + sl = node.get("source_location") or "" + node_types[nid] = ft + if nid in _written_ids: + continue + _written_ids.add(nid) + + try: + if incremental: + if ft == "code": + conn.execute( + f"MERGE (n:code {{id: $nid}}) " + f"ON CREATE SET n.label = $label, " + f"n.source_file = $sf, n.source_location = $sl " + f"ON MATCH SET n.label = $label, " + f"n.source_file = $sf, n.source_location = $sl", + parameters={"nid": nid, "label": label, "sf": sf, "sl": sl}, + ) + else: + conn.execute( + f"MERGE (n:{ft} {{id: $nid}}) " + f"ON CREATE SET n.label = $label, n.source_file = $sf " + f"ON MATCH SET n.label = $label, n.source_file = $sf", + parameters={"nid": nid, "label": label, "sf": sf}, + ) + else: + if ft == "code": + conn.execute( + f"CREATE (n:code {{id: $nid, label: $label, " + f"source_file: $sf, source_location: $sl}})", + parameters={"nid": nid, "label": label, "sf": sf, "sl": sl}, + ) + else: + conn.execute( + f"CREATE (n:{ft} {{id: $nid, label: $label, " + f"source_file: $sf}})", + parameters={"nid": nid, "label": label, "sf": sf}, + ) + except RuntimeError: + _n_errors += 1 + + # --- write edges --- + _e_errors = 0 + for edge in edges: + src_key = edge.get("source") or edge.get("from", "") + tgt_key = edge.get("target") or edge.get("to", "") + src_id = _normalize_id(src_key) + tgt_id = _normalize_id(tgt_key) + if not src_id or not tgt_id: + continue + + src_ft = node_types.get(src_id) + tgt_ft = node_types.get(tgt_id) + if not src_ft or not tgt_ft: + continue + + rel_raw = edge.get("relation", "") + conf_raw = edge.get("confidence", "") + tbl = _ensure_rel_table(conn, src_ft, tgt_ft, rel_raw, _known) + conf_score = float(edge.get("confidence_score", 0.0)) + e_sf = _norm_source_file(edge.get("source_file"), _root) or "" + weight = float(edge.get("weight", 1.0)) + + try: + conn.execute( + f"MATCH (a:{src_ft} {{id: $src_id}}), " + f"(b:{tgt_ft} {{id: $tgt_id}}) " + f"CREATE (a)-[:{tbl} {{relation: $rel, confidence: $conf, " + f"confidence_score: $conf_score, source_file: $e_sf, " + f"weight: $weight}}]->(b)", + parameters={ + "src_id": src_id, "tgt_id": tgt_id, + "rel": rel_raw, "conf": conf_raw, + "conf_score": conf_score, "e_sf": e_sf, + "weight": weight, + }, + ) + except RuntimeError: + _e_errors += 1 + + if _n_errors or _e_errors: + import logging + logging.getLogger(__name__).warning( + "NeuG ingest: %d node(s) and %d edge(s) skipped due to errors", + _n_errors, _e_errors, + ) + + return node_types + + +def ingest_communities( + conn: object, + communities: dict[int, list[str]], + community_labels: dict[int, str] | None = None, + node_types: dict[str, str] | None = None, +) -> None: + """Write community assignments into NeuG node properties. + + If node_types is provided (id -> file_type mapping from ingest_extraction), + each node is looked up in its specific table directly. Otherwise falls + back to probing all 6 tables (slower). + + Note: NeuG does not support parameterised SET for non-string values, + so community ID is interpolated as an integer literal. The id value + uses a parameterised query. + """ + for cid, node_ids in communities.items(): + cid_int = int(cid) + for nid in node_ids: + nid_norm = _normalize_id(nid) + if not nid_norm: + continue + if node_types and nid_norm in node_types: + tbl = node_types[nid_norm] + conn.execute( + f"MATCH (n:{tbl}) WHERE n.id = $nid " + f"SET n.community = {cid_int}", + parameters={"nid": nid_norm}, + ) + else: + for tbl in _NODE_TABLES: + conn.execute( + f"MATCH (n:{tbl}) WHERE n.id = $nid " + f"SET n.community = {cid_int}", + parameters={"nid": nid_norm}, + ) + + +def execute_cypher(conn: object, query: str) -> list[list]: + """Execute a Cypher query and return results as list of lists.""" + try: + return list(conn.execute(query)) + except RuntimeError as exc: + raise RuntimeError(f"Cypher query failed: {exc}") from exc + + +def close_db(db: object, conn: object) -> None: + """Close the NeuG connection and database.""" + conn.close() + db.close() diff --git a/pyproject.toml b/pyproject.toml index 8c4c071d5..a7d2d496a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,11 +64,12 @@ gemini = ["openai", "tiktoken"] openai = ["openai", "tiktoken"] chinese = ["jieba"] sql = ["tree-sitter-sql"] +neug = ["neug>=0.1.2,<0.2"] # tree-sitter-dm (BYOND DreamMaker) ships only a Windows wheel, so on Linux/Mac it # must compile from source (needs a C toolchain + python3-dev). Keeping it optional # avoids breaking the default `uv tool install graphifyy` for everyone (#1104). dm = ["tree-sitter-dm"] -all = ["mcp", "neo4j", "pypdf", "markdownify", "watchdog", "graspologic; python_version < '3.13'", "python-docx", "openpyxl", "faster-whisper; python_version >= '3.11'", "yt-dlp", "matplotlib", "openai", "tiktoken", "boto3", "tree-sitter-sql", "jieba", "tree-sitter-dm"] +all = ["mcp", "neo4j", "neug>=0.1.2,<0.2", "pypdf", "markdownify", "watchdog", "graspologic; python_version < '3.13'", "python-docx", "openpyxl", "faster-whisper; python_version >= '3.11'", "yt-dlp", "matplotlib", "openai", "tiktoken", "boto3", "tree-sitter-sql", "jieba", "tree-sitter-dm"] [project.scripts] graphify = "graphify.__main__:main" diff --git a/tests/test_cypher_cli.py b/tests/test_cypher_cli.py new file mode 100644 index 000000000..8e18a3125 --- /dev/null +++ b/tests/test_cypher_cli.py @@ -0,0 +1,59 @@ +"""Tests for the `graphify cypher` CLI command.""" +import json +import subprocess +import sys +import tempfile +from pathlib import Path + +import pytest + +try: + import neug + _has_neug = True +except ImportError: + _has_neug = False + +pytestmark = pytest.mark.skipif(not _has_neug, reason="neug not installed") + +FIXTURES = Path(__file__).parent / "fixtures" +EXTRACTION_JSON = FIXTURES / "extraction.json" + + +def _build_db(tmp_path) -> str: + from graphify.storage import init_db, ensure_schema, ingest_extraction, close_db + db_path = str(tmp_path / "graph.db") + ext = json.loads(EXTRACTION_JSON.read_text()) + db, conn = init_db(db_path) + known = ensure_schema(conn) + ingest_extraction(conn, ext, incremental=False, known_tables=known) + close_db(db, conn) + return db_path + + +def test_cypher_command_basic(tmp_path): + db_path = _build_db(tmp_path) + result = subprocess.run( + [sys.executable, "-m", "graphify", "cypher", + "MATCH (n:code) RETURN count(n)", "--db", db_path], + capture_output=True, text=True, timeout=30, + ) + assert result.returncode == 0 + assert "3" in result.stdout + + +def test_cypher_command_db_not_found(tmp_path): + result = subprocess.run( + [sys.executable, "-m", "graphify", "cypher", + "MATCH (n) RETURN n", "--db", str(tmp_path / "nonexistent.db")], + capture_output=True, text=True, timeout=30, + ) + assert result.returncode != 0 + assert "not found" in result.stderr.lower() or "error" in result.stderr.lower() + + +def test_cypher_command_no_query(): + result = subprocess.run( + [sys.executable, "-m", "graphify", "cypher"], + capture_output=True, text=True, timeout=30, + ) + assert result.returncode != 0 diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 000000000..ad6bcfd02 --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,180 @@ +"""Tests for graphify.storage — NeuG adapter layer.""" +import json +import shutil +import tempfile +from pathlib import Path + +import pytest + +try: + import neug + _has_neug = True +except ImportError: + _has_neug = False + +pytestmark = pytest.mark.skipif(not _has_neug, reason="neug not installed") + +FIXTURES = Path(__file__).parent / "fixtures" +EXTRACTION_JSON = FIXTURES / "extraction.json" + + +def _load_extraction() -> dict: + return json.loads(EXTRACTION_JSON.read_text()) + + +@pytest.fixture() +def tmp_db(tmp_path): + db_path = str(tmp_path / "test.db") + yield db_path + + +def _init(db_path): + from graphify.storage import init_db, ensure_schema + db, conn = init_db(db_path) + ensure_schema(conn) + return db, conn + + +def _close(db, conn): + from graphify.storage import close_db + close_db(db, conn) + + +def _query(conn, cypher): + from graphify.storage import execute_cypher + return execute_cypher(conn, cypher) + + +# --- init_db --- + +def test_init_db_creates_tables(tmp_db): + db, conn = _init(tmp_db) + for tbl in ("code", "document", "paper", "image", "concept", "rationale"): + rows = _query(conn, f"MATCH (n:{tbl}) RETURN count(n)") + assert rows == [[0]] + _close(db, conn) + + +# --- ingest_extraction: CREATE mode --- + +def test_ingest_extraction_create_mode(tmp_db): + from graphify.storage import ingest_extraction + db, conn = _init(tmp_db) + ext = _load_extraction() + ingest_extraction(conn, ext, incremental=False) + rows = _query(conn, "MATCH (n:code) RETURN n.id ORDER BY n.id") + ids = sorted([r[0] for r in rows]) + assert "n_attention" in ids + assert "n_transformer" in ids + assert "n_layernorm" in ids + edge_rows = _query(conn, "MATCH (a:code)-[e:edge_code_code_contains]->(b:code) RETURN count(e)") + assert edge_rows[0][0] == 2 + _close(db, conn) + + +# --- ingest_extraction: MERGE mode --- + +def test_ingest_extraction_merge_mode(tmp_db): + from graphify.storage import ingest_extraction + db, conn = _init(tmp_db) + ext = _load_extraction() + ingest_extraction(conn, ext, incremental=False) + ext["nodes"][0]["label"] = "TransformerV2" + ingest_extraction(conn, ext, incremental=True) + rows = _query(conn, "MATCH (n:code) WHERE n.id = 'n_transformer' RETURN n.label") + assert rows[0][0] == "TransformerV2" + count = _query(conn, "MATCH (n:code) RETURN count(n)") + assert count[0][0] == 3 + _close(db, conn) + + +# --- file_type routing --- + +def test_ingest_extraction_file_type_routing(tmp_db): + from graphify.storage import ingest_extraction + db, conn = _init(tmp_db) + ext = _load_extraction() + ingest_extraction(conn, ext, incremental=False) + doc_rows = _query(conn, "MATCH (n:document) RETURN n.id") + assert len(doc_rows) == 1 + assert doc_rows[0][0] == "n_concept_attn" + _close(db, conn) + + +# --- prune_sources --- + +def test_ingest_extraction_prune(tmp_db): + from graphify.storage import ingest_extraction + db, conn = _init(tmp_db) + ext = _load_extraction() + ingest_extraction(conn, ext, incremental=False) + before = _query(conn, "MATCH (n:code) RETURN count(n)")[0][0] + assert before == 3 + ingest_extraction(conn, ext, incremental=True, prune_sources=["model.py"]) + after_prune = _query(conn, "MATCH (n:code) RETURN count(n)")[0][0] + assert after_prune == 3 + _close(db, conn) + + +# --- fallback rel table --- + +def test_fallback_rel_table(tmp_db): + from graphify.storage import _ensure_rel_table, ensure_schema + db, conn = _init(tmp_db) + known = ensure_schema(conn) + tbl = _ensure_rel_table(conn, "paper", "document", "cites", known) + assert tbl == "edge_paper_document_cites" + assert tbl in known + _close(db, conn) + + +# --- communities --- + +def test_ingest_communities(tmp_db): + from graphify.storage import ingest_extraction, ingest_communities + db, conn = _init(tmp_db) + ext = _load_extraction() + ingest_extraction(conn, ext, incremental=False) + communities = {0: ["n_transformer", "n_attention"], 1: ["n_layernorm"]} + ingest_communities(conn, communities) + rows = _query(conn, "MATCH (n:code) WHERE n.id = 'n_transformer' RETURN n.community") + assert rows[0][0] == 0 + rows = _query(conn, "MATCH (n:code) WHERE n.id = 'n_layernorm' RETURN n.community") + assert rows[0][0] == 1 + _close(db, conn) + + +# --- execute_cypher --- + +def test_execute_cypher(tmp_db): + from graphify.storage import ingest_extraction + db, conn = _init(tmp_db) + ext = _load_extraction() + ingest_extraction(conn, ext, incremental=False) + rows = _query(conn, "MATCH (n:code) RETURN n.label ORDER BY n.id") + labels = [r[0] for r in rows] + assert "MultiHeadAttention" in labels + assert "Transformer" in labels + _close(db, conn) + + +def test_execute_cypher_bad_query(tmp_db): + db, conn = _init(tmp_db) + with pytest.raises(RuntimeError): + _query(conn, "THIS IS NOT VALID CYPHER") + _close(db, conn) + + +# --- roundtrip consistency --- + +def test_roundtrip_node_count(tmp_db): + from graphify.storage import ingest_extraction + db, conn = _init(tmp_db) + ext = _load_extraction() + ingest_extraction(conn, ext, incremental=False) + total = 0 + for tbl in ("code", "document", "paper", "image", "concept", "rationale"): + rows = _query(conn, f"MATCH (n:{tbl}) RETURN count(n)") + total += rows[0][0] + assert total == len(ext["nodes"]) + _close(db, conn)