Skip to content

Commit 965f8ec

Browse files
committed
Fix lazy imports of objects on Python 3.6.
1 parent 42f0e2c commit 965f8ec

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

src/websockets/imports.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import importlib
21
import sys
32
import warnings
43
from typing import Any, Dict, Iterable, Optional
@@ -7,6 +6,27 @@
76
__all__ = ["lazy_import"]
87

98

9+
def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any:
10+
"""
11+
Import <name> from <source> in <namespace>.
12+
13+
There are two cases:
14+
15+
- <name> is an object defined in <source>
16+
- <name> is a submodule of source
17+
18+
Neither __import__ nor importlib.import_module does exactly this.
19+
__import__ is closer to the intended behavior.
20+
21+
"""
22+
level = 0
23+
while source[level] == ".":
24+
level += 1
25+
assert level < len(source), "importing from parent isn't supported"
26+
module = __import__(source[level:], namespace, None, [name], level)
27+
return getattr(module, name)
28+
29+
1030
def lazy_import(
1131
namespace: Dict[str, Any],
1232
aliases: Optional[Dict[str, str]] = None,
@@ -58,8 +78,7 @@ def __getattr__(name: str) -> Any:
5878
except KeyError:
5979
pass
6080
else:
61-
module = importlib.import_module(source, package)
62-
return getattr(module, name)
81+
return import_name(name, source, namespace)
6382

6483
assert deprecated_aliases is not None # mypy cannot figure this out
6584
try:
@@ -72,8 +91,7 @@ def __getattr__(name: str) -> Any:
7291
DeprecationWarning,
7392
stacklevel=2,
7493
)
75-
module = importlib.import_module(source, package)
76-
return getattr(module, name)
94+
return import_name(name, source, namespace)
7795

7896
raise AttributeError(f"module {package!r} has no attribute {name!r}")
7997

@@ -87,9 +105,7 @@ def __dir__() -> Iterable[str]:
87105
else: # pragma: no cover
88106

89107
for name, source in aliases.items():
90-
module = importlib.import_module(source, package)
91-
namespace[name] = getattr(module, name)
108+
namespace[name] = import_name(name, source, namespace)
92109

93110
for name, source in deprecated_aliases.items():
94-
module = importlib.import_module(source, package)
95-
namespace[name] = getattr(module, name)
111+
namespace[name] = import_name(name, source, namespace)

tests/test_imports.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import types
23
import unittest
34
import warnings
@@ -11,18 +12,30 @@
1112

1213

1314
class ImportsTests(unittest.TestCase):
15+
def setUp(self):
16+
self.mod = types.ModuleType("tests.test_imports.test_alias")
17+
self.mod.__package__ = self.mod.__name__
18+
1419
def test_get_alias(self):
15-
mod = types.ModuleType("tests.test_imports.test_alias")
16-
lazy_import(vars(mod), aliases={"foo": ".."})
20+
lazy_import(
21+
vars(self.mod),
22+
aliases={"foo": "...test_imports"},
23+
)
1724

18-
self.assertEqual(mod.foo, foo)
25+
self.assertEqual(self.mod.foo, foo)
1926

2027
def test_get_deprecated_alias(self):
21-
mod = types.ModuleType("tests.test_imports.test_alias")
22-
lazy_import(vars(mod), deprecated_aliases={"bar": ".."})
28+
lazy_import(
29+
vars(self.mod),
30+
deprecated_aliases={"bar": "...test_imports"},
31+
)
2332

2433
with warnings.catch_warnings(record=True) as recorded_warnings:
25-
self.assertEqual(mod.bar, bar)
34+
self.assertEqual(self.mod.bar, bar)
35+
36+
# No warnings raised on pre-PEP 526 Python.
37+
if sys.version_info[:2] < (3, 7): # pragma: no cover
38+
return
2639

2740
self.assertEqual(len(recorded_warnings), 1)
2841
warning = recorded_warnings[0].message
@@ -32,20 +45,22 @@ def test_get_deprecated_alias(self):
3245
self.assertEqual(type(warning), DeprecationWarning)
3346

3447
def test_dir(self):
35-
mod = types.ModuleType("tests.test_imports.test_alias")
36-
lazy_import(vars(mod), aliases={"foo": ".."}, deprecated_aliases={"bar": ".."})
48+
lazy_import(
49+
vars(self.mod),
50+
aliases={"foo": "...test_imports"},
51+
deprecated_aliases={"bar": "...test_imports"},
52+
)
3753

3854
self.assertEqual(
39-
[item for item in dir(mod) if not item[:2] == item[-2:] == "__"],
55+
[item for item in dir(self.mod) if not item[:2] == item[-2:] == "__"],
4056
["bar", "foo"],
4157
)
4258

4359
def test_attribute_error(self):
44-
mod = types.ModuleType("tests.test_imports.test_alias")
45-
lazy_import(vars(mod))
60+
lazy_import(vars(self.mod))
4661

4762
with self.assertRaises(AttributeError) as raised:
48-
mod.foo
63+
self.mod.foo
4964

5065
self.assertEqual(
5166
str(raised.exception),

0 commit comments

Comments
 (0)