@@ -147,8 +147,9 @@ def framework_training_fn():
147147 model .save (model_file )
148148
149149
150- @pytest .mark .parametrize ('user_script' , [USER_SCRIPT_WITH_SAVE , USER_SCRIPT_WITH_SAVE ])
151- def test_training_framework (user_script ):
150+ @pytest .mark .parametrize ('user_script, capture_error' ,
151+ [[USER_SCRIPT_WITH_SAVE , False ], [USER_SCRIPT_WITH_SAVE , True ]])
152+ def test_training_framework (user_script , capture_error ):
152153 with pytest .raises (ImportError ):
153154 importlib .import_module (modules .DEFAULT_MODULE_NAME )
154155
@@ -234,18 +235,19 @@ def test_trainer_report_failure():
234235 assert 'No such file or directory' in message
235236
236237
237- def framework_training_with_script_mode_fn ():
238+ def framework_training_with_script_mode_fn (capture_error ):
238239 training_env = sagemaker_containers .training_env ()
239240
240241 entry_point .run (training_env .module_dir , training_env .user_entry_point , training_env .to_cmd_args (),
241- training_env .to_env_vars ())
242+ training_env .to_env_vars (), capture_error = capture_error )
242243
243244
244- def framework_training_with_run_modules_fn ():
245+ def framework_training_with_run_modules_fn (capture_error ):
245246 training_env = sagemaker_containers .training_env ()
246247
247248 modules .run_module (training_env .module_dir , training_env .to_cmd_args (),
248- training_env .to_env_vars (), training_env .module_name )
249+ training_env .to_env_vars (), training_env .module_name ,
250+ capture_error = capture_error )
249251
250252
251253def test_parameter_server ():
@@ -261,10 +263,10 @@ def test_parameter_server():
261263 process .kill ()
262264
263265
264- @pytest .mark .parametrize ('user_script, training_fn' , [
265- [USER_MODE_SCRIPT , framework_training_with_script_mode_fn ],
266- [USER_MODE_SCRIPT , framework_training_with_run_modules_fn ]])
267- def test_script_mode (user_script , training_fn ):
266+ @pytest .mark .parametrize ('user_script, training_fn, capture_error ' , [
267+ [USER_MODE_SCRIPT , framework_training_with_script_mode_fn , True ],
268+ [USER_MODE_SCRIPT , framework_training_with_run_modules_fn , False ]])
269+ def test_script_mode (user_script , training_fn , capture_error ):
268270 channel = test .Channel .create (name = 'training' )
269271
270272 features = [1 , 2 , 3 , 4 ]
@@ -278,7 +280,7 @@ def test_script_mode(user_script, training_fn):
278280
279281 test .prepare (user_module = module , hyperparameters = hyperparameters , channels = [channel ])
280282
281- assert execute_an_wrap_exit (training_fn ) == trainer .SUCCESS_CODE
283+ assert execute_an_wrap_exit (training_fn , capture_error = capture_error ) == trainer .SUCCESS_CODE
282284
283285 model_path = os .path .join (env .model_dir , 'saved_model' )
284286
@@ -290,10 +292,10 @@ def test_script_mode(user_script, training_fn):
290292 assert model .optimizer == 'SGD'
291293
292294
293- @pytest .mark .parametrize ('user_script, training_fn' , [
294- [USER_MODE_SCRIPT , framework_training_with_script_mode_fn ],
295- [USER_MODE_SCRIPT , framework_training_with_run_modules_fn ]])
296- def test_script_mode_local_directory (user_script , training_fn , tmpdir ):
295+ @pytest .mark .parametrize ('user_script, training_fn, capture_error ' , [
296+ [USER_MODE_SCRIPT , framework_training_with_script_mode_fn , False ],
297+ [USER_MODE_SCRIPT , framework_training_with_run_modules_fn , True ]])
298+ def test_script_mode_local_directory (user_script , training_fn , capture_error , tmpdir ):
297299 channel = test .Channel .create (name = 'training' )
298300
299301 features = [1 , 2 , 3 , 4 ]
@@ -311,7 +313,7 @@ def test_script_mode_local_directory(user_script, training_fn, tmpdir):
311313
312314 test .prepare (user_module = module , hyperparameters = hyperparameters , channels = [channel ], local = True )
313315
314- assert execute_an_wrap_exit (training_fn ) == trainer .SUCCESS_CODE
316+ assert execute_an_wrap_exit (training_fn , capture_error = capture_error ) == trainer .SUCCESS_CODE
315317
316318 model_path = os .path .join (env .model_dir , 'saved_model' )
317319
@@ -329,10 +331,10 @@ def test_script_mode_local_directory(user_script, training_fn, tmpdir):
329331"""
330332
331333
332- @pytest .mark .parametrize ('training_fn' , [
333- framework_training_with_script_mode_fn ,
334- framework_training_with_run_modules_fn ])
335- def test_script_mode_client_error (training_fn ):
334+ @pytest .mark .parametrize ('training_fn, capture_error ' , [
335+ ( framework_training_with_script_mode_fn , True ) ,
336+ ( framework_training_with_run_modules_fn , False ) ])
337+ def test_script_mode_client_error (training_fn , capture_error ):
336338 channel = test .Channel .create (name = 'training' )
337339
338340 module = test .UserModule (test .File (name = 'user_script.py' , data = USER_MODE_SCRIPT_WITH_ERROR ))
@@ -342,16 +344,18 @@ def test_script_mode_client_error(training_fn):
342344 test .prepare (user_module = module , hyperparameters = hyperparameters , channels = [channel ])
343345
344346 with pytest .raises (errors .ExecuteUserScriptError ) as e :
345- training_fn ()
347+ training_fn (capture_error )
346348
347349 message = str (e .value )
348350 assert 'ExecuteUserScriptError' in message
351+ if capture_error :
352+ assert 'ZeroDivisionError' in message
349353
350354
351- @pytest .mark .parametrize ('training_fn' , [
352- framework_training_with_script_mode_fn ,
353- framework_training_with_run_modules_fn ])
354- def test_script_mode_client_import_error (training_fn ):
355+ @pytest .mark .parametrize ('training_fn, capture_error ' , [
356+ [ framework_training_with_script_mode_fn , True ] ,
357+ [ framework_training_with_run_modules_fn , False ] ])
358+ def test_script_mode_client_import_error (training_fn , capture_error ):
355359 channel = test .Channel .create (name = 'training' )
356360
357361 requirements_file = test .File ('requirements.txt' , '42/0' )
@@ -364,20 +368,24 @@ def test_script_mode_client_import_error(training_fn):
364368 test .prepare (user_module = module , hyperparameters = hyperparameters , channels = [channel ])
365369
366370 with pytest .raises (errors .InstallModuleError ) as e :
367- training_fn ()
371+ training_fn (capture_error )
368372
369373 message = str (e .value )
370374 assert 'InstallModuleError:' in message
371375
376+ if capture_error :
377+ assert "Invalid requirement: \' 42/0\' " in message
378+ assert "It looks like a path. File \' 42/0\' does not exist." in message
379+
372380
373381def failure_message ():
374382 with open (os .path .join (env .output_dir , 'failure' )) as f :
375383 return f .read ()
376384
377385
378- def execute_an_wrap_exit (fn ):
386+ def execute_an_wrap_exit (fn , ** kargs ):
379387 try :
380- fn ()
388+ fn (** kargs )
381389 return trainer .SUCCESS_CODE
382390 except ValueError as e :
383391 return int (str (e ))
0 commit comments