Skip to content

Commit c70f114

Browse files
authored
Merge pull request #10 from hanse7962/isolated-imports
Import Functionality
2 parents 59c2ff1 + 53e16f7 commit c70f114

File tree

2 files changed

+121
-9
lines changed

2 files changed

+121
-9
lines changed

src/inline/plugin.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(self):
159159
self.check_stmts = []
160160
self.given_stmts = []
161161
self.previous_stmts = []
162+
self.import_stmts = []
162163
self.prev_stmt_type = PrevStmtType.StmtExpr
163164
# the line number of test statement
164165
self.lineno = 0
@@ -174,10 +175,18 @@ def __init__(self):
174175
self.devices = None
175176
self.globs = {}
176177

178+
def write_imports(self):
179+
import_str = ""
180+
for n in self.import_stmts:
181+
import_str += ExtractInlineTest.node_to_source_code(n) + "\n"
182+
return import_str
183+
177184
def to_test(self):
185+
prefix = "\n"
186+
178187
if self.prev_stmt_type == PrevStmtType.CondExpr:
179188
if self.assume_stmts == []:
180-
return "\n".join(
189+
return prefix.join(
181190
[ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts]
182191
+ [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts]
183192
)
@@ -187,11 +196,11 @@ def to_test(self):
187196
)
188197
assume_statement = self.assume_stmts[0]
189198
assume_node = self.build_assume_node(assume_statement, body_nodes)
190-
return "\n".join(ExtractInlineTest.node_to_source_code(assume_node))
199+
return prefix.join(ExtractInlineTest.node_to_source_code(assume_node))
191200

192201
else:
193202
if self.assume_stmts is None or self.assume_stmts == []:
194-
return "\n".join(
203+
return prefix.join(
195204
[ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts]
196205
+ [ExtractInlineTest.node_to_source_code(n) for n in self.previous_stmts]
197206
+ [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts]
@@ -202,7 +211,7 @@ def to_test(self):
202211
)
203212
assume_statement = self.assume_stmts[0]
204213
assume_node = self.build_assume_node(assume_statement, body_nodes)
205-
return "\n".join([ExtractInlineTest.node_to_source_code(assume_node)])
214+
return prefix.join([ExtractInlineTest.node_to_source_code(assume_node)])
206215

207216
def build_assume_node(self, assumption_node, body_nodes):
208217
return ast.If(assumption_node, body_nodes, [])
@@ -296,6 +305,11 @@ class ExtractInlineTest(ast.NodeTransformer):
296305
arg_timeout_str = "timeout"
297306

298307
assume = "assume"
308+
309+
import_str = "import"
310+
from_str = "from"
311+
as_str = "as"
312+
299313
inline_module_imported = False
300314

301315
def __init__(self):
@@ -360,6 +374,23 @@ def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]):
360374
inline_test_calls.append(node)
361375
self.collect_inline_test_calls(node.func, inline_test_calls)
362376

377+
def collect_import_calls(self, node, import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]):
378+
"""
379+
collect all import calls in the node (should be done first)
380+
"""
381+
382+
while not isinstance(node, ast.Module) and node.parent != None:
383+
node = node.parent
384+
385+
if not isinstance(node, ast.Module):
386+
return
387+
388+
for child in node.children:
389+
if isinstance(child, ast.Import):
390+
import_calls.append(child)
391+
elif isinstance(child, ast.ImportFrom):
392+
import_from_calls.append(child)
393+
363394
def parse_constructor(self, node):
364395
"""
365396
Parse a constructor call.
@@ -931,8 +962,13 @@ def parse_parameterized_test(self):
931962
parameterized_test.test_name = self.cur_inline_test.test_name + "_" + str(index)
932963

933964
def parse_inline_test(self, node):
934-
inline_test_calls = []
965+
import_calls = []
966+
import_from_calls = []
967+
inline_test_calls = []
968+
935969
self.collect_inline_test_calls(node, inline_test_calls)
970+
self.collect_import_calls(node, import_calls, import_from_calls)
971+
936972
inline_test_calls.reverse()
937973

938974
if len(inline_test_calls) <= 1:
@@ -953,14 +989,20 @@ def parse_inline_test(self, node):
953989
self.parse_assume(call)
954990
inline_test_call_index += 1
955991

956-
# "given(a, 1)"
957992
for call in inline_test_calls[inline_test_call_index:]:
958-
if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str:
959-
self.parse_given(call)
960-
inline_test_call_index += 1
993+
if isinstance(call.func, ast.Attribute):
994+
if call.func.attr == self.given_str:
995+
self.parse_given(call)
996+
inline_test_call_index += 1
961997
else:
962998
break
963999

1000+
for import_stmt in import_calls:
1001+
self.cur_inline_test.import_stmts.append(import_stmt)
1002+
for import_stmt in import_from_calls:
1003+
self.cur_inline_test.import_stmts.append(import_stmt)
1004+
1005+
9641006
# "check_eq" or "check_true" or "check_false" or "check_neq"
9651007
for call in inline_test_calls[inline_test_call_index:]:
9661008
# "check_eq(a, 1)"

tests/test_plugin.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,76 @@ def m(a):
3131
items, reprec = pytester.inline_genitems(x)
3232
assert len(items) == 0
3333

34+
def test_inline_detects_imports(self, pytester: Pytester):
35+
checkfile = pytester.makepyfile(
36+
"""
37+
from inline import itest
38+
import datetime
39+
40+
def m(a):
41+
b = a + datetime.timedelta(days=365)
42+
itest().given(a, datetime.timedelta(days=1)).check_eq(b, datetime.timedelta(days=366))
43+
"""
44+
)
45+
for x in (pytester.path, checkfile):
46+
items, reprec = pytester.inline_genitems(x)
47+
assert len(items) == 1
48+
res = pytester.runpytest()
49+
assert res.ret != 1
50+
51+
def test_inline_detects_import_alias(self, pytester: Pytester):
52+
checkfile = pytester.makepyfile(
53+
"""
54+
from inline import itest
55+
import datetime as dt
56+
57+
def m(a):
58+
b = a + dt.timedelta(days=365)
59+
itest().given(a, dt.timedelta(days=1)).check_eq(b, dt.timedelta(days=366))
60+
"""
61+
)
62+
for x in (pytester.path, checkfile):
63+
items, reprec = pytester.inline_genitems(x)
64+
assert len(items) == 1
65+
res = pytester.runpytest()
66+
assert res.ret != 1
67+
68+
def test_inline_detects_from_imports(self, pytester: Pytester):
69+
checkfile = pytester.makepyfile(
70+
"""
71+
from inline import itest
72+
from enum import Enum
73+
74+
class Choice(Enum):
75+
YES = 0
76+
NO = 1
77+
78+
def m(a):
79+
b = a
80+
itest().given(a, Choice.YES).check_eq(b, Choice.YES)
81+
"""
82+
)
83+
for x in (pytester.path, checkfile):
84+
items, reprec = pytester.inline_genitems(x)
85+
assert len(items) == 1
86+
res = pytester.runpytest()
87+
assert res.ret == 0
88+
89+
def test_fail_on_importing_missing_module(self, pytester: Pytester):
90+
checkfile = pytester.makepyfile(
91+
"""
92+
from inline import itest
93+
from scipy import owijef as st
94+
95+
def m(n, p):
96+
b = st.binom(n, p)
97+
itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p)
98+
"""
99+
)
100+
for x in (pytester.path, checkfile):
101+
items, reprec = pytester.inline_genitems(x)
102+
assert len(items) == 0
103+
34104
def test_inline_malformed_given(self, pytester: Pytester):
35105
checkfile = pytester.makepyfile(
36106
"""

0 commit comments

Comments
 (0)