diff --git a/code_review_graph/parser.py b/code_review_graph/parser.py index f94519db..8d0ff569 100644 --- a/code_review_graph/parser.py +++ b/code_review_graph/parser.py @@ -187,7 +187,7 @@ class EdgeInfo: "typescript": ["class_declaration", "class"], "tsx": ["class_declaration", "class"], "go": ["type_declaration"], - "rust": ["struct_item", "enum_item", "impl_item"], + "rust": ["struct_item", "enum_item", "impl_item", "trait_item"], "java": ["class_declaration", "interface_declaration", "enum_declaration"], "c": ["struct_specifier", "type_definition"], "cpp": ["class_specifier", "struct_specifier"], @@ -5306,6 +5306,14 @@ def _collect_import_names( if child.type == "import_clause": self._collect_js_import_names(child, module, import_map) + elif language == "rust": + # Walk use_declaration children to find the inner import node + for child in node.children: + if child.type in ("scoped_identifier", "scoped_use_list", "use_as_clause", "use_wildcard", "identifier"): + for local_name, full_path in self._parse_rust_use_node(child): + if local_name != "*": + import_map[local_name] = full_path + def _collect_js_import_names( self, clause_node, module: str, import_map: dict[str, str], ) -> None: @@ -5481,8 +5489,149 @@ def _do_resolve_module( break current = current.parent + elif language == "rust": + # Standardize path separator to :: + module_path = module.replace("::", "/") + segments = [seg for seg in module_path.split("/") if seg] + if not segments: + return None + + # Find crate root / source root and collect all dependencies from Cargo.toml files found along the way + crate_root = None + dependencies = {} + current = caller_dir + for _ in range(20): + cargo_toml = current / "Cargo.toml" + if cargo_toml.is_file(): + if not crate_root: + crate_root = current + dependencies.update(self._find_rust_local_dependencies(cargo_toml)) + if current.parent == current: + break + current = current.parent + + if crate_root: + src_root = crate_root / "src" + if not src_root.is_dir(): + src_root = crate_root + else: + src_root = caller_dir + + # Determine starting directory and segments to resolve + start_dir = None + resolved_file = None + + if segments[0] == "crate": + start_dir = src_root + resolved_file = src_root / "lib.rs" if (src_root / "lib.rs").is_file() else (src_root / "main.rs" if (src_root / "main.rs").is_file() else None) + resolve_segments = segments[1:] + elif segments[0] == "super": + super_count = 0 + for seg in segments: + if seg == "super": + super_count += 1 + else: + break + + curr_file = Path(file_path) + for _ in range(super_count): + parent_mod = self._find_rust_parent_module(curr_file) + if parent_mod: + curr_file = parent_mod + else: + curr_file = curr_file.parent + + resolved_file = curr_file + if resolved_file.is_dir(): + start_dir = resolved_file + resolved_file = None + else: + start_dir = resolved_file.parent + + resolve_segments = segments[super_count:] + elif segments[0] == "self": + resolved_file = Path(file_path) + start_dir = resolved_file.parent + resolve_segments = segments[1:] + elif segments[0] in dependencies: + dep_path = dependencies[segments[0]] + dep_src_root = dep_path / "src" if (dep_path / "src").is_dir() else dep_path + start_dir = dep_src_root + resolved_file = dep_src_root / "lib.rs" if (dep_src_root / "lib.rs").is_file() else (dep_src_root / "main.rs" if (dep_src_root / "main.rs").is_file() else None) + resolve_segments = segments[1:] + else: + start_dir = caller_dir + resolve_segments = segments + + # Walk resolve segments + current_dir = start_dir + for seg in resolve_segments: + file_candidate = current_dir / f"{seg}.rs" + dir_candidate = current_dir / seg / "mod.rs" + + if file_candidate.is_file(): + resolved_file = file_candidate + current_dir = current_dir / seg + elif dir_candidate.is_file(): + resolved_file = dir_candidate + current_dir = current_dir / seg + else: + # Last segment might be a struct or function, so we keep the previous resolved file + break + + if resolved_file: + return str(resolved_file.resolve()) + return None + def _find_rust_parent_module(self, path: Path) -> Optional[Path]: + parent_dir = path.parent + name = path.name + + if name in ("mod.rs", "lib.rs", "main.rs"): + gp = parent_dir.parent + cand = gp / f"{parent_dir.name}.rs" + if cand.is_file(): + return cand + for name_cand in ("mod.rs", "lib.rs", "main.rs"): + cand = gp / name_cand + if cand.is_file(): + return cand + else: + cand = parent_dir / "mod.rs" + if cand.is_file(): + return cand + gp = parent_dir.parent + cand = gp / f"{parent_dir.name}.rs" + if cand.is_file(): + return cand + for name_cand in ("lib.rs", "main.rs"): + cand = parent_dir / name_cand + if cand.is_file(): + return cand + return None + + def _find_rust_local_dependencies(self, cargo_toml_path: Path) -> dict[str, Path]: + deps = {} + if not cargo_toml_path.is_file(): + return deps + try: + text = cargo_toml_path.read_text(encoding="utf-8", errors="replace") + except OSError: + return deps + # Match name = { path = "..." } + pattern = re.compile(r'^([\w-]+)\s*=\s*\{[^}]*path\s*=\s*"([^"]+)"', re.MULTILINE) + for match in pattern.finditer(text): + name = match.group(1).strip() + path_val = match.group(2).strip() + try: + resolved_path = (cargo_toml_path.parent / path_val).resolve() + deps[name] = resolved_path + deps[name.replace("-", "_")] = resolved_path + except (OSError, ValueError): + pass + return deps + def _find_dart_pubspec_root( self, start: Path, pkg_name: str, ) -> Optional[Path]: @@ -5514,6 +5663,48 @@ def _find_dart_pubspec_root( self._dart_pubspec_cache[cache_key] = None return None + def _resolve_rust_scoped_call( + self, + call_name: str, + file_path: str, + import_map: dict[str, str], + defined_names: set[str], + ) -> Optional[str]: + if "::" not in call_name: + return None + + parts = call_name.split("::") + method = parts[-1] + type_prefix_parts = parts[:-1] + type_prefix = "::".join(type_prefix_parts) + + # Check if type_prefix is defined in the same file + if type_prefix in defined_names: + qualified_type = self._qualify(type_prefix, file_path, None) + return f"{qualified_type}.{method}" + + # Check if type_prefix is imported + if type_prefix in import_map: + resolved_type = self._resolve_imported_symbol( + type_prefix, import_map[type_prefix], file_path, "rust", + ) + if resolved_type: + return f"{resolved_type}.{method}" + + # What if type_prefix itself is a path, like `crate::db::InMemoryRepo`? + if type_prefix_parts[0] in ("crate", "super", "self"): + if len(type_prefix_parts) >= 2: + module_part = "::".join(type_prefix_parts[:-1]) + type_name = type_prefix_parts[-1] + resolved_module_file = self._resolve_module_to_file( + module_part, file_path, "rust" + ) + if resolved_module_file: + qualified_type = self._qualify(type_name, resolved_module_file, None) + return f"{qualified_type}.{method}" + + return None + def _resolve_call_target( self, call_name: str, @@ -5523,6 +5714,13 @@ def _resolve_call_target( defined_names: set[str], ) -> str: """Resolve a bare call name to a qualified target, with fallback.""" + if language == "rust" and "::" in call_name: + resolved = self._resolve_rust_scoped_call( + call_name, file_path, import_map, defined_names + ) + if resolved: + return resolved + if call_name in defined_names: return self._qualify(call_name, file_path, None) if call_name in import_map: @@ -5663,6 +5861,27 @@ def _qualify(self, name: str, file_path: str, enclosing_class: Optional[str]) -> def _get_name(self, node, language: str, kind: str) -> Optional[str]: """Extract the name from a class/function definition node.""" + if language == "rust" and node.type == "impl_item": + # Find if there is a 'for' child + for_index = -1 + type_identifiers = [] + for i, child in enumerate(node.children): + if child.type == "for": + for_index = i + elif child.type in ("type_identifier", "generic_type", "scoped_type_identifier"): + type_identifiers.append((i, child)) + + if for_index != -1: + # Implementing a trait: the target struct/type is after 'for'. + for idx, child in type_identifiers: + if idx > for_index: + return self._get_rust_base_type(child) + else: + # Inherent impl: the target struct/type is the first type identifier + if type_identifiers: + return self._get_rust_base_type(type_identifiers[0][1]) + return None + # Dart: function_signature has a return-type node before the identifier; # search only for 'identifier' to avoid returning the return type name. if language == "dart" and node.type == "function_signature": @@ -5940,6 +6159,61 @@ def _leaf_name(qi): return self._get_name(child, language, kind) return None + def _get_rust_base_type(self, node) -> str: + if node.type in ("type_identifier", "identifier"): + return node.text.decode("utf-8", errors="replace") + if node.type == "generic_type": + for child in node.children: + if child.type == "type_identifier": + return self._get_rust_base_type(child) + if node.type == "scoped_type_identifier": + for child in reversed(node.children): + if child.type == "type_identifier": + return self._get_rust_base_type(child) + return node.text.decode("utf-8", errors="replace") + + def _parse_rust_use_node(self, node, prefix="") -> list[tuple[str, str]]: + results = [] + if node.type == "identifier": + name = node.text.decode("utf-8", errors="replace") + full_path = f"{prefix}::{name}" if prefix else name + results.append((name, full_path)) + elif node.type == "scoped_identifier": + full_path = node.text.decode("utf-8", errors="replace") + name = node.children[-1].text.decode("utf-8", errors="replace") + if prefix: + full_path = f"{prefix}::{full_path}" + results.append((name, full_path)) + elif node.type == "scoped_use_list": + prefix_node = node.children[0] + prefix_str = prefix_node.text.decode("utf-8", errors="replace") + if prefix: + prefix_str = f"{prefix}::{prefix_str}" + use_list = node.children[-1] + if use_list.type == "use_list": + for child in use_list.children: + if child.type in ("identifier", "scoped_identifier", "scoped_use_list", "use_as_clause", "use_wildcard"): + results.extend(self._parse_rust_use_node(child, prefix_str)) + elif node.type == "use_list": + for child in node.children: + if child.type in ("identifier", "scoped_identifier", "scoped_use_list", "use_as_clause", "use_wildcard"): + results.extend(self._parse_rust_use_node(child, prefix)) + elif node.type == "use_as_clause": + path_node = node.children[0] + alias_node = node.children[-1] + path_str = path_node.text.decode("utf-8", errors="replace") + if prefix: + path_str = f"{prefix}::{path_str}" + alias_str = alias_node.text.decode("utf-8", errors="replace") + results.append((alias_str, path_str)) + elif node.type == "use_wildcard": + prefix_node = node.children[0] + prefix_str = prefix_node.text.decode("utf-8", errors="replace") + if prefix: + prefix_str = f"{prefix}::{prefix_str}" + results.append(("*", prefix_str)) + return results + def _get_go_receiver_type(self, node) -> Optional[str]: """Extract the receiver type from a Go method_declaration. @@ -6151,6 +6425,21 @@ def _get_bases(self, node, language: str, source: bytes) -> list[str]: bases.append( idents[0].text.decode("utf-8", errors="replace"), ) + elif language == "rust": + if node.type == "impl_item": + # Check if there's a 'for' child + for_index = -1 + type_identifiers = [] + for i, child in enumerate(node.children): + if child.type == "for": + for_index = i + elif child.type in ("type_identifier", "generic_type", "scoped_type_identifier"): + type_identifiers.append((i, child)) + + if for_index != -1 and type_identifiers: + # The trait is the first type_identifier (before the 'for' keyword) + trait_node = type_identifiers[0][1] + bases.append(self._get_rust_base_type(trait_node)) return bases def _extract_import(self, node, language: str, source: bytes) -> list[str]: @@ -6190,8 +6479,14 @@ def _extract_import(self, node, language: str, source: bytes) -> list[str]: val = s.text.decode("utf-8", errors="replace") imports.append(val.strip('"')) elif language == "rust": - # use crate::module::item - imports.append(text.replace("use ", "").rstrip(";").strip()) + # Walk use_declaration children to find the inner import node + for child in node.children: + if child.type in ("scoped_identifier", "scoped_use_list", "use_as_clause", "use_wildcard", "identifier"): + for _, full_path in self._parse_rust_use_node(child): + if full_path.endswith("::*"): + imports.append(full_path[:-3]) + else: + imports.append(full_path) elif language in ("c", "cpp"): # #include
or #include "header" for child in node.children: diff --git a/tests/test_multilang.py b/tests/test_multilang.py index afda355e..994ccfa5 100644 --- a/tests/test_multilang.py +++ b/tests/test_multilang.py @@ -96,6 +96,7 @@ def test_finds_structs_and_traits(self): names = {c.name for c in classes} assert "User" in names assert "InMemoryRepo" in names + assert "Repository" in names def test_finds_functions(self): funcs = [n for n in self.nodes if n.kind == "Function"] @@ -108,11 +109,27 @@ def test_finds_functions(self): def test_finds_imports(self): imports = [e for e in self.edges if e.kind == "IMPORTS_FROM"] assert len(imports) >= 1 + targets = {e.target for e in imports} + assert "std::collections::HashMap" in targets def test_finds_calls(self): calls = [e for e in self.edges if e.kind == "CALLS"] assert len(calls) >= 3 + def test_finds_inheritance(self): + inherits = [e for e in self.edges if e.kind == "INHERITS"] + assert len(inherits) >= 1 + pairs = {(e.source.split("::")[-1], e.target) for e in inherits} + assert ("InMemoryRepo", "Repository") in pairs + + def test_resolves_scoped_calls(self): + calls = [e for e in self.edges if e.kind == "CALLS"] + # Find the call on line 54 + line_54_calls = [e for e in calls if e.line == 54] + assert len(line_54_calls) == 1 + call = line_54_calls[0] + assert call.target.endswith("::InMemoryRepo.new") + def test_detects_test_attribute(self): tests = [n for n in self.nodes if n.kind == "Test"] names = {t.name for t in tests} @@ -134,6 +151,73 @@ def test_non_test_functions_not_misclassified(self): assert not n.is_test +class TestRustImportResolution: + def test_resolves_rust_project_import(self, tmp_path): + src = tmp_path / "src" + src.mkdir(parents=True) + (src / "db.rs").write_text("pub struct InMemoryRepo {}\n") + (src / "lib.rs").write_text( + "pub mod db;\n" + "use crate::db::InMemoryRepo;\n" + ) + (tmp_path / "Cargo.toml").write_text("[package]\nname = \"test_pkg\"\n") + + parser = CodeParser() + _, edges = parser.parse_file(src / "lib.rs") + imports = [e for e in edges if e.kind == "IMPORTS_FROM"] + assert len(imports) == 1 + assert imports[0].target == str((src / "db.rs").resolve()) + + def test_resolves_rust_relative_import(self, tmp_path): + src = tmp_path / "src" + db_dir = src / "db" + db_dir.mkdir(parents=True) + (src / "lib.rs").write_text( + "pub mod db;\n" + "pub fn top_func() {}\n" + ) + (db_dir / "mod.rs").write_text("pub mod tests;\n") + (db_dir / "tests.rs").write_text("use super::super::top_func;\n") + (tmp_path / "Cargo.toml").write_text("[package]\nname = \"test_pkg\"\n") + + parser = CodeParser() + _, edges = parser.parse_file(db_dir / "tests.rs") + imports = [e for e in edges if e.kind == "IMPORTS_FROM"] + assert len(imports) == 1 + assert imports[0].target == str((src / "lib.rs").resolve()) + + def test_resolves_rust_workspace_import(self, tmp_path): + # Create workspace root Cargo.toml + (tmp_path / "Cargo.toml").write_text( + "[workspace]\n" + "members = [\"crates/*\"]\n" + "[dependencies]\n" + "dep-crate = { path = \"./crates/dep_crate\" }\n" + ) + + # Create dep_crate + dep_dir = tmp_path / "crates" / "dep_crate" + dep_dir.mkdir(parents=True) + (dep_dir / "Cargo.toml").write_text("[package]\nname = \"dep-crate\"\n") + dep_src = dep_dir / "src" + dep_src.mkdir() + (dep_src / "lib.rs").write_text("pub struct Helper;\n") + + # Create app_crate + app_dir = tmp_path / "crates" / "app_crate" + app_dir.mkdir(parents=True) + (app_dir / "Cargo.toml").write_text("[package]\nname = \"app_crate\"\n") + app_src = app_dir / "src" + app_src.mkdir() + (app_src / "main.rs").write_text("use dep_crate::Helper;\n") + + parser = CodeParser() + _, edges = parser.parse_file(app_src / "main.rs") + imports = [e for e in edges if e.kind == "IMPORTS_FROM"] + assert len(imports) == 1 + assert imports[0].target == str((dep_src / "lib.rs").resolve()) + + class TestJavaParsing: def setup_method(self): self.parser = CodeParser()