Skip to content

Commit 2e10f28

Browse files
HarikrishnanBalagopalkmehant
authored andcommitted
fix: crash when output directory doesn't exist
Signed-off-by: Harikrishnan Balagopal <[email protected]> Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent c59da48 commit 2e10f28

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

tuning/sft_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,9 @@ def main():
636636
combined_tracker_configs.file_logger_config = file_logger_config
637637
combined_tracker_configs.aim_config = aim_config
638638

639+
if model_args.output_dir:
640+
os.makedirs(model_args.output_dir, exist_ok=True)
641+
logger.info("using the output directory at %s", model_args.output_dir)
639642
try:
640643
trainer, additional_train_info = train(
641644
model_args=model_args,

tuning/utils/preprocessing_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def get_data_collator(
167167
max_seq_length: int,
168168
tokens_field: str = True,
169169
instruction_template: Optional[str] = None,
170-
padding_free: str = None,
170+
padding_free: Any = None,
171171
) -> Callable:
172172
"""Create and return the the appropriate collator type based on the configuration for packing,
173173
response_template, and dataset_text_field.
@@ -187,7 +187,7 @@ def get_data_collator(
187187
feature having tokens
188188
instruction_template: Optional[str]
189189
start of user answer.
190-
padding_free: str
190+
padding_free: Any
191191
padding free method
192192
193193
Returns:
@@ -243,7 +243,7 @@ def get_data_collator(
243243
return DataCollatorForSeq2Seq(
244244
tokenizer=tokenizer, padding=True, max_length=max_seq_length
245245
)
246-
if padding_free:
246+
if padding_free is not None and padding_free.method:
247247
# when packing is false but padding_free is used and no response template is used
248248
# then its a pretrained scenario.
249249
return DataCollatorForSeq2Seq(

0 commit comments

Comments
 (0)