diff --git a/train.py b/train.py index 402d907..ceb4112 100644 --- a/train.py +++ b/train.py @@ -299,9 +299,12 @@ def _train_with_truefoundry(config_base: Path = Path("examples/"), **kwargs): shutil.copy2(readme_path, os.path.join(model_dir, "README.md")) logger.info(f"Merged model has been saved to {model_dir}") if cfg.truefoundry_ml_enable_reporting is True and cfg.truefoundry_ml_log_merged_model is True: - *_, model_name = cfg.base_model.rsplit("/", 1) - model_name = "-".join(["finetuned", model_name, timestamp]) - model_name = sanitize_name(model_name) + if not cfg.truefoundry_ml_finetuned_model_name: + *_, model_name = cfg.base_model.rsplit("/", 1) + model_name = "-".join(["finetuned", model_name, timestamp]) + model_name = sanitize_name(model_name) + else: + model_name = cfg.truefoundry_ml_finetuned_model_name run = get_or_create_run( ml_repo=cfg.truefoundry_ml_repo, run_name=cfg.truefoundry_ml_run_name,