Skip to content
This repository was archived by the owner on Aug 26, 2020. It is now read-only.

Commit 3bdea78

Browse files
authored
Add capture_error flag to process.check_error and process.create (#140)
* feature: add capture_error flag to process.check_error and process.create and to all functions that runs process: modules.run, modules,run_module, and entry_point.run
1 parent d3c901c commit 3bdea78

File tree

10 files changed

+142
-70
lines changed

10 files changed

+142
-70
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
2.3.4
6+
=====
7+
8+
* feature: add capture_error flag to process.check_error and process.create and to all functions that runs process: modules.run, modules,run_module, and entry_point.run
9+
510
2.3.3
611
=====
712

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def read(file_name):
3434

3535
setup(
3636
name='sagemaker_containers',
37-
version='2.3.3',
37+
version='2.3.4',
3838
description='Open source library for creating containers to run on Amazon SageMaker.',
3939

4040
packages=packages,

src/sagemaker_containers/_errors.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import textwrap
1616

17+
import six
18+
1719

1820
class ClientError(Exception):
1921
pass
@@ -27,12 +29,20 @@ class _CalledProcessError(ClientError):
2729
cmd, return_code, output
2830
"""
2931

30-
def __init__(self, cmd, return_code=None):
32+
def __init__(self, cmd, return_code=None, output=None):
3133
self.return_code = return_code
3234
self.cmd = cmd
35+
self.output = output
3336

3437
def __str__(self):
35-
message = '%s:\nCommand "%s"' % (type(self).__name__, self.cmd)
38+
if six.PY3 and self.output:
39+
error_msg = '\n%s' % self.output.decode('latin1')
40+
elif self.output:
41+
error_msg = '\n%s' % self.output
42+
else:
43+
error_msg = ''
44+
45+
message = '%s:\nCommand "%s"%s' % (type(self).__name__, self.cmd, error_msg)
3646
return message.strip()
3747

3848

src/sagemaker_containers/_modules.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,12 @@ def prepare(path, name): # type: (str, str) -> None
9191
_files.write_file(os.path.join(path, 'MANIFEST.in'), data)
9292

9393

94-
def install(path): # type: (str) -> None
94+
def install(path, capture_error=False): # type: (str, bool) -> None
9595
"""Install a Python module in the executing Python environment.
9696
Args:
9797
path (str): Real path location of the Python module.
98+
capture_error (bool): Default false. If True, the running process captures the
99+
stderr, and appends it to the returned Exception message in case of errors.
98100
"""
99101
cmd = '%s -m pip install -U . ' % _process.python_executable()
100102

@@ -103,10 +105,11 @@ def install(path): # type: (str) -> None
103105

104106
logger.info('Installing module with the following command:\n%s', cmd)
105107

106-
_process.check_error(shlex.split(cmd), _errors.InstallModuleError, cwd=path)
108+
_process.check_error(shlex.split(cmd), _errors.InstallModuleError, cwd=path, capture_error=capture_error)
107109

108110

109-
def run(module_name, args=None, env_vars=None, wait=True): # type: (str, list, dict, bool) -> Popen
111+
def run(module_name, args=None, env_vars=None, wait=True, capture_error=False):
112+
# type: (str, list, dict, bool, bool) -> Popen
110113
"""Run Python module as a script.
111114
112115
Search sys.path for the named module and execute its contents as the __main__ module.
@@ -154,6 +157,8 @@ def run(module_name, args=None, env_vars=None, wait=True): # type: (str, list,
154157
module_name (str): module name in the same format required by python -m <module-name> cli command.
155158
args (list): A list of program arguments.
156159
env_vars (dict): A map containing the environment variables to be written.
160+
capture_error (bool): Default false. If True, the running process captures the
161+
stderr, and appends it to the returned Exception message in case of errors.
157162
"""
158163
args = args or []
159164
env_vars = env_vars or {}
@@ -163,10 +168,10 @@ def run(module_name, args=None, env_vars=None, wait=True): # type: (str, list,
163168
_logging.log_script_invocation(cmd, env_vars)
164169

165170
if wait:
166-
return _process.check_error(cmd, _errors.ExecuteUserScriptError)
171+
return _process.check_error(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)
167172

168173
else:
169-
return _process.create(cmd, _errors.ExecuteUserScriptError)
174+
return _process.create(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)
170175

171176

172177
def import_module(uri, name=DEFAULT_MODULE_NAME, cache=None): # type: (str, str, bool) -> module
@@ -195,8 +200,8 @@ def import_module(uri, name=DEFAULT_MODULE_NAME, cache=None): # type: (str, str
195200
six.reraise(_errors.ImportModuleError, _errors.ImportModuleError(e), sys.exc_info()[2])
196201

197202

198-
def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, wait=True):
199-
# type: (str, list, dict, str, bool, bool) -> Popen
203+
def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, wait=True, capture_error=False):
204+
# type: (str, list, dict, str, bool, bool, bool) -> Popen
200205
"""Download, prepare and executes a compressed tar file from S3 or provided directory as a module.
201206
202207
SageMaker Python SDK saves the user provided scripts as compressed tar files in S3
@@ -222,7 +227,7 @@ def run_module(uri, args, env_vars=None, name=DEFAULT_MODULE_NAME, cache=None, w
222227

223228
_env.write_env_vars(env_vars)
224229

225-
return run(name, args, env_vars, wait)
230+
return run(name, args, env_vars, wait, capture_error)
226231

227232

228233
def _warning_cache_deprecation(cache):

src/sagemaker_containers/_process.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,26 @@
2121
from sagemaker_containers import _env
2222

2323

24-
def create(cmd, error_class, cwd=None, **kwargs):
24+
def create(cmd, error_class, cwd=None, capture_error=False, **kwargs):
2525
try:
26-
return subprocess.Popen(cmd, env=os.environ, cwd=cwd or _env.code_dir, **kwargs)
26+
stderr = subprocess.PIPE if capture_error else None
27+
return subprocess.Popen(cmd, env=os.environ, cwd=cwd or _env.code_dir, stderr=stderr, **kwargs)
2728
except Exception as e:
2829
six.reraise(error_class, error_class(e), sys.exc_info()[2])
2930

3031

31-
def check_error(cmd, error_class, **kwargs):
32-
process = create(cmd, error_class, **kwargs)
33-
return_code = process.wait()
32+
def check_error(cmd, error_class, capture_error=False, **kwargs):
33+
process = create(cmd, error_class, capture_error=capture_error, **kwargs)
34+
35+
if capture_error:
36+
_, stderr = process.communicate()
37+
return_code = process.poll()
38+
else:
39+
stderr = None
40+
return_code = process.wait()
3441

3542
if return_code:
36-
raise error_class(return_code=return_code, cmd=' '.join(cmd))
43+
raise error_class(return_code=return_code, cmd=' '.join(cmd), output=stderr)
3744
return process
3845

3946

src/sagemaker_containers/entry_point.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from sagemaker_containers import _env, _errors, _files, _logging, _modules, _process
2020

2121

22-
def run(uri, user_entry_point, args, env_vars=None, wait=True):
23-
# type: (str, str, list, dict, bool) -> subprocess.Popen
22+
def run(uri, user_entry_point, args, env_vars=None, wait=True, capture_error=False):
23+
# type: (str, str, list, dict, bool, bool) -> subprocess.Popen
2424
"""Download, prepare and executes a compressed tar file from S3 or provided directory as an user
2525
entrypoint. Runs the user entry point, passing env_vars as environment variables and args as command
2626
arguments.
@@ -59,39 +59,45 @@ def run(uri, user_entry_point, args, env_vars=None, wait=True):
5959
uri (str): the location of the module.
6060
wait (bool): If True, holds the process executing the user entry-point.
6161
If False, returns the process that is executing it.
62+
capture_error (bool): Default false. If True, the running process captures the
63+
stderr, and appends it to the returned Exception message in case of errors.
64+
6265
"""
6366
env_vars = env_vars or {}
6467
env_vars = env_vars.copy()
6568

6669
_files.download_and_extract(uri, user_entry_point, _env.code_dir)
6770

68-
install(user_entry_point, _env.code_dir)
71+
install(user_entry_point, _env.code_dir, capture_error)
6972

7073
_env.write_env_vars(env_vars)
7174

72-
return _call(user_entry_point, args, env_vars, wait)
75+
return _call(user_entry_point, args, env_vars, wait, capture_error)
7376

7477

75-
def install(name, dst):
78+
def install(name, dst, capture_error=False):
7679
"""Install the user provided entry point to be executed as follow:
7780
- add the path to sys path
7881
- if the user entry point is a command, gives exec permissions to the script
7982
8083
Args:
8184
name (str): name of the script or module.
8285
dst (str): path to directory with the script or module.
86+
capture_error (bool): Default false. If True, the running process captures the
87+
stderr, and appends it to the returned Exception message in case of errors.
8388
"""
8489
if dst not in sys.path:
8590
sys.path.insert(0, dst)
8691

8792
entrypoint_type = entry_point_type(dst, name)
8893
if entrypoint_type is EntryPointType.PYTHON_PACKAGE:
89-
_modules.install(dst)
94+
_modules.install(dst, capture_error)
9095
if entrypoint_type is EntryPointType.COMMAND:
9196
os.chmod(os.path.join(dst, name), 511)
9297

9398

94-
def _call(user_entry_point, args=None, env_vars=None, wait=True): # type: (str, list, dict, bool) -> Popen
99+
def _call(user_entry_point, args=None, env_vars=None, wait=True, capture_error=False):
100+
# type: (str, list, dict, bool, bool) -> Popen
95101
args = args or []
96102
env_vars = env_vars or {}
97103

@@ -107,10 +113,10 @@ def _call(user_entry_point, args=None, env_vars=None, wait=True): # type: (str,
107113
_logging.log_script_invocation(cmd, env_vars)
108114

109115
if wait:
110-
return _process.check_error(cmd, _errors.ExecuteUserScriptError)
116+
return _process.check_error(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)
111117

112118
else:
113-
return _process.create(cmd, _errors.ExecuteUserScriptError)
119+
return _process.create(cmd, _errors.ExecuteUserScriptError, capture_error=capture_error)
114120

115121

116122
class EntryPointType(enum.Enum):

test/functional/test_training_framework.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,9 @@ def framework_training_fn():
147147
model.save(model_file)
148148

149149

150-
@pytest.mark.parametrize('user_script', [USER_SCRIPT_WITH_SAVE, USER_SCRIPT_WITH_SAVE])
151-
def test_training_framework(user_script):
150+
@pytest.mark.parametrize('user_script, capture_error',
151+
[[USER_SCRIPT_WITH_SAVE, False], [USER_SCRIPT_WITH_SAVE, True]])
152+
def test_training_framework(user_script, capture_error):
152153
with pytest.raises(ImportError):
153154
importlib.import_module(modules.DEFAULT_MODULE_NAME)
154155

@@ -234,18 +235,19 @@ def test_trainer_report_failure():
234235
assert 'No such file or directory' in message
235236

236237

237-
def framework_training_with_script_mode_fn():
238+
def framework_training_with_script_mode_fn(capture_error):
238239
training_env = sagemaker_containers.training_env()
239240

240241
entry_point.run(training_env.module_dir, training_env.user_entry_point, training_env.to_cmd_args(),
241-
training_env.to_env_vars())
242+
training_env.to_env_vars(), capture_error=capture_error)
242243

243244

244-
def framework_training_with_run_modules_fn():
245+
def framework_training_with_run_modules_fn(capture_error):
245246
training_env = sagemaker_containers.training_env()
246247

247248
modules.run_module(training_env.module_dir, training_env.to_cmd_args(),
248-
training_env.to_env_vars(), training_env.module_name)
249+
training_env.to_env_vars(), training_env.module_name,
250+
capture_error=capture_error)
249251

250252

251253
def test_parameter_server():
@@ -261,10 +263,10 @@ def test_parameter_server():
261263
process.kill()
262264

263265

264-
@pytest.mark.parametrize('user_script, training_fn', [
265-
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn],
266-
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn]])
267-
def test_script_mode(user_script, training_fn):
266+
@pytest.mark.parametrize('user_script, training_fn, capture_error', [
267+
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn, True],
268+
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn, False]])
269+
def test_script_mode(user_script, training_fn, capture_error):
268270
channel = test.Channel.create(name='training')
269271

270272
features = [1, 2, 3, 4]
@@ -278,7 +280,7 @@ def test_script_mode(user_script, training_fn):
278280

279281
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel])
280282

281-
assert execute_an_wrap_exit(training_fn) == trainer.SUCCESS_CODE
283+
assert execute_an_wrap_exit(training_fn, capture_error=capture_error) == trainer.SUCCESS_CODE
282284

283285
model_path = os.path.join(env.model_dir, 'saved_model')
284286

@@ -290,10 +292,10 @@ def test_script_mode(user_script, training_fn):
290292
assert model.optimizer == 'SGD'
291293

292294

293-
@pytest.mark.parametrize('user_script, training_fn', [
294-
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn],
295-
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn]])
296-
def test_script_mode_local_directory(user_script, training_fn, tmpdir):
295+
@pytest.mark.parametrize('user_script, training_fn, capture_error', [
296+
[USER_MODE_SCRIPT, framework_training_with_script_mode_fn, False],
297+
[USER_MODE_SCRIPT, framework_training_with_run_modules_fn, True]])
298+
def test_script_mode_local_directory(user_script, training_fn, capture_error, tmpdir):
297299
channel = test.Channel.create(name='training')
298300

299301
features = [1, 2, 3, 4]
@@ -311,7 +313,7 @@ def test_script_mode_local_directory(user_script, training_fn, tmpdir):
311313

312314
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel], local=True)
313315

314-
assert execute_an_wrap_exit(training_fn) == trainer.SUCCESS_CODE
316+
assert execute_an_wrap_exit(training_fn, capture_error=capture_error) == trainer.SUCCESS_CODE
315317

316318
model_path = os.path.join(env.model_dir, 'saved_model')
317319

@@ -329,10 +331,10 @@ def test_script_mode_local_directory(user_script, training_fn, tmpdir):
329331
"""
330332

331333

332-
@pytest.mark.parametrize('training_fn', [
333-
framework_training_with_script_mode_fn,
334-
framework_training_with_run_modules_fn])
335-
def test_script_mode_client_error(training_fn):
334+
@pytest.mark.parametrize('training_fn, capture_error', [
335+
(framework_training_with_script_mode_fn, True),
336+
(framework_training_with_run_modules_fn, False)])
337+
def test_script_mode_client_error(training_fn, capture_error):
336338
channel = test.Channel.create(name='training')
337339

338340
module = test.UserModule(test.File(name='user_script.py', data=USER_MODE_SCRIPT_WITH_ERROR))
@@ -342,16 +344,18 @@ def test_script_mode_client_error(training_fn):
342344
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel])
343345

344346
with pytest.raises(errors.ExecuteUserScriptError) as e:
345-
training_fn()
347+
training_fn(capture_error)
346348

347349
message = str(e.value)
348350
assert 'ExecuteUserScriptError' in message
351+
if capture_error:
352+
assert 'ZeroDivisionError' in message
349353

350354

351-
@pytest.mark.parametrize('training_fn', [
352-
framework_training_with_script_mode_fn,
353-
framework_training_with_run_modules_fn])
354-
def test_script_mode_client_import_error(training_fn):
355+
@pytest.mark.parametrize('training_fn, capture_error', [
356+
[framework_training_with_script_mode_fn, True],
357+
[framework_training_with_run_modules_fn, False]])
358+
def test_script_mode_client_import_error(training_fn, capture_error):
355359
channel = test.Channel.create(name='training')
356360

357361
requirements_file = test.File('requirements.txt', '42/0')
@@ -364,20 +368,24 @@ def test_script_mode_client_import_error(training_fn):
364368
test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel])
365369

366370
with pytest.raises(errors.InstallModuleError) as e:
367-
training_fn()
371+
training_fn(capture_error)
368372

369373
message = str(e.value)
370374
assert 'InstallModuleError:' in message
371375

376+
if capture_error:
377+
assert "Invalid requirement: \'42/0\'" in message
378+
assert "It looks like a path. File \'42/0\' does not exist." in message
379+
372380

373381
def failure_message():
374382
with open(os.path.join(env.output_dir, 'failure')) as f:
375383
return f.read()
376384

377385

378-
def execute_an_wrap_exit(fn):
386+
def execute_an_wrap_exit(fn, **kargs):
379387
try:
380-
fn()
388+
fn(**kargs)
381389
return trainer.SUCCESS_CODE
382390
except ValueError as e:
383391
return int(str(e))

0 commit comments

Comments
 (0)