1818import argparse
1919import ast
2020import sys
21+ import typing
2122
22- ErrorRecord = dict [
23- str , str | int
24- ] # Keys: "class", "arg", "error", "lineno", "filename"
23+
24+ class ErrorRecord (typing .TypedDict ):
25+ cls : str
26+ arg : str
27+ error : str
28+ lineno : int
29+ filename : str
2530
2631
2732def extract_tuple_from_node (node : ast .AST ) -> tuple [str , ...] | None :
2833 """Extract a tuple of strings from an AST node."""
2934 if isinstance (node , ast .Tuple ):
3035 return tuple (
31- elt .value for elt in node .elts if isinstance (elt , ast .Constant )
36+ str (elt .value )
37+ for elt in node .elts
38+ if isinstance (elt , ast .Constant )
3239 )
3340 return None
3441
@@ -58,6 +65,36 @@ def get_do_evaluate_node(class_node: ast.ClassDef) -> ast.FunctionDef | None:
5865 return None
5966
6067
68+ def get_init_node (class_node : ast .ClassDef ) -> ast .FunctionDef | None :
69+ """Get the __init__ method node from a class definition."""
70+ for item in class_node .body :
71+ if isinstance (item , ast .FunctionDef ) and item .name == "__init__" :
72+ return item
73+ return None
74+
75+
76+ def get_non_child_args_length (init_node : ast .FunctionDef ) -> int | None :
77+ """
78+ Get the length of the tuple assigned to self._non_child_args in __init__.
79+ Returns None if the assignment is not found or is not a tuple.
80+ """
81+ for stmt in ast .walk (init_node ):
82+ # Look for assignments: self._non_child_args = (...)
83+ if isinstance (stmt , ast .Assign ):
84+ # Check if target is self._non_child_args
85+ for target in stmt .targets :
86+ if (
87+ isinstance (target , ast .Attribute )
88+ and isinstance (target .value , ast .Name )
89+ and target .value .id == "self"
90+ and target .attr == "_non_child_args"
91+ ):
92+ # Check if the value is a tuple
93+ if isinstance (stmt .value , ast .Tuple ):
94+ return len (stmt .value .elts )
95+ return None
96+
97+
6198def get_do_evaluate_params (method_node : ast .FunctionDef ) -> list [str ]:
6299 """Get parameter names from do_evaluate method."""
63100 params = []
@@ -124,7 +161,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
124161 if nc not in do_evaluate_params :
125162 records .append (
126163 {
127- "class " : class_name ,
164+ "cls " : class_name ,
128165 "arg" : nc ,
129166 "error" : "Missing" ,
130167 "lineno" : method_node .lineno ,
@@ -134,7 +171,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
134171 elif do_evaluate_params .index (nc ) != i :
135172 records .append (
136173 {
137- "class " : class_name ,
174+ "cls " : class_name ,
138175 "arg" : nc ,
139176 "error" : "Wrong position" ,
140177 "lineno" : method_node .lineno ,
@@ -156,7 +193,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
156193 if type_name != "DataFrame" :
157194 records .append (
158195 {
159- "class " : class_name ,
196+ "cls " : class_name ,
160197 "arg" : arg .arg ,
161198 "error" : f"Wrong type annotation '{ type_name } ' (expected 'DataFrame')" ,
162199 "lineno" : method_node .lineno ,
@@ -169,7 +206,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
169206 if len (kwonly_args ) != 1 :
170207 records .append (
171208 {
172- "class " : class_name ,
209+ "cls " : class_name ,
173210 "arg" : "kwonly" ,
174211 "error" : f"Expected 1 keyword-only argument, found { len (kwonly_args )} " ,
175212 "lineno" : method_node .lineno ,
@@ -179,7 +216,7 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
179216 elif kwonly_args [0 ].arg != "context" :
180217 records .append (
181218 {
182- "class " : class_name ,
219+ "cls " : class_name ,
183220 "arg" : kwonly_args [0 ].arg ,
184221 "error" : "Keyword-only argument should be named 'context'" ,
185222 "lineno" : method_node .lineno ,
@@ -192,14 +229,30 @@ def analyze_content(content: str, filename: str) -> list[ErrorRecord]:
192229 if type_name != "IRExecutionContext" :
193230 records .append (
194231 {
195- "class " : class_name ,
232+ "cls " : class_name ,
196233 "arg" : "context" ,
197234 "error" : f"Wrong type annotation '{ type_name } ' (expected 'IRExecutionContext')" ,
198235 "lineno" : method_node .lineno ,
199236 "filename" : filename ,
200237 }
201238 )
202239
240+ # Check that __init__ assigns self._non_child_args with matching length
241+ init_node = get_init_node (node )
242+ if init_node is not None :
243+ non_child_args_length = get_non_child_args_length (init_node )
244+ if non_child_args_length is not None :
245+ if non_child_args_length != len (non_child ):
246+ records .append (
247+ {
248+ "cls" : class_name ,
249+ "arg" : "_non_child_args" ,
250+ "error" : "Mismatch between 'self._non_child_args' and 'cls._non_child'" ,
251+ "lineno" : init_node .lineno ,
252+ "filename" : filename ,
253+ }
254+ )
255+
203256 return records
204257
205258
@@ -236,7 +289,7 @@ def main() -> int:
236289 for record in all_records :
237290 filename = record ["filename" ]
238291 lineno = record ["lineno" ]
239- class_name = record ["class " ]
292+ class_name = record ["cls " ]
240293 error = record ["error" ]
241294 arg = record ["arg" ]
242295 print (
0 commit comments