Skip to content

Commit f542ca8

Browse files
committed
Validate _non_child_args
1 parent 4a43787 commit f542ca8

File tree

2 files changed

+65
-11
lines changed

2 files changed

+65
-11
lines changed

ci/check_cudf_polars_ir.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,24 @@
1818
import argparse
1919
import ast
2020
import sys
21+
import typing
2122

22-
ErrorRecord = dict[
23-
str, str | int
24-
] # Keys: "class", "arg", "error", "lineno", "filename"
23+
24+
class ErrorRecord(typing.TypedDict):
25+
cls: str
26+
arg: str
27+
error: str
28+
lineno: int
29+
filename: str
2530

2631

2732
def extract_tuple_from_node(node: ast.AST) -> tuple[str, ...] | None:
2833
"""Extract a tuple of strings from an AST node."""
2934
if isinstance(node, ast.Tuple):
3035
return tuple(
31-
elt.value for elt in node.elts if isinstance(elt, ast.Constant)
36+
str(elt.value)
37+
for elt in node.elts
38+
if isinstance(elt, ast.Constant)
3239
)
3340
return None
3441

@@ -58,6 +65,36 @@ def get_do_evaluate_node(class_node: ast.ClassDef) -> ast.FunctionDef | None:
5865
return None
5966

6067

68+
def get_init_node(class_node: ast.ClassDef) -> ast.FunctionDef | None:
69+
"""Get the __init__ method node from a class definition."""
70+
for item in class_node.body:
71+
if isinstance(item, ast.FunctionDef) and item.name == "__init__":
72+
return item
73+
return None
74+
75+
76+
def get_non_child_args_length(init_node: ast.FunctionDef) -> int | None:
77+
"""
78+
Get the length of the tuple assigned to self._non_child_args in __init__.
79+
Returns None if the assignment is not found or is not a tuple.
80+
"""
81+
for stmt in ast.walk(init_node):
82+
# Look for assignments: self._non_child_args = (...)
83+
if isinstance(stmt, ast.Assign):
84+
# Check if target is self._non_child_args
85+
for target in stmt.targets:
86+
if (
87+
isinstance(target, ast.Attribute)
88+
and isinstance(target.value, ast.Name)
89+
and target.value.id == "self"
90+
and target.attr == "_non_child_args"
91+
):
92+
# Check if the value is a tuple
93+
if isinstance(stmt.value, ast.Tuple):
94+
return len(stmt.value.elts)
95+
return None
96+
97+
6198
def get_do_evaluate_params(method_node: ast.FunctionDef) -> list[str]:
6299
"""Get parameter names from do_evaluate method."""
63100
params = []
@@ -124,7 +161,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
124161
if nc not in do_evaluate_params:
125162
records.append(
126163
{
127-
"class": class_name,
164+
"cls": class_name,
128165
"arg": nc,
129166
"error": "Missing",
130167
"lineno": method_node.lineno,
@@ -134,7 +171,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
134171
elif do_evaluate_params.index(nc) != i:
135172
records.append(
136173
{
137-
"class": class_name,
174+
"cls": class_name,
138175
"arg": nc,
139176
"error": "Wrong position",
140177
"lineno": method_node.lineno,
@@ -156,7 +193,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
156193
if type_name != "DataFrame":
157194
records.append(
158195
{
159-
"class": class_name,
196+
"cls": class_name,
160197
"arg": arg.arg,
161198
"error": f"Wrong type annotation '{type_name}' (expected 'DataFrame')",
162199
"lineno": method_node.lineno,
@@ -169,7 +206,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
169206
if len(kwonly_args) != 1:
170207
records.append(
171208
{
172-
"class": class_name,
209+
"cls": class_name,
173210
"arg": "kwonly",
174211
"error": f"Expected 1 keyword-only argument, found {len(kwonly_args)}",
175212
"lineno": method_node.lineno,
@@ -179,7 +216,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
179216
elif kwonly_args[0].arg != "context":
180217
records.append(
181218
{
182-
"class": class_name,
219+
"cls": class_name,
183220
"arg": kwonly_args[0].arg,
184221
"error": "Keyword-only argument should be named 'context'",
185222
"lineno": method_node.lineno,
@@ -192,14 +229,30 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
192229
if type_name != "IRExecutionContext":
193230
records.append(
194231
{
195-
"class": class_name,
232+
"cls": class_name,
196233
"arg": "context",
197234
"error": f"Wrong type annotation '{type_name}' (expected 'IRExecutionContext')",
198235
"lineno": method_node.lineno,
199236
"filename": filename,
200237
}
201238
)
202239

240+
# Check that __init__ assigns self._non_child_args with matching length
241+
init_node = get_init_node(node)
242+
if init_node is not None:
243+
non_child_args_length = get_non_child_args_length(init_node)
244+
if non_child_args_length is not None:
245+
if non_child_args_length != len(non_child):
246+
records.append(
247+
{
248+
"cls": class_name,
249+
"arg": "_non_child_args",
250+
"error": "Mismatch between 'self._non_child_args' and 'cls._non_child'",
251+
"lineno": init_node.lineno,
252+
"filename": filename,
253+
}
254+
)
255+
203256
return records
204257

205258

@@ -236,7 +289,7 @@ def main() -> int:
236289
for record in all_records:
237290
filename = record["filename"]
238291
lineno = record["lineno"]
239-
class_name = record["class"]
292+
class_name = record["cls"]
240293
error = record["error"]
241294
arg = record["arg"]
242295
print(

python/cudf_polars/cudf_polars/dsl/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,6 +1617,7 @@ def __init__(
16171617
self.zlice = zlice
16181618
self.children = (df,)
16191619
self._non_child_args = (
1620+
schema,
16201621
index,
16211622
index_dtype,
16221623
preceding_ordinal,

0 commit comments

Comments
 (0)