Skip to content

Commit f91e684

Browse files
committed
Addressed Changes from Pull Request
- Removed IDE Debug statements - Removed Diff_test functionality; will be added back later in future PR
1 parent 67f0a04 commit f91e684

File tree

2 files changed

+5
-253
lines changed

2 files changed

+5
-253
lines changed

src/inline/plugin.py

Lines changed: 5 additions & 249 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,7 @@ class ExtractInlineTest(ast.NodeTransformer):
294294
arg_tag_str = "tag"
295295
arg_disabled_str = "disabled"
296296
arg_timeout_str = "timeout"
297-
arg_devices_str = "devices"
298-
diff_test_str = "diff_test"
297+
299298
assume = "assume"
300299
inline_module_imported = False
301300

@@ -384,10 +383,9 @@ def parse_constructor(self, node):
384383
self.arg_tag_str : 3,
385384
self.arg_disabled_str : 4,
386385
self.arg_timeout_str : 5,
387-
self.arg_devices_str : 6
388386
}
389387

390-
NUM_OF_ARGUMENTS = 7
388+
NUM_OF_ARGUMENTS = 6
391389
if len(node.args) + len(node.keywords) <= NUM_OF_ARGUMENTS:
392390
# positional arguments
393391
self.parse_constructor_args(node.args)
@@ -418,7 +416,6 @@ class ConstrArgs(enum.Enum):
418416
TAG_STR = 3
419417
DISABLED = 4
420418
TIMEOUT = 5
421-
DEVICES = 6
422419

423420
property_names = {
424421
ConstrArgs.TEST_NAME : "test_name",
@@ -427,7 +424,6 @@ class ConstrArgs(enum.Enum):
427424
ConstrArgs.TAG_STR : "tag",
428425
ConstrArgs.DISABLED : "disabled",
429426
ConstrArgs.TIMEOUT : "timeout",
430-
ConstrArgs.DEVICES : "devices"
431427
}
432428

433429
pre_38_val_names = {
@@ -437,7 +433,6 @@ class ConstrArgs(enum.Enum):
437433
ConstrArgs.TAG_STR : "s",
438434
ConstrArgs.DISABLED : "value",
439435
ConstrArgs.TIMEOUT : "n",
440-
ConstrArgs.DEVICES : ""
441436
}
442437

443438
pre_38_expec_ast_arg_type = {
@@ -465,27 +460,17 @@ class ConstrArgs(enum.Enum):
465460
ConstrArgs.TAG_STR : [None],
466461
ConstrArgs.DISABLED : [bool],
467462
ConstrArgs.TIMEOUT : [float, int],
468-
ConstrArgs.DEVICES : [str]
469463
}
470464

471-
NUM_OF_ARGUMENTS = 7
465+
NUM_OF_ARGUMENTS = 6
472466

473467
# Arguments organized by expected ast type, value type, and index in that order
474468
for index, arg in enumerate(args):
475469
# Skips over null arguments; needed for keywords
476470
if arg == None:
477471
continue
478472

479-
# Devices are not referenced in versions before 3.8; all other arguments can be from any version
480-
if index == ConstrArgs.DEVICES and isinstance(arg, ast.List):
481-
devices = []
482-
for elt in arg.elts:
483-
if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)):
484-
raise MalformedException("devices can only be List of string")
485-
if elt.value not in {"cpu", "cuda", "mps"}:
486-
raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']")
487-
devices.append(elt.value)
488-
self.cur_inline_test.devices = devices
473+
489474
# Assumes version is past 3.8, no explicit references to ast.Constant before 3.8
490475
else:
491476
corr_arg_type = False
@@ -898,231 +883,6 @@ def parse_check_not_same(self, node):
898883
else:
899884
raise MalformedException("inline test: invalid check_not_same(), expected 2 args")
900885

901-
def parse_diff_test(self, node):
902-
if not self.cur_inline_test.devices:
903-
raise MalformedException("diff_test can only be used with the 'devices' parameter.")
904-
905-
if len(node.args) != 1:
906-
raise MalformedException("diff_test() requires exactly 1 argument.")
907-
908-
output_node = self.parse_group(node.args[0])
909-
910-
# Get the original operation
911-
original_op = None
912-
for stmt in self.cur_inline_test.previous_stmts:
913-
if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id:
914-
original_op = stmt.value
915-
break
916-
917-
if not original_op:
918-
raise MalformedException("Could not find original operation for diff_test")
919-
920-
# Create our new statements
921-
new_statements = []
922-
device_outputs = []
923-
924-
# Import necessary modules for seed setting - Always add these
925-
# Import random module
926-
import_random = ast.ImportFrom(
927-
module='random',
928-
names=[ast.alias(name='seed', asname=None)],
929-
level=0
930-
)
931-
new_statements.append(import_random)
932-
933-
# Import numpy.random
934-
import_np = ast.ImportFrom(
935-
module='numpy',
936-
names=[ast.alias(name='random', asname='np_random')],
937-
level=0
938-
)
939-
new_statements.append(import_np)
940-
941-
# Create seed function - Always add this
942-
seed_func_def = ast.FunctionDef(
943-
name='set_random_seed',
944-
args=ast.arguments(
945-
posonlyargs=[],
946-
args=[ast.arg(arg='seed_value', annotation=None)],
947-
kwonlyargs=[],
948-
kw_defaults=[],
949-
defaults=[]
950-
),
951-
body=[
952-
ast.Expr(
953-
value=ast.Call(
954-
func=ast.Name(id='seed', ctx=ast.Load()),
955-
args=[ast.Name(id='seed_value', ctx=ast.Load())],
956-
keywords=[]
957-
)
958-
),
959-
ast.Expr(
960-
value=ast.Call(
961-
func=ast.Attribute(
962-
value=ast.Name(id='torch', ctx=ast.Load()),
963-
attr='manual_seed'
964-
),
965-
args=[ast.Name(id='seed_value', ctx=ast.Load())],
966-
keywords=[]
967-
)
968-
),
969-
ast.Expr(
970-
value=ast.Call(
971-
func=ast.Attribute(
972-
value=ast.Name(id='np_random', ctx=ast.Load()),
973-
attr='seed'
974-
),
975-
args=[ast.Name(id='seed_value', ctx=ast.Load())],
976-
keywords=[]
977-
)
978-
)
979-
],
980-
decorator_list=[],
981-
returns=None
982-
)
983-
new_statements.append(seed_func_def)
984-
985-
# Process input tensors
986-
for given_stmt in self.cur_inline_test.given_stmts:
987-
input_var = given_stmt.targets[0].id
988-
ref_var = f"{input_var}_ref"
989-
990-
# Always clone inputs for in-place operations
991-
new_statements.append(
992-
ast.Assign(
993-
targets=[ast.Name(id=ref_var, ctx=ast.Store())],
994-
value=ast.Call(
995-
func=ast.Attribute(
996-
value=given_stmt.value,
997-
attr="clone"
998-
),
999-
args=[],
1000-
keywords=[]
1001-
)
1002-
)
1003-
)
1004-
1005-
# Create device-specific versions
1006-
for device in self.cur_inline_test.devices:
1007-
device_var = f"{input_var}_{device}"
1008-
1009-
new_statements.append(
1010-
ast.Assign(
1011-
targets=[ast.Name(id=device_var, ctx=ast.Store())],
1012-
value=ast.Call(
1013-
func=ast.Attribute(
1014-
value=ast.Name(id=ref_var, ctx=ast.Load()),
1015-
attr="to"
1016-
),
1017-
args=[ast.Constant(value=device)],
1018-
keywords=[]
1019-
)
1020-
)
1021-
)
1022-
1023-
# Create device-specific operations
1024-
device_input_map = {device: {} for device in self.cur_inline_test.devices}
1025-
for device in self.cur_inline_test.devices:
1026-
for given_stmt in self.cur_inline_test.given_stmts:
1027-
input_var = given_stmt.targets[0].id
1028-
device_input_map[device][input_var] = f"{input_var}_{device}"
1029-
1030-
# Always set seed before each device operation - no condition check
1031-
new_statements.append(
1032-
ast.Expr(
1033-
value=ast.Call(
1034-
func=ast.Name(id='set_random_seed', ctx=ast.Load()),
1035-
args=[ast.Constant(value=42)], # Use constant seed 42
1036-
keywords=[]
1037-
)
1038-
)
1039-
)
1040-
1041-
device_op = copy.deepcopy(original_op)
1042-
1043-
# Replace input references
1044-
class ReplaceInputs(ast.NodeTransformer):
1045-
def visit_Name(self, node):
1046-
if node.id in device_input_map[device]:
1047-
return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx)
1048-
return node
1049-
1050-
device_op = ReplaceInputs().visit(device_op)
1051-
device_output = f"output_{device}"
1052-
1053-
new_statements.append(
1054-
ast.Assign(
1055-
targets=[ast.Name(id=device_output, ctx=ast.Store())],
1056-
value=device_op
1057-
)
1058-
)
1059-
device_outputs.append(device_output)
1060-
1061-
# Standard comparison method for all operations - no condition check
1062-
comparisons = []
1063-
for i in range(len(device_outputs) - 1):
1064-
dev1 = device_outputs[i]
1065-
dev2 = device_outputs[i + 1]
1066-
1067-
dev1_cpu = f"{dev1}_cpu"
1068-
dev2_cpu = f"{dev2}_cpu"
1069-
1070-
# Move outputs back to CPU for comparison
1071-
new_statements.append(
1072-
ast.Assign(
1073-
targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())],
1074-
value=ast.Call(
1075-
func=ast.Attribute(
1076-
value=ast.Name(id=dev1, ctx=ast.Load()),
1077-
attr="to"
1078-
),
1079-
args=[ast.Constant(value="cpu")],
1080-
keywords=[]
1081-
)
1082-
)
1083-
)
1084-
1085-
new_statements.append(
1086-
ast.Assign(
1087-
targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())],
1088-
value=ast.Call(
1089-
func=ast.Attribute(
1090-
value=ast.Name(id=dev2, ctx=ast.Load()),
1091-
attr="to"
1092-
),
1093-
args=[ast.Constant(value="cpu")],
1094-
keywords=[]
1095-
)
1096-
)
1097-
)
1098-
1099-
# Standard allclose comparison
1100-
comparison = self.build_assert_eq(
1101-
ast.Call(
1102-
func=ast.Attribute(
1103-
value=ast.Name(id=dev1_cpu, ctx=ast.Load()),
1104-
attr="allclose"
1105-
),
1106-
args=[
1107-
ast.Name(id=dev2_cpu, ctx=ast.Load())
1108-
],
1109-
keywords=[
1110-
ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)),
1111-
ast.keyword(arg="atol", value=ast.Constant(value=1e-4)),
1112-
ast.keyword(arg="equal_nan", value=ast.Constant(value=True))
1113-
]
1114-
),
1115-
ast.Constant(value=True)
1116-
)
1117-
comparisons.append(comparison)
1118-
1119-
# Replace statements
1120-
self.cur_inline_test.previous_stmts = new_statements
1121-
self.cur_inline_test.check_stmts = comparisons
1122-
1123-
1124-
1125-
1126886
def build_fail(self):
1127887
equal_node = ast.Compare(
1128888
left=ast.Constant(0),
@@ -1223,13 +983,11 @@ def parse_inline_test(self, node):
1223983
self.parse_check_same(call)
1224984
elif call.func.attr == self.check_not_same:
1225985
self.parse_check_not_same(call)
1226-
elif call.func.attr == self.diff_test_str:
1227-
self.parse_diff_test(call)
1228986
elif call.func.attr == self.fail_str:
1229987
self.parse_fail(call)
1230988
elif call.func.attr == self.given_str:
1231989
raise MalformedException(
1232-
f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()"
990+
f"inline test: given() must be called before check_eq()/check_true()/check_false()"
1233991
)
1234992
else:
1235993
raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}")
@@ -1370,8 +1128,6 @@ def _find(self, tests, obj, module, globs, seen):
13701128
######################################################################
13711129
class InlineTestRunner:
13721130
def run(self, test: InlineTest, out: List) -> None:
1373-
test_str = test.to_test()
1374-
print(test_str)
13751131
tree = ast.parse(test.to_test())
13761132
codeobj = compile(tree, filename="<ast>", mode="exec")
13771133
start_time = time.time()

tests/test_plugin.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22
from _pytest.pytester import Pytester
33
import pytest
44

5-
# For testing in Spyder only
6-
# if __name__ == "__main__":
7-
# pytest.main(['-v', '-s'])
8-
95

106
# pytest -p pytester
117
class TestInlinetests:

0 commit comments

Comments
 (0)