Skip to content

Commit

Permalink
DPO Data Mixing (#308)
Browse files Browse the repository at this point in the history
* init pr

* update metadata

* nit

* typo

* add dummy 1b config

* obvious fix

* add dummy config, fix
  • Loading branch information
natolambert authored Aug 29, 2024
1 parent db30f2a commit d275e2b
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 14 deletions.
27 changes: 27 additions & 0 deletions configs/train_configs/dpo/olmo_1b_mix_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
model_name_or_path: allenai/OLMo-1B-0724-hf
use_flash_attn: false
gradient_checkpointing: true
tokenizer_name: allenai/OLMo-1B-0724-hf
dataset_mixer:
allenai/ultrafeedback_binarized_cleaned_train: 61814
ai2-adapt-dev/DaringAnteater-prefs-RM-filter: 1618
ai2-adapt-dev/hh-rlhf-helpful: 20000
ai2-adapt-dev/webgpt-binarized: 14346
ai2-adapt-dev/WildChat-prefs-280824: 11487
max_seq_length: 2048
gradient_checkpointing: true
preprocessing_num_workers: 16
per_device_train_batch_size: 8
gradient_accumulation_steps: 4 # designed for 1 GPUs, so batch size 32
learning_rate: 5.0e-7
dpo_beta: 0.1
lr_scheduler_type: linear
warmup_ratio: 0.1
weight_decay: 0.0
num_train_epochs: 3
output_dir: output/olmo_instruct_dpo/
with_tracking: true
report_to:
- wandb
logging_steps: 1
add_bos: true
35 changes: 30 additions & 5 deletions open_instruct/dpo_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from open_instruct.utils import (
ArgumentParserPlus,
clean_last_n_checkpoints,
get_datasets,
get_last_checkpoint_path,
get_wandb_tags,
is_beaker_job,
Expand Down Expand Up @@ -559,6 +560,24 @@ def main(args: FlatArguments):
args.dataset_name,
args.dataset_config_name,
)
elif args.dataset_mixer is not None:
# mixing datasets via config
raw_datasets = get_datasets(
args.dataset_mixer,
configs=args.dataset_config_name,
splits=["train"],
save_data_dir=args.dataset_mix_dir,
columns_to_keep=["chosen", "rejected"],
)
elif args.dataset_mixer_list is not None:
# mixing datasets via config
raw_datasets = get_datasets(
args.dataset_mixer_list,
configs=args.dataset_config_name,
splits=["train"],
save_data_dir=args.dataset_mix_dir,
columns_to_keep=["messages"],
)
else:
data_files = {}
dataset_args = {}
Expand Down Expand Up @@ -1039,24 +1058,30 @@ def load_model():
hf_repo_revision=args.hf_repo_revision,
)
if args.hf_metadata_dataset and accelerator.is_main_process and is_beaker_job():
# dpo script only supports these two options right now for datasets
dataset_name = args.dataset_name if args.dataset_name else args.train_file
if args.dataset_mixer:
dataset_list = args.dataset_mixer.keys()
elif args.dataset_mixer_list:
dataset_list = args.dataset_mixer_list[::2] # even indices
elif args.dataset_name:
dataset_list = [args.dataset_name]
else:
dataset_list = [args.train_file]
# mainly just focussing here on what would be useful for the leaderboard.
# wandb will have even more useful information.
metadata_blob = {
"model_name": args.exp_name,
"model_type": "sft",
"datasets": [dataset_name],
"datasets": dataset_list,
"base_model": args.model_name_or_path,
"wandb_path": wandb_tracker.run.get_url(),
"beaker_experiment": beaker_config.beaker_experiment_url,
"beaker_datasets": beaker_config.beaker_dataset_id_urls
"beaker_datasets": beaker_config.beaker_dataset_id_urls,
}
upload_metadata_to_hf(
metadata_blob,
"metadata.json",
args.hf_metadata_dataset,
'results/' + args.hf_repo_revision, # to match what the auto-evals name as.
"results/" + args.hf_repo_revision, # to match what the auto-evals name as.
)

accelerator.wait_for_everyone()
Expand Down
5 changes: 2 additions & 3 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,6 @@ def main(args: FlatArguments):
hf_repo_revision=args.hf_repo_revision,
)
if args.hf_metadata_dataset and accelerator.is_main_process and is_beaker_job():
# dpo script only supports these two options right now for datasets
if args.dataset_mixer:
dataset_list = args.dataset_mixer.keys()
elif args.dataset_mixer_list:
Expand All @@ -1036,13 +1035,13 @@ def main(args: FlatArguments):
"base_model": args.model_name_or_path,
"wandb_path": wandb_tracker.run.get_url(),
"beaker_experiment": beaker_config.beaker_experiment_url,
"beaker_datasets": beaker_config.beaker_dataset_id_urls
"beaker_datasets": beaker_config.beaker_dataset_id_urls,
}
upload_metadata_to_hf(
metadata_blob,
"metadata.json",
args.hf_metadata_dataset,
'results/' + args.hf_repo_revision, # to match what the auto-evals name as.
"results/" + args.hf_repo_revision, # to match what the auto-evals name as.
)

accelerator.wait_for_everyone()
Expand Down
8 changes: 8 additions & 0 deletions open_instruct/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def test_loading_single_split_with_unit_fractions(self):
self.assertEqual(len(datasets["test"]), 100)
self.assertRaises(KeyError, lambda: datasets["train"])

def test_loading_preference_data(self):
dataset_mixer = {
"ai2-adapt-dev/ultrafeedback-small": 1000,
"ai2-adapt-dev/summarize_from_feedback_small": 1000,
}
pref_datasets = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["chosen", "rejected"])
self.assertEqual(len(pref_datasets["train"]), 2000)


# useful for checking if public datasets are still available
# class CheckTuluDatasetsTest(unittest.TestCase):
Expand Down
16 changes: 10 additions & 6 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def clean_last_n_checkpoints(output_dir: str, keep_last_n_checkpoints: int) -> N
# find the checkpoint with the largest step
checkpoints = sorted(folders, key=lambda x: int(x.split("_")[-1]))
if len(checkpoints) > keep_last_n_checkpoints:
for checkpoint in checkpoints[:len(checkpoints) - keep_last_n_checkpoints]:
for checkpoint in checkpoints[: len(checkpoints) - keep_last_n_checkpoints]:
shutil.rmtree(os.path.join(output_dir, checkpoint))


Expand Down Expand Up @@ -625,7 +625,11 @@ def get_beaker_whoami() -> Optional[str]:

def maybe_get_beaker_config():
beaker_dataset_ids = get_beaker_dataset_ids(os.environ["BEAKER_WORKLOAD_ID"])
beaker_dataset_id_urls = [f"https://beaker.org/ds/{dataset_id}" for dataset_id in beaker_dataset_ids]
# fix condition on basic interactive jobs
if beaker_dataset_ids is None:
beaker_dataset_id_urls = []
else:
beaker_dataset_id_urls = [f"https://beaker.org/ds/{dataset_id}" for dataset_id in beaker_dataset_ids]
return BeakerRuntimeConfig(
beaker_workload_id=os.environ["BEAKER_WORKLOAD_ID"],
beaker_node_hostname=os.environ["BEAKER_NODE_HOSTNAME"],
Expand Down Expand Up @@ -691,10 +695,10 @@ def submit_beaker_eval_jobs(


def upload_metadata_to_hf(
metadata_dict,
filename,
hf_dataset_name,
hf_dataset_save_dir,
metadata_dict,
filename,
hf_dataset_name,
hf_dataset_save_dir,
):
# upload a random dict to HF. Originally for uploading metadata to HF
# about a model for leaderboard displays.
Expand Down

0 comments on commit d275e2b

Please sign in to comment.