Skip to content

Commit dafbd04

Browse files
committed
ir checker script
1 parent de0e12f commit dafbd04

File tree

2 files changed

+252
-1
lines changed

2 files changed

+252
-1
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ repos:
102102
- id: cudf-polars-ir-signatures
103103
name: cudf-polars-ir-signatures
104104
description: 'Validate cudf-polars IR.do_evaluate signatures.'
105-
entry: ./ci/ir_check.py
105+
entry: ./ci/check_cudf_polars_ir.py
106106
language: python
107107
files: ^python/cudf_polars/cudf_polars/(dsl/ir|experimental/(shuffle|io|sort))\.py$
108108
pass_filenames: true

ci/check_cudf_polars_ir.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Check IR node consistency in cudf_polars.
7+
8+
Verifies that the `do_evaluate` method signatures in IR subclasses
9+
10+
- Are a classmethod
11+
- Accept `*_non_child` positional arguments, followed by
12+
- `*children` positional arguments, followed by
13+
- A keyword-only `context` argument
14+
"""
15+
16+
from __future__ import annotations
17+
18+
import argparse
19+
import ast
20+
import sys
21+
22+
ErrorRecord = dict[
23+
str, str | int
24+
] # Keys: "class", "arg", "error", "lineno", "filename"
25+
26+
27+
def extract_tuple_from_node(node: ast.AST) -> tuple[str, ...] | None:
28+
"""Extract a tuple of strings from an AST node."""
29+
if isinstance(node, ast.Tuple):
30+
return tuple(
31+
elt.value for elt in node.elts if isinstance(elt, ast.Constant)
32+
)
33+
return None
34+
35+
36+
def get_non_child(class_node: ast.ClassDef) -> tuple[str, ...] | None:
37+
"""Get _non_child attribute from a class definition."""
38+
for item in class_node.body:
39+
# Handle annotated assignment: _non_child: ClassVar[...] = (...)
40+
if isinstance(item, ast.AnnAssign) and isinstance(
41+
item.target, ast.Name
42+
):
43+
if item.target.id == "_non_child" and item.value:
44+
return extract_tuple_from_node(item.value)
45+
# Handle regular assignment: _non_child = (...)
46+
elif isinstance(item, ast.Assign):
47+
for target in item.targets:
48+
if isinstance(target, ast.Name) and target.id == "_non_child":
49+
return extract_tuple_from_node(item.value)
50+
return None
51+
52+
53+
def get_do_evaluate_node(class_node: ast.ClassDef) -> ast.FunctionDef | None:
54+
"""Get the do_evaluate method node from a class definition."""
55+
for item in class_node.body:
56+
if isinstance(item, ast.FunctionDef) and item.name == "do_evaluate":
57+
return item
58+
return None
59+
60+
61+
def get_do_evaluate_params(method_node: ast.FunctionDef) -> list[str]:
62+
"""Get parameter names from do_evaluate method."""
63+
params = []
64+
for arg in method_node.args.args:
65+
# Skip 'cls' and 'self' parameters
66+
if arg.arg not in ("cls", "self"):
67+
params.append(arg.arg)
68+
# Don't include keyword-only args like 'context'
69+
return params
70+
71+
72+
def get_type_annotation_name(annotation: ast.expr | None) -> str | None:
73+
"""Extract the name from a type annotation."""
74+
if annotation is None:
75+
return None
76+
if isinstance(annotation, ast.Name):
77+
return annotation.id
78+
# Handle complex types like list[str], dict[str, Any], etc.
79+
# We just want the base type name
80+
if isinstance(annotation, ast.Subscript):
81+
if isinstance(annotation.value, ast.Name):
82+
return annotation.value.id
83+
return None
84+
85+
86+
def is_ir_subclass(class_node: ast.ClassDef) -> bool:
87+
"""Check if a class is a subclass of IR."""
88+
if not class_node.bases:
89+
return False
90+
91+
for base in class_node.bases:
92+
if isinstance(base, ast.Name) and base.id == "IR":
93+
return True
94+
return False
95+
96+
97+
def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
98+
"""Analyze the Python file content for IR node consistency."""
99+
tree = ast.parse(content, filename=filename)
100+
101+
records: list[ErrorRecord] = []
102+
103+
# Find all class definitions
104+
for node in ast.walk(tree):
105+
if isinstance(node, ast.ClassDef):
106+
if not is_ir_subclass(node):
107+
continue
108+
109+
class_name = node.name
110+
111+
non_child = get_non_child(node)
112+
if non_child is None:
113+
continue
114+
115+
method_node = get_do_evaluate_node(node)
116+
if method_node is None:
117+
# Some nodes (e.g. ErrorNode) don't have a do_evaluate method
118+
continue
119+
120+
do_evaluate_params = get_do_evaluate_params(method_node)
121+
122+
# Check each non_child element
123+
for i, nc in enumerate(non_child):
124+
if nc not in do_evaluate_params:
125+
records.append(
126+
{
127+
"class": class_name,
128+
"arg": nc,
129+
"error": "Missing",
130+
"lineno": method_node.lineno,
131+
"filename": filename,
132+
}
133+
)
134+
elif do_evaluate_params.index(nc) != i:
135+
records.append(
136+
{
137+
"class": class_name,
138+
"arg": nc,
139+
"error": "Wrong position",
140+
"lineno": method_node.lineno,
141+
"filename": filename,
142+
}
143+
)
144+
145+
# Check that all *remaining* args in do_evaluate are 'DataFrame' type
146+
# Skip 'cls' or 'self' parameter
147+
regular_args = [
148+
arg
149+
for arg in method_node.args.args
150+
if arg.arg not in ("cls", "self")
151+
]
152+
153+
# Check args after _non_child parameters
154+
for arg in regular_args[len(non_child) :]:
155+
type_name = get_type_annotation_name(arg.annotation)
156+
if type_name != "DataFrame":
157+
records.append(
158+
{
159+
"class": class_name,
160+
"arg": arg.arg,
161+
"error": f"Wrong type annotation '{type_name}' (expected 'DataFrame')",
162+
"lineno": method_node.lineno,
163+
"filename": filename,
164+
}
165+
)
166+
167+
# Check that the only kw-only argument is 'context' with type 'IRExecutionContext'
168+
kwonly_args = method_node.args.kwonlyargs
169+
if len(kwonly_args) != 1:
170+
records.append(
171+
{
172+
"class": class_name,
173+
"arg": "kwonly",
174+
"error": f"Expected 1 keyword-only argument, found {len(kwonly_args)}",
175+
"lineno": method_node.lineno,
176+
"filename": filename,
177+
}
178+
)
179+
elif kwonly_args[0].arg != "context":
180+
records.append(
181+
{
182+
"class": class_name,
183+
"arg": kwonly_args[0].arg,
184+
"error": "Keyword-only argument should be named 'context'",
185+
"lineno": method_node.lineno,
186+
"filename": filename,
187+
}
188+
)
189+
else:
190+
# Check type annotation
191+
type_name = get_type_annotation_name(kwonly_args[0].annotation)
192+
if type_name != "IRExecutionContext":
193+
records.append(
194+
{
195+
"class": class_name,
196+
"arg": "context",
197+
"error": f"Wrong type annotation '{type_name}' (expected 'IRExecutionContext')",
198+
"lineno": method_node.lineno,
199+
"filename": filename,
200+
}
201+
)
202+
203+
return records
204+
205+
206+
def main() -> int:
207+
"""Main entry point for the CLI."""
208+
parser = argparse.ArgumentParser(
209+
description="Check IR node do_evaluate signatures match _non_child declarations"
210+
)
211+
parser.add_argument(
212+
"files",
213+
nargs="+",
214+
type=argparse.FileType("r"),
215+
help="Path(s) to Python file(s) to check (use '-' for stdin)",
216+
)
217+
218+
args = parser.parse_args()
219+
220+
all_records: list[ErrorRecord] = []
221+
222+
try:
223+
for file in args.files:
224+
content = file.read()
225+
filename = file.name
226+
file.close()
227+
228+
records = analyze_content(content, filename)
229+
all_records.extend(records)
230+
except Exception as e:
231+
print(f"Error: {e}", file=sys.stderr)
232+
return 1
233+
234+
if all_records:
235+
print("Found errors in IR node signatures:", end="\n\n")
236+
for record in all_records:
237+
filename = record["filename"]
238+
lineno = record["lineno"]
239+
class_name = record["class"]
240+
error = record["error"]
241+
arg = record["arg"]
242+
print(
243+
f" {filename}:{lineno}: {class_name}: {error} argument '{arg}'"
244+
)
245+
return 1
246+
else:
247+
return 0
248+
249+
250+
if __name__ == "__main__":
251+
sys.exit(main())

0 commit comments

Comments
 (0)