Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 298 additions & 3 deletions code_review_graph/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class EdgeInfo:
"typescript": ["class_declaration", "class"],
"tsx": ["class_declaration", "class"],
"go": ["type_declaration"],
"rust": ["struct_item", "enum_item", "impl_item"],
"rust": ["struct_item", "enum_item", "impl_item", "trait_item"],
"java": ["class_declaration", "interface_declaration", "enum_declaration"],
"c": ["struct_specifier", "type_definition"],
"cpp": ["class_specifier", "struct_specifier"],
Expand Down Expand Up @@ -5306,6 +5306,14 @@ def _collect_import_names(
if child.type == "import_clause":
self._collect_js_import_names(child, module, import_map)

elif language == "rust":
# Walk use_declaration children to find the inner import node
for child in node.children:
if child.type in ("scoped_identifier", "scoped_use_list", "use_as_clause", "use_wildcard", "identifier"):
for local_name, full_path in self._parse_rust_use_node(child):
if local_name != "*":
import_map[local_name] = full_path

def _collect_js_import_names(
self, clause_node, module: str, import_map: dict[str, str],
) -> None:
Expand Down Expand Up @@ -5481,8 +5489,149 @@ def _do_resolve_module(
break
current = current.parent

elif language == "rust":
# Standardize path separator to ::
module_path = module.replace("::", "/")
segments = [seg for seg in module_path.split("/") if seg]
if not segments:
return None

# Find crate root / source root and collect all dependencies from Cargo.toml files found along the way
crate_root = None
dependencies = {}
current = caller_dir
for _ in range(20):
cargo_toml = current / "Cargo.toml"
if cargo_toml.is_file():
if not crate_root:
crate_root = current
dependencies.update(self._find_rust_local_dependencies(cargo_toml))
if current.parent == current:
break
current = current.parent

if crate_root:
src_root = crate_root / "src"
if not src_root.is_dir():
src_root = crate_root
else:
src_root = caller_dir

# Determine starting directory and segments to resolve
start_dir = None
resolved_file = None

if segments[0] == "crate":
start_dir = src_root
resolved_file = src_root / "lib.rs" if (src_root / "lib.rs").is_file() else (src_root / "main.rs" if (src_root / "main.rs").is_file() else None)
resolve_segments = segments[1:]
elif segments[0] == "super":
super_count = 0
for seg in segments:
if seg == "super":
super_count += 1
else:
break

curr_file = Path(file_path)
for _ in range(super_count):
parent_mod = self._find_rust_parent_module(curr_file)
if parent_mod:
curr_file = parent_mod
else:
curr_file = curr_file.parent

resolved_file = curr_file
if resolved_file.is_dir():
start_dir = resolved_file
resolved_file = None
else:
start_dir = resolved_file.parent

resolve_segments = segments[super_count:]
elif segments[0] == "self":
resolved_file = Path(file_path)
start_dir = resolved_file.parent
resolve_segments = segments[1:]
elif segments[0] in dependencies:
dep_path = dependencies[segments[0]]
dep_src_root = dep_path / "src" if (dep_path / "src").is_dir() else dep_path
start_dir = dep_src_root
resolved_file = dep_src_root / "lib.rs" if (dep_src_root / "lib.rs").is_file() else (dep_src_root / "main.rs" if (dep_src_root / "main.rs").is_file() else None)
resolve_segments = segments[1:]
else:
start_dir = caller_dir
resolve_segments = segments

# Walk resolve segments
current_dir = start_dir
for seg in resolve_segments:
file_candidate = current_dir / f"{seg}.rs"
dir_candidate = current_dir / seg / "mod.rs"

if file_candidate.is_file():
resolved_file = file_candidate
current_dir = current_dir / seg
elif dir_candidate.is_file():
resolved_file = dir_candidate
current_dir = current_dir / seg
else:
# Last segment might be a struct or function, so we keep the previous resolved file
break

if resolved_file:
return str(resolved_file.resolve())

return None

def _find_rust_parent_module(self, path: Path) -> Optional[Path]:
parent_dir = path.parent
name = path.name

if name in ("mod.rs", "lib.rs", "main.rs"):
gp = parent_dir.parent
cand = gp / f"{parent_dir.name}.rs"
if cand.is_file():
return cand
for name_cand in ("mod.rs", "lib.rs", "main.rs"):
cand = gp / name_cand
if cand.is_file():
return cand
else:
cand = parent_dir / "mod.rs"
if cand.is_file():
return cand
gp = parent_dir.parent
cand = gp / f"{parent_dir.name}.rs"
if cand.is_file():
return cand
for name_cand in ("lib.rs", "main.rs"):
cand = parent_dir / name_cand
if cand.is_file():
return cand
return None

def _find_rust_local_dependencies(self, cargo_toml_path: Path) -> dict[str, Path]:
deps = {}
if not cargo_toml_path.is_file():
return deps
try:
text = cargo_toml_path.read_text(encoding="utf-8", errors="replace")
except OSError:
return deps
# Match name = { path = "..." }
pattern = re.compile(r'^([\w-]+)\s*=\s*\{[^}]*path\s*=\s*"([^"]+)"', re.MULTILINE)
for match in pattern.finditer(text):
name = match.group(1).strip()
path_val = match.group(2).strip()
try:
resolved_path = (cargo_toml_path.parent / path_val).resolve()
deps[name] = resolved_path
deps[name.replace("-", "_")] = resolved_path
except (OSError, ValueError):
pass
return deps

def _find_dart_pubspec_root(
self, start: Path, pkg_name: str,
) -> Optional[Path]:
Expand Down Expand Up @@ -5514,6 +5663,48 @@ def _find_dart_pubspec_root(
self._dart_pubspec_cache[cache_key] = None
return None

def _resolve_rust_scoped_call(
self,
call_name: str,
file_path: str,
import_map: dict[str, str],
defined_names: set[str],
) -> Optional[str]:
if "::" not in call_name:
return None

parts = call_name.split("::")
method = parts[-1]
type_prefix_parts = parts[:-1]
type_prefix = "::".join(type_prefix_parts)

# Check if type_prefix is defined in the same file
if type_prefix in defined_names:
qualified_type = self._qualify(type_prefix, file_path, None)
return f"{qualified_type}.{method}"

# Check if type_prefix is imported
if type_prefix in import_map:
resolved_type = self._resolve_imported_symbol(
type_prefix, import_map[type_prefix], file_path, "rust",
)
if resolved_type:
return f"{resolved_type}.{method}"

# What if type_prefix itself is a path, like `crate::db::InMemoryRepo`?
if type_prefix_parts[0] in ("crate", "super", "self"):
if len(type_prefix_parts) >= 2:
module_part = "::".join(type_prefix_parts[:-1])
type_name = type_prefix_parts[-1]
resolved_module_file = self._resolve_module_to_file(
module_part, file_path, "rust"
)
if resolved_module_file:
qualified_type = self._qualify(type_name, resolved_module_file, None)
return f"{qualified_type}.{method}"

return None

def _resolve_call_target(
self,
call_name: str,
Expand All @@ -5523,6 +5714,13 @@ def _resolve_call_target(
defined_names: set[str],
) -> str:
"""Resolve a bare call name to a qualified target, with fallback."""
if language == "rust" and "::" in call_name:
resolved = self._resolve_rust_scoped_call(
call_name, file_path, import_map, defined_names
)
if resolved:
return resolved

if call_name in defined_names:
return self._qualify(call_name, file_path, None)
if call_name in import_map:
Expand Down Expand Up @@ -5663,6 +5861,27 @@ def _qualify(self, name: str, file_path: str, enclosing_class: Optional[str]) ->

def _get_name(self, node, language: str, kind: str) -> Optional[str]:
"""Extract the name from a class/function definition node."""
if language == "rust" and node.type == "impl_item":
# Find if there is a 'for' child
for_index = -1
type_identifiers = []
for i, child in enumerate(node.children):
if child.type == "for":
for_index = i
elif child.type in ("type_identifier", "generic_type", "scoped_type_identifier"):
type_identifiers.append((i, child))

if for_index != -1:
# Implementing a trait: the target struct/type is after 'for'.
for idx, child in type_identifiers:
if idx > for_index:
return self._get_rust_base_type(child)
else:
# Inherent impl: the target struct/type is the first type identifier
if type_identifiers:
return self._get_rust_base_type(type_identifiers[0][1])
return None

# Dart: function_signature has a return-type node before the identifier;
# search only for 'identifier' to avoid returning the return type name.
if language == "dart" and node.type == "function_signature":
Expand Down Expand Up @@ -5940,6 +6159,61 @@ def _leaf_name(qi):
return self._get_name(child, language, kind)
return None

def _get_rust_base_type(self, node) -> str:
if node.type in ("type_identifier", "identifier"):
return node.text.decode("utf-8", errors="replace")
if node.type == "generic_type":
for child in node.children:
if child.type == "type_identifier":
return self._get_rust_base_type(child)
if node.type == "scoped_type_identifier":
for child in reversed(node.children):
if child.type == "type_identifier":
return self._get_rust_base_type(child)
return node.text.decode("utf-8", errors="replace")

def _parse_rust_use_node(self, node, prefix="") -> list[tuple[str, str]]:
results = []
if node.type == "identifier":
name = node.text.decode("utf-8", errors="replace")
full_path = f"{prefix}::{name}" if prefix else name
results.append((name, full_path))
elif node.type == "scoped_identifier":
full_path = node.text.decode("utf-8", errors="replace")
name = node.children[-1].text.decode("utf-8", errors="replace")
if prefix:
full_path = f"{prefix}::{full_path}"
results.append((name, full_path))
elif node.type == "scoped_use_list":
prefix_node = node.children[0]
prefix_str = prefix_node.text.decode("utf-8", errors="replace")
if prefix:
prefix_str = f"{prefix}::{prefix_str}"
use_list = node.children[-1]
if use_list.type == "use_list":
for child in use_list.children:
if child.type in ("identifier", "scoped_identifier", "scoped_use_list", "use_as_clause", "use_wildcard"):
results.extend(self._parse_rust_use_node(child, prefix_str))
elif node.type == "use_list":
for child in node.children:
if child.type in ("identifier", "scoped_identifier", "scoped_use_list", "use_as_clause", "use_wildcard"):
results.extend(self._parse_rust_use_node(child, prefix))
elif node.type == "use_as_clause":
path_node = node.children[0]
alias_node = node.children[-1]
path_str = path_node.text.decode("utf-8", errors="replace")
if prefix:
path_str = f"{prefix}::{path_str}"
alias_str = alias_node.text.decode("utf-8", errors="replace")
results.append((alias_str, path_str))
elif node.type == "use_wildcard":
prefix_node = node.children[0]
prefix_str = prefix_node.text.decode("utf-8", errors="replace")
if prefix:
prefix_str = f"{prefix}::{prefix_str}"
results.append(("*", prefix_str))
return results

def _get_go_receiver_type(self, node) -> Optional[str]:
"""Extract the receiver type from a Go method_declaration.

Expand Down Expand Up @@ -6151,6 +6425,21 @@ def _get_bases(self, node, language: str, source: bytes) -> list[str]:
bases.append(
idents[0].text.decode("utf-8", errors="replace"),
)
elif language == "rust":
if node.type == "impl_item":
# Check if there's a 'for' child
for_index = -1
type_identifiers = []
for i, child in enumerate(node.children):
if child.type == "for":
for_index = i
elif child.type in ("type_identifier", "generic_type", "scoped_type_identifier"):
type_identifiers.append((i, child))

if for_index != -1 and type_identifiers:
# The trait is the first type_identifier (before the 'for' keyword)
trait_node = type_identifiers[0][1]
bases.append(self._get_rust_base_type(trait_node))
return bases

def _extract_import(self, node, language: str, source: bytes) -> list[str]:
Expand Down Expand Up @@ -6190,8 +6479,14 @@ def _extract_import(self, node, language: str, source: bytes) -> list[str]:
val = s.text.decode("utf-8", errors="replace")
imports.append(val.strip('"'))
elif language == "rust":
# use crate::module::item
imports.append(text.replace("use ", "").rstrip(";").strip())
# Walk use_declaration children to find the inner import node
for child in node.children:
if child.type in ("scoped_identifier", "scoped_use_list", "use_as_clause", "use_wildcard", "identifier"):
for _, full_path in self._parse_rust_use_node(child):
if full_path.endswith("::*"):
imports.append(full_path[:-3])
else:
imports.append(full_path)
elif language in ("c", "cpp"):
# #include <header> or #include "header"
for child in node.children:
Expand Down
Loading