Skip to content

Commit 099eac9

Browse files
committed
Add an option for user to remove inputs and container artifacts when using local model trainer
1 parent 3c818cb commit 099eac9

File tree

3 files changed

+82
-9
lines changed

3 files changed

+82
-9
lines changed

src/sagemaker/modules/local_core/local_container.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ class _LocalContainer(BaseModel):
108108
container_entrypoint: Optional[List[str]]
109109
container_arguments: Optional[List[str]]
110110

111+
_temperary_folders: List[str] = []
112+
111113
def model_post_init(self, __context: Any):
112114
"""Post init method to perform custom validation and set default values."""
113115
self.hosts = [f"algo-{i}" for i in range(1, self.instance_count + 1)]
@@ -146,12 +148,15 @@ def model_post_init(self, __context: Any):
146148
def train(
147149
self,
148150
wait: bool,
151+
remove_inputs_and_container_artifacts: Optional[bool] = True,
149152
) -> str:
150153
"""Run a training job locally using docker-compose.
151154
152155
Args:
153156
wait (bool):
154157
Whether to wait the training output before exiting.
158+
remove_inputs_and_container_artifacts (Optional[bool]):
159+
Whether to remove inputs and container artifacts after training.
155160
"""
156161
# create output/data folder since sagemaker-containers 2.0 expects it
157162
os.makedirs(os.path.join(self.container_root, "output", "data"), exist_ok=True)
@@ -201,6 +206,13 @@ def train(
201206

202207
# Print our Job Complete line
203208
logger.info("Local training job completed, output artifacts saved to %s", artifacts)
209+
210+
if remove_inputs_and_container_artifacts:
211+
shutil.rmtree(os.path.join(self.container_root, "input"))
212+
for host in self.hosts:
213+
shutil.rmtree(os.path.join(self.container_root, host))
214+
for folder in self._temperary_folders:
215+
shutil.rmtree(os.path.join(self.container_root, folder))
204216
return artifacts
205217

206218
def retrieve_artifacts(
@@ -540,6 +552,7 @@ def _get_data_source_local_path(self, data_source: DataSource):
540552
uri = data_source.s3_data_source.s3_uri
541553
parsed_uri = urlparse(uri)
542554
local_dir = TemporaryDirectory(prefix=os.path.join(self.container_root + "/")).name
555+
self._temperary_folders.append(local_dir)
543556
download_folder(parsed_uri.netloc, parsed_uri.path, local_dir, self.sagemaker_session)
544557
return local_dir
545558
else:

src/sagemaker/modules/train/model_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ class ModelTrainer(BaseModel):
203203
local_container_root (Optional[str]):
204204
The local root directory to store artifacts from a training job launched in
205205
"LOCAL_CONTAINER" mode.
206+
remove_inputs_and_container_artifacts (Optional[bool]):
207+
Whether to remove inputs and container artifacts after training.
206208
"""
207209

208210
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
@@ -227,6 +229,7 @@ class ModelTrainer(BaseModel):
227229
hyperparameters: Optional[Dict[str, Any]] = {}
228230
tags: Optional[List[Tag]] = None
229231
local_container_root: Optional[str] = os.getcwd()
232+
remove_inputs_and_container_artifacts: Optional[bool] = True
230233

231234
# Created Artifacts
232235
_latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None)
@@ -646,7 +649,7 @@ def train(
646649
hyper_parameters=string_hyper_parameters,
647650
environment=self.environment,
648651
)
649-
local_container.train(wait)
652+
local_container.train(wait, self.remove_inputs_and_container_artifacts)
650653

651654
def create_input_data_channel(
652655
self, channel_name: str, data_source: DataSourceType, key_prefix: Optional[str] = None

tests/integ/sagemaker/modules/train/test_local_model_trainer.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,15 @@ def test_single_container_local_mode_local_data(modules_sagemaker_session):
9292
"compressed_artifacts",
9393
"artifacts",
9494
"model",
95-
"shared",
96-
"input",
9795
"output",
98-
"algo-1",
9996
]
10097

10198
for directory in directories:
10299
path = os.path.join(CWD, directory)
103100
delete_local_path(path)
104101

105102

106-
def test_single_container_local_mode_s3_data(modules_sagemaker_session):
103+
def test_single_container_local_mode_s3_data_remove_input(modules_sagemaker_session):
107104
with lock.lock(LOCK_PATH):
108105
try:
109106
# upload local data to s3
@@ -145,6 +142,70 @@ def test_single_container_local_mode_s3_data(modules_sagemaker_session):
145142
training_mode=Mode.LOCAL_CONTAINER,
146143
)
147144

145+
model_trainer.train()
146+
assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz"))
147+
finally:
148+
subprocess.run(["docker", "compose", "down", "-v"])
149+
150+
assert not os.path.exists(os.path.join(CWD, "shared"))
151+
assert not os.path.exists(os.path.join(CWD, "input"))
152+
assert not os.path.exists(os.path.join(CWD, "algo-1"))
153+
154+
directories = [
155+
"compressed_artifacts",
156+
"artifacts",
157+
"model",
158+
"output",
159+
]
160+
161+
for directory in directories:
162+
path = os.path.join(CWD, directory)
163+
delete_local_path(path)
164+
165+
166+
def test_single_container_local_mode_s3_data_not_remove_input(modules_sagemaker_session):
167+
with lock.lock(LOCK_PATH):
168+
try:
169+
# upload local data to s3
170+
session = modules_sagemaker_session
171+
bucket = session.default_bucket()
172+
session.upload_data(
173+
path=os.path.join(SOURCE_DIR, "data/train/"),
174+
bucket=bucket,
175+
key_prefix="data/train",
176+
)
177+
session.upload_data(
178+
path=os.path.join(SOURCE_DIR, "data/test/"),
179+
bucket=bucket,
180+
key_prefix="data/test",
181+
)
182+
183+
source_code = SourceCode(
184+
source_dir=SOURCE_DIR,
185+
entry_script="local_training_script.py",
186+
)
187+
188+
compute = Compute(
189+
instance_type="local_cpu",
190+
instance_count=1,
191+
)
192+
193+
# read input data from s3
194+
train_data = InputData(channel_name="train", data_source=f"s3://{bucket}/data/train/")
195+
196+
test_data = InputData(channel_name="test", data_source=f"s3://{bucket}/data/test/")
197+
198+
model_trainer = ModelTrainer(
199+
training_image=DEFAULT_CPU_IMAGE,
200+
sagemaker_session=modules_sagemaker_session,
201+
source_code=source_code,
202+
compute=compute,
203+
input_data_config=[train_data, test_data],
204+
base_job_name="local_mode_single_container_s3_data",
205+
training_mode=Mode.LOCAL_CONTAINER,
206+
remove_inputs_and_container_artifacts=False,
207+
)
208+
148209
model_trainer.train()
149210
assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz"))
150211
finally:
@@ -213,11 +274,7 @@ def test_multi_container_local_mode(modules_sagemaker_session):
213274
"compressed_artifacts",
214275
"artifacts",
215276
"model",
216-
"shared",
217-
"input",
218277
"output",
219-
"algo-1",
220-
"algo-2",
221278
]
222279

223280
for directory in directories:

0 commit comments

Comments
 (0)