From d275e2bcc802e9aa55f3f6933ee01cd3f54f108e Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 29 Aug 2024 09:09:34 -0700 Subject: [PATCH] DPO Data Mixing (#308) * init pr * update metadata * nit * typo * add dummy 1b config * obvious fix * add dummy config, fix --- .../dpo/olmo_1b_mix_example.yaml | 27 ++++++++++++++ open_instruct/dpo_tune.py | 35 ++++++++++++++++--- open_instruct/finetune.py | 5 ++- open_instruct/test_utils.py | 8 +++++ open_instruct/utils.py | 16 +++++---- 5 files changed, 77 insertions(+), 14 deletions(-) create mode 100644 configs/train_configs/dpo/olmo_1b_mix_example.yaml diff --git a/configs/train_configs/dpo/olmo_1b_mix_example.yaml b/configs/train_configs/dpo/olmo_1b_mix_example.yaml new file mode 100644 index 000000000..44305d43f --- /dev/null +++ b/configs/train_configs/dpo/olmo_1b_mix_example.yaml @@ -0,0 +1,27 @@ +model_name_or_path: allenai/OLMo-1B-0724-hf +use_flash_attn: false +gradient_checkpointing: true +tokenizer_name: allenai/OLMo-1B-0724-hf +dataset_mixer: + allenai/ultrafeedback_binarized_cleaned_train: 61814 + ai2-adapt-dev/DaringAnteater-prefs-RM-filter: 1618 + ai2-adapt-dev/hh-rlhf-helpful: 20000 + ai2-adapt-dev/webgpt-binarized: 14346 + ai2-adapt-dev/WildChat-prefs-280824: 11487 +max_seq_length: 2048 +gradient_checkpointing: true +preprocessing_num_workers: 16 +per_device_train_batch_size: 8 +gradient_accumulation_steps: 4 # designed for 1 GPUs, so batch size 32 +learning_rate: 5.0e-7 +dpo_beta: 0.1 +lr_scheduler_type: linear +warmup_ratio: 0.1 +weight_decay: 0.0 +num_train_epochs: 3 +output_dir: output/olmo_instruct_dpo/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +add_bos: true \ No newline at end of file diff --git a/open_instruct/dpo_tune.py b/open_instruct/dpo_tune.py index 310a69dfa..4998aebee 100644 --- a/open_instruct/dpo_tune.py +++ b/open_instruct/dpo_tune.py @@ -64,6 +64,7 @@ from open_instruct.utils import ( ArgumentParserPlus, clean_last_n_checkpoints, + get_datasets, get_last_checkpoint_path, get_wandb_tags, is_beaker_job, @@ -559,6 +560,24 @@ def main(args: FlatArguments): args.dataset_name, args.dataset_config_name, ) + elif args.dataset_mixer is not None: + # mixing datasets via config + raw_datasets = get_datasets( + args.dataset_mixer, + configs=args.dataset_config_name, + splits=["train"], + save_data_dir=args.dataset_mix_dir, + columns_to_keep=["chosen", "rejected"], + ) + elif args.dataset_mixer_list is not None: + # mixing datasets via config + raw_datasets = get_datasets( + args.dataset_mixer_list, + configs=args.dataset_config_name, + splits=["train"], + save_data_dir=args.dataset_mix_dir, + columns_to_keep=["messages"], + ) else: data_files = {} dataset_args = {} @@ -1039,24 +1058,30 @@ def load_model(): hf_repo_revision=args.hf_repo_revision, ) if args.hf_metadata_dataset and accelerator.is_main_process and is_beaker_job(): - # dpo script only supports these two options right now for datasets - dataset_name = args.dataset_name if args.dataset_name else args.train_file + if args.dataset_mixer: + dataset_list = args.dataset_mixer.keys() + elif args.dataset_mixer_list: + dataset_list = args.dataset_mixer_list[::2] # even indices + elif args.dataset_name: + dataset_list = [args.dataset_name] + else: + dataset_list = [args.train_file] # mainly just focussing here on what would be useful for the leaderboard. # wandb will have even more useful information. metadata_blob = { "model_name": args.exp_name, "model_type": "sft", - "datasets": [dataset_name], + "datasets": dataset_list, "base_model": args.model_name_or_path, "wandb_path": wandb_tracker.run.get_url(), "beaker_experiment": beaker_config.beaker_experiment_url, - "beaker_datasets": beaker_config.beaker_dataset_id_urls + "beaker_datasets": beaker_config.beaker_dataset_id_urls, } upload_metadata_to_hf( metadata_blob, "metadata.json", args.hf_metadata_dataset, - 'results/' + args.hf_repo_revision, # to match what the auto-evals name as. + "results/" + args.hf_repo_revision, # to match what the auto-evals name as. ) accelerator.wait_for_everyone() diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 4d2ce10bb..f8aeafbcb 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -1018,7 +1018,6 @@ def main(args: FlatArguments): hf_repo_revision=args.hf_repo_revision, ) if args.hf_metadata_dataset and accelerator.is_main_process and is_beaker_job(): - # dpo script only supports these two options right now for datasets if args.dataset_mixer: dataset_list = args.dataset_mixer.keys() elif args.dataset_mixer_list: @@ -1036,13 +1035,13 @@ def main(args: FlatArguments): "base_model": args.model_name_or_path, "wandb_path": wandb_tracker.run.get_url(), "beaker_experiment": beaker_config.beaker_experiment_url, - "beaker_datasets": beaker_config.beaker_dataset_id_urls + "beaker_datasets": beaker_config.beaker_dataset_id_urls, } upload_metadata_to_hf( metadata_blob, "metadata.json", args.hf_metadata_dataset, - 'results/' + args.hf_repo_revision, # to match what the auto-evals name as. + "results/" + args.hf_repo_revision, # to match what the auto-evals name as. ) accelerator.wait_for_everyone() diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index 3e1f14f0a..49f7a3996 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -68,6 +68,14 @@ def test_loading_single_split_with_unit_fractions(self): self.assertEqual(len(datasets["test"]), 100) self.assertRaises(KeyError, lambda: datasets["train"]) + def test_loading_preference_data(self): + dataset_mixer = { + "ai2-adapt-dev/ultrafeedback-small": 1000, + "ai2-adapt-dev/summarize_from_feedback_small": 1000, + } + pref_datasets = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["chosen", "rejected"]) + self.assertEqual(len(pref_datasets["train"]), 2000) + # useful for checking if public datasets are still available # class CheckTuluDatasetsTest(unittest.TestCase): diff --git a/open_instruct/utils.py b/open_instruct/utils.py index ac1bd598d..4130c5835 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -569,7 +569,7 @@ def clean_last_n_checkpoints(output_dir: str, keep_last_n_checkpoints: int) -> N # find the checkpoint with the largest step checkpoints = sorted(folders, key=lambda x: int(x.split("_")[-1])) if len(checkpoints) > keep_last_n_checkpoints: - for checkpoint in checkpoints[:len(checkpoints) - keep_last_n_checkpoints]: + for checkpoint in checkpoints[: len(checkpoints) - keep_last_n_checkpoints]: shutil.rmtree(os.path.join(output_dir, checkpoint)) @@ -625,7 +625,11 @@ def get_beaker_whoami() -> Optional[str]: def maybe_get_beaker_config(): beaker_dataset_ids = get_beaker_dataset_ids(os.environ["BEAKER_WORKLOAD_ID"]) - beaker_dataset_id_urls = [f"https://beaker.org/ds/{dataset_id}" for dataset_id in beaker_dataset_ids] + # fix condition on basic interactive jobs + if beaker_dataset_ids is None: + beaker_dataset_id_urls = [] + else: + beaker_dataset_id_urls = [f"https://beaker.org/ds/{dataset_id}" for dataset_id in beaker_dataset_ids] return BeakerRuntimeConfig( beaker_workload_id=os.environ["BEAKER_WORKLOAD_ID"], beaker_node_hostname=os.environ["BEAKER_NODE_HOSTNAME"], @@ -691,10 +695,10 @@ def submit_beaker_eval_jobs( def upload_metadata_to_hf( - metadata_dict, - filename, - hf_dataset_name, - hf_dataset_save_dir, + metadata_dict, + filename, + hf_dataset_name, + hf_dataset_save_dir, ): # upload a random dict to HF. Originally for uploading metadata to HF # about a model for leaderboard displays.