@@ -218,17 +218,31 @@ def test_import_module(reload, import_module, install, download_and_extract):
218218 reload .assert_called_with (import_module (_modules .DEFAULT_MODULE_NAME ))
219219
220220
221- def test_download_and_install_local_directory ():
221+ @patch ("sagemaker_containers._modules.exists" , return_value = False )
222+ @patch ("sagemaker_containers._files.tmpdir" )
223+ @patch ("sagemaker_containers._files.download_and_extract" )
224+ @patch ("sagemaker_containers._modules.prepare" )
225+ @patch ("sagemaker_containers._modules.install" )
226+ def test_download_and_install (install , prepare , download_and_extract , files_tmpdir , module_exists ):
227+ files_tmpdir .return_value .__enter__ .return_value = "tmp"
228+ uri = "s3://foo/bar"
229+ _modules .download_and_install (uri )
230+
231+ module_path = os .path .join ("tmp" , "module_dir" )
232+ download_and_extract .assert_called_with (uri , module_path )
233+ prepare .assert_called_with (module_path , "default_user_module_name" )
234+ install .assert_called_with (module_path )
235+
236+
237+ @patch ("sagemaker_containers._files.s3_download" )
238+ @patch ("tarfile.open" )
239+ @patch ("sagemaker_containers._modules.prepare" )
240+ @patch ("sagemaker_containers._modules.install" )
241+ def test_download_and_install_local_directory (install , prepare , tarfile , s3_download ):
222242 uri = "/opt/ml/input/data/code/sourcedir.tar.gz"
243+ _modules .download_and_install (uri )
223244
224- with patch ("sagemaker_containers._files.s3_download" ) as s3_download , patch (
225- "tarfile.open"
226- ) as tarfile , patch ("sagemaker_containers._modules.prepare" ) as prepare , patch (
227- "sagemaker_containers._modules.install"
228- ) as install :
229- _modules .download_and_install (uri )
230-
231- s3_download .assert_not_called ()
232- tarfile .assert_called_with (name = "/opt/ml/input/data/code/sourcedir.tar.gz" , mode = "r:gz" )
233- prepare .assert_called_once ()
234- install .assert_called_once ()
245+ s3_download .assert_not_called ()
246+ tarfile .assert_called_with (name = "/opt/ml/input/data/code/sourcedir.tar.gz" , mode = "r:gz" )
247+ prepare .assert_called_once ()
248+ install .assert_called_once ()
0 commit comments