@@ -159,6 +159,7 @@ def __init__(self):
159159 self .check_stmts = []
160160 self .given_stmts = []
161161 self .previous_stmts = []
162+ self .import_stmts = []
162163 self .prev_stmt_type = PrevStmtType .StmtExpr
163164 # the line number of test statement
164165 self .lineno = 0
@@ -174,10 +175,18 @@ def __init__(self):
174175 self .devices = None
175176 self .globs = {}
176177
178+ def write_imports (self ):
179+ import_str = ""
180+ for n in self .import_stmts :
181+ import_str += ExtractInlineTest .node_to_source_code (n ) + "\n "
182+ return import_str
183+
177184 def to_test (self ):
185+ prefix = "\n "
186+
178187 if self .prev_stmt_type == PrevStmtType .CondExpr :
179188 if self .assume_stmts == []:
180- return " \n " .join (
189+ return prefix .join (
181190 [ExtractInlineTest .node_to_source_code (n ) for n in self .given_stmts ]
182191 + [ExtractInlineTest .node_to_source_code (n ) for n in self .check_stmts ]
183192 )
@@ -187,11 +196,11 @@ def to_test(self):
187196 )
188197 assume_statement = self .assume_stmts [0 ]
189198 assume_node = self .build_assume_node (assume_statement , body_nodes )
190- return " \n " .join (ExtractInlineTest .node_to_source_code (assume_node ))
199+ return prefix .join (ExtractInlineTest .node_to_source_code (assume_node ))
191200
192201 else :
193202 if self .assume_stmts is None or self .assume_stmts == []:
194- return " \n " .join (
203+ return prefix .join (
195204 [ExtractInlineTest .node_to_source_code (n ) for n in self .given_stmts ]
196205 + [ExtractInlineTest .node_to_source_code (n ) for n in self .previous_stmts ]
197206 + [ExtractInlineTest .node_to_source_code (n ) for n in self .check_stmts ]
@@ -202,7 +211,7 @@ def to_test(self):
202211 )
203212 assume_statement = self .assume_stmts [0 ]
204213 assume_node = self .build_assume_node (assume_statement , body_nodes )
205- return " \n " .join ([ExtractInlineTest .node_to_source_code (assume_node )])
214+ return prefix .join ([ExtractInlineTest .node_to_source_code (assume_node )])
206215
207216 def build_assume_node (self , assumption_node , body_nodes ):
208217 return ast .If (assumption_node , body_nodes , [])
@@ -296,6 +305,11 @@ class ExtractInlineTest(ast.NodeTransformer):
296305 arg_timeout_str = "timeout"
297306
298307 assume = "assume"
308+
309+ import_str = "import"
310+ from_str = "from"
311+ as_str = "as"
312+
299313 inline_module_imported = False
300314
301315 def __init__ (self ):
@@ -360,6 +374,23 @@ def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]):
360374 inline_test_calls .append (node )
361375 self .collect_inline_test_calls (node .func , inline_test_calls )
362376
377+ def collect_import_calls (self , node , import_calls : List [ast .Import ], import_from_calls : List [ast .ImportFrom ]):
378+ """
379+ collect all import calls in the node (should be done first)
380+ """
381+
382+ while not isinstance (node , ast .Module ) and node .parent != None :
383+ node = node .parent
384+
385+ if not isinstance (node , ast .Module ):
386+ return
387+
388+ for child in node .children :
389+ if isinstance (child , ast .Import ):
390+ import_calls .append (child )
391+ elif isinstance (child , ast .ImportFrom ):
392+ import_from_calls .append (child )
393+
363394 def parse_constructor (self , node ):
364395 """
365396 Parse a constructor call.
@@ -931,8 +962,13 @@ def parse_parameterized_test(self):
931962 parameterized_test .test_name = self .cur_inline_test .test_name + "_" + str (index )
932963
933964 def parse_inline_test (self , node ):
934- inline_test_calls = []
965+ import_calls = []
966+ import_from_calls = []
967+ inline_test_calls = []
968+
935969 self .collect_inline_test_calls (node , inline_test_calls )
970+ self .collect_import_calls (node , import_calls , import_from_calls )
971+
936972 inline_test_calls .reverse ()
937973
938974 if len (inline_test_calls ) <= 1 :
@@ -953,14 +989,20 @@ def parse_inline_test(self, node):
953989 self .parse_assume (call )
954990 inline_test_call_index += 1
955991
956- # "given(a, 1)"
957992 for call in inline_test_calls [inline_test_call_index :]:
958- if isinstance (call .func , ast .Attribute ) and call .func .attr == self .given_str :
959- self .parse_given (call )
960- inline_test_call_index += 1
993+ if isinstance (call .func , ast .Attribute ):
994+ if call .func .attr == self .given_str :
995+ self .parse_given (call )
996+ inline_test_call_index += 1
961997 else :
962998 break
963999
1000+ for import_stmt in import_calls :
1001+ self .cur_inline_test .import_stmts .append (import_stmt )
1002+ for import_stmt in import_from_calls :
1003+ self .cur_inline_test .import_stmts .append (import_stmt )
1004+
1005+
9641006 # "check_eq" or "check_true" or "check_false" or "check_neq"
9651007 for call in inline_test_calls [inline_test_call_index :]:
9661008 # "check_eq(a, 1)"
0 commit comments