From 4365dea3d1a6111e8b2712af06b22a4512a0df88 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 15 Jan 2025 08:57:19 -0500 Subject: [PATCH] Fix auto save logic (#519) --- open_instruct/dpo_tune.py | 2 +- open_instruct/dpo_tune_cache.py | 2 +- open_instruct/finetune.py | 2 +- open_instruct/ppo_vllm_thread_ray.py | 2 +- open_instruct/ppo_vllm_thread_ray_gtrl.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/open_instruct/dpo_tune.py b/open_instruct/dpo_tune.py index f8d767ee7..a4bd874ad 100644 --- a/open_instruct/dpo_tune.py +++ b/open_instruct/dpo_tune.py @@ -1124,7 +1124,7 @@ def load_model(): args.try_auto_save_to_beaker and accelerator.is_main_process and len(beaker_config.beaker_dataset_id_urls) > 0 - and args.output_dir != "/output" + and args.output_dir.rstrip("/") != "/output" ): shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) diff --git a/open_instruct/dpo_tune_cache.py b/open_instruct/dpo_tune_cache.py index 4b92dd1a9..e95673bae 100644 --- a/open_instruct/dpo_tune_cache.py +++ b/open_instruct/dpo_tune_cache.py @@ -1144,7 +1144,7 @@ def load_model(): args.try_auto_save_to_beaker and accelerator.is_main_process and len(beaker_config.beaker_dataset_id_urls) > 0 - and args.output_dir != "/output" + and args.output_dir.rstrip("/") != "/output" ): shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 3e688b3f4..30ee740da 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -1038,7 +1038,7 @@ def main(args: FlatArguments): args.try_auto_save_to_beaker and accelerator.is_main_process and len(beaker_config.beaker_dataset_id_urls) > 0 - and args.output_dir != "/output" + and args.output_dir.rstrip("/") != "/output" ): shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) diff --git a/open_instruct/ppo_vllm_thread_ray.py b/open_instruct/ppo_vllm_thread_ray.py index 4e9ce0c00..985cc829e 100644 --- a/open_instruct/ppo_vllm_thread_ray.py +++ b/open_instruct/ppo_vllm_thread_ray.py @@ -1307,7 +1307,7 @@ def vllm_generate( args.try_auto_save_to_beaker and self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0 - and args.output_dir != "/output" + and args.output_dir.rstrip("/") != "/output" ): shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) print("finished training") diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index 6f5255abf..989b676cc 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -1384,7 +1384,7 @@ def vllm_generate( args.try_auto_save_to_beaker and self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0 - and args.output_dir != "/output" + and args.output_dir.rstrip("/") != "/output" ): shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) print("finished training")