@@ -241,6 +241,13 @@ def framework_training_with_script_mode_fn():
241241 training_env .to_env_vars ())
242242
243243
244+ def framework_training_with_run_modules_fn ():
245+ training_env = sagemaker_containers .training_env ()
246+
247+ modules .run_module (training_env .module_dir , training_env .to_cmd_args (),
248+ training_env .to_env_vars (), training_env .module_name )
249+
250+
244251def test_parameter_server ():
245252 module = test .UserModule (test .File (name = 'user_script.py' , data = PARAMETER_SERVER_SCRIPT ))
246253 hyperparameters = dict (sagemaker_program = 'user_script.py' )
@@ -254,8 +261,10 @@ def test_parameter_server():
254261 process .kill ()
255262
256263
257- @pytest .mark .parametrize ('user_script' , [USER_MODE_SCRIPT ])
258- def test_script_mode (user_script ):
264+ @pytest .mark .parametrize ('user_script, training_fn' , [
265+ [USER_MODE_SCRIPT , framework_training_with_script_mode_fn ],
266+ [USER_MODE_SCRIPT , framework_training_with_run_modules_fn ]])
267+ def test_script_mode (user_script , training_fn ):
259268 channel = test .Channel .create (name = 'training' )
260269
261270 features = [1 , 2 , 3 , 4 ]
@@ -269,7 +278,7 @@ def test_script_mode(user_script):
269278
270279 test .prepare (user_module = module , hyperparameters = hyperparameters , channels = [channel ])
271280
272- assert execute_an_wrap_exit (framework_training_with_script_mode_fn ) == trainer .SUCCESS_CODE
281+ assert execute_an_wrap_exit (training_fn ) == trainer .SUCCESS_CODE
273282
274283 model_path = os .path .join (env .model_dir , 'saved_model' )
275284
@@ -281,8 +290,10 @@ def test_script_mode(user_script):
281290 assert model .optimizer == 'SGD'
282291
283292
284- @pytest .mark .parametrize ('user_script' , [USER_MODE_SCRIPT ])
285- def test_script_mode_local_directory (user_script , tmpdir ):
293+ @pytest .mark .parametrize ('user_script, training_fn' , [
294+ [USER_MODE_SCRIPT , framework_training_with_script_mode_fn ],
295+ [USER_MODE_SCRIPT , framework_training_with_run_modules_fn ]])
296+ def test_script_mode_local_directory (user_script , training_fn , tmpdir ):
286297 channel = test .Channel .create (name = 'training' )
287298
288299 features = [1 , 2 , 3 , 4 ]
@@ -300,7 +311,7 @@ def test_script_mode_local_directory(user_script, tmpdir):
300311
301312 test .prepare (user_module = module , hyperparameters = hyperparameters , channels = [channel ], local = True )
302313
303- assert execute_an_wrap_exit (framework_training_with_script_mode_fn ) == trainer .SUCCESS_CODE
314+ assert execute_an_wrap_exit (training_fn ) == trainer .SUCCESS_CODE
304315
305316 model_path = os .path .join (env .model_dir , 'saved_model' )
306317
@@ -318,7 +329,10 @@ def test_script_mode_local_directory(user_script, tmpdir):
318329"""
319330
320331
321- def test_script_mode_client_error ():
332+ @pytest .mark .parametrize ('training_fn' , [
333+ framework_training_with_script_mode_fn ,
334+ framework_training_with_run_modules_fn ])
335+ def test_script_mode_client_error (training_fn ):
322336 channel = test .Channel .create (name = 'training' )
323337
324338 module = test .UserModule (test .File (name = 'user_script.py' , data = USER_MODE_SCRIPT_WITH_ERROR ))
@@ -328,13 +342,16 @@ def test_script_mode_client_error():
328342 test .prepare (user_module = module , hyperparameters = hyperparameters , channels = [channel ])
329343
330344 with pytest .raises (errors .ExecuteUserScriptError ) as e :
331- framework_training_with_script_mode_fn ()
345+ training_fn ()
332346
333347 message = str (e .value )
334348 assert 'ExecuteUserScriptError' in message
335349
336350
337- def test_script_mode_client_import_error ():
351+ @pytest .mark .parametrize ('training_fn' , [
352+ framework_training_with_script_mode_fn ,
353+ framework_training_with_run_modules_fn ])
354+ def test_script_mode_client_import_error (training_fn ):
338355 channel = test .Channel .create (name = 'training' )
339356
340357 requirements_file = test .File ('requirements.txt' , '42/0' )
@@ -347,7 +364,7 @@ def test_script_mode_client_import_error():
347364 test .prepare (user_module = module , hyperparameters = hyperparameters , channels = [channel ])
348365
349366 with pytest .raises (errors .InstallModuleError ) as e :
350- framework_training_with_script_mode_fn ()
367+ training_fn ()
351368
352369 message = str (e .value )
353370 assert 'InstallModuleError:' in message
0 commit comments