@@ -294,8 +294,7 @@ class ExtractInlineTest(ast.NodeTransformer):
294294 arg_tag_str = "tag"
295295 arg_disabled_str = "disabled"
296296 arg_timeout_str = "timeout"
297- arg_devices_str = "devices"
298- diff_test_str = "diff_test"
297+
299298 assume = "assume"
300299 inline_module_imported = False
301300
@@ -384,10 +383,9 @@ def parse_constructor(self, node):
384383 self .arg_tag_str : 3 ,
385384 self .arg_disabled_str : 4 ,
386385 self .arg_timeout_str : 5 ,
387- self .arg_devices_str : 6
388386 }
389387
390- NUM_OF_ARGUMENTS = 7
388+ NUM_OF_ARGUMENTS = 6
391389 if len (node .args ) + len (node .keywords ) <= NUM_OF_ARGUMENTS :
392390 # positional arguments
393391 self .parse_constructor_args (node .args )
@@ -418,7 +416,6 @@ class ConstrArgs(enum.Enum):
418416 TAG_STR = 3
419417 DISABLED = 4
420418 TIMEOUT = 5
421- DEVICES = 6
422419
423420 property_names = {
424421 ConstrArgs .TEST_NAME : "test_name" ,
@@ -427,7 +424,6 @@ class ConstrArgs(enum.Enum):
427424 ConstrArgs .TAG_STR : "tag" ,
428425 ConstrArgs .DISABLED : "disabled" ,
429426 ConstrArgs .TIMEOUT : "timeout" ,
430- ConstrArgs .DEVICES : "devices"
431427 }
432428
433429 pre_38_val_names = {
@@ -437,7 +433,6 @@ class ConstrArgs(enum.Enum):
437433 ConstrArgs .TAG_STR : "s" ,
438434 ConstrArgs .DISABLED : "value" ,
439435 ConstrArgs .TIMEOUT : "n" ,
440- ConstrArgs .DEVICES : ""
441436 }
442437
443438 pre_38_expec_ast_arg_type = {
@@ -465,27 +460,17 @@ class ConstrArgs(enum.Enum):
465460 ConstrArgs .TAG_STR : [None ],
466461 ConstrArgs .DISABLED : [bool ],
467462 ConstrArgs .TIMEOUT : [float , int ],
468- ConstrArgs .DEVICES : [str ]
469463 }
470464
471- NUM_OF_ARGUMENTS = 7
465+ NUM_OF_ARGUMENTS = 6
472466
473467 # Arguments organized by expected ast type, value type, and index in that order
474468 for index , arg in enumerate (args ):
475469 # Skips over null arguments; needed for keywords
476470 if arg == None :
477471 continue
478472
479- # Devices are not referenced in versions before 3.8; all other arguments can be from any version
480- if index == ConstrArgs .DEVICES and isinstance (arg , ast .List ):
481- devices = []
482- for elt in arg .elts :
483- if not (isinstance (elt , ast .Constant ) and isinstance (elt .value , str )):
484- raise MalformedException ("devices can only be List of string" )
485- if elt .value not in {"cpu" , "cuda" , "mps" }:
486- raise MalformedException (f"Invalid device: { elt .value } . Must be one of ['cpu', 'cuda', 'mps']" )
487- devices .append (elt .value )
488- self .cur_inline_test .devices = devices
473+
489474 # Assumes version is past 3.8, no explicit references to ast.Constant before 3.8
490475 else :
491476 corr_arg_type = False
@@ -898,231 +883,6 @@ def parse_check_not_same(self, node):
898883 else :
899884 raise MalformedException ("inline test: invalid check_not_same(), expected 2 args" )
900885
901- def parse_diff_test (self , node ):
902- if not self .cur_inline_test .devices :
903- raise MalformedException ("diff_test can only be used with the 'devices' parameter." )
904-
905- if len (node .args ) != 1 :
906- raise MalformedException ("diff_test() requires exactly 1 argument." )
907-
908- output_node = self .parse_group (node .args [0 ])
909-
910- # Get the original operation
911- original_op = None
912- for stmt in self .cur_inline_test .previous_stmts :
913- if isinstance (stmt , ast .Assign ) and stmt .targets [0 ].id == output_node .id :
914- original_op = stmt .value
915- break
916-
917- if not original_op :
918- raise MalformedException ("Could not find original operation for diff_test" )
919-
920- # Create our new statements
921- new_statements = []
922- device_outputs = []
923-
924- # Import necessary modules for seed setting - Always add these
925- # Import random module
926- import_random = ast .ImportFrom (
927- module = 'random' ,
928- names = [ast .alias (name = 'seed' , asname = None )],
929- level = 0
930- )
931- new_statements .append (import_random )
932-
933- # Import numpy.random
934- import_np = ast .ImportFrom (
935- module = 'numpy' ,
936- names = [ast .alias (name = 'random' , asname = 'np_random' )],
937- level = 0
938- )
939- new_statements .append (import_np )
940-
941- # Create seed function - Always add this
942- seed_func_def = ast .FunctionDef (
943- name = 'set_random_seed' ,
944- args = ast .arguments (
945- posonlyargs = [],
946- args = [ast .arg (arg = 'seed_value' , annotation = None )],
947- kwonlyargs = [],
948- kw_defaults = [],
949- defaults = []
950- ),
951- body = [
952- ast .Expr (
953- value = ast .Call (
954- func = ast .Name (id = 'seed' , ctx = ast .Load ()),
955- args = [ast .Name (id = 'seed_value' , ctx = ast .Load ())],
956- keywords = []
957- )
958- ),
959- ast .Expr (
960- value = ast .Call (
961- func = ast .Attribute (
962- value = ast .Name (id = 'torch' , ctx = ast .Load ()),
963- attr = 'manual_seed'
964- ),
965- args = [ast .Name (id = 'seed_value' , ctx = ast .Load ())],
966- keywords = []
967- )
968- ),
969- ast .Expr (
970- value = ast .Call (
971- func = ast .Attribute (
972- value = ast .Name (id = 'np_random' , ctx = ast .Load ()),
973- attr = 'seed'
974- ),
975- args = [ast .Name (id = 'seed_value' , ctx = ast .Load ())],
976- keywords = []
977- )
978- )
979- ],
980- decorator_list = [],
981- returns = None
982- )
983- new_statements .append (seed_func_def )
984-
985- # Process input tensors
986- for given_stmt in self .cur_inline_test .given_stmts :
987- input_var = given_stmt .targets [0 ].id
988- ref_var = f"{ input_var } _ref"
989-
990- # Always clone inputs for in-place operations
991- new_statements .append (
992- ast .Assign (
993- targets = [ast .Name (id = ref_var , ctx = ast .Store ())],
994- value = ast .Call (
995- func = ast .Attribute (
996- value = given_stmt .value ,
997- attr = "clone"
998- ),
999- args = [],
1000- keywords = []
1001- )
1002- )
1003- )
1004-
1005- # Create device-specific versions
1006- for device in self .cur_inline_test .devices :
1007- device_var = f"{ input_var } _{ device } "
1008-
1009- new_statements .append (
1010- ast .Assign (
1011- targets = [ast .Name (id = device_var , ctx = ast .Store ())],
1012- value = ast .Call (
1013- func = ast .Attribute (
1014- value = ast .Name (id = ref_var , ctx = ast .Load ()),
1015- attr = "to"
1016- ),
1017- args = [ast .Constant (value = device )],
1018- keywords = []
1019- )
1020- )
1021- )
1022-
1023- # Create device-specific operations
1024- device_input_map = {device : {} for device in self .cur_inline_test .devices }
1025- for device in self .cur_inline_test .devices :
1026- for given_stmt in self .cur_inline_test .given_stmts :
1027- input_var = given_stmt .targets [0 ].id
1028- device_input_map [device ][input_var ] = f"{ input_var } _{ device } "
1029-
1030- # Always set seed before each device operation - no condition check
1031- new_statements .append (
1032- ast .Expr (
1033- value = ast .Call (
1034- func = ast .Name (id = 'set_random_seed' , ctx = ast .Load ()),
1035- args = [ast .Constant (value = 42 )], # Use constant seed 42
1036- keywords = []
1037- )
1038- )
1039- )
1040-
1041- device_op = copy .deepcopy (original_op )
1042-
1043- # Replace input references
1044- class ReplaceInputs (ast .NodeTransformer ):
1045- def visit_Name (self , node ):
1046- if node .id in device_input_map [device ]:
1047- return ast .Name (id = device_input_map [device ][node .id ], ctx = node .ctx )
1048- return node
1049-
1050- device_op = ReplaceInputs ().visit (device_op )
1051- device_output = f"output_{ device } "
1052-
1053- new_statements .append (
1054- ast .Assign (
1055- targets = [ast .Name (id = device_output , ctx = ast .Store ())],
1056- value = device_op
1057- )
1058- )
1059- device_outputs .append (device_output )
1060-
1061- # Standard comparison method for all operations - no condition check
1062- comparisons = []
1063- for i in range (len (device_outputs ) - 1 ):
1064- dev1 = device_outputs [i ]
1065- dev2 = device_outputs [i + 1 ]
1066-
1067- dev1_cpu = f"{ dev1 } _cpu"
1068- dev2_cpu = f"{ dev2 } _cpu"
1069-
1070- # Move outputs back to CPU for comparison
1071- new_statements .append (
1072- ast .Assign (
1073- targets = [ast .Name (id = dev1_cpu , ctx = ast .Store ())],
1074- value = ast .Call (
1075- func = ast .Attribute (
1076- value = ast .Name (id = dev1 , ctx = ast .Load ()),
1077- attr = "to"
1078- ),
1079- args = [ast .Constant (value = "cpu" )],
1080- keywords = []
1081- )
1082- )
1083- )
1084-
1085- new_statements .append (
1086- ast .Assign (
1087- targets = [ast .Name (id = dev2_cpu , ctx = ast .Store ())],
1088- value = ast .Call (
1089- func = ast .Attribute (
1090- value = ast .Name (id = dev2 , ctx = ast .Load ()),
1091- attr = "to"
1092- ),
1093- args = [ast .Constant (value = "cpu" )],
1094- keywords = []
1095- )
1096- )
1097- )
1098-
1099- # Standard allclose comparison
1100- comparison = self .build_assert_eq (
1101- ast .Call (
1102- func = ast .Attribute (
1103- value = ast .Name (id = dev1_cpu , ctx = ast .Load ()),
1104- attr = "allclose"
1105- ),
1106- args = [
1107- ast .Name (id = dev2_cpu , ctx = ast .Load ())
1108- ],
1109- keywords = [
1110- ast .keyword (arg = "rtol" , value = ast .Constant (value = 1e-4 )),
1111- ast .keyword (arg = "atol" , value = ast .Constant (value = 1e-4 )),
1112- ast .keyword (arg = "equal_nan" , value = ast .Constant (value = True ))
1113- ]
1114- ),
1115- ast .Constant (value = True )
1116- )
1117- comparisons .append (comparison )
1118-
1119- # Replace statements
1120- self .cur_inline_test .previous_stmts = new_statements
1121- self .cur_inline_test .check_stmts = comparisons
1122-
1123-
1124-
1125-
1126886 def build_fail (self ):
1127887 equal_node = ast .Compare (
1128888 left = ast .Constant (0 ),
@@ -1223,13 +983,11 @@ def parse_inline_test(self, node):
1223983 self .parse_check_same (call )
1224984 elif call .func .attr == self .check_not_same :
1225985 self .parse_check_not_same (call )
1226- elif call .func .attr == self .diff_test_str :
1227- self .parse_diff_test (call )
1228986 elif call .func .attr == self .fail_str :
1229987 self .parse_fail (call )
1230988 elif call .func .attr == self .given_str :
1231989 raise MalformedException (
1232- f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test() "
990+ f"inline test: given() must be called before check_eq()/check_true()/check_false()"
1233991 )
1234992 else :
1235993 raise MalformedException (f"inline test: invalid function call { self .node_to_source_code (call .func )} " )
@@ -1370,8 +1128,6 @@ def _find(self, tests, obj, module, globs, seen):
13701128######################################################################
13711129class InlineTestRunner :
13721130 def run (self , test : InlineTest , out : List ) -> None :
1373- test_str = test .to_test ()
1374- print (test_str )
13751131 tree = ast .parse (test .to_test ())
13761132 codeobj = compile (tree , filename = "<ast>" , mode = "exec" )
13771133 start_time = time .time ()
0 commit comments