Skip to content
This repository was archived by the owner on Aug 26, 2020. It is now read-only.

Commit b23264e

Browse files
authored
Mvs prepare (#138)
* bug-fix: reintroduce _modules.prepare for backwards compatibility
1 parent b2e5150 commit b23264e

File tree

5 files changed

+120
-12
lines changed

5 files changed

+120
-12
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
2.3.2
6+
=====
7+
8+
* bug-fix: reintroduce _modules.prepare for backwards compatibility
9+
510
2.3.1
611
=====
712

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def read(file_name):
3434

3535
setup(
3636
name='sagemaker_containers',
37-
version='2.3.1',
37+
version='2.3.2',
3838
description='Open source library for creating containers to run on Amazon SageMaker.',
3939

4040
packages=packages,

src/sagemaker_containers/_modules.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import shlex
1818
import sys
19+
import textwrap
1920
import warnings
2021

2122
import six
@@ -48,6 +49,48 @@ def has_requirements(path): # type: (str) -> None
4849
return os.path.exists(os.path.join(path, 'requirements.txt'))
4950

5051

52+
def prepare(path, name): # type: (str, str) -> None
53+
"""Prepare a Python script (or module) to be imported as a module.
54+
If the script does not contain a setup.py file, it creates a minimal setup.
55+
Args:
56+
path (str): path to directory with the script or module.
57+
name (str): name of the script or module.
58+
"""
59+
setup_path = os.path.join(path, 'setup.py')
60+
if not os.path.exists(setup_path):
61+
data = textwrap.dedent("""
62+
from setuptools import setup
63+
setup(packages=[''],
64+
name="%s",
65+
version='1.0.0',
66+
include_package_data=True)
67+
""" % name)
68+
69+
logger.info('Module %s does not provide a setup.py. \nGenerating setup.py' % name)
70+
71+
_files.write_file(setup_path, data)
72+
73+
data = textwrap.dedent("""
74+
[wheel]
75+
universal = 1
76+
""")
77+
78+
logger.info('Generating setup.cfg')
79+
80+
_files.write_file(os.path.join(path, 'setup.cfg'), data)
81+
82+
data = textwrap.dedent("""
83+
recursive-include . *
84+
recursive-exclude . __pycache__*
85+
recursive-exclude . *.pyc
86+
recursive-exclude . *.pyo
87+
""")
88+
89+
logger.info('Generating MANIFEST.in')
90+
91+
_files.write_file(os.path.join(path, 'MANIFEST.in'), data)
92+
93+
5194
def install(path): # type: (str) -> None
5295
"""Install a Python module in the executing Python environment.
5396
Args:
@@ -173,6 +216,7 @@ def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, w
173216

174217
_files.download_and_extract(uri, name, _env.code_dir)
175218

219+
prepare(_env.code_dir, name)
176220
install(_env.code_dir)
177221

178222
_env.write_env_vars(env_vars)

test/functional/test_training_framework.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,13 @@ def framework_training_with_script_mode_fn():
241241
training_env.to_env_vars())
242242

243243

244+
def framework_training_with_run_modules_fn():
245+
training_env = sagemaker_containers.training_env()
246+
247+
modules.run_module(training_env.module_dir, training_env.to_cmd_args(),
248+
training_env.to_env_vars(), training_env.module_name)
249+
250+
244251
def test_parameter_server():
245252
module = test.UserModule(test.File(name='user_script.py', data=PARAMETER_SERVER_SCRIPT))
246253
hyperparameters = dict(sagemaker_program='user_script.py')
@@ -254,8 +261,10 @@ def test_parameter_server():
254261
process.kill()
255262

256263

257-
@pytest.mark.parametrize('user_script', [USER_MODE_SCRIPT])
258-
def test_script_mode(user_script):
264+
@pytest.mark.parametrize('user_script, training_fn', [
265+
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn],
266+
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn]])
267+
def test_script_mode(user_script, training_fn):
259268
channel = test.Channel.create(name='training')
260269

261270
features = [1, 2, 3, 4]
@@ -269,7 +278,7 @@ def test_script_mode(user_script):
269278

270279
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel])
271280

272-
assert execute_an_wrap_exit(framework_training_with_script_mode_fn) == trainer.SUCCESS_CODE
281+
assert execute_an_wrap_exit(training_fn) == trainer.SUCCESS_CODE
273282

274283
model_path = os.path.join(env.model_dir, 'saved_model')
275284

@@ -281,8 +290,10 @@ def test_script_mode(user_script):
281290
assert model.optimizer == 'SGD'
282291

283292

284-
@pytest.mark.parametrize('user_script', [USER_MODE_SCRIPT])
285-
def test_script_mode_local_directory(user_script, tmpdir):
293+
@pytest.mark.parametrize('user_script, training_fn', [
294+
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn],
295+
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn]])
296+
def test_script_mode_local_directory(user_script, training_fn, tmpdir):
286297
channel = test.Channel.create(name='training')
287298

288299
features = [1, 2, 3, 4]
@@ -300,7 +311,7 @@ def test_script_mode_local_directory(user_script, tmpdir):
300311

301312
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel], local=True)
302313

303-
assert execute_an_wrap_exit(framework_training_with_script_mode_fn) == trainer.SUCCESS_CODE
314+
assert execute_an_wrap_exit(training_fn) == trainer.SUCCESS_CODE
304315

305316
model_path = os.path.join(env.model_dir, 'saved_model')
306317

@@ -318,7 +329,10 @@ def test_script_mode_local_directory(user_script, tmpdir):
318329
"""
319330

320331

321-
def test_script_mode_client_error():
332+
@pytest.mark.parametrize('training_fn', [
333+
framework_training_with_script_mode_fn,
334+
framework_training_with_run_modules_fn])
335+
def test_script_mode_client_error(training_fn):
322336
channel = test.Channel.create(name='training')
323337

324338
module = test.UserModule(test.File(name='user_script.py', data=USER_MODE_SCRIPT_WITH_ERROR))
@@ -328,13 +342,16 @@ def test_script_mode_client_error():
328342
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel])
329343

330344
with pytest.raises(errors.ExecuteUserScriptError) as e:
331-
framework_training_with_script_mode_fn()
345+
training_fn()
332346

333347
message = str(e.value)
334348
assert 'ExecuteUserScriptError' in message
335349

336350

337-
def test_script_mode_client_import_error():
351+
@pytest.mark.parametrize('training_fn', [
352+
framework_training_with_script_mode_fn,
353+
framework_training_with_run_modules_fn])
354+
def test_script_mode_client_import_error(training_fn):
338355
channel = test.Channel.create(name='training')
339356

340357
requirements_file = test.File('requirements.txt', '42/0')
@@ -347,7 +364,7 @@ def test_script_mode_client_import_error():
347364
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel])
348365

349366
with pytest.raises(errors.InstallModuleError) as e:
350-
framework_training_with_script_mode_fn()
367+
training_fn()
351368

352369
message = str(e.value)
353370
assert 'InstallModuleError:' in message

test/unit/test_modules.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
import contextlib
1616
import os
1717
import sys
18+
import textwrap
1819

19-
from mock import call, patch
20+
from mock import call, mock_open, patch
2021
import pytest
2122
from six import PY2
2223

@@ -77,6 +78,47 @@ def patch_tmpdir():
7778
yield '/tmp'
7879

7980

81+
@patch(builtins_open, mock_open())
82+
@patch('os.path.exists', lambda x: False)
83+
def test_prepare():
84+
_modules.prepare('c:/path/to/', 'my-module')
85+
86+
open.assert_any_call('c:/path/to/setup.py', 'w')
87+
open.assert_any_call('c:/path/to/setup.cfg', 'w')
88+
open.assert_any_call('c:/path/to/MANIFEST.in', 'w')
89+
90+
data = textwrap.dedent("""
91+
from setuptools import setup
92+
setup(packages=[''],
93+
name="my-module",
94+
version='1.0.0',
95+
include_package_data=True)
96+
""")
97+
98+
open().write.assert_any_call(data)
99+
100+
data = textwrap.dedent("""
101+
[wheel]
102+
universal = 1
103+
""")
104+
open().write.assert_any_call(data)
105+
106+
data = textwrap.dedent("""
107+
recursive-include . *
108+
recursive-exclude . __pycache__*
109+
recursive-exclude . *.pyc
110+
recursive-exclude . *.pyo
111+
""")
112+
open().write.assert_any_call(data)
113+
114+
115+
@patch(builtins_open, mock_open())
116+
@patch('os.path.exists', lambda x: True)
117+
def test_prepare_already_prepared():
118+
_modules.prepare('c:/path/to/', 'my-module')
119+
open.assert_not_called()
120+
121+
80122
@patch('importlib.import_module')
81123
def test_exists(import_module):
82124
assert _modules.exists('my_module')

0 commit comments

Comments
 (0)