From a832163c39c920fe56060e756a068b2b356e2cea Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 16 Jan 2025 09:31:03 -0800 Subject: [PATCH 01/32] Add dataset cache / mixing support --- open_instruct/dataset_transformation.py | 618 ++++++++++++++++++++++++ 1 file changed, 618 insertions(+) create mode 100644 open_instruct/dataset_transformation.py diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py new file mode 100644 index 000000000..e12cb68f4 --- /dev/null +++ b/open_instruct/dataset_transformation.py @@ -0,0 +1,618 @@ +# this file deals with dataset pre-processing before training + +# 1. PPO (prompt) +# 2. SFT (prompt + demonstration), there is also packing. +# 3. ✅ RM / DPO (chosen and rejected) +# 4. ✅ Visualization of length distributions? +# 5. ✅ Filter? +# * Smart truncation? +# 6. ✅ dataset_num_proc +# 7. ✅ check EOS token +# 8. dataset mixer? +# 9. ✅ pretty print that show tokenization? +# 10. ✅ hashable tokneization? +# 11. inputs / labels / attention_mask +# 12. ✅ always set a `tokenizer.pad_token_id`? +# 13. a new DataCollatorForLanguageModeling? +# 14. ✅ `add_bos_token` and `add_eos_token`? E.g., LLAMA models +# 15. ✅ generate properties: has eos_token, bos_token (through chat template) + +# ✅ get tokenizer revision +# ✅ get dataset revision +# create a cached tokenized dataset, with tokenized revision, dataset revision, tokenization function name. + +# too many names related to "maximum length": +# * `max_seq_length` in SFT +# * `max_length`, `max_target_length` in RM / DPO, +# * `max_prompt_length` in DPO +""" +This file contains the utility to transform and cache datasets with different configurations. +The main things we are looking for are: +* handle dataset mixing +* handle different tokenization functions +* **cache** the tokenized dataset so we don't have to re-tokenize every time + * This is especially important when we have 405B SFT models: 32 nodes are just spending like + 5 minutes to tokenize the dataset. This translates to 32 * 5 * 8 = 1280 minutes = 21 hours of + wasted H100 time. + * Sometimes we also launch on places that don't have a shared cache (e.g., GCP), so we would + download individual datasets 32 times, and wait for concatenation and tokenization (actually + twice because the `with accelerator.main_process_first()` function assumes a shared cache) +""" + +import copy +import hashlib +import json +from dataclasses import dataclass, field, asdict +import multiprocessing +import os +from typing import Any, Dict, List, Optional +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + GPTNeoXTokenizerFast, + LlamaTokenizer, + LlamaTokenizerFast, +) +import transformers +from transformers.utils.hub import cached_file, extract_commit_hash +from datasets import Dataset, load_dataset, concatenate_datasets +from huggingface_hub import HfApi, revision_exists + +from open_instruct.dataset_processor import CHAT_TEMPLATES + + +# ---------------------------------------------------------------------------- +# Utilities +def get_commit_hash(model_name_or_path: str, revision: str, filename: str = "config.json", repo_type: str = "model"): + file = cached_file(model_name_or_path, revision=revision, filename=filename, repo_type=repo_type) + commit_hash = extract_commit_hash(file, None) + return commit_hash + +# Performance tuning. Some rough numbers: +APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU = 400 +FILTER_EXAMPLE_PER_SECOND_PER_CPU = 1130 + +def get_num_proc(dataset_len: int, num_available_cpus: int, example_per_second_per_cpu) -> int: + num_required_cpus = max(1, dataset_len // example_per_second_per_cpu) + return min(num_required_cpus, num_available_cpus) + + +# ---------------------------------------------------------------------------- +# Tokenization +@dataclass +class TokenizerConfig: + model_name_or_path: str + revision: str + trust_remote_code: bool = True + use_fast: bool = True + chat_template_name: Optional[str] = None + add_bos: bool = False + get_tokenizer_fn: str = "get_tokenizer_v1" + + # for tracking purposes + tokenizer_commit_hash: Optional[str] = None + + def __post_init__(self): + self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision) + + +def get_tokenizer_v1(tc: TokenizerConfig): + tokenizer = AutoTokenizer.from_pretrained( + tc.model_name_or_path, + revision=tc.revision, + trust_remote_code=tc.trust_remote_code, + use_fast=tc.use_fast, + ) + # no default pad token for llama! + # here we add all special tokens again, because the default ones are not in the special_tokens_map + if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast): + num_added_tokens = tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + "pad_token": "", + } + ) + assert num_added_tokens in [ + 0, + 1, + ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." + elif isinstance(tokenizer, GPTNeoXTokenizerFast): + # OLMo newer models use this tokenizer + if tokenizer.bos_token is None: + tokenizer.bos_token = tokenizer.eos_token + assert ( + tc.add_bos + ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." + # else, pythia / other models + else: + num_added_tokens = tokenizer.add_special_tokens( + { + "pad_token": "", + } + ) + assert ( + num_added_tokens <= 1 + ), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens if already set in SFT)." + # NOTE: (Costa) I just commented the `OPTForCausalLM` because we are not likely to use it. + # elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): + # num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""}) + elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: + num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) + assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." + + # set the tokenizer chat template to the training format + # this will be used for encoding the training examples + # and saved together with the tokenizer to be used later. + if tc.chat_template_name in CHAT_TEMPLATES: + tokenizer.chat_template = CHAT_TEMPLATES[tc.chat_template_name] + else: + try: + tokenizer.chat_template = AutoTokenizer.from_pretrained(tc.chat_template_name).chat_template + except Exception: + raise ValueError(f"Could not find chat template for {tc.chat_template_name}.") + + if tc.add_bos: + if tokenizer.chat_template.startswith("{{ bos_token }}") or ( + tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token) + ): + raise ValueError( + "You specified add_bos=True, but the chat template already has a bos_token at the beginning." + ) + # also add bos in the chat template if not already there + tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template + + + # TODO: test it out: PPO should have the same tokenizer as SFT / DPO. + # # create a tokenizer (pad from right) + # config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) + # tokenizer = AutoTokenizer.from_pretrained( + # model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" + # ) + # if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: + # tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + # else: + # tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding + # if dataset_config.chat_template is not None: + # tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + return tokenizer + +GET_TOKENIZER_FNS = { + "get_tokenizer_v1": get_tokenizer_v1, +} + +def get_tokenizer(tc: 'TokenizerConfig'): + return GET_TOKENIZER_FNS[tc.get_tokenizer_fn](tc) + + +# TODO: for testing, we should load the tokenizer from the sft / dpo / rl and make sure they are all the same. + + +# ---------------------------------------------------------------------------- +# Dataset Transformation +COLORS = ["on red", "on green", "on blue", "on yellow", "on magenta"] +# Preference dataset +INPUT_IDS_CHOSEN_KEY = "input_ids_chosen" +ATTENTION_MASK_CHOSEN_KEY = "attention_mask_chosen" +INPUT_IDS_REJECTED_KEY = "input_ids_rejected" +ATTENTION_MASK_REJECTED_KEY = "attention_mask_rejected" +INPUT_IDS_PROMPT_KEY = "input_ids_prompt" +ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt" +GROUND_TRUTHS_KEY = "ground_truth" +DATASET_SOURCE_KEY = "dataset" + +# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only +# also we don't really need `ATTENTION_MASK_CHOSEN_KEY` and `ATTENTION_MASK_REJECTED_KEY` +# since we are always padding from the right with a collator; however they might become +# more useful if we want to do some sort of packing in the future. The nice thing is +# that the tokenization logic would work for both DPO and RM training. +TOKENIZED_PREFERENCE_DATASET_KEYS = [ + INPUT_IDS_CHOSEN_KEY, + INPUT_IDS_REJECTED_KEY, + # ATTENTION_MASK_CHOSEN_KEY, + # ATTENTION_MASK_REJECTED_KEY, + # INPUT_IDS_PROMPT_KEY, + # ATTENTION_MASK_PROMPT_KEY, +] + +# SFT dataset +SFT_MESSAGE_KEY = "messages" +INPUT_IDS_KEY = "input_ids" +ATTENTION_MASK_KEY = "attention_mask" +LABELS_KEY = "labels" + +# Binary dataset +BINARY_LABEL_KEY = "binary_labels" +BINARY_DATASET_KEYS = [ + INPUT_IDS_KEY, + LABELS_KEY, + BINARY_LABEL_KEY, +] + + +def sft_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): + if len(row[SFT_MESSAGE_KEY]) == 1: + prompt = row[SFT_MESSAGE_KEY] + else: + prompt = row[SFT_MESSAGE_KEY][:-1] + + row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + ) + row[INPUT_IDS_KEY] = tokenizer.apply_chat_template(row[SFT_MESSAGE_KEY]) + row[ATTENTION_MASK_KEY] = [1] * len(row[INPUT_IDS_KEY]) + labels = copy.deepcopy(row[INPUT_IDS_KEY]) + row[LABELS_KEY] = labels + return row + + +def sft_tokenize_mask_out_prompt_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): + """mask out the prompt tokens by manipulating labels""" + if len(row[SFT_MESSAGE_KEY]) == 1: + prompt = row[SFT_MESSAGE_KEY] + else: + prompt = row[SFT_MESSAGE_KEY][:-1] + + row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + ) + row[INPUT_IDS_KEY] = tokenizer.apply_chat_template(row[SFT_MESSAGE_KEY]) + row[ATTENTION_MASK_KEY] = [1] * len(row[INPUT_IDS_KEY]) + labels = copy.deepcopy(row[INPUT_IDS_KEY]) + labels[: len(row[INPUT_IDS_PROMPT_KEY])] = [-100] * len(row[INPUT_IDS_PROMPT_KEY]) + row[LABELS_KEY] = labels + return row + + +def sft_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_prompt_token_length: Optional[int] = None, max_token_length: Optional[int] = None, need_contain_labels: bool = True): + max_prompt_token_length_ok = True + if max_prompt_token_length is not None: + max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= max_prompt_token_length + + max_token_length_ok = True + if max_token_length is not None: + max_token_length_ok = len(row[INPUT_IDS_KEY]) <= max_token_length + + contain_some_labels = any(x != -100 for x in row[LABELS_KEY]) + return ( + max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) + ) + +def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): + # Extract prompt (all messages except the last one) + prompt = row["chosen"][:-1] + + # Tokenize prompt + row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + ) + row[ATTENTION_MASK_PROMPT_KEY] = [1] * len(row[INPUT_IDS_PROMPT_KEY]) + + # Tokenize chosen completion + row[INPUT_IDS_CHOSEN_KEY] = tokenizer.apply_chat_template(row["chosen"]) + row[ATTENTION_MASK_CHOSEN_KEY] = [1] * len(row[INPUT_IDS_CHOSEN_KEY]) + + # Tokenize rejected completion + row[INPUT_IDS_REJECTED_KEY] = tokenizer.apply_chat_template(row["rejected"]) + row[ATTENTION_MASK_REJECTED_KEY] = [1] * len(row[INPUT_IDS_REJECTED_KEY]) + + return row + +def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_prompt_token_length: Optional[int] = None, max_token_length: Optional[int] = None): + # Check prompt length if specified + if max_prompt_token_length is not None: + if len(row[INPUT_IDS_PROMPT_KEY]) > max_prompt_token_length: + return False + + # Check total sequence lengths if specified + if max_token_length is not None: + if len(row[INPUT_IDS_CHOSEN_KEY]) > max_token_length: + return False + if len(row[INPUT_IDS_REJECTED_KEY]) > max_token_length: + return False + + return True + +TRANSFORM_FNS = { + "sft_tokenize_v1": (sft_tokenize_v1, "map"), + "sft_tokenize_mask_out_prompt_v1": (sft_tokenize_mask_out_prompt_v1, "map"), + "sft_filter_v1": (sft_filter_v1, "filter"), + "preference_tokenize_v1": (preference_tokenize_v1, "map"), + "preference_filter_v1": (preference_filter_v1, "filter"), +} + +# ---------------------------------------------------------------------------- +# Dataset Configuration and Caching +@dataclass +class DatasetConfig: + dataset_name: str + dataset_split: str + dataset_revision: str + transform_fn: List[str] = field(default_factory=list) + transform_fn_args: Dict[str, Dict[str, Any]] = field(default_factory=dict) + get_dataset_fn: str = "get_dataset_v1" + + # for tracking purposes + dataset_commit_hash: Optional[str] = None + + def __post_init__(self): + self.dataset_commit_hash = get_commit_hash(self.dataset_name, self.dataset_revision, "README.md", "dataset") + +def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): + # beaker specific logic; we may get assigned 15.5 CPU, so we convert it to float then int + num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count()))) + + tokenizer = get_tokenizer(tc) + dataset = load_dataset( + dc.dataset_name, + split=dc.dataset_split, + revision=dc.dataset_revision, + ) + + for fn_name in dc.transform_fn: + fn, fn_type = TRANSFORM_FNS[fn_name] + # always pass in tokenizer and other args if needed + fn_kwargs = {"tokenizer": tokenizer} + if fn_name in dc.transform_fn_args: + fn_kwargs.update(dc.transform_fn_args[fn_name]) + + # perform the transformation + if fn_type == "map": + dataset = dataset.map( + fn, + fn_kwargs=fn_kwargs, + num_proc=get_num_proc(len(dataset), num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU), + ) + elif fn_type == "filter": + dataset = dataset.filter( + fn, + fn_kwargs=fn_kwargs, + num_proc=get_num_proc(len(dataset), num_proc, FILTER_EXAMPLE_PER_SECOND_PER_CPU), + ) + # NOTE: elif we can implement packing here to create a packed SFT dataset. Low priority for now. + else: + raise ValueError(f"Unknown transform function type: {fn_type}") + + return dataset + +class DatasetTransformationCache: + def __init__(self, hf_entity: Optional[str] = None): + self.hf_entity = hf_entity or HfApi().whoami()["name"] + + def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> str: + """Compute a deterministic hash of both configs for caching.""" + dc_dicts = [ + {k: v for k, v in asdict(dc).items() if v is not None} + for dc in dcs + ] + tc_dict = {k: v for k, v in asdict(tc).items() if v is not None} + combined_dict = { + "dataset_configs": dc_dicts, + "tokenizer_config": tc_dict + } + config_str = json.dumps(combined_dict, sort_keys=True) + return hashlib.sha256(config_str.encode()).hexdigest()[:10] + + def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> Dataset: + """Load dataset from cache if it exists, otherwise transform and cache it.""" + config_hash = self.compute_config_hash(dcs, tc) + repo_name = f"{self.hf_entity}/dataset-mix-cached" + + # Check if the revision exists + if revision_exists(repo_name, config_hash, repo_type="dataset"): + print(f"Found cached dataset at {repo_name}@{config_hash}") + # Use the split from the first dataset config as default + return load_dataset( + repo_name, + split=dcs[0].dataset_split, + revision=config_hash + ) + + print(f"Cache not found, transforming datasets...") + + # Transform each dataset + transformed_datasets = [] + for dc in dcs: + dataset = get_dataset_v1(dc, tc) + # Add id column if not present + if "id" not in dataset.column_names: + base_name = dc.dataset_name.split("/")[-1] + id_col = [f"{base_name}-{i}" for i in range(len(dataset))] + dataset = dataset.add_column("id", id_col) + transformed_datasets.append(dataset) + + # Combine datasets + combined_dataset = concatenate_datasets(transformed_datasets) + + # Push to hub with config hash as revision + combined_dataset.push_to_hub( + repo_name, + private=True, + revision=config_hash, + commit_message=f"Cache combined dataset with configs hash: {config_hash}" + ) + print(f"Pushed transformed dataset to {repo_name}@{config_hash}") + + # NOTE: Load the dataset again to make sure it's downloaded to the HF cache + print(f"Found cached dataset at {repo_name}@{config_hash}") + return load_dataset( + repo_name, + split=dc.dataset_split, + revision=config_hash + ) + + +def get_cached_dataset(dcs: List[DatasetConfig], tc: TokenizerConfig, hf_entity: Optional[str] = None) -> Dataset: + """Get transformed and cached dataset from multiple dataset configs.""" + cache = DatasetTransformationCache(hf_entity=hf_entity) + return cache.load_or_transform_dataset(dcs, tc) + +def test_config_hash_different(): + """Test that different configurations produce different hashes.""" + tc = TokenizerConfig( + model_name_or_path="meta-llama/Llama-3.1-8B", + revision="main", + chat_template_name="tulu" + ) + + dcs1 = [ + DatasetConfig( + dataset_name="allenai/tulu-3-sft-personas-algebra", + dataset_split="train", + dataset_revision="main", + transform_fn=["sft_tokenize_v1"], + transform_fn_args={} + ) + ] + + dcs2 = [ + DatasetConfig( + dataset_name="allenai/tulu-3-sft-personas-algebra", + dataset_split="train", + dataset_revision="main", + transform_fn=["sft_tokenize_mask_out_prompt_v1"], + transform_fn_args={} + ) + ] + + cache = DatasetTransformationCache() + hash1 = cache.compute_config_hash(dcs1, tc) + hash2 = cache.compute_config_hash(dcs2, tc) + assert hash1 != hash2, "Different configs should have different hashes" + +def test_sft_dataset_caching(): + """Test caching functionality for SFT datasets.""" + tc = TokenizerConfig( + model_name_or_path="meta-llama/Llama-3.1-8B", + revision="main", + chat_template_name="tulu" + ) + + dcs = [ + DatasetConfig( + dataset_name="allenai/tulu-3-sft-personas-algebra", + dataset_split="train", + dataset_revision="main", + transform_fn=["sft_tokenize_v1"], + transform_fn_args={} + ), + DatasetConfig( + dataset_name="allenai/tulu-3-hard-coded-10x", + dataset_split="train", + dataset_revision="main", + transform_fn=["sft_tokenize_v1"], + transform_fn_args={} + ) + ] + + # First transformation should cache + dataset1 = get_cached_dataset(dcs, tc) + + # Second load should use cache + dataset1_cached = get_cached_dataset(dcs, tc) + + # Verify the datasets are the same + assert len(dataset1) == len(dataset1_cached), "Cached dataset should have same length" + +def test_sft_different_transform(): + """Test different transform functions produce different cached datasets.""" + tc = TokenizerConfig( + model_name_or_path="meta-llama/Llama-3.1-8B", + revision="main", + chat_template_name="tulu" + ) + + dcs = [ + DatasetConfig( + dataset_name="allenai/tulu-3-sft-personas-algebra", + dataset_split="train", + dataset_revision="main", + transform_fn=["sft_tokenize_mask_out_prompt_v1"], + transform_fn_args={} + ), + DatasetConfig( + dataset_name="allenai/tulu-3-hard-coded-10x", + dataset_split="train", + dataset_revision="main", + transform_fn=["sft_tokenize_mask_out_prompt_v1"], + transform_fn_args={} + ) + ] + + dataset = get_cached_dataset(dcs, tc) + assert dataset is not None, "Should successfully create dataset with different transform" + + +def test_sft_filter(): + """Test different transform functions produce different cached datasets.""" + tc = TokenizerConfig( + model_name_or_path="meta-llama/Llama-3.1-8B", + revision="main", + chat_template_name="tulu" + ) + + ARBITRARY_MAX_LENGTH = 1000 + dcs = [ + DatasetConfig( + dataset_name="allenai/tulu-3-sft-personas-algebra", + dataset_split="train", + dataset_revision="main", + transform_fn=["sft_tokenize_v1", "sft_filter_v1"], # First tokenize, then filter + transform_fn_args={ + "sft_filter_v1": { + "max_token_length": ARBITRARY_MAX_LENGTH # Filter to sequences <= ARBITRARY_MAX_LENGTH tokens + } + } + ) + ] + + filtered_dataset = get_cached_dataset(dcs, tc) + # Verify that all sequences are <= ARBITRARY_MAX_LENGTH tokens + max_length = max(len(example[INPUT_IDS_KEY]) for example in filtered_dataset) + assert max_length <= ARBITRARY_MAX_LENGTH, f"Found sequence with length {max_length} > {ARBITRARY_MAX_LENGTH}" + + print("Filter test passed! Max sequence length:", max_length) + print("All tests passed!") + assert filtered_dataset is not None, "Should successfully create dataset with different transform" + + +def test_preference_dataset(): + """Test caching functionality for preference datasets.""" + tc = TokenizerConfig( + model_name_or_path="meta-llama/Llama-3.1-8B", + revision="main", + chat_template_name="tulu" + ) + + dcs_pref = [ + DatasetConfig( + dataset_name="allenai/tulu-3-pref-personas-instruction-following", + dataset_split="train", + dataset_revision="main", + transform_fn=["preference_tokenize_v1"], + transform_fn_args={} + ), + DatasetConfig( + dataset_name="allenai/tulu-3-wildchat-reused-on-policy-70b", + dataset_split="train", + dataset_revision="main", + transform_fn=["preference_tokenize_v1"], + transform_fn_args={} + ) + ] + + dataset_pref = get_cached_dataset(dcs_pref, tc) + assert dataset_pref is not None, "Should successfully create preference dataset" + + +if __name__ == "__main__": + test_config_hash_different() + test_sft_dataset_caching() + test_sft_different_transform() + test_preference_dataset() + test_sft_filter() + print("All tests passed!") \ No newline at end of file From e92da057146d192fb5185a513ddb12e7565a4aec Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 17 Jan 2025 09:06:35 -0800 Subject: [PATCH 02/32] prototyping --- open_instruct/dataset_transformation.py | 351 ++++++++++++++++-------- open_instruct/finetune.py | 232 +--------------- 2 files changed, 247 insertions(+), 336 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index e12cb68f4..fe7fadb07 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -46,6 +46,7 @@ import multiprocessing import os from typing import Any, Dict, List, Optional +import torch from transformers import ( AutoTokenizer, PreTrainedTokenizer, @@ -80,111 +81,99 @@ def get_num_proc(dataset_len: int, num_available_cpus: int, example_per_second_p # ---------------------------------------------------------------------------- # Tokenization @dataclass -class TokenizerConfig: +class TokenizerConfigV1: model_name_or_path: str revision: str trust_remote_code: bool = True use_fast: bool = True - chat_template_name: Optional[str] = None + chat_template_name: Optional[str] = None # TODO: should I give an option to force override? add_bos: bool = False - get_tokenizer_fn: str = "get_tokenizer_v1" # for tracking purposes tokenizer_commit_hash: Optional[str] = None def __post_init__(self): self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision) - - -def get_tokenizer_v1(tc: TokenizerConfig): - tokenizer = AutoTokenizer.from_pretrained( - tc.model_name_or_path, - revision=tc.revision, - trust_remote_code=tc.trust_remote_code, - use_fast=tc.use_fast, - ) - # no default pad token for llama! - # here we add all special tokens again, because the default ones are not in the special_tokens_map - if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast): - num_added_tokens = tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "", - } + + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name_or_path, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + use_fast=self.use_fast, ) - assert num_added_tokens in [ - 0, - 1, - ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast): - # OLMo newer models use this tokenizer - if tokenizer.bos_token is None: - tokenizer.bos_token = tokenizer.eos_token - assert ( - tc.add_bos - ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." - # else, pythia / other models - else: - num_added_tokens = tokenizer.add_special_tokens( + # no default pad token for llama! + # here we add all special tokens again, because the default ones are not in the special_tokens_map + if isinstance(self.tokenizer, LlamaTokenizer) or isinstance(self.tokenizer, LlamaTokenizerFast): + num_added_tokens = self.tokenizer.add_special_tokens( { + "bos_token": "", + "eos_token": "", + "unk_token": "", "pad_token": "", } ) - assert ( - num_added_tokens <= 1 - ), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens if already set in SFT)." - # NOTE: (Costa) I just commented the `OPTForCausalLM` because we are not likely to use it. - # elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): - # num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""}) - elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: - num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) - assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." - - # set the tokenizer chat template to the training format - # this will be used for encoding the training examples - # and saved together with the tokenizer to be used later. - if tc.chat_template_name in CHAT_TEMPLATES: - tokenizer.chat_template = CHAT_TEMPLATES[tc.chat_template_name] - else: - try: - tokenizer.chat_template = AutoTokenizer.from_pretrained(tc.chat_template_name).chat_template - except Exception: - raise ValueError(f"Could not find chat template for {tc.chat_template_name}.") - - if tc.add_bos: - if tokenizer.chat_template.startswith("{{ bos_token }}") or ( - tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token) - ): - raise ValueError( - "You specified add_bos=True, but the chat template already has a bos_token at the beginning." - ) - # also add bos in the chat template if not already there - tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template - - - # TODO: test it out: PPO should have the same tokenizer as SFT / DPO. - # # create a tokenizer (pad from right) - # config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) - # tokenizer = AutoTokenizer.from_pretrained( - # model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" - # ) - # if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: - # tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> - # else: - # tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding - # if dataset_config.chat_template is not None: - # tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] - return tokenizer - -GET_TOKENIZER_FNS = { - "get_tokenizer_v1": get_tokenizer_v1, -} - -def get_tokenizer(tc: 'TokenizerConfig'): - return GET_TOKENIZER_FNS[tc.get_tokenizer_fn](tc) - + assert num_added_tokens in [ + 0, + 1, + ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." + elif isinstance(self.tokenizer, GPTNeoXTokenizerFast): + # OLMo newer models use this self.tokenizer + if self.tokenizer.bos_token is None: + self.tokenizer.bos_token = self.tokenizer.eos_token + assert ( + self.add_bos + ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." + # else, pythia / other models + else: + num_added_tokens = self.tokenizer.add_special_tokens( + { + "pad_token": "", + } + ) + assert ( + num_added_tokens <= 1 + ), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens if already set in SFT)." + # NOTE: (Costa) I just commented the `OPTForCausalLM` because we are not likely to use it. + # elif isinstance(self.tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): + # num_added_tokens = self.tokenizer.add_special_tokens({"unk_token": ""}) + elif isinstance(self.tokenizer, transformers.PreTrainedTokenizerFast) and self.tokenizer.pad_token is None: + num_added_tokens = self.tokenizer.add_special_tokens({"pad_token": ""}) + assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." + + # set the self.tokenizer chat template to the training format + # this will be used for encoding the training examples + # and saved together with the self.tokenizer to be used later. + if self.chat_template_name in CHAT_TEMPLATES: + self.tokenizer.chat_template = CHAT_TEMPLATES[self.chat_template_name] + else: + try: + self.tokenizer.chat_template = AutoTokenizer.from_pretrained(self.chat_template_name).chat_template + except Exception: + raise ValueError(f"Could not find chat template for {self.chat_template_name}.") + + if self.add_bos: + if self.tokenizer.chat_template.startswith("{{ bos_token }}") or ( + self.tokenizer.bos_token is not None and self.tokenizer.chat_template.startswith(self.tokenizer.bos_token) + ): + raise ValueError( + "You specified add_bos=True, but the chat template already has a bos_token at the beginning." + ) + # also add bos in the chat template if not already there + self.tokenizer.chat_template = "{{ bos_token }}" + self.tokenizer.chat_template + + + # TODO: test it out: PPO should have the sametokenizer as SFT / DPO. + # # create a tokenizer (pad from right) + # config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) + # tokenizer = AutoTokenizer.from_pretrained( + # model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" + # ) + # if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: + # tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + # else: + # tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding + # if dataset_config.chat_template is not None: + # tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] # TODO: for testing, we should load the tokenizer from the sft / dpo / rl and make sure they are all the same. @@ -221,6 +210,11 @@ def get_tokenizer(tc: 'TokenizerConfig'): INPUT_IDS_KEY = "input_ids" ATTENTION_MASK_KEY = "attention_mask" LABELS_KEY = "labels" +TOKENIZED_SFT_DATASET_KEYS = [ + INPUT_IDS_KEY, + ATTENTION_MASK_KEY, + LABELS_KEY, +] # Binary dataset BINARY_LABEL_KEY = "binary_labels" @@ -281,6 +275,77 @@ def sft_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_promp max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) ) +def sft_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int): + """taken directly from https://github.com/allenai/open-instruct/blob/ba11286e5b9eb00d4ce5b40ef4cac1389888416a/open_instruct/finetune.py#L385""" + messages = row["messages"] + if len(messages) == 0: + raise ValueError("messages field is empty.") + input_ids = tokenizer.apply_chat_template( + conversation=messages, + tokenize=True, + return_tensors="pt", + padding=False, + truncation=True, + max_length=max_seq_length, + add_generation_prompt=False, + ) + labels = input_ids.clone() + # mask the non-assistant part for avoiding loss + for message_idx, message in enumerate(messages): + if message["role"] != "assistant": + # we calculate the start index of this non-assistant message + if message_idx == 0: + message_start_idx = 0 + else: + message_start_idx = tokenizer.apply_chat_template( + conversation=messages[:message_idx], # here marks the end of the previous messages + tokenize=True, + return_tensors="pt", + padding=False, + truncation=True, + max_length=max_seq_length, + add_generation_prompt=False, + ).shape[1] + # next, we calculate the end index of this non-assistant message + if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant": + # for intermediate messages that follow with an assistant message, we need to + # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss + # (e.g., `<|assistant|>`) + message_end_idx = tokenizer.apply_chat_template( + conversation=messages[: message_idx + 1], + tokenize=True, + return_tensors="pt", + padding=False, + truncation=True, + max_length=max_seq_length, + add_generation_prompt=True, + ).shape[1] + else: + # for the last message or the message that doesn't follow with an assistant message, + # we don't need to add the assistant generation prefix + message_end_idx = tokenizer.apply_chat_template( + conversation=messages[: message_idx + 1], + tokenize=True, + return_tensors="pt", + padding=False, + truncation=True, + max_length=max_seq_length, + add_generation_prompt=False, + ).shape[1] + # set the label to -100 for the non-assistant part + labels[:, message_start_idx:message_end_idx] = -100 + if max_seq_length and message_end_idx >= max_seq_length: + break + attention_mask = torch.ones_like(input_ids) + row[INPUT_IDS_KEY] = input_ids.flatten() + row[LABELS_KEY] = labels.flatten() + row[ATTENTION_MASK_KEY] = attention_mask.flatten() + return row + + +def sft_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): + return row["labels"] != -100 + def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): # Extract prompt (all messages except the last one) prompt = row["chosen"][:-1] @@ -321,43 +386,58 @@ def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, ma "sft_tokenize_v1": (sft_tokenize_v1, "map"), "sft_tokenize_mask_out_prompt_v1": (sft_tokenize_mask_out_prompt_v1, "map"), "sft_filter_v1": (sft_filter_v1, "filter"), - "preference_tokenize_v1": (preference_tokenize_v1, "map"), + "sft_tulu_tokenize_and_truncate_v1": (sft_tulu_tokenize_and_truncate_v1, "map"), + "sft_tulu_filter_v1": (sft_tulu_filter_v1, "filter"), + "preference_tokenize_v1": (preference_tokenize_v1, "map", TOKENIZED_PREFERENCE_DATASET_KEYS), "preference_filter_v1": (preference_filter_v1, "filter"), } # ---------------------------------------------------------------------------- # Dataset Configuration and Caching @dataclass -class DatasetConfig: +class DatasetConfigV1: dataset_name: str dataset_split: str dataset_revision: str + dataset_range: Optional[int] = None transform_fn: List[str] = field(default_factory=list) transform_fn_args: Dict[str, Dict[str, Any]] = field(default_factory=dict) - get_dataset_fn: str = "get_dataset_v1" + get_dataset_fn_version: str = "v1" # for tracking purposes dataset_commit_hash: Optional[str] = None def __post_init__(self): self.dataset_commit_hash = get_commit_hash(self.dataset_name, self.dataset_revision, "README.md", "dataset") + self.dataset = load_dataset( + self.dataset_name, + split=self.dataset_split, + revision=self.dataset_revision, + ) + if self.dataset_range is None: + dataset_range = len(self.dataset) + self.update_range(dataset_range) + + def update_range(self, dataset_range: int): + self.dataset_range = dataset_range + if self.dataset_range > len(self.dataset): + raise ValueError("Dataset range exceeds dataset length") + self.dataset = self.dataset.select(range(self.dataset_range)) -def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): +def get_dataset_v1(dc: DatasetConfigV1, tc: TokenizerConfigV1): # beaker specific logic; we may get assigned 15.5 CPU, so we convert it to float then int num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count()))) - tokenizer = get_tokenizer(tc) - dataset = load_dataset( - dc.dataset_name, - split=dc.dataset_split, - revision=dc.dataset_revision, - ) + tokenizer = tc.tokenizer + dataset = dc.dataset for fn_name in dc.transform_fn: fn, fn_type = TRANSFORM_FNS[fn_name] # always pass in tokenizer and other args if needed fn_kwargs = {"tokenizer": tokenizer} + target_columns = dataset.column_names if fn_name in dc.transform_fn_args: + target_columns = dc.transform_fn_args[fn_name].pop("target_columns", dataset.column_names) fn_kwargs.update(dc.transform_fn_args[fn_name]) # perform the transformation @@ -365,6 +445,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): dataset = dataset.map( fn, fn_kwargs=fn_kwargs, + remove_columns=[col for col in dataset.column_names if col not in target_columns], num_proc=get_num_proc(len(dataset), num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU), ) elif fn_type == "filter": @@ -377,13 +458,15 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): else: raise ValueError(f"Unknown transform function type: {fn_type}") + if len(dataset) == 0: + raise ValueError("No examples left after transformation") return dataset class DatasetTransformationCache: def __init__(self, hf_entity: Optional[str] = None): self.hf_entity = hf_entity or HfApi().whoami()["name"] - def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> str: + def compute_config_hash(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) -> str: """Compute a deterministic hash of both configs for caching.""" dc_dicts = [ {k: v for k, v in asdict(dc).items() if v is not None} @@ -397,14 +480,14 @@ def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> config_str = json.dumps(combined_dict, sort_keys=True) return hashlib.sha256(config_str.encode()).hexdigest()[:10] - def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> Dataset: + def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) -> Dataset: """Load dataset from cache if it exists, otherwise transform and cache it.""" config_hash = self.compute_config_hash(dcs, tc) repo_name = f"{self.hf_entity}/dataset-mix-cached" # Check if the revision exists if revision_exists(repo_name, config_hash, repo_type="dataset"): - print(f"Found cached dataset at {repo_name}@{config_hash}") + print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") # Use the split from the first dataset config as default return load_dataset( repo_name, @@ -435,10 +518,10 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfi revision=config_hash, commit_message=f"Cache combined dataset with configs hash: {config_hash}" ) - print(f"Pushed transformed dataset to {repo_name}@{config_hash}") + print(f"🚀 Pushed transformed dataset to https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") # NOTE: Load the dataset again to make sure it's downloaded to the HF cache - print(f"Found cached dataset at {repo_name}@{config_hash}") + print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") return load_dataset( repo_name, split=dc.dataset_split, @@ -446,21 +529,51 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfi ) -def get_cached_dataset(dcs: List[DatasetConfig], tc: TokenizerConfig, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset(dcs: List[DatasetConfigV1], tc: TokenizerConfigV1, hf_entity: Optional[str] = None) -> Dataset: + """Get transformed and cached dataset from multiple dataset configs.""" + cache = DatasetTransformationCache(hf_entity=hf_entity) + return cache.load_or_transform_dataset(dcs, tc) + +# TODO: need to support length / range of the dataset +def get_cached_dataset_tulu(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: """Get transformed and cached dataset from multiple dataset configs.""" + dcs = [] + assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" + for i in range(0, len(dataset_mixer_list), 2): + dataset_name = dataset_mixer_list[i] + frac_or_num_samples = dataset_mixer_list[i + 1] + + dataset_config = DatasetConfigV1( + dataset_name=dataset_name, + dataset_split="train", + dataset_revision="main", + transform_fn=["sft_tulu_tokenize_and_truncate_v1", "sft_tulu_filter_v1"], + transform_fn_args={ + "sft_tulu_tokenize_and_truncate_v1": { + "max_seq_length": max_seq_length, + "target_columns": TOKENIZED_SFT_DATASET_KEYS, + } + } + ) + if frac_or_num_samples > 1.0: + new_range = int(frac_or_num_samples) + else: + new_range = int(frac_or_num_samples * len(dataset_config.dataset)) + dataset_config.update_range(new_range) + dcs.append(dataset_config) cache = DatasetTransformationCache(hf_entity=hf_entity) return cache.load_or_transform_dataset(dcs, tc) def test_config_hash_different(): """Test that different configurations produce different hashes.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs1 = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -470,7 +583,7 @@ def test_config_hash_different(): ] dcs2 = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -486,21 +599,21 @@ def test_config_hash_different(): def test_sft_dataset_caching(): """Test caching functionality for SFT datasets.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], transform_fn_args={} ), - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", @@ -520,21 +633,21 @@ def test_sft_dataset_caching(): def test_sft_different_transform(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], transform_fn_args={} ), - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", @@ -549,7 +662,7 @@ def test_sft_different_transform(): def test_sft_filter(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" @@ -557,7 +670,7 @@ def test_sft_filter(): ARBITRARY_MAX_LENGTH = 1000 dcs = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -582,21 +695,21 @@ def test_sft_filter(): def test_preference_dataset(): """Test caching functionality for preference datasets.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs_pref = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-pref-personas-instruction-following", dataset_split="train", dataset_revision="main", transform_fn=["preference_tokenize_v1"], transform_fn_args={} ), - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-wildchat-reused-on-policy-70b", dataset_split="train", dataset_revision="main", diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 30ee740da..f6804ed18 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -54,6 +54,7 @@ ) from open_instruct.dataset_processor import CHAT_TEMPLATES +from open_instruct.dataset_transformation import TokenizerConfigV1, get_cached_dataset_tulu from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, @@ -382,79 +383,6 @@ def __post_init__(self): raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") -def encode_sft_example(example, tokenizer, max_seq_length): - """ - This function encodes a single example into a format that can be used for sft training. - Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields. - We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors. - """ - messages = example["messages"] - if len(messages) == 0: - raise ValueError("messages field is empty.") - input_ids = tokenizer.apply_chat_template( - conversation=messages, - tokenize=True, - return_tensors="pt", - padding=False, - truncation=True, - max_length=max_seq_length, - add_generation_prompt=False, - ) - labels = input_ids.clone() - # mask the non-assistant part for avoiding loss - for message_idx, message in enumerate(messages): - if message["role"] != "assistant": - # we calculate the start index of this non-assistant message - if message_idx == 0: - message_start_idx = 0 - else: - message_start_idx = tokenizer.apply_chat_template( - conversation=messages[:message_idx], # here marks the end of the previous messages - tokenize=True, - return_tensors="pt", - padding=False, - truncation=True, - max_length=max_seq_length, - add_generation_prompt=False, - ).shape[1] - # next, we calculate the end index of this non-assistant message - if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant": - # for intermediate messages that follow with an assistant message, we need to - # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss - # (e.g., `<|assistant|>`) - message_end_idx = tokenizer.apply_chat_template( - conversation=messages[: message_idx + 1], - tokenize=True, - return_tensors="pt", - padding=False, - truncation=True, - max_length=max_seq_length, - add_generation_prompt=True, - ).shape[1] - else: - # for the last message or the message that doesn't follow with an assistant message, - # we don't need to add the assistant generation prefix - message_end_idx = tokenizer.apply_chat_template( - conversation=messages[: message_idx + 1], - tokenize=True, - return_tensors="pt", - padding=False, - truncation=True, - max_length=max_seq_length, - add_generation_prompt=False, - ).shape[1] - # set the label to -100 for the non-assistant part - labels[:, message_start_idx:message_end_idx] = -100 - if max_seq_length and message_end_idx >= max_seq_length: - break - attention_mask = torch.ones_like(input_ids) - return { - "input_ids": input_ids.flatten(), - "labels": labels.flatten(), - "attention_mask": attention_mask.flatten(), - } - - def main(args: FlatArguments): # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -515,40 +443,20 @@ def main(args: FlatArguments): accelerator.wait_for_everyone() - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - args.dataset_name, - args.dataset_config_name, - ) - elif args.dataset_mixer is not None: - # mixing datasets via config - raw_datasets = get_datasets( - args.dataset_mixer, - configs=args.dataset_config_name, - splits=["train"], - save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["messages"], - ) - elif args.dataset_mixer_list is not None: - # mixing datasets via config - raw_datasets = get_datasets( - args.dataset_mixer_list, - configs=args.dataset_config_name, - splits=["train"], - save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["messages"], - ) - else: - data_files = {} - dataset_args = {} - if args.train_file is not None: - data_files["train"] = args.train_file - raw_datasets = load_dataset( - "json", - data_files=data_files, - **dataset_args, - ) + + tc = TokenizerConfigV1( + model_name_or_path=args.model_name_or_path, + revision=args.model_revision, + use_fast=not args.use_slow_tokenizer, + chat_template_name=args.chat_template_name, + add_bos=args.add_bos, + ) + tokenizer = tc.tokenizer + train_dataset = get_cached_dataset_tulu( + args.dataset_mixer_list, + tc, + args.max_seq_length, + ) # Load pretrained model and tokenizer if args.config_name: @@ -568,34 +476,6 @@ def main(args: FlatArguments): "You are instantiating a new config instance from scratch. This is not supported by this script." ) - tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision - if tokenizer_revision != args.model_revision: - # Warn user if tokenizer and model use different revisions; this is an unusual - # use case. - warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different - from the model revision `{args.model_revision}`.""" - logger.warning(warning) - - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name, - revision=tokenizer_revision, - trust_remote_code=args.trust_remote_code, - use_fast=not args.use_slow_tokenizer, - ) - elif args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained( - args.model_name_or_path, - revision=tokenizer_revision, - trust_remote_code=args.trust_remote_code, - use_fast=not args.use_slow_tokenizer, - ) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) - if args.model_name_or_path: if args.use_qlora: bnb_config = BitsAndBytesConfig( @@ -632,42 +512,6 @@ def main(args: FlatArguments): logger.info("Training new model from scratch") model = AutoModelForCausalLM.from_config(config) - # no default pad token for llama! - # here we add all special tokens again, because the default ones are not in the special_tokens_map - if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast): - num_added_tokens = tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "", - } - ) - assert num_added_tokens in [ - 0, - 1, - ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast): - # OLMo newer models use this tokenizer - if tokenizer.bos_token is None: - tokenizer.bos_token = tokenizer.eos_token - assert ( - args.add_bos - ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." - # else, pythia / other models - else: - num_added_tokens = tokenizer.add_special_tokens( - { - "pad_token": "", - } - ) - assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token." - elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): - num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""}) - elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: - num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) - assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." - # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. # gather deepspeed to get "real" embedding size @@ -683,27 +527,6 @@ def main(args: FlatArguments): with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): embedding_size = embeddings.weight.shape[0] - # set the tokenizer chat template to the training format - # this will be used for encoding the training examples - # and saved together with the tokenizer to be used later. - if args.chat_template_name in CHAT_TEMPLATES: - tokenizer.chat_template = CHAT_TEMPLATES[args.chat_template_name] - else: - try: - tokenizer.chat_template = AutoTokenizer.from_pretrained(args.chat_template_name).chat_template - except Exception: - raise ValueError(f"Could not find chat template for {args.chat_template_name}.") - - if args.add_bos: - if tokenizer.chat_template.startswith("{{ bos_token }}") or ( - tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token) - ): - raise ValueError( - "You specified add_bos=True, but the chat template already has a bos_token at the beginning." - ) - # also add bos in the chat template if not already there - tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template - if args.use_lora: if args.use_qlora: model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) @@ -722,31 +545,6 @@ def main(args: FlatArguments): elif args.gradient_checkpointing: model.gradient_checkpointing_enable() - train_dataset = raw_datasets["train"] - # debugging tool for fewer samples - if args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), args.max_train_samples) - logger.info(f"Limiting training samples to {max_train_samples} from {len(train_dataset)}.") - train_dataset = train_dataset.select(range(max_train_samples)) - - with accelerator.main_process_first(): - train_dataset = train_dataset.map( - partial(encode_sft_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length), - batched=False, - num_proc=args.preprocessing_num_workers, - load_from_cache_file=not args.overwrite_cache, - remove_columns=[ - name for name in train_dataset.column_names if name not in ["input_ids", "labels", "attention_mask"] - ], - desc="Tokenizing and reformatting instruction data", - ) - train_dataset.set_format(type="pt") - train_dataset = train_dataset.filter(lambda example: (example["labels"] != -100).any()) - - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - # DataLoaders creation: train_dataloader = DataLoader( train_dataset, From a73d9620bab01d7ee187517e214e7d159146b7e3 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 17 Jan 2025 12:10:26 -0800 Subject: [PATCH 03/32] makesure to shuffle with seed --- open_instruct/finetune.py | 1 + 1 file changed, 1 insertion(+) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index f6804ed18..da25c0af9 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -457,6 +457,7 @@ def main(args: FlatArguments): tc, args.max_seq_length, ) + train_dataset.shuffle(seed=args.seed) # Load pretrained model and tokenizer if args.config_name: From f4021c04a82c2ade91ffbc62461c41a101f2a455 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 17 Jan 2025 12:24:53 -0800 Subject: [PATCH 04/32] quick fix --- open_instruct/dataset_transformation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index fe7fadb07..5e8404b48 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -542,6 +542,10 @@ def get_cached_dataset_tulu(dataset_mixer_list: List[str], tc: TokenizerConfigV1 for i in range(0, len(dataset_mixer_list), 2): dataset_name = dataset_mixer_list[i] frac_or_num_samples = dataset_mixer_list[i + 1] + if "." in frac_or_num_samples: + frac_or_num_samples = float(frac_or_num_samples) + else: + frac_or_num_samples = int(frac_or_num_samples) dataset_config = DatasetConfigV1( dataset_name=dataset_name, From c8b6a62a7ea15f5f856a9c10b7c05ee975e79197 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 17 Jan 2025 15:12:02 -0800 Subject: [PATCH 05/32] add finetune1 --- open_instruct/dataset_transformation.py | 12 +- open_instruct/finetune1.py | 915 ++++++++++++++++++++++++ 2 files changed, 922 insertions(+), 5 deletions(-) create mode 100644 open_instruct/finetune1.py diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 5e8404b48..ce9c642f3 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -25,6 +25,8 @@ # * `max_seq_length` in SFT # * `max_length`, `max_target_length` in RM / DPO, # * `max_prompt_length` in DPO + +# TODO: note that tokenizer doesn't change but model name does change. Should be mindful of this. """ This file contains the utility to transform and cache datasets with different configurations. The main things we are looking for are: @@ -501,11 +503,11 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerCon transformed_datasets = [] for dc in dcs: dataset = get_dataset_v1(dc, tc) - # Add id column if not present - if "id" not in dataset.column_names: - base_name = dc.dataset_name.split("/")[-1] - id_col = [f"{base_name}-{i}" for i in range(len(dataset))] - dataset = dataset.add_column("id", id_col) + # # Add id column if not present + # if "id" not in dataset.column_names: + # base_name = dc.dataset_name.split("/")[-1] + # id_col = [f"{base_name}-{i}" for i in range(len(dataset))] + # dataset = dataset.add_column("id", id_col) transformed_datasets.append(dataset) # Combine datasets diff --git a/open_instruct/finetune1.py b/open_instruct/finetune1.py new file mode 100644 index 000000000..9fd85f506 --- /dev/null +++ b/open_instruct/finetune1.py @@ -0,0 +1,915 @@ +# !/usr/bin/env python +# coding=utf-8 +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import math +import os +import random +import shutil +import subprocess +import time +from dataclasses import dataclass, field +from datetime import timedelta +from functools import partial +from typing import List, Optional, Union + +import datasets +import deepspeed +import torch +import transformers +from accelerate import Accelerator, DataLoaderConfiguration +from accelerate.logging import get_logger +from accelerate.utils import InitProcessGroupKwargs, set_seed +from datasets import load_dataset +from huggingface_hub import HfApi +from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + DataCollatorForSeq2Seq, + GPT2Tokenizer, + GPTNeoXTokenizerFast, + LlamaTokenizer, + LlamaTokenizerFast, + OPTForCausalLM, + get_scheduler, +) + +from open_instruct.dataset_processor import CHAT_TEMPLATES +from open_instruct.dataset_transformation import TokenizerConfigV1, get_cached_dataset_tulu +from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate +from open_instruct.utils import ( + ArgumentParserPlus, + clean_last_n_checkpoints, + get_datasets, + get_last_checkpoint_path, + get_wandb_tags, + is_beaker_job, + maybe_get_beaker_config, + maybe_use_ai2_hf_entity, + maybe_use_ai2_wandb_entity, + upload_metadata_to_hf, +) + +logger = get_logger(__name__) + + +@dataclass +class FlatArguments: + """ + Full arguments class for all fine-tuning jobs. + """ + + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """The name of this experiment""" + run_name: Optional[str] = None + """A unique name of this run""" + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." + ) + }, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + tokenizer_revision: Optional[str] = field( + default=None, + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + chat_template_name: str = field( + default="tulu", + metadata={ + "help": ( + f"The name of the chat template to use. " + f"You can choose one of our pre-defined templates: {', '.join(CHAT_TEMPLATES.keys())}." + f"Or, you can provide a tokenizer name or path here and we will apply its chat template." + ) + }, + ) + use_flash_attn: bool = field( + default=True, + metadata={"help": "Whether to use flash attention in the model training"}, + ) + use_slow_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the slow tokenizer or not (which is then fast tokenizer)."}, + ) + model_revision: Optional[str] = field( + default=None, + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. " + "This option should only be set to `True` for repositories you trust and in which you " + "have read the code, as it will execute code present on the Hub on your local machine." + ) + }, + ) + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to create the model as an empty shell, " + "then only materialize its parameters when the pretrained weights are loaded. " + "set True will benefit LLM loading time and RAM consumption." + ) + }, + ) + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_mixer: Optional[dict] = field( + default=None, metadata={"help": "A dictionary of datasets (local or HF) to sample from."} + ) + dataset_mixer_list: Optional[list[str]] = field( + default=None, metadata={"help": "A list of datasets (local or HF) to sample from."} + ) + dataset_mix_dir: Optional[str] = field( + default=None, metadata={"help": "The directory to save the mixed dataset to disk."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a json/jsonl file)."} + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. " + "Sequences longer than this will be truncated," + ) + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + add_bos: bool = field( + default=False, + metadata={ + "help": "Forcibly add bos token to the beginning of the input sequence." + " Use only when tokenizer does not add bos token by default." + }, + ) + clip_grad_norm: float = field( + default=-1, + metadata={"help": "Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead)."}, + ) + gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + learning_rate: float = field( + default=2e-5, + metadata={"help": "The initial learning rate for AdamW optimizer."}, + ) + logging_steps: Optional[int] = field( + default=None, + metadata={"help": "Log the training loss and learning rate every logging_steps steps."}, + ) + lora_rank: int = field( + default=64, + metadata={"help": "The rank of lora."}, + ) + lora_alpha: float = field( + default=16, + metadata={"help": "The alpha parameter of lora."}, + ) + lora_dropout: float = field( + default=0.1, + metadata={"help": "The dropout rate of lora modules."}, + ) + lr_scheduler_type: str = field( + default="linear", + metadata={ + "help": "The scheduler type to use for learning rate adjustment.", + "choices": ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + }, + ) + num_train_epochs: int = field( + default=2, + metadata={"help": "Total number of training epochs to perform."}, + ) + output_dir: str = field( + default="output/", + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + per_device_train_batch_size: int = field( + default=8, + metadata={"help": "Batch size per GPU/TPU core/CPU for training."}, + ) + use_lora: bool = field( + default=False, + metadata={"help": "If True, will use LORA (low-rank parameter-efficient training) to train the model."}, + ) + use_qlora: bool = field( + default=False, + metadata={"help": "Use qLoRA training - initializes model in quantized form. Not compatible with deepspeed."}, + ) + use_8bit_optimizer: bool = field( + default=False, + metadata={"help": "Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed."}, + ) + warmup_ratio: float = field( + default=0.03, + metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}, + ) + weight_decay: float = field( + default=0.0, + metadata={"help": "Weight decay for AdamW if we apply some."}, + ) + timeout: int = field( + default=1800, + metadata={ + "help": "Timeout for the training process in seconds." + "Useful if tokenization process is long. Default is 1800 seconds (30 minutes)." + }, + ) + reduce_loss: str = field( + default="mean", + metadata={ + "help": "How to reduce loss over tokens. Options are 'mean' or 'sum'." + "Using 'sum' can improve chat model performance." + }, + ) + wandb_entity: Optional[str] = field( + default=None, + metadata={"help": "Entity to use for logging to wandb."}, + ) + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "If the training should continue from a checkpoint folder."}, + ) + with_tracking: bool = field( + default=False, + metadata={"help": "Whether to enable experiment trackers for logging."}, + ) + report_to: Union[str, List[str]] = field( + default="all", + metadata={ + "help": "The integration(s) to report results and logs to. " + "Can be a single string or a list of strings. " + "Options are 'tensorboard', 'wandb', 'comet_ml', 'clearml', or 'all'. " + "Specify multiple by listing them: e.g., ['tensorboard', 'wandb']" + }, + ) + save_to_hub: Optional[str] = field( + default=None, + metadata={"help": "Save the model to the Hub under this name. E.g allenai/your-model"}, + ) + gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Turn on gradient checkpointing. Saves memory but slows training."}, + ) + max_train_steps: Optional[int] = field( + default=None, + metadata={"help": "If set, overrides the number of training steps. Otherwise, num_train_epochs is used."}, + ) + seed: int = field(default=42, metadata={"help": "Random seed for initialization and dataset shuffling."}) + checkpointing_steps: Optional[str] = field( + default=None, + metadata={ + "help": "Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." # noqa + }, + ) + overwrite_output_dir: bool = field( + default=False, + metadata={ + "help": "Overwrite the content of the output directory. Means that resumption will always start from scratch." + }, + ) + keep_last_n_checkpoints: int = field( + default=3, + metadata={"help": "How many checkpoints to keep in the output directory. -1 for all."}, + ) + fused_optimizer: bool = field( + default=True, + metadata={ + "help": "Whether to use fused AdamW or not.", + }, + ) + load_balancing_loss: bool = field( + default=False, + metadata={ + "help": "Whether to include a load balancing loss (for OLMoE) or not.", + }, + ) + load_balancing_weight: float = field( + default=0.5, + metadata={"help": "Weight for load balancing loss if applicable."}, + ) + try_auto_save_to_beaker: bool = True + """Whether to try to save the model to Beaker dataset `/output` after training""" + push_to_hub: bool = True + """Whether to upload the saved model to huggingface""" + hf_entity: Optional[str] = None + """The user or org name of the model repository from the Hugging Face Hub""" + hf_repo_id: Optional[str] = None + """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_revision: Optional[str] = None + """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_url: Optional[str] = None + """The url of the saved model in the Hugging Face Hub (will be autoset)""" + try_launch_beaker_eval_jobs: bool = True + """Whether to launch beaker evaluation jobs after training""" + hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" + """What dataset to upload the metadata to. If unset, don't upload metadata""" + + def __post_init__(self): + if self.reduce_loss not in ["mean", "sum"]: + raise ValueError("reduce_loss must be either 'mean' or 'sum'") + if ( + self.dataset_name is None + and self.train_file is None + and self.dataset_mixer is None + and self.dataset_mixer_list is None + ): + raise ValueError("Need either a dataset name, dataset mixer, or a training file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["json", "jsonl"], "`train_file` should be a json or a jsonl file." + if ( + (self.dataset_name is not None and (self.dataset_mixer is not None or self.dataset_mixer_list is not None)) + or (self.dataset_name is not None and self.train_file is not None) + or ( + (self.dataset_mixer is not None or self.dataset_mixer_list is not None) and self.train_file is not None + ) + or (self.dataset_mixer is not None and self.dataset_mixer_list is not None) + ): + raise ValueError("Cannot provide two dataset selection mechanisms.") + if self.try_launch_beaker_eval_jobs and not self.push_to_hub: + raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") + + +def main(args: FlatArguments): + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + if args.push_to_hub: + if args.hf_repo_id is None: # auto-generate one + args.hf_repo_id = "open_instruct_dev" + if args.hf_entity is None: # first try to use AI2 entity + args.hf_entity = maybe_use_ai2_hf_entity() + if args.hf_entity is None: # then try to use the user's entity + args.hf_entity = HfApi().whoami()["name"] + args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" + if args.hf_repo_revision is None: + args.hf_repo_revision = args.run_name + args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" + + if is_beaker_job(): + beaker_config = maybe_get_beaker_config() + + accelerator_log_kwargs = {} + + if args.with_tracking: + accelerator_log_kwargs["log_with"] = args.report_to + accelerator_log_kwargs["project_dir"] = args.output_dir + + # if you get timeouts (e.g. due to long tokenization) increase this. + timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout)) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + dataloader_config=dataloader_config, + **accelerator_log_kwargs, + kwargs_handlers=[timeout_kwargs], + ) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + accelerator.wait_for_everyone() + + + tc = TokenizerConfigV1( + model_name_or_path=args.model_name_or_path, + revision=args.model_revision, + use_fast=not args.use_slow_tokenizer, + chat_template_name=args.chat_template_name, + add_bos=args.add_bos, + ) + tokenizer = tc.tokenizer + train_dataset = get_cached_dataset_tulu( + args.dataset_mixer_list, + tc, + args.max_seq_length, + ) + train_dataset.shuffle(seed=args.seed) + train_dataset.set_format(type="pt") + + # Load pretrained model and tokenizer + if args.config_name: + config = AutoConfig.from_pretrained( + args.config_name, + revision=args.model_revision, + trust_remote_code=args.trust_remote_code, + ) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained( + args.model_name_or_path, + revision=args.model_revision, + trust_remote_code=args.trust_remote_code, + ) + else: + raise ValueError( + "You are instantiating a new config instance from scratch. This is not supported by this script." + ) + + if args.model_name_or_path: + if args.use_qlora: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + device_index = accelerator.local_process_index + device_map = {"": device_index} # force data-parallel training. + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + revision=args.model_revision, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + trust_remote_code=args.trust_remote_code, + quantization_config=bnb_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", + ) + else: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + revision=args.model_revision, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + trust_remote_code=args.trust_remote_code, + low_cpu_mem_usage=args.low_cpu_mem_usage, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config) + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + # gather deepspeed to get "real" embedding size + embeddings = model.get_input_embeddings() + with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): + embedding_size = embeddings.weight.shape[0] + # resize does its own gather + if len(tokenizer) > embedding_size: + # pad to multiple for tensor cores. + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) + # update embedding size after resizing for sum loss + embeddings = model.get_input_embeddings() + with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): + embedding_size = embeddings.weight.shape[0] + + if args.use_lora: + if args.use_qlora: + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) + + logger.info("Initializing LORA model...") + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=["q_proj", "o_proj", "v_proj", "k_proj", "gate_proj", "up_proj", "down_proj"], + ) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + elif args.gradient_checkpointing: + model.gradient_checkpointing_enable() + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"), + batch_size=args.per_device_train_batch_size, + ) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "layer_norm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + if args.use_qlora: + from bitsandbytes.optim import AdamW + + optimizer = AdamW( + optimizer_grouped_parameters, + lr=args.learning_rate, + optim_bits=8 if args.use_8bit_optimizer else 32, + is_paged=True, + ) + else: + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, fused=args.fused_optimizer) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + # Create the learning rate scheduler. + # Note: the current accelerator.step() calls the .step() of the real scheduler + # for the `num_processes` times. This is because they assume + # the user initialize the scheduler with the entire training set. + # In the case of data parallel training, each process only + # sees a subset (1/num_processes) of the training set. + # So each time the process needs to update the lr multiple times so that the total + # number of updates in the end matches the num_training_steps here. + # Here we need to set the num_training_steps to either using the + # entire training set (when epochs is specified) or we need to multiply the + # num_training_steps by num_processes so that the total number of + # updates matches the num_training_steps. + num_training_steps_for_scheduler = ( + args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes + ) + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_training_steps=num_training_steps_for_scheduler, + num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), + ) + # Prepare everything with `accelerator`. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and str(checkpointing_steps).lower() != "epoch": + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"] + + # (Optional) Ai2 internal tracking + if args.wandb_entity is None: + args.wandb_entity = maybe_use_ai2_wandb_entity() + if is_beaker_job(): + experiment_config.update(vars(beaker_config)) + accelerator.init_trackers( + "open_instruct_internal", + experiment_config, + init_kwargs={ + "wandb": { + "name": args.run_name, + "entity": args.wandb_entity, + "tags": [args.exp_name] + get_wandb_tags(), + } + }, + ) + wandb_tracker = accelerator.get_tracker("wandb") + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + last_checkpoint_path = get_last_checkpoint_path(args) + if last_checkpoint_path: + accelerator.print(f"Resumed from checkpoint: {last_checkpoint_path}") + accelerator.load_state(last_checkpoint_path) + # Extract `epoch_{i}` or `step_{i}` + last_checkpoint_path = os.path.basename(last_checkpoint_path) + training_difference = os.path.splitext(last_checkpoint_path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + starting_epoch = resume_step // len(train_dataloader) + completed_steps = resume_step // args.gradient_accumulation_steps + resume_step -= starting_epoch * len(train_dataloader) + + print(f"Starting from epoch {starting_epoch} and step {completed_steps}.") + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + local_total_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device) + total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device) + start_time = time.time() + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + train_dataloader.set_epoch(epoch) + total_loss = 0 + total_aux_loss = 0 + if last_checkpoint_path and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) + else: + active_dataloader = train_dataloader + for step, batch in enumerate(active_dataloader): + local_total_tokens += batch["attention_mask"].sum() + total_token_including_padding += batch["attention_mask"].numel() + with accelerator.accumulate(model): + if args.load_balancing_loss: + outputs = model(**batch, use_cache=False, output_router_logits=True) + else: + outputs = model(**batch, use_cache=False) + if args.reduce_loss == "mean": + loss = outputs.loss + else: + # reduce loss is sum + # this ensures that we weight all tokens in the dataset equally, + # rather than weighting each overall example equally when + # using high amounts of gradient accumulation. + # this can result in > 5 point improvements in AlpacaEval + # see https://github.com/huggingface/transformers/issues/24725 for + # more discussion and details. + logits = outputs.logits + labels = batch["labels"] + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss(reduction="sum") + shift_logits = shift_logits.view(-1, embedding_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if args.load_balancing_loss: + aux_loss = args.load_balancing_weight * outputs.aux_loss + loss += aux_loss + # We keep track of the loss at each logged step + total_loss += loss.detach().float() + accelerator.backward(loss) + if args.load_balancing_loss: + total_aux_loss += aux_loss.detach().float() + # clip gradient norm. don't do this with deepspeed + if accelerator.sync_gradients and args.clip_grad_norm > 0: + accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + if args.logging_steps and completed_steps % args.logging_steps == 0: + avg_loss = ( + accelerator.gather(total_loss).mean().item() + / args.gradient_accumulation_steps + / args.logging_steps + ) + total_tokens = accelerator.gather(local_total_tokens).sum().item() + total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item() + metrics_to_log = { + "learning_rate": lr_scheduler.get_last_lr()[0], + "train_loss": avg_loss, + "total_tokens": total_tokens, + "per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time), + "total_tokens_including_padding": total_tokens_including_padding, + "per_device_tps_including_padding": total_tokens_including_padding + / accelerator.num_processes + / (time.time() - start_time), + } + if args.load_balancing_loss: + avg_aux_loss = ( + accelerator.gather(total_aux_loss).mean().item() + / args.gradient_accumulation_steps + / args.logging_steps + ) + logger.info( + f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, Aux Loss: {avg_aux_loss}, TPS: {total_tokens / (time.time() - start_time)}" + ) + metrics_to_log["aux_loss"] = avg_aux_loss + else: + logger.info( + f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}" + ) + if args.with_tracking: + accelerator.log( + metrics_to_log, + step=completed_steps, + ) + total_loss = 0 + total_aux_loss = 0 + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints + with open( + os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w" + ) as f: + f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. + if accelerator.is_local_main_process: + clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) + accelerator.wait_for_everyone() + + if completed_steps >= args.max_train_steps: + break + + if checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints + with open(os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w") as f: + f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. + if accelerator.is_local_main_process: + clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) + accelerator.wait_for_everyone() + + if args.output_dir is not None: + save_with_accelerate( + accelerator, + model, + tokenizer, + args.output_dir, + args.use_lora, + ) + + # remove all checkpoints to save space + if accelerator.is_local_main_process: + clean_last_n_checkpoints(args.output_dir, keep_last_n_checkpoints=0) + + if ( + args.try_auto_save_to_beaker + and accelerator.is_main_process + and len(beaker_config.beaker_dataset_id_urls) > 0 + and args.output_dir.rstrip("/") != "/output" + ): + shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) + + if is_beaker_job() and accelerator.is_main_process: + # dpo script only supports these two options right now for datasets + if args.dataset_mixer: + dataset_list = list(args.dataset_mixer.keys()) + elif args.dataset_mixer_list: + dataset_list = args.dataset_mixer_list[::2] # even indices + elif args.dataset_name: + dataset_list = [args.dataset_name] + else: + dataset_list = [args.train_file] + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "sft", + "datasets": dataset_list, + "base_model": args.model_name_or_path, + "wandb_path": wandb_tracker.run.get_url(), + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + # save metadata to the output directory. then it should also get pushed to HF. + with open(os.path.join(args.output_dir, "metadata.json"), "w") as f: + json.dump(metadata_blob, f) + + # upload metadata to the dataset if set + if args.hf_metadata_dataset: + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + args.run_name, # to match what the auto-evals name as. + ) + + if args.try_launch_beaker_eval_jobs: + command = f"""\ + python mason.py \ + --cluster ai2/ganymede-cirrascale ai2/ceres-cirrascale ai2/neptune-cirrascale ai2/saturn-cirrascale ai2/jupiter-cirrascale-2 \ + --priority low \ + --preemptible \ + --budget ai2/allennlp \ + --workspace ai2/tulu-2-improvements \ + --image nathanl/open_instruct_auto \ + --pure_docker_mode \ + --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ + --beaker_workload_id {beaker_config.beaker_workload_id} \ + --upload_to_hf {args.hf_metadata_dataset} \ + --model_name {args.run_name} \ + --run_id {wandb_tracker.run.get_url()} + """ + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + if args.push_to_hub: + push_folder_to_hub( + accelerator, + args.output_dir, + args.hf_repo_id, + args.hf_repo_revision, + ) + accelerator.wait_for_everyone() + if args.with_tracking: + accelerator.end_training() + + +if __name__ == "__main__": + parser = ArgumentParserPlus((FlatArguments)) + args = parser.parse() + main(args) From 15ba99940b1bac252da9234a7852c4c0cc4bb201 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 06:34:13 -0800 Subject: [PATCH 06/32] support preference tuning as well --- open_instruct/dataset_transformation.py | 167 +++- open_instruct/dpo_tune_cache1.py | 1087 +++++++++++++++++++++++ open_instruct/finetune1.py | 18 +- 3 files changed, 1218 insertions(+), 54 deletions(-) create mode 100644 open_instruct/dpo_tune_cache1.py diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index ce9c642f3..2acb4ac0d 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -5,7 +5,6 @@ # 3. ✅ RM / DPO (chosen and rejected) # 4. ✅ Visualization of length distributions? # 5. ✅ Filter? -# * Smart truncation? # 6. ✅ dataset_num_proc # 7. ✅ check EOS token # 8. dataset mixer? @@ -49,6 +48,8 @@ import os from typing import Any, Dict, List, Optional import torch +from rich.console import Console +from rich.text import Text from transformers import ( AutoTokenizer, PreTrainedTokenizer, @@ -79,6 +80,17 @@ def get_num_proc(dataset_len: int, num_available_cpus: int, example_per_second_p num_required_cpus = max(1, dataset_len // example_per_second_per_cpu) return min(num_required_cpus, num_available_cpus) +COLORS = ["on red", "on green", "on blue", "on yellow", "on magenta"] + +def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): + i = 0 + console = Console() + rich_text = Text() + for i, token in enumerate(tokens): + color = COLORS[i % len(COLORS)] + decoded_token = tokenizer.decode(token) + rich_text.append(f"{decoded_token}", style=color) + console.print(rich_text) # ---------------------------------------------------------------------------- # Tokenization @@ -95,7 +107,7 @@ class TokenizerConfigV1: tokenizer_commit_hash: Optional[str] = None def __post_init__(self): - self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision) + self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision, filename="tokenizer_config.json") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name_or_path, @@ -182,31 +194,6 @@ def __post_init__(self): # ---------------------------------------------------------------------------- # Dataset Transformation -COLORS = ["on red", "on green", "on blue", "on yellow", "on magenta"] -# Preference dataset -INPUT_IDS_CHOSEN_KEY = "input_ids_chosen" -ATTENTION_MASK_CHOSEN_KEY = "attention_mask_chosen" -INPUT_IDS_REJECTED_KEY = "input_ids_rejected" -ATTENTION_MASK_REJECTED_KEY = "attention_mask_rejected" -INPUT_IDS_PROMPT_KEY = "input_ids_prompt" -ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt" -GROUND_TRUTHS_KEY = "ground_truth" -DATASET_SOURCE_KEY = "dataset" - -# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only -# also we don't really need `ATTENTION_MASK_CHOSEN_KEY` and `ATTENTION_MASK_REJECTED_KEY` -# since we are always padding from the right with a collator; however they might become -# more useful if we want to do some sort of packing in the future. The nice thing is -# that the tokenization logic would work for both DPO and RM training. -TOKENIZED_PREFERENCE_DATASET_KEYS = [ - INPUT_IDS_CHOSEN_KEY, - INPUT_IDS_REJECTED_KEY, - # ATTENTION_MASK_CHOSEN_KEY, - # ATTENTION_MASK_REJECTED_KEY, - # INPUT_IDS_PROMPT_KEY, - # ATTENTION_MASK_PROMPT_KEY, -] - # SFT dataset SFT_MESSAGE_KEY = "messages" INPUT_IDS_KEY = "input_ids" @@ -218,15 +205,38 @@ def __post_init__(self): LABELS_KEY, ] -# Binary dataset -BINARY_LABEL_KEY = "binary_labels" -BINARY_DATASET_KEYS = [ - INPUT_IDS_KEY, - LABELS_KEY, - BINARY_LABEL_KEY, +# Preference dataset +# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only +# also we don't really need `CHOSEN_ATTENTION_MASK_KEY` and `REJECTED_ATTENTION_MASK_KEY` +# since we are always padding from the right with a collator; however they might become +# more useful if we want to do some sort of packing in the future. The nice thing is +# that the tokenization logic would work for both DPO and RM training. +CHOSEN_KEY = "chosen" +REJECTED_KEY = "rejected" +CHOSEN_INPUT_IDS_KEY = "chosen_input_ids" +CHOSEN_ATTENTION_MASK_KEY = "chosen_attention_mask" +CHOSEN_LABELS_KEY = "chosen_labels" +REJECTED_INPUT_IDS_KEY = "rejected_input_ids" +REJECTED_ATTENTION_MASK_KEY = "rejected_attention_mask" +REJECTED_LABELS_KEY = "rejected_labels" + +INPUT_IDS_PROMPT_KEY = "input_ids_prompt" +ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt" +GROUND_TRUTHS_KEY = "ground_truth" +DATASET_SOURCE_KEY = "dataset" + +TOKENIZED_PREFERENCE_DATASET_KEYS = [ + CHOSEN_INPUT_IDS_KEY, + CHOSEN_LABELS_KEY, + CHOSEN_ATTENTION_MASK_KEY, + REJECTED_INPUT_IDS_KEY, + REJECTED_LABELS_KEY, + REJECTED_ATTENTION_MASK_KEY, ] + +# TODO: allow passing in sft_message key, so we can train on "chosen" of pref dataset. def sft_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): if len(row[SFT_MESSAGE_KEY]) == 1: prompt = row[SFT_MESSAGE_KEY] @@ -346,7 +356,7 @@ def sft_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrained def sft_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): - return row["labels"] != -100 + return any(x != -100 for x in row[LABELS_KEY]) def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): # Extract prompt (all messages except the last one) @@ -360,12 +370,12 @@ def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): row[ATTENTION_MASK_PROMPT_KEY] = [1] * len(row[INPUT_IDS_PROMPT_KEY]) # Tokenize chosen completion - row[INPUT_IDS_CHOSEN_KEY] = tokenizer.apply_chat_template(row["chosen"]) - row[ATTENTION_MASK_CHOSEN_KEY] = [1] * len(row[INPUT_IDS_CHOSEN_KEY]) + row[CHOSEN_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["chosen"]) + row[CHOSEN_ATTENTION_MASK_KEY] = [1] * len(row[CHOSEN_INPUT_IDS_KEY]) # Tokenize rejected completion - row[INPUT_IDS_REJECTED_KEY] = tokenizer.apply_chat_template(row["rejected"]) - row[ATTENTION_MASK_REJECTED_KEY] = [1] * len(row[INPUT_IDS_REJECTED_KEY]) + row[REJECTED_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["rejected"]) + row[REJECTED_ATTENTION_MASK_KEY] = [1] * len(row[REJECTED_INPUT_IDS_KEY]) return row @@ -377,13 +387,44 @@ def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, ma # Check total sequence lengths if specified if max_token_length is not None: - if len(row[INPUT_IDS_CHOSEN_KEY]) > max_token_length: + if len(row[CHOSEN_INPUT_IDS_KEY]) > max_token_length: return False - if len(row[INPUT_IDS_REJECTED_KEY]) > max_token_length: + if len(row[REJECTED_INPUT_IDS_KEY]) > max_token_length: return False return True + +def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int): + """ + Here we assume each example has a rejected and chosen field, both of which are a list of messages. + Each message is a dict with 'role' and 'content' fields. + We assume only the last message is different, and the prompt is contained in the list of messages. + """ + chosen_messages = row[CHOSEN_KEY] + rejected_messages = row[REJECTED_KEY] + if len(chosen_messages) == 0: + raise ValueError("chosen messages field is empty.") + if len(rejected_messages) == 0: + raise ValueError("rejected messages field is empty.") + + chosen_encoded = sft_tulu_tokenize_and_truncate_v1({SFT_MESSAGE_KEY: chosen_messages}, tokenizer, max_seq_length) + rejected_encoded = sft_tulu_tokenize_and_truncate_v1({SFT_MESSAGE_KEY: rejected_messages}, tokenizer, max_seq_length) + + return { + CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"], + CHOSEN_LABELS_KEY: chosen_encoded["labels"], + CHOSEN_ATTENTION_MASK_KEY: chosen_encoded["attention_mask"], + REJECTED_INPUT_IDS_KEY: rejected_encoded["input_ids"], + REJECTED_LABELS_KEY: rejected_encoded["labels"], + REJECTED_ATTENTION_MASK_KEY: rejected_encoded["attention_mask"], + } + + +def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): + return any(x != -100 for x in row[CHOSEN_LABELS_KEY]) and any(x != -100 for x in row[REJECTED_LABELS_KEY]) + + TRANSFORM_FNS = { "sft_tokenize_v1": (sft_tokenize_v1, "map"), "sft_tokenize_mask_out_prompt_v1": (sft_tokenize_mask_out_prompt_v1, "map"), @@ -392,6 +433,8 @@ def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, ma "sft_tulu_filter_v1": (sft_tulu_filter_v1, "filter"), "preference_tokenize_v1": (preference_tokenize_v1, "map", TOKENIZED_PREFERENCE_DATASET_KEYS), "preference_filter_v1": (preference_filter_v1, "filter"), + "preference_tulu_tokenize_and_truncate_v1": (preference_tulu_tokenize_and_truncate_v1, "map"), + "preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"), } # ---------------------------------------------------------------------------- @@ -503,11 +546,6 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerCon transformed_datasets = [] for dc in dcs: dataset = get_dataset_v1(dc, tc) - # # Add id column if not present - # if "id" not in dataset.column_names: - # base_name = dc.dataset_name.split("/")[-1] - # id_col = [f"{base_name}-{i}" for i in range(len(dataset))] - # dataset = dataset.add_column("id", id_col) transformed_datasets.append(dataset) # Combine datasets @@ -532,13 +570,10 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerCon def get_cached_dataset(dcs: List[DatasetConfigV1], tc: TokenizerConfigV1, hf_entity: Optional[str] = None) -> Dataset: - """Get transformed and cached dataset from multiple dataset configs.""" cache = DatasetTransformationCache(hf_entity=hf_entity) return cache.load_or_transform_dataset(dcs, tc) -# TODO: need to support length / range of the dataset -def get_cached_dataset_tulu(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: - """Get transformed and cached dataset from multiple dataset configs.""" +def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -570,6 +605,40 @@ def get_cached_dataset_tulu(dataset_mixer_list: List[str], tc: TokenizerConfigV1 cache = DatasetTransformationCache(hf_entity=hf_entity) return cache.load_or_transform_dataset(dcs, tc) + +def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: + dcs = [] + assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" + for i in range(0, len(dataset_mixer_list), 2): + dataset_name = dataset_mixer_list[i] + frac_or_num_samples = dataset_mixer_list[i + 1] + if "." in frac_or_num_samples: + frac_or_num_samples = float(frac_or_num_samples) + else: + frac_or_num_samples = int(frac_or_num_samples) + + dataset_config = DatasetConfigV1( + dataset_name=dataset_name, + dataset_split="train", + dataset_revision="main", + transform_fn=["preference_tulu_tokenize_and_truncate_v1", "preference_tulu_filter_v1"], + transform_fn_args={ + "preference_tulu_tokenize_and_truncate_v1": { + "max_seq_length": max_seq_length, + "target_columns": TOKENIZED_PREFERENCE_DATASET_KEYS, + } + } + ) + if frac_or_num_samples > 1.0: + new_range = int(frac_or_num_samples) + else: + new_range = int(frac_or_num_samples * len(dataset_config.dataset)) + dataset_config.update_range(new_range) + dcs.append(dataset_config) + cache = DatasetTransformationCache(hf_entity=hf_entity) + return cache.load_or_transform_dataset(dcs, tc) + + def test_config_hash_different(): """Test that different configurations produce different hashes.""" tc = TokenizerConfigV1( diff --git a/open_instruct/dpo_tune_cache1.py b/open_instruct/dpo_tune_cache1.py new file mode 100644 index 000000000..c056e8df1 --- /dev/null +++ b/open_instruct/dpo_tune_cache1.py @@ -0,0 +1,1087 @@ +# !/usr/bin/env python +# coding=utf-8 +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +DPO tuning script. Adapted from our finetuning script. +""" + +import json +import logging +import math +import os +import random +import shutil +import subprocess +import time +from dataclasses import dataclass, field +from datetime import timedelta +from functools import partial +from typing import Callable, List, Optional, Union + +import datasets +import deepspeed +import torch +import torch.utils +import torch.utils.data +import transformers +from accelerate import Accelerator, DataLoaderConfiguration +from accelerate.logging import get_logger +from accelerate.utils import InitProcessGroupKwargs, set_seed +from datasets import load_dataset +from huggingface_hub import HfApi +from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + GPT2Tokenizer, + GPTNeoXTokenizerFast, + LlamaTokenizer, + LlamaTokenizerFast, + OPTForCausalLM, + get_scheduler, +) + +from open_instruct.dataset_processor import CHAT_TEMPLATES +from open_instruct.dataset_transformation import TokenizerConfigV1, get_cached_dataset_tulu_preference +from open_instruct.dpo_utils import ( + DataCollatorForSeq2SeqDPO, + concatenated_forward, + dpo_loss, + separate_forward, + simpo_loss, + wpo_loss, +) +from open_instruct.finetune import encode_sft_example +from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate +from open_instruct.utils import ( + ArgumentParserPlus, + clean_last_n_checkpoints, + get_datasets, + get_last_checkpoint_path, + get_wandb_tags, + is_beaker_job, + maybe_get_beaker_config, + maybe_use_ai2_hf_entity, + maybe_use_ai2_wandb_entity, + upload_metadata_to_hf, +) + +logger = get_logger(__name__) + + +@dataclass +class FlatArguments: + """ + Full arguments class for all fine-tuning jobs. + """ + + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """The name of this experiment""" + run_name: Optional[str] = None + """A unique name of this run""" + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." + ) + }, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + dpo_use_paged_optimizer: bool = field( + default=False, + metadata={ + "help": "Use paged optimizer from bitsandbytes." + " Not compatible with deepspeed (use deepspeed config instead)." + }, + ) + dpo_beta: float = field( + default=0.1, + metadata={"help": "Beta parameter for DPO loss. Default is 0.1."}, + ) + dpo_loss_type: str = field( + default="dpo", + metadata={"help": "Type of DPO loss to use. Options are 'dpo', 'dpo_norm', 'simpo', 'wpo'."}, + ) + dpo_gamma_beta_ratio: float = field( + default=0.3, + metadata={"help": "Gamma to beta ratio for SimPO loss. Default is 0.3. Not used for DPO loss."}, + ) + dpo_label_smoothing: float = field( + default=0.0, + metadata={"help": "Label smoothing for DPO/SimPO loss. Default is 0 (no smoothing)."}, + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + tokenizer_revision: Optional[str] = field( + default=None, + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + chat_template_name: str = field( + default="tulu", + metadata={ + "help": ( + f"The name of the chat template to use. " + f"You can choose one of our pre-defined templates: {', '.join(CHAT_TEMPLATES.keys())}." + f"Or, you can provide a tokenizer name or path here and we will apply its chat template." + ) + }, + ) + use_flash_attn: bool = field( + default=True, + metadata={"help": "Whether to use flash attention in the model training"}, + ) + use_slow_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the slow tokenizer or not (which is then fast tokenizer)."}, + ) + model_revision: Optional[str] = field( + default=None, + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. " + "This option should only be set to `True` for repositories you trust and in which you " + "have read the code, as it will execute code present on the Hub on your local machine." + ) + }, + ) + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to create the model as an empty shell, " + "then only materialize its parameters when the pretrained weights are loaded. " + "set True will benefit LLM loading time and RAM consumption." + ) + }, + ) + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_mixer: Optional[dict] = field( + default=None, metadata={"help": "A dictionary of datasets (local or HF) to sample from."} + ) + dataset_mixer_list: Optional[list[str]] = field( + default=None, metadata={"help": "A list of datasets (local or HF) to sample from."} + ) + dataset_mix_dir: Optional[str] = field( + default=None, metadata={"help": "The directory to save the mixed dataset to disk."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a json/jsonl file)."} + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. " + "Sequences longer than this will be truncated," + ) + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + add_bos: bool = field( + default=False, + metadata={ + "help": "Forcibly add bos token to the beginning of the input sequence." + " Use only when tokenizer does not add bos token by default." + }, + ) + clip_grad_norm: float = field( + default=-1, + metadata={"help": "Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead)."}, + ) + gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + learning_rate: float = field( + default=2e-5, + metadata={"help": "The initial learning rate for AdamW optimizer."}, + ) + logging_steps: Optional[int] = field( + default=None, + metadata={"help": "Log the training loss and learning rate every logging_steps steps."}, + ) + lora_rank: int = field( + default=64, + metadata={"help": "The rank of lora."}, + ) + lora_alpha: float = field( + default=16, + metadata={"help": "The alpha parameter of lora."}, + ) + lora_dropout: float = field( + default=0.1, + metadata={"help": "The dropout rate of lora modules."}, + ) + lr_scheduler_type: str = field( + default="linear", + metadata={ + "help": "The scheduler type to use for learning rate adjustment.", + "choices": ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + }, + ) + num_train_epochs: int = field( + default=2, + metadata={"help": "Total number of training epochs to perform."}, + ) + output_dir: str = field( + default="output/", + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + per_device_train_batch_size: int = field( + default=8, + metadata={"help": "Batch size per GPU/TPU core/CPU for training."}, + ) + use_lora: bool = field( + default=False, + metadata={"help": "If True, will use LORA (low-rank parameter-efficient training) to train the model."}, + ) + use_qlora: bool = field( + default=False, + metadata={"help": "Use qLoRA training - initializes model in quantized form. Not compatible with deepspeed."}, + ) + use_8bit_optimizer: bool = field( + default=False, + metadata={"help": "Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed."}, + ) + warmup_ratio: float = field( + default=0.03, + metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}, + ) + weight_decay: float = field( + default=0.0, + metadata={"help": "Weight decay for AdamW if we apply some."}, + ) + timeout: int = field( + default=1800, + metadata={ + "help": "Timeout for the training process in seconds." + "Useful if tokenization process is long. Default is 1800 seconds (30 minutes)." + }, + ) + reduce_loss: str = field( + default="mean", + metadata={ + "help": "How to reduce loss over tokens. Options are 'mean' or 'sum'." + "Using 'sum' can improve chat model performance." + }, + ) + wandb_entity: Optional[str] = field( + default=None, + metadata={"help": "Entity to use for logging to wandb."}, + ) + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "If the training should continue from a checkpoint folder."}, + ) + with_tracking: bool = field( + default=False, + metadata={"help": "Whether to enable experiment trackers for logging."}, + ) + report_to: Union[str, List[str]] = field( + default="all", + metadata={ + "help": "The integration(s) to report results and logs to. " + "Can be a single string or a list of strings. " + "Options are 'tensorboard', 'wandb', 'comet_ml', 'clearml', or 'all'. " + "Specify multiple by listing them: e.g., ['tensorboard', 'wandb']" + }, + ) + save_to_hub: Optional[str] = field( + default=None, + metadata={"help": "Save the model to the Hub under this name. E.g allenai/your-model"}, + ) + gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Turn on gradient checkpointing. Saves memory but slows training."}, + ) + max_train_steps: Optional[int] = field( + default=None, + metadata={"help": "If set, overrides the number of training steps. Otherwise, num_train_epochs is used."}, + ) + seed: int = field(default=42, metadata={"help": "Random seed for initialization and dataset shuffling."}) + checkpointing_steps: Optional[str] = field( + default=None, + metadata={ + "help": "Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." # noqa + }, + ) + overwrite_output_dir: bool = field( + default=False, + metadata={ + "help": "Overwrite the content of the output directory. Means that resumption will always start from scratch." + }, + ) + keep_last_n_checkpoints: int = field( + default=3, + metadata={"help": "How many checkpoints to keep in the output directory. -1 for all."}, + ) + fused_optimizer: bool = field( + default=True, + metadata={ + "help": "Whether to use fused AdamW or not.", + }, + ) + load_balancing_loss: bool = field( + default=False, + metadata={ + "help": "Whether to include a load balancing loss (for OLMoE) or not.", + }, + ) + load_balancing_weight: float = field( + default=0.001, + metadata={"help": "Weight for load balancing loss if applicable."}, + ) + concatenated_forward: bool = True + """Whether to concatenate chosen and rejected for DPO training; True is good but you can set to False for saving memory.""" + try_auto_save_to_beaker: bool = True + """Whether to try to save the model to Beaker dataset `/output` after training""" + push_to_hub: bool = True + """Whether to upload the saved model to huggingface""" + hf_entity: Optional[str] = None + """The user or org name of the model repository from the Hugging Face Hub""" + hf_repo_id: Optional[str] = None + """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_revision: Optional[str] = None + """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_url: Optional[str] = None + """The url of the saved model in the Hugging Face Hub (will be autoset)""" + try_launch_beaker_eval_jobs: bool = True + """Whether to launch beaker evaluation jobs after training""" + hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" + """What dataset to upload the metadata to. If unset, don't upload metadata""" + + def __post_init__(self): + if self.reduce_loss not in ["mean", "sum"]: + raise ValueError("reduce_loss must be either 'mean' or 'sum'") + if ( + self.dataset_name is None + and self.train_file is None + and self.dataset_mixer is None + and self.dataset_mixer_list is None + ): + raise ValueError("Need either a dataset name, dataset mixer, or a training file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["json", "jsonl"], "`train_file` should be a json or a jsonl file." + if ( + (self.dataset_name is not None and (self.dataset_mixer is not None or self.dataset_mixer_list is not None)) + or (self.dataset_name is not None and self.train_file is not None) + or ( + (self.dataset_mixer is not None or self.dataset_mixer_list is not None) and self.train_file is not None + ) + or (self.dataset_mixer is not None and self.dataset_mixer_list is not None) + ): + raise ValueError("Cannot provide two dataset selection mechanisms.") + + if self.try_launch_beaker_eval_jobs and not self.push_to_hub: + raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") + + + +def get_cache_ref_logprobs( + model: torch.nn.Module, + active_dataloader: torch.utils.data.DataLoader, + accelerator: Accelerator, + average_log_prob: bool, + last_checkpoint_path: Optional[str], + resume_step: int, + epoch_range: range, + forward_fn: Callable, +): + epoch_cached_reference_chosen_logps = [] + epoch_cached_reference_rejected_logps = [] + for epoch in epoch_range: + active_dataloader.set_epoch(epoch) + if last_checkpoint_path and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = accelerator.skip_first_batches(active_dataloader, resume_step) + cached_reference_chosen_logps = [] + cached_reference_rejected_logps = [] + with torch.no_grad(): + for step, batch in tqdm(enumerate(active_dataloader), disable=not accelerator.is_local_main_process): + if args.use_lora: + with accelerator.unwrap_model(model).disable_adapter(): + reference_chosen_logps, reference_rejected_logps, _ = forward_fn( + model, batch, average_log_prob=average_log_prob + ) + else: + reference_chosen_logps, reference_rejected_logps, _ = forward_fn( + model, batch, average_log_prob=average_log_prob + ) + cached_reference_chosen_logps.append(reference_chosen_logps.cpu()) + cached_reference_rejected_logps.append(reference_rejected_logps.cpu()) + epoch_cached_reference_chosen_logps.append(cached_reference_chosen_logps) + epoch_cached_reference_rejected_logps.append(cached_reference_rejected_logps) + return epoch_cached_reference_chosen_logps, epoch_cached_reference_rejected_logps + + +def main(args: FlatArguments): + init_gpu_memory = None + if torch.cuda.is_available(): + init_gpu_memory = torch.cuda.mem_get_info()[0] + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + args.run_name = f"{args.exp_name}__{args.model_name_or_path.replace('/', '_')}__{args.seed}__{int(time.time())}" + if args.push_to_hub: + if args.hf_repo_id is None: # auto-generate one + args.hf_repo_id = "open_instruct_dev" + if args.hf_entity is None: # first try to use AI2 entity + args.hf_entity = maybe_use_ai2_hf_entity() + if args.hf_entity is None: # then try to use the user's entity + args.hf_entity = HfApi().whoami()["name"] + args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" + if args.hf_repo_revision is None: + args.hf_repo_revision = args.run_name + args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" + + if is_beaker_job(): + beaker_config = maybe_get_beaker_config() + + accelerator_log_kwargs = {} + + if args.with_tracking: + accelerator_log_kwargs["log_with"] = args.report_to + accelerator_log_kwargs["project_dir"] = args.output_dir + + # if you get timeouts (e.g. due to long tokenization) increase this. + timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout)) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + dataloader_config=dataloader_config, + **accelerator_log_kwargs, + kwargs_handlers=[timeout_kwargs], + ) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + accelerator.wait_for_everyone() + + tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision + tokenizer_name = args.tokenizer_name if args.tokenizer_name is not None else args.model_name_or_path + if tokenizer_revision != args.model_revision: + # Warn user if tokenizer and model use different revisions; this is an unusual + # use case. + warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different + from the model revision `{args.model_revision}`.""" + logger.warning(warning) + tc = TokenizerConfigV1( + model_name_or_path=tokenizer_name, + revision=args.model_revision, + use_fast=not args.use_slow_tokenizer, + chat_template_name=args.chat_template_name, + add_bos=args.add_bos, + ) + tokenizer = tc.tokenizer + train_dataset = get_cached_dataset_tulu_preference( + args.dataset_mixer_list, + tc, + args.max_seq_length, + ) + train_dataset.shuffle(seed=args.seed) + train_dataset.set_format(type="pt") + + # Load pretrained model and tokenizer + if args.config_name: + config = AutoConfig.from_pretrained( + args.config_name, + revision=args.model_revision, + trust_remote_code=args.trust_remote_code, + ) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained( + args.model_name_or_path, + revision=args.model_revision, + trust_remote_code=args.trust_remote_code, + ) + else: + raise ValueError( + "You are instantiating a new config instance from scratch. This is not supported by this script." + ) + + def load_model(): + if args.model_name_or_path: + if args.use_qlora: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + device_index = accelerator.local_process_index + device_map = {"": device_index} # force data-parallel training. + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + revision=args.model_revision, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + trust_remote_code=args.trust_remote_code, + quantization_config=bnb_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + use_flash_attention_2=True if args.use_flash_attn else False, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + revision=args.model_revision, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + trust_remote_code=args.trust_remote_code, + low_cpu_mem_usage=args.low_cpu_mem_usage, + use_flash_attention_2=True if args.use_flash_attn else False, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config) + return model + + model = load_model() + print("=============model loaded") + print_gpu_stats(init_gpu_memory) + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + # gather deepspeed to get "real" embedding size + embeddings = model.get_input_embeddings() + with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): + if len(tokenizer) > embeddings.weight.shape[0]: + model.resize_token_embeddings(len(tokenizer)) + + if args.use_lora: + if args.use_qlora: + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) + + logger.info("Initializing LORA model...") + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=["q_proj", "o_proj", "v_proj", "k_proj", "gate_proj", "up_proj", "down_proj"], + ) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + elif args.gradient_checkpointing: + model.gradient_checkpointing_enable() + + # # Preprocessing the datasets. + # if "prompt" in raw_datasets["train"].column_names and "completion" in raw_datasets["train"].column_names: + # raise ValueError("Sorry, prompt-completion format is not supported for DPO training.") + # elif "chosen" in raw_datasets["train"].column_names and "rejected" in raw_datasets["train"].column_names: + # encode_function = partial( + # encode_dpo_example, + # tokenizer=tokenizer, + # max_seq_length=args.max_seq_length, + # ) + # else: + # raise ValueError("You need to have 'chosen' and 'rejected in your column names.") + + # train_dataset = raw_datasets["train"] + + # debugging tool for fewer samples + if args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), args.max_train_samples) + logger.info(f"Limiting training samples to {max_train_samples} from {len(train_dataset)}.") + train_dataset = train_dataset.select(range(max_train_samples)) + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + collate_fn=DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest"), + batch_size=args.per_device_train_batch_size, + ) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "layer_norm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + if args.use_qlora or args.dpo_use_paged_optimizer: + from bitsandbytes.optim import AdamW + + optimizer = AdamW( + optimizer_grouped_parameters, + lr=args.learning_rate, + optim_bits=8 if args.use_8bit_optimizer else 32, + is_paged=True, + ) + else: + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, fused=args.fused_optimizer) + print("=============optimizer loaded") + print_gpu_stats(init_gpu_memory) + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + # Create the learning rate scheduler. + # Note: the current accelerator.step() calls the .step() of the real scheduler + # for the `num_processes` times. This is because they assume + # the user initialize the scheduler with the entire training set. + # In the case of data parallel training, each process only + # sees a subset (1/num_processes) of the training set. + # So each time the process needs to update the lr multiple times so that the total + # number of updates in the end matches the num_training_steps here. + # Here we need to set the num_training_steps to either using the + # entire training set (when epochs is specified) or we need to multiply the + # num_training_steps by num_processes so that the total number of + # updates matches the num_training_steps. + num_training_steps_for_scheduler = ( + args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes + ) + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_training_steps=num_training_steps_for_scheduler, + num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), + ) + # Prepare everything with `accelerator`. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + print("=============accelerate prepared") + print_gpu_stats(init_gpu_memory) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and str(checkpointing_steps).lower() != "epoch": + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"] + + # (Optional) Ai2 internal tracking + if args.wandb_entity is None: + args.wandb_entity = maybe_use_ai2_wandb_entity() + if is_beaker_job(): + experiment_config.update(vars(beaker_config)) + accelerator.init_trackers( + "open_instruct_internal", + experiment_config, + init_kwargs={ + "wandb": { + "name": args.run_name, + "entity": args.wandb_entity, + "tags": [args.exp_name] + get_wandb_tags(), + } + }, + ) + wandb_tracker = accelerator.get_tracker("wandb") + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + last_checkpoint_path = get_last_checkpoint_path(args) + resume_step = None + if last_checkpoint_path: + accelerator.print(f"Resumed from checkpoint: {last_checkpoint_path}") + accelerator.load_state(last_checkpoint_path) + # Extract `epoch_{i}` or `step_{i}` + last_checkpoint_path = os.path.basename(last_checkpoint_path) + training_difference = os.path.splitext(last_checkpoint_path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + starting_epoch = resume_step // len(train_dataloader) + completed_steps = resume_step // args.gradient_accumulation_steps + resume_step -= starting_epoch * len(train_dataloader) + + print(f"Starting from epoch {starting_epoch} and step {completed_steps}.") + + print("=============before cache logprobs") + print_gpu_stats(init_gpu_memory) + + # Cache the logprobs + average_log_prob_loss_types = ["simpo", "dpo_norm"] + average_log_prob = args.dpo_loss_type in average_log_prob_loss_types + forward_fn = concatenated_forward if args.concatenated_forward else separate_forward + if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": + epoch_cached_reference_chosen_logps, epoch_cached_reference_rejected_logps = get_cache_ref_logprobs( + model, + train_dataloader, + accelerator, + average_log_prob, + last_checkpoint_path, + resume_step, + range(starting_epoch, args.num_train_epochs), + forward_fn, + ) + print("=============after cache logprobs") + print_gpu_stats(init_gpu_memory) + torch.cuda.empty_cache() # clear cache + + print("=============after cache logprobs; clear cache") + print_gpu_stats(init_gpu_memory) + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + local_metrics = torch.zeros((20), device=accelerator.device) + episode = 0 + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + train_dataloader.set_epoch(epoch) + if last_checkpoint_path and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) + else: + active_dataloader = train_dataloader + # we need to average the log probs for simpo loss + for step, batch in enumerate(active_dataloader): + episode += len(batch["chosen_input_ids"]) * accelerator.num_processes + # dpo forward pass & loss + with accelerator.accumulate(model): + policy_chosen_logps, policy_rejected_logps, aux_loss = forward_fn( + model, batch, average_log_prob=average_log_prob, output_router_logits=args.load_balancing_loss + ) # `aux_loss` is only used when `args.load_balancing_loss = True` + if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": + p_device = policy_chosen_logps.device + reference_chosen_logps = epoch_cached_reference_chosen_logps[epoch][step].to(p_device) + reference_rejected_logps = epoch_cached_reference_rejected_logps[epoch][step].to(p_device) + losses, _, _ = dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + beta=args.dpo_beta, + label_smoothing=args.dpo_label_smoothing, + ) + elif args.dpo_loss_type == "simpo": + losses, _, _ = simpo_loss( + policy_chosen_logps, + policy_rejected_logps, + beta=args.dpo_beta, + gamma_beta_ratio=args.dpo_gamma_beta_ratio, + label_smoothing=args.dpo_label_smoothing, + ) + elif args.dpo_loss_type == "wpo": + losses, _, _ = wpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + beta=args.dpo_beta, + label_smoothing=args.dpo_label_smoothing, + chosen_loss_mask=batch["chosen_labels"] != -100, + rejected_loss_mask=batch["rejected_labels"] != -100, + ) + else: + raise ValueError(f"Invalid dpo loss type {args.dpo_loss_type}.") + # TODO: metric logging + loss = losses.mean() + if args.load_balancing_loss: + weighted_aux_loss = args.load_balancing_weight * aux_loss + loss += weighted_aux_loss + accelerator.backward(loss) + # clip gradient norm. don't do this with deepspeed + if accelerator.sync_gradients and args.clip_grad_norm > 0: + accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + # We keep track of the loss at each logged step + with torch.no_grad(): + local_metrics[0] += loss + if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": + chosen_rewards = (args.dpo_beta * (policy_chosen_logps - reference_chosen_logps)).mean() + rejected_rewards = (args.dpo_beta * (policy_rejected_logps - reference_rejected_logps)).mean() + average_rewards = (chosen_rewards + rejected_rewards) / 2 + accuracy = (chosen_rewards > rejected_rewards).float().mean() + margin = (chosen_rewards - rejected_rewards).mean() + local_metrics[1] += chosen_rewards + local_metrics[2] += rejected_rewards + local_metrics[3] += average_rewards + local_metrics[4] += accuracy + local_metrics[5] += margin + local_metrics[6] += policy_chosen_logps.mean() + local_metrics[7] += policy_rejected_logps.mean() + if args.load_balancing_loss: + local_metrics[19] += weighted_aux_loss + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + if args.logging_steps and completed_steps % args.logging_steps == 0: + # single all reduce to save time, avoiding per metric all reduce + global_metrics = accelerator.reduce(local_metrics, reduction="mean") + global_metrics /= args.gradient_accumulation_steps * args.logging_steps + global_metrics = global_metrics.tolist() + metrics_to_log = { + "training_step": completed_steps, + "learning_rate": lr_scheduler.get_last_lr()[0], + "epoch": episode / len(train_dataset), + "train_loss": global_metrics[0], + "logps/chosen": global_metrics[6], + "logps/rejected": global_metrics[7], + } + if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": + metrics_to_log.update( + { + "rewards/chosen": global_metrics[1], + "rewards/rejected": global_metrics[2], + "rewards/average": global_metrics[3], + "rewards/accuracy": global_metrics[4], + "rewards/margin": global_metrics[5], + } + ) + logger_str = ( + f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {global_metrics[0]}" + ) + if args.load_balancing_loss: + logger_str += f" Aux Loss: {global_metrics[19]}" + metrics_to_log["aux_loss"] = global_metrics[19] + logger.info(logger_str) + if args.with_tracking: + accelerator.log( + metrics_to_log, + step=completed_steps, + ) + # Reset the local metrics + local_metrics.zero_() + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints + with open( + os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w" + ) as f: + f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. + if accelerator.is_local_main_process: + clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) + accelerator.wait_for_everyone() + + if completed_steps >= args.max_train_steps: + break + + if checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints + with open(os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w") as f: + f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. + if accelerator.is_local_main_process: + clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) + accelerator.wait_for_everyone() + + if args.output_dir is not None: + save_with_accelerate( + accelerator, + model, + tokenizer, + args.output_dir, + args.use_lora, + ) + + # remove all checkpoints to save space + if accelerator.is_local_main_process: + clean_last_n_checkpoints(args.output_dir, keep_last_n_checkpoints=0) + + if ( + args.try_auto_save_to_beaker + and accelerator.is_main_process + and len(beaker_config.beaker_dataset_id_urls) > 0 + and args.output_dir.rstrip("/") != "/output" + ): + shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) + + if is_beaker_job() and accelerator.is_main_process: + # dpo script only supports these two options right now for datasets + if args.dataset_mixer: + dataset_list = list(args.dataset_mixer.keys()) + elif args.dataset_mixer_list: + dataset_list = args.dataset_mixer_list[::2] # even indices + elif args.dataset_name: + dataset_list = [args.dataset_name] + else: + dataset_list = [args.train_file] + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "dpo", + "datasets": dataset_list, + "base_model": args.model_name_or_path, + "wandb_path": wandb_tracker.run.get_url(), + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + # save metadata to the output directory. then it should also get pushed to HF. + with open(os.path.join(args.output_dir, "metadata.json"), "w") as f: + json.dump(metadata_blob, f) + + # upload metadata to the dataset if set + if args.hf_metadata_dataset: + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + args.run_name, # to match what the auto-evals name as. + ) + + if args.try_launch_beaker_eval_jobs: + command = f"""\ + python mason.py \ + --cluster ai2/ganymede-cirrascale ai2/ceres-cirrascale ai2/neptune-cirrascale ai2/saturn-cirrascale ai2/jupiter-cirrascale-2 \ + --priority low \ + --preemptible \ + --budget ai2/allennlp \ + --workspace ai2/tulu-2-improvements \ + --image nathanl/open_instruct_auto \ + --pure_docker_mode \ + --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ + --beaker_workload_id {beaker_config.beaker_workload_id} \ + --upload_to_hf {args.hf_metadata_dataset} \ + --model_name {args.run_name} \ + --run_id {wandb_tracker.run.get_url()} + """ + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + if args.push_to_hub: + push_folder_to_hub( + accelerator, + args.output_dir, + args.hf_repo_id, + args.hf_repo_revision, + ) + accelerator.wait_for_everyone() + if args.with_tracking: + accelerator.end_training() + + +def print_gpu_stats(init_gpu_memory: Optional[int]): + if torch.cuda.is_available(): + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + print(f"Peak memory usage: {peak_memory / 1024 ** 3:.2f} GB") + print(f"Total memory usage: {total_gpu_memory / 1024 ** 3:.2f} GB") + print(f"Free memory: {free_gpu_memory / 1024 ** 3:.2f} GB") + + +if __name__ == "__main__": + parser = ArgumentParserPlus((FlatArguments)) + args = parser.parse() + main(args) diff --git a/open_instruct/finetune1.py b/open_instruct/finetune1.py index 9fd85f506..c7f3c61a3 100644 --- a/open_instruct/finetune1.py +++ b/open_instruct/finetune1.py @@ -54,7 +54,7 @@ ) from open_instruct.dataset_processor import CHAT_TEMPLATES -from open_instruct.dataset_transformation import TokenizerConfigV1, get_cached_dataset_tulu +from open_instruct.dataset_transformation import TokenizerConfigV1, get_cached_dataset_tulu_sft from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, @@ -443,16 +443,23 @@ def main(args: FlatArguments): accelerator.wait_for_everyone() - + tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision + tokenizer_name = args.tokenizer_name if args.tokenizer_name is not None else args.model_name_or_path + if tokenizer_revision != args.model_revision: + # Warn user if tokenizer and model use different revisions; this is an unusual + # use case. + warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different + from the model revision `{args.model_revision}`.""" + logger.warning(warning) tc = TokenizerConfigV1( - model_name_or_path=args.model_name_or_path, + model_name_or_path=tokenizer_name, revision=args.model_revision, use_fast=not args.use_slow_tokenizer, chat_template_name=args.chat_template_name, add_bos=args.add_bos, ) tokenizer = tc.tokenizer - train_dataset = get_cached_dataset_tulu( + train_dataset = get_cached_dataset_tulu_sft( args.dataset_mixer_list, tc, args.max_seq_length, @@ -708,6 +715,7 @@ def main(args: FlatArguments): if args.load_balancing_loss: outputs = model(**batch, use_cache=False, output_router_logits=True) else: + # TODO: we have calculated the mean loss here anyway, so doubling the calculation outputs = model(**batch, use_cache=False) if args.reduce_loss == "mean": loss = outputs.loss @@ -801,7 +809,7 @@ def main(args: FlatArguments): os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w" ) as f: f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. - if accelerator.is_local_main_process: + if accelerator.is_local_main_process: # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem. clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) accelerator.wait_for_everyone() From 966d9d1b4c8d4e7115e5ca0110bfabe2b4c19597 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 06:36:46 -0800 Subject: [PATCH 07/32] refactord --- open_instruct/dataset_transformation.py | 51 +++++++++++++------------ 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 2acb4ac0d..51ae9cc2e 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -95,13 +95,14 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): # ---------------------------------------------------------------------------- # Tokenization @dataclass -class TokenizerConfigV1: +class TokenizerConfig: model_name_or_path: str revision: str trust_remote_code: bool = True use_fast: bool = True chat_template_name: Optional[str] = None # TODO: should I give an option to force override? add_bos: bool = False + version: str = "v1" # for tracking purposes tokenizer_commit_hash: Optional[str] = None @@ -440,14 +441,14 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize # ---------------------------------------------------------------------------- # Dataset Configuration and Caching @dataclass -class DatasetConfigV1: +class DatasetConfig: dataset_name: str dataset_split: str dataset_revision: str dataset_range: Optional[int] = None transform_fn: List[str] = field(default_factory=list) transform_fn_args: Dict[str, Dict[str, Any]] = field(default_factory=dict) - get_dataset_fn_version: str = "v1" + version: str = "v1" # for tracking purposes dataset_commit_hash: Optional[str] = None @@ -469,7 +470,7 @@ def update_range(self, dataset_range: int): raise ValueError("Dataset range exceeds dataset length") self.dataset = self.dataset.select(range(self.dataset_range)) -def get_dataset_v1(dc: DatasetConfigV1, tc: TokenizerConfigV1): +def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): # beaker specific logic; we may get assigned 15.5 CPU, so we convert it to float then int num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count()))) @@ -511,7 +512,7 @@ class DatasetTransformationCache: def __init__(self, hf_entity: Optional[str] = None): self.hf_entity = hf_entity or HfApi().whoami()["name"] - def compute_config_hash(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) -> str: + def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> str: """Compute a deterministic hash of both configs for caching.""" dc_dicts = [ {k: v for k, v in asdict(dc).items() if v is not None} @@ -525,7 +526,7 @@ def compute_config_hash(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) config_str = json.dumps(combined_dict, sort_keys=True) return hashlib.sha256(config_str.encode()).hexdigest()[:10] - def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) -> Dataset: + def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> Dataset: """Load dataset from cache if it exists, otherwise transform and cache it.""" config_hash = self.compute_config_hash(dcs, tc) repo_name = f"{self.hf_entity}/dataset-mix-cached" @@ -569,11 +570,11 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerCon ) -def get_cached_dataset(dcs: List[DatasetConfigV1], tc: TokenizerConfigV1, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset(dcs: List[DatasetConfig], tc: TokenizerConfig, hf_entity: Optional[str] = None) -> Dataset: cache = DatasetTransformationCache(hf_entity=hf_entity) return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -584,7 +585,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf else: frac_or_num_samples = int(frac_or_num_samples) - dataset_config = DatasetConfigV1( + dataset_config = DatasetConfig( dataset_name=dataset_name, dataset_split="train", dataset_revision="main", @@ -606,7 +607,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -617,7 +618,7 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni else: frac_or_num_samples = int(frac_or_num_samples) - dataset_config = DatasetConfigV1( + dataset_config = DatasetConfig( dataset_name=dataset_name, dataset_split="train", dataset_revision="main", @@ -641,14 +642,14 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni def test_config_hash_different(): """Test that different configurations produce different hashes.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs1 = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -658,7 +659,7 @@ def test_config_hash_different(): ] dcs2 = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -674,21 +675,21 @@ def test_config_hash_different(): def test_sft_dataset_caching(): """Test caching functionality for SFT datasets.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], transform_fn_args={} ), - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", @@ -708,21 +709,21 @@ def test_sft_dataset_caching(): def test_sft_different_transform(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], transform_fn_args={} ), - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", @@ -737,7 +738,7 @@ def test_sft_different_transform(): def test_sft_filter(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" @@ -745,7 +746,7 @@ def test_sft_filter(): ARBITRARY_MAX_LENGTH = 1000 dcs = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -770,21 +771,21 @@ def test_sft_filter(): def test_preference_dataset(): """Test caching functionality for preference datasets.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs_pref = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-pref-personas-instruction-following", dataset_split="train", dataset_revision="main", transform_fn=["preference_tokenize_v1"], transform_fn_args={} ), - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-wildchat-reused-on-policy-70b", dataset_split="train", dataset_revision="main", From 41592557b797ccdc4da1efec1d0bcb48a2d9d88e Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 06:38:10 -0800 Subject: [PATCH 08/32] allow passing in an SFT message key --- open_instruct/dataset_transformation.py | 26 ++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 51ae9cc2e..827721ac1 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -196,7 +196,7 @@ def __post_init__(self): # ---------------------------------------------------------------------------- # Dataset Transformation # SFT dataset -SFT_MESSAGE_KEY = "messages" +DEFAULT_SFT_MESSAGE_KEY = "messages" INPUT_IDS_KEY = "input_ids" ATTENTION_MASK_KEY = "attention_mask" LABELS_KEY = "labels" @@ -238,35 +238,35 @@ def __post_init__(self): # TODO: allow passing in sft_message key, so we can train on "chosen" of pref dataset. -def sft_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): - if len(row[SFT_MESSAGE_KEY]) == 1: - prompt = row[SFT_MESSAGE_KEY] +def sft_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY): + if len(row[sft_message_key]) == 1: + prompt = row[sft_message_key] else: - prompt = row[SFT_MESSAGE_KEY][:-1] + prompt = row[sft_message_key][:-1] row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, ) - row[INPUT_IDS_KEY] = tokenizer.apply_chat_template(row[SFT_MESSAGE_KEY]) + row[INPUT_IDS_KEY] = tokenizer.apply_chat_template(row[sft_message_key]) row[ATTENTION_MASK_KEY] = [1] * len(row[INPUT_IDS_KEY]) labels = copy.deepcopy(row[INPUT_IDS_KEY]) row[LABELS_KEY] = labels return row -def sft_tokenize_mask_out_prompt_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): +def sft_tokenize_mask_out_prompt_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY): """mask out the prompt tokens by manipulating labels""" - if len(row[SFT_MESSAGE_KEY]) == 1: - prompt = row[SFT_MESSAGE_KEY] + if len(row[sft_message_key]) == 1: + prompt = row[sft_message_key] else: - prompt = row[SFT_MESSAGE_KEY][:-1] + prompt = row[sft_message_key][:-1] row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, ) - row[INPUT_IDS_KEY] = tokenizer.apply_chat_template(row[SFT_MESSAGE_KEY]) + row[INPUT_IDS_KEY] = tokenizer.apply_chat_template(row[sft_message_key]) row[ATTENTION_MASK_KEY] = [1] * len(row[INPUT_IDS_KEY]) labels = copy.deepcopy(row[INPUT_IDS_KEY]) labels[: len(row[INPUT_IDS_PROMPT_KEY])] = [-100] * len(row[INPUT_IDS_PROMPT_KEY]) @@ -409,8 +409,8 @@ def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: Pre if len(rejected_messages) == 0: raise ValueError("rejected messages field is empty.") - chosen_encoded = sft_tulu_tokenize_and_truncate_v1({SFT_MESSAGE_KEY: chosen_messages}, tokenizer, max_seq_length) - rejected_encoded = sft_tulu_tokenize_and_truncate_v1({SFT_MESSAGE_KEY: rejected_messages}, tokenizer, max_seq_length) + chosen_encoded = sft_tulu_tokenize_and_truncate_v1({DEFAULT_SFT_MESSAGE_KEY: chosen_messages}, tokenizer, max_seq_length) + rejected_encoded = sft_tulu_tokenize_and_truncate_v1({DEFAULT_SFT_MESSAGE_KEY: rejected_messages}, tokenizer, max_seq_length) return { CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"], From 367c9e5d9f40ec5f5cd9136622fab4c35c680367 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 06:39:18 -0800 Subject: [PATCH 09/32] allow customizing chosen / rejected key --- open_instruct/dataset_transformation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 827721ac1..2f44e7fca 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -212,8 +212,8 @@ def __post_init__(self): # since we are always padding from the right with a collator; however they might become # more useful if we want to do some sort of packing in the future. The nice thing is # that the tokenization logic would work for both DPO and RM training. -CHOSEN_KEY = "chosen" -REJECTED_KEY = "rejected" +DEFAULT_CHOSEN_KEY = "chosen" +DEFAULT_REJECTED_KEY = "rejected" CHOSEN_INPUT_IDS_KEY = "chosen_input_ids" CHOSEN_ATTENTION_MASK_KEY = "chosen_attention_mask" CHOSEN_LABELS_KEY = "chosen_labels" @@ -396,14 +396,14 @@ def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, ma return True -def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int): +def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int, chosen_key: str = DEFAULT_CHOSEN_KEY, rejected_key: str = DEFAULT_REJECTED_KEY): """ Here we assume each example has a rejected and chosen field, both of which are a list of messages. Each message is a dict with 'role' and 'content' fields. We assume only the last message is different, and the prompt is contained in the list of messages. """ - chosen_messages = row[CHOSEN_KEY] - rejected_messages = row[REJECTED_KEY] + chosen_messages = row[chosen_key] + rejected_messages = row[rejected_key] if len(chosen_messages) == 0: raise ValueError("chosen messages field is empty.") if len(rejected_messages) == 0: From 22d724fe584203bc4aca88bfd2d1ab46d7ab46e8 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 06:45:42 -0800 Subject: [PATCH 10/32] fix tests --- open_instruct/dataset_transformation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 2f44e7fca..a21d7b82b 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -432,7 +432,7 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize "sft_filter_v1": (sft_filter_v1, "filter"), "sft_tulu_tokenize_and_truncate_v1": (sft_tulu_tokenize_and_truncate_v1, "map"), "sft_tulu_filter_v1": (sft_tulu_filter_v1, "filter"), - "preference_tokenize_v1": (preference_tokenize_v1, "map", TOKENIZED_PREFERENCE_DATASET_KEYS), + "preference_tokenize_v1": (preference_tokenize_v1, "map"), "preference_filter_v1": (preference_filter_v1, "filter"), "preference_tulu_tokenize_and_truncate_v1": (preference_tulu_tokenize_and_truncate_v1, "map"), "preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"), From 0257767b6bf86410098a7923bc63b6bdd4ac6473 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 07:11:07 -0800 Subject: [PATCH 11/32] add huggingface card --- open_instruct/dataset_transformation.py | 270 +++++++++++++++--------- 1 file changed, 174 insertions(+), 96 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index a21d7b82b..52e75b681 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -57,6 +57,7 @@ LlamaTokenizer, LlamaTokenizerFast, ) +from huggingface_hub import ModelCard import transformers from transformers.utils.hub import cached_file, extract_commit_hash from datasets import Dataset, load_dataset, concatenate_datasets @@ -94,89 +95,108 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): # ---------------------------------------------------------------------------- # Tokenization +def get_tokenizer_simple_v1(tc: 'TokenizerConfigV1'): + tokenizer = AutoTokenizer.from_pretrained( + tc.model_name_or_path, + revision=tc.revision, + trust_remote_code=tc.trust_remote_code, + use_fast=tc.use_fast, + ) + return tokenizer + +def get_tokenizer_tulu_v1(tc: 'TokenizerConfigV1'): + tokenizer = AutoTokenizer.from_pretrained( + tc.model_name_or_path, + revision=tc.revision, + trust_remote_code=tc.trust_remote_code, + use_fast=tc.use_fast, + ) + # no default pad token for llama! + # here we add all special tokens again, because the default ones are not in the special_tokens_map + if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast): + num_added_tokens = tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + "pad_token": "", + } + ) + assert num_added_tokens in [ + 0, + 1, + ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." + elif isinstance(tokenizer, GPTNeoXTokenizerFast): + # OLMo newer models use this tokenizer + if tokenizer.bos_token is None: + tokenizer.bos_token = tokenizer.eos_token + assert ( + tc.add_bos + ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." + # else, pythia / other models + else: + num_added_tokens = tokenizer.add_special_tokens( + { + "pad_token": "", + } + ) + assert ( + num_added_tokens <= 1 + ), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens if already set in SFT)." + # NOTE: (Costa) I just commented the `OPTForCausalLM` because we are not likely to use it. + # elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): + # num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""}) + elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: + num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) + assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." + + # set the tokenizer chat template to the training format + # this will be used for encoding the training examples + # and saved together with the tokenizer to be used later. + if tc.chat_template_name in CHAT_TEMPLATES: + tokenizer.chat_template = CHAT_TEMPLATES[tc.chat_template_name] + else: + try: + tokenizer.chat_template = AutoTokenizer.from_pretrained(tc.chat_template_name).chat_template + except Exception: + raise ValueError(f"Could not find chat template for {tc.chat_template_name}.") + + if tc.add_bos: + if tokenizer.chat_template.startswith("{{ bos_token }}") or ( + tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token) + ): + raise ValueError( + "You specified add_bos=True, but the chat template already has a bos_token at the beginning." + ) + # also add bos in the chat template if not already there + tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template + + return tokenizer + + +GET_TOKENIZER_FN = { + "get_tokenizer_simple_v1": get_tokenizer_simple_v1, + "get_tokenizer_tulu_v1": get_tokenizer_tulu_v1, +} + + + @dataclass -class TokenizerConfig: +class TokenizerConfigV1: model_name_or_path: str revision: str trust_remote_code: bool = True use_fast: bool = True chat_template_name: Optional[str] = None # TODO: should I give an option to force override? add_bos: bool = False - version: str = "v1" + get_tokenizer_fn: str = "get_tokenizer_tulu_v1" # for tracking purposes tokenizer_commit_hash: Optional[str] = None - + def __post_init__(self): self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision, filename="tokenizer_config.json") - - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_name_or_path, - revision=self.revision, - trust_remote_code=self.trust_remote_code, - use_fast=self.use_fast, - ) - # no default pad token for llama! - # here we add all special tokens again, because the default ones are not in the special_tokens_map - if isinstance(self.tokenizer, LlamaTokenizer) or isinstance(self.tokenizer, LlamaTokenizerFast): - num_added_tokens = self.tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "", - } - ) - assert num_added_tokens in [ - 0, - 1, - ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(self.tokenizer, GPTNeoXTokenizerFast): - # OLMo newer models use this self.tokenizer - if self.tokenizer.bos_token is None: - self.tokenizer.bos_token = self.tokenizer.eos_token - assert ( - self.add_bos - ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." - # else, pythia / other models - else: - num_added_tokens = self.tokenizer.add_special_tokens( - { - "pad_token": "", - } - ) - assert ( - num_added_tokens <= 1 - ), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens if already set in SFT)." - # NOTE: (Costa) I just commented the `OPTForCausalLM` because we are not likely to use it. - # elif isinstance(self.tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): - # num_added_tokens = self.tokenizer.add_special_tokens({"unk_token": ""}) - elif isinstance(self.tokenizer, transformers.PreTrainedTokenizerFast) and self.tokenizer.pad_token is None: - num_added_tokens = self.tokenizer.add_special_tokens({"pad_token": ""}) - assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." - - # set the self.tokenizer chat template to the training format - # this will be used for encoding the training examples - # and saved together with the self.tokenizer to be used later. - if self.chat_template_name in CHAT_TEMPLATES: - self.tokenizer.chat_template = CHAT_TEMPLATES[self.chat_template_name] - else: - try: - self.tokenizer.chat_template = AutoTokenizer.from_pretrained(self.chat_template_name).chat_template - except Exception: - raise ValueError(f"Could not find chat template for {self.chat_template_name}.") - - if self.add_bos: - if self.tokenizer.chat_template.startswith("{{ bos_token }}") or ( - self.tokenizer.bos_token is not None and self.tokenizer.chat_template.startswith(self.tokenizer.bos_token) - ): - raise ValueError( - "You specified add_bos=True, but the chat template already has a bos_token at the beginning." - ) - # also add bos in the chat template if not already there - self.tokenizer.chat_template = "{{ bos_token }}" + self.tokenizer.chat_template - - + self.tokenizer = GET_TOKENIZER_FN[self.get_tokenizer_fn](self) # TODO: test it out: PPO should have the sametokenizer as SFT / DPO. # # create a tokenizer (pad from right) # config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) @@ -441,14 +461,13 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize # ---------------------------------------------------------------------------- # Dataset Configuration and Caching @dataclass -class DatasetConfig: +class DatasetConfigV1: dataset_name: str dataset_split: str dataset_revision: str dataset_range: Optional[int] = None transform_fn: List[str] = field(default_factory=list) transform_fn_args: Dict[str, Dict[str, Any]] = field(default_factory=dict) - version: str = "v1" # for tracking purposes dataset_commit_hash: Optional[str] = None @@ -470,7 +489,7 @@ def update_range(self, dataset_range: int): raise ValueError("Dataset range exceeds dataset length") self.dataset = self.dataset.select(range(self.dataset_range)) -def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): +def get_dataset_v1(dc: DatasetConfigV1, tc: TokenizerConfigV1): # beaker specific logic; we may get assigned 15.5 CPU, so we convert it to float then int num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count()))) @@ -512,7 +531,7 @@ class DatasetTransformationCache: def __init__(self, hf_entity: Optional[str] = None): self.hf_entity = hf_entity or HfApi().whoami()["name"] - def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> str: + def compute_config_hash(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) -> str: """Compute a deterministic hash of both configs for caching.""" dc_dicts = [ {k: v for k, v in asdict(dc).items() if v is not None} @@ -526,7 +545,7 @@ def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> config_str = json.dumps(combined_dict, sort_keys=True) return hashlib.sha256(config_str.encode()).hexdigest()[:10] - def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> Dataset: + def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) -> Dataset: """Load dataset from cache if it exists, otherwise transform and cache it.""" config_hash = self.compute_config_hash(dcs, tc) repo_name = f"{self.hf_entity}/dataset-mix-cached" @@ -561,6 +580,32 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfi ) print(f"🚀 Pushed transformed dataset to https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") + + model_card = ModelCard(f"""\ +--- +tags: [open-instruct] +--- + +# Cached Tokenized Datasets + +## Summary + +This is a cached dataset produced by https://github.com/allenai/open-instruct + +## Configuration + +`TokenizerConfigV1`: +```json +{json.dumps(asdict(tc), indent=2)} +``` + +`List[DatasetConfigV1]`: +```json +{json.dumps([asdict(dc) for dc in dcs], indent=2)} +``` +""") + model_card.push_to_hub(repo_name, repo_type="dataset", revision=config_hash) + # NOTE: Load the dataset again to make sure it's downloaded to the HF cache print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") return load_dataset( @@ -570,11 +615,11 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfi ) -def get_cached_dataset(dcs: List[DatasetConfig], tc: TokenizerConfig, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset(dcs: List[DatasetConfigV1], tc: TokenizerConfigV1, hf_entity: Optional[str] = None) -> Dataset: cache = DatasetTransformationCache(hf_entity=hf_entity) return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -585,7 +630,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf else: frac_or_num_samples = int(frac_or_num_samples) - dataset_config = DatasetConfig( + dataset_config = DatasetConfigV1( dataset_name=dataset_name, dataset_split="train", dataset_revision="main", @@ -607,7 +652,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -618,7 +663,7 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni else: frac_or_num_samples = int(frac_or_num_samples) - dataset_config = DatasetConfig( + dataset_config = DatasetConfigV1( dataset_name=dataset_name, dataset_split="train", dataset_revision="main", @@ -640,16 +685,48 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni return cache.load_or_transform_dataset(dcs, tc) +def test_sft_dpo_same_tokenizer(): + base_to_sft_tc = TokenizerConfigV1( + model_name_or_path="meta-llama/Llama-3.1-8B", + revision="main", + chat_template_name="tulu" + ) + sft_to_dpo_tc = TokenizerConfigV1( + model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-SFT", + revision="main", + chat_template_name="tulu" + ) + dpo_to_rl_tc = TokenizerConfigV1( + model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-DPO", + revision="main", + chat_template_name="tulu" + ) + + def equal_tokenizer(tc1, tc2): + tok1 = tc1.tokenizer + tok2 = tc2.tokenizer + assert tok1.vocab_size == tok2.vocab_size, "Vocab size should be the same" + assert tok1.model_max_length == tok2.model_max_length, "Model max length should be the same" + assert tok1.is_fast == tok2.is_fast, "is_fast should be the same" + assert tok1.padding_side == tok2.padding_side, "padding_side should be the same" + assert tok1.truncation_side == tok2.truncation_side, "truncation_side should be the same" + assert tok1.clean_up_tokenization_spaces == tok2.clean_up_tokenization_spaces, "clean_up_tokenization_spaces should be the same" + assert tok1.added_tokens_decoder == tok2.added_tokens_decoder, "added_tokens_decoder should be the same" + equal_tokenizer(base_to_sft_tc, sft_to_dpo_tc) + equal_tokenizer(sft_to_dpo_tc, dpo_to_rl_tc) + equal_tokenizer(base_to_sft_tc, dpo_to_rl_tc) + + def test_config_hash_different(): """Test that different configurations produce different hashes.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) - + dcs1 = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -659,7 +736,7 @@ def test_config_hash_different(): ] dcs2 = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -675,21 +752,21 @@ def test_config_hash_different(): def test_sft_dataset_caching(): """Test caching functionality for SFT datasets.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], transform_fn_args={} ), - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", @@ -709,21 +786,21 @@ def test_sft_dataset_caching(): def test_sft_different_transform(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], transform_fn_args={} ), - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", @@ -738,7 +815,7 @@ def test_sft_different_transform(): def test_sft_filter(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" @@ -746,7 +823,7 @@ def test_sft_filter(): ARBITRARY_MAX_LENGTH = 1000 dcs = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -771,21 +848,21 @@ def test_sft_filter(): def test_preference_dataset(): """Test caching functionality for preference datasets.""" - tc = TokenizerConfig( + tc = TokenizerConfigV1( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs_pref = [ - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-pref-personas-instruction-following", dataset_split="train", dataset_revision="main", transform_fn=["preference_tokenize_v1"], transform_fn_args={} ), - DatasetConfig( + DatasetConfigV1( dataset_name="allenai/tulu-3-wildchat-reused-on-policy-70b", dataset_split="train", dataset_revision="main", @@ -799,6 +876,7 @@ def test_preference_dataset(): if __name__ == "__main__": + test_sft_dpo_same_tokenizer() test_config_hash_different() test_sft_dataset_caching() test_sft_different_transform() From f22f99279a726cb3557314cd09226e7cc9c248f1 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 07:12:26 -0800 Subject: [PATCH 12/32] refactor --- open_instruct/dataset_transformation.py | 76 +++++++++++-------------- 1 file changed, 32 insertions(+), 44 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 52e75b681..f56e4517a 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -95,7 +95,7 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): # ---------------------------------------------------------------------------- # Tokenization -def get_tokenizer_simple_v1(tc: 'TokenizerConfigV1'): +def get_tokenizer_simple_v1(tc: 'TokenizerConfig'): tokenizer = AutoTokenizer.from_pretrained( tc.model_name_or_path, revision=tc.revision, @@ -104,7 +104,7 @@ def get_tokenizer_simple_v1(tc: 'TokenizerConfigV1'): ) return tokenizer -def get_tokenizer_tulu_v1(tc: 'TokenizerConfigV1'): +def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): tokenizer = AutoTokenizer.from_pretrained( tc.model_name_or_path, revision=tc.revision, @@ -182,7 +182,7 @@ def get_tokenizer_tulu_v1(tc: 'TokenizerConfigV1'): @dataclass -class TokenizerConfigV1: +class TokenizerConfig: model_name_or_path: str revision: str trust_remote_code: bool = True @@ -196,19 +196,7 @@ class TokenizerConfigV1: def __post_init__(self): self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision, filename="tokenizer_config.json") - self.tokenizer = GET_TOKENIZER_FN[self.get_tokenizer_fn](self) - # TODO: test it out: PPO should have the sametokenizer as SFT / DPO. - # # create a tokenizer (pad from right) - # config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) - # tokenizer = AutoTokenizer.from_pretrained( - # model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" - # ) - # if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: - # tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> - # else: - # tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding - # if dataset_config.chat_template is not None: - # tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + self.tokenizer = GET_TOKENIZER_FN[self.get_tokenizer_fn](self) # TODO: for testing, we should load the tokenizer from the sft / dpo / rl and make sure they are all the same. @@ -461,7 +449,7 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize # ---------------------------------------------------------------------------- # Dataset Configuration and Caching @dataclass -class DatasetConfigV1: +class DatasetConfig: dataset_name: str dataset_split: str dataset_revision: str @@ -489,7 +477,7 @@ def update_range(self, dataset_range: int): raise ValueError("Dataset range exceeds dataset length") self.dataset = self.dataset.select(range(self.dataset_range)) -def get_dataset_v1(dc: DatasetConfigV1, tc: TokenizerConfigV1): +def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): # beaker specific logic; we may get assigned 15.5 CPU, so we convert it to float then int num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count()))) @@ -531,7 +519,7 @@ class DatasetTransformationCache: def __init__(self, hf_entity: Optional[str] = None): self.hf_entity = hf_entity or HfApi().whoami()["name"] - def compute_config_hash(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) -> str: + def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> str: """Compute a deterministic hash of both configs for caching.""" dc_dicts = [ {k: v for k, v in asdict(dc).items() if v is not None} @@ -545,7 +533,7 @@ def compute_config_hash(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) config_str = json.dumps(combined_dict, sort_keys=True) return hashlib.sha256(config_str.encode()).hexdigest()[:10] - def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerConfigV1) -> Dataset: + def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> Dataset: """Load dataset from cache if it exists, otherwise transform and cache it.""" config_hash = self.compute_config_hash(dcs, tc) repo_name = f"{self.hf_entity}/dataset-mix-cached" @@ -594,12 +582,12 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerCon ## Configuration -`TokenizerConfigV1`: +`TokenizerConfig`: ```json {json.dumps(asdict(tc), indent=2)} ``` -`List[DatasetConfigV1]`: +`List[DatasetConfig]`: ```json {json.dumps([asdict(dc) for dc in dcs], indent=2)} ``` @@ -615,11 +603,11 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerCon ) -def get_cached_dataset(dcs: List[DatasetConfigV1], tc: TokenizerConfigV1, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset(dcs: List[DatasetConfig], tc: TokenizerConfig, hf_entity: Optional[str] = None) -> Dataset: cache = DatasetTransformationCache(hf_entity=hf_entity) return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -630,7 +618,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf else: frac_or_num_samples = int(frac_or_num_samples) - dataset_config = DatasetConfigV1( + dataset_config = DatasetConfig( dataset_name=dataset_name, dataset_split="train", dataset_revision="main", @@ -652,7 +640,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -663,7 +651,7 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni else: frac_or_num_samples = int(frac_or_num_samples) - dataset_config = DatasetConfigV1( + dataset_config = DatasetConfig( dataset_name=dataset_name, dataset_split="train", dataset_revision="main", @@ -686,17 +674,17 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni def test_sft_dpo_same_tokenizer(): - base_to_sft_tc = TokenizerConfigV1( + base_to_sft_tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) - sft_to_dpo_tc = TokenizerConfigV1( + sft_to_dpo_tc = TokenizerConfig( model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-SFT", revision="main", chat_template_name="tulu" ) - dpo_to_rl_tc = TokenizerConfigV1( + dpo_to_rl_tc = TokenizerConfig( model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-DPO", revision="main", chat_template_name="tulu" @@ -719,14 +707,14 @@ def equal_tokenizer(tc1, tc2): def test_config_hash_different(): """Test that different configurations produce different hashes.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs1 = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -736,7 +724,7 @@ def test_config_hash_different(): ] dcs2 = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -752,21 +740,21 @@ def test_config_hash_different(): def test_sft_dataset_caching(): """Test caching functionality for SFT datasets.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], transform_fn_args={} ), - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", @@ -786,21 +774,21 @@ def test_sft_dataset_caching(): def test_sft_different_transform(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], transform_fn_args={} ), - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", @@ -815,7 +803,7 @@ def test_sft_different_transform(): def test_sft_filter(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" @@ -823,7 +811,7 @@ def test_sft_filter(): ARBITRARY_MAX_LENGTH = 1000 dcs = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", @@ -848,21 +836,21 @@ def test_sft_filter(): def test_preference_dataset(): """Test caching functionality for preference datasets.""" - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) dcs_pref = [ - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-pref-personas-instruction-following", dataset_split="train", dataset_revision="main", transform_fn=["preference_tokenize_v1"], transform_fn_args={} ), - DatasetConfigV1( + DatasetConfig( dataset_name="allenai/tulu-3-wildchat-reused-on-policy-70b", dataset_split="train", dataset_revision="main", From a82f5745ce164c349ee1ae2c523b0931325b316b Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 07:17:09 -0800 Subject: [PATCH 13/32] quick fix --- open_instruct/dpo_tune_cache1.py | 6 ++---- open_instruct/finetune1.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/open_instruct/dpo_tune_cache1.py b/open_instruct/dpo_tune_cache1.py index c056e8df1..18603edbc 100644 --- a/open_instruct/dpo_tune_cache1.py +++ b/open_instruct/dpo_tune_cache1.py @@ -58,7 +58,7 @@ ) from open_instruct.dataset_processor import CHAT_TEMPLATES -from open_instruct.dataset_transformation import TokenizerConfigV1, get_cached_dataset_tulu_preference +from open_instruct.dataset_transformation import TokenizerConfig, get_cached_dataset_tulu_preference from open_instruct.dpo_utils import ( DataCollatorForSeq2SeqDPO, concatenated_forward, @@ -67,12 +67,10 @@ simpo_loss, wpo_loss, ) -from open_instruct.finetune import encode_sft_example from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, clean_last_n_checkpoints, - get_datasets, get_last_checkpoint_path, get_wandb_tags, is_beaker_job, @@ -531,7 +529,7 @@ def main(args: FlatArguments): warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different from the model revision `{args.model_revision}`.""" logger.warning(warning) - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path=tokenizer_name, revision=args.model_revision, use_fast=not args.use_slow_tokenizer, diff --git a/open_instruct/finetune1.py b/open_instruct/finetune1.py index c7f3c61a3..cd424ad7b 100644 --- a/open_instruct/finetune1.py +++ b/open_instruct/finetune1.py @@ -54,7 +54,7 @@ ) from open_instruct.dataset_processor import CHAT_TEMPLATES -from open_instruct.dataset_transformation import TokenizerConfigV1, get_cached_dataset_tulu_sft +from open_instruct.dataset_transformation import TokenizerConfig, get_cached_dataset_tulu_sft from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, @@ -451,7 +451,7 @@ def main(args: FlatArguments): warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different from the model revision `{args.model_revision}`.""" logger.warning(warning) - tc = TokenizerConfigV1( + tc = TokenizerConfig( model_name_or_path=tokenizer_name, revision=args.model_revision, use_fast=not args.use_slow_tokenizer, From 90d18467158d4b53c229935c9ebe0bd93d87d32d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 08:30:45 -0800 Subject: [PATCH 14/32] add an option to cache dataset only --- open_instruct/dataset_transformation.py | 71 ++++++++++++++++++++++++- open_instruct/dpo_tune_cache1.py | 36 +++++-------- open_instruct/finetune1.py | 24 +++++---- 3 files changed, 96 insertions(+), 35 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index f56e4517a..dfcdb4809 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -63,8 +63,6 @@ from datasets import Dataset, load_dataset, concatenate_datasets from huggingface_hub import HfApi, revision_exists -from open_instruct.dataset_processor import CHAT_TEMPLATES - # ---------------------------------------------------------------------------- # Utilities @@ -95,6 +93,75 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): # ---------------------------------------------------------------------------- # Tokenization +# Chat templates +# flake8: noqa +# note we added `{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}` +# because we want the template to not output eos_token if `add_generation_prompt=True` +CHAT_TEMPLATES = { + "simple_concat_with_space": ( + "{% for message in messages %}" + "{{ ' ' if not loop.first else '' }}" + "{{ message['content'] }}" + "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" + "{% endfor %}" + ), + "simple_concat_with_new_line": ( + "{% for message in messages %}" + "{{ '\n' if not loop.first else '' }}" + "{{ message['content'] }}" + "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" + "{% endfor %}" + ), + "simple_chat": ( + "{% for message in messages %}" + "{{ '\n\n' if not loop.first else '' }}" + "{{ message['role'].capitalize() + ': ' + message['content'] }}" + "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" + "{% endfor %}" + ), + "assistant_message_only": ( + "{% for message in messages %}" + "{% if message['role'] == 'assistant' %}" + "{{ message['content'] }}" + "{% endif %}" + "{% endfor %}" + ), + "zephyr": ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}" + "{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{ '<|assistant|>\n' }}" + "{% endif %}" + "{% endfor %}" + ), + "tulu": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|system|>\n' + message['content'] + '\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|user|>\n' + message['content'] + '\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{% if not loop.last %}" + "{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}" + "{% else %}" + "{{ '<|assistant|>\n' + message['content'] + eos_token }}" + "{% endif %}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{ '<|assistant|>\n' }}" + "{% endif %}" + "{% endfor %}" + ), +} +# flake8: noqa + + def get_tokenizer_simple_v1(tc: 'TokenizerConfig'): tokenizer = AutoTokenizer.from_pretrained( tc.model_name_or_path, diff --git a/open_instruct/dpo_tune_cache1.py b/open_instruct/dpo_tune_cache1.py index 18603edbc..2305ca394 100644 --- a/open_instruct/dpo_tune_cache1.py +++ b/open_instruct/dpo_tune_cache1.py @@ -57,8 +57,7 @@ get_scheduler, ) -from open_instruct.dataset_processor import CHAT_TEMPLATES -from open_instruct.dataset_transformation import TokenizerConfig, get_cached_dataset_tulu_preference +from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_preference from open_instruct.dpo_utils import ( DataCollatorForSeq2SeqDPO, concatenated_forward, @@ -373,6 +372,8 @@ class FlatArguments: default=0.001, metadata={"help": "Weight for load balancing loss if applicable."}, ) + cache_dataset_only: bool = False + """Immediately exit after caching the dataset""" concatenated_forward: bool = True """Whether to concatenate chosen and rejected for DPO training; True is good but you can set to False for saving memory.""" try_auto_save_to_beaker: bool = True @@ -537,13 +538,16 @@ def main(args: FlatArguments): add_bos=args.add_bos, ) tokenizer = tc.tokenizer - train_dataset = get_cached_dataset_tulu_preference( - args.dataset_mixer_list, - tc, - args.max_seq_length, - ) - train_dataset.shuffle(seed=args.seed) - train_dataset.set_format(type="pt") + with accelerator.main_process_first(): + train_dataset = get_cached_dataset_tulu_preference( + args.dataset_mixer_list, + tc, + args.max_seq_length, + ) + train_dataset.shuffle(seed=args.seed) + train_dataset.set_format(type="pt") + if args.cache_dataset_only: + return # Load pretrained model and tokenizer if args.config_name: @@ -630,20 +634,6 @@ def load_model(): elif args.gradient_checkpointing: model.gradient_checkpointing_enable() - # # Preprocessing the datasets. - # if "prompt" in raw_datasets["train"].column_names and "completion" in raw_datasets["train"].column_names: - # raise ValueError("Sorry, prompt-completion format is not supported for DPO training.") - # elif "chosen" in raw_datasets["train"].column_names and "rejected" in raw_datasets["train"].column_names: - # encode_function = partial( - # encode_dpo_example, - # tokenizer=tokenizer, - # max_seq_length=args.max_seq_length, - # ) - # else: - # raise ValueError("You need to have 'chosen' and 'rejected in your column names.") - - # train_dataset = raw_datasets["train"] - # debugging tool for fewer samples if args.max_train_samples is not None: max_train_samples = min(len(train_dataset), args.max_train_samples) diff --git a/open_instruct/finetune1.py b/open_instruct/finetune1.py index cd424ad7b..367136476 100644 --- a/open_instruct/finetune1.py +++ b/open_instruct/finetune1.py @@ -53,8 +53,7 @@ get_scheduler, ) -from open_instruct.dataset_processor import CHAT_TEMPLATES -from open_instruct.dataset_transformation import TokenizerConfig, get_cached_dataset_tulu_sft +from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_sft from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, @@ -339,6 +338,8 @@ class FlatArguments: default=0.5, metadata={"help": "Weight for load balancing loss if applicable."}, ) + cache_dataset_only: bool = False + """Immediately exit after caching the dataset""" try_auto_save_to_beaker: bool = True """Whether to try to save the model to Beaker dataset `/output` after training""" push_to_hub: bool = True @@ -459,14 +460,17 @@ def main(args: FlatArguments): add_bos=args.add_bos, ) tokenizer = tc.tokenizer - train_dataset = get_cached_dataset_tulu_sft( - args.dataset_mixer_list, - tc, - args.max_seq_length, - ) - train_dataset.shuffle(seed=args.seed) - train_dataset.set_format(type="pt") - + with accelerator.main_process_first(): + train_dataset = get_cached_dataset_tulu_sft( + args.dataset_mixer_list, + tc, + args.max_seq_length, + ) + train_dataset.shuffle(seed=args.seed) + train_dataset.set_format(type="pt") + if args.cache_dataset_only: + return + # Load pretrained model and tokenizer if args.config_name: config = AutoConfig.from_pretrained( From b11afedb325fc6b9d297eef9f50fbd5f1e62e47e Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 08:52:26 -0800 Subject: [PATCH 15/32] Add some logic for caching. --- mason.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/mason.py b/mason.py index e31dcc90d..813e4aea9 100644 --- a/mason.py +++ b/mason.py @@ -412,7 +412,7 @@ def get_datasets(beaker_datasets, cluster: List[str]): return res -def make_task_spec(args, command, i, beaker_secrets, whoami, resumable: bool): +def make_task_spec(args, command: List[str], i: int, beaker_secrets: str, whoami: str, resumable: bool): # special logic to deal with escape like # python mason.py ... -- python x.py --dataset_mixer '{"trl-internal-testing/sentiment-trl-style": 1.0}' # we need to wrap the json string with single quote @@ -427,6 +427,28 @@ def make_task_spec(args, command, i, beaker_secrets, whoami, resumable: bool): "git config --global safe.directory '*' && " # fix the permission issue with git "umask 000 && " # fix the permission issue with the cache folder ) + + # HACK: Cache dataset logic: + # Here we basically try to run the tokenization full_command locally before running it on beaker + # We could in theory submit a cpu only job to beaker to do this, but that requires setting up + # dependency jobs somehow. Since tokenization is like ~5 minutes, we can just run it locally. + # Once it's cached, we don't need to cache it again. + + def find_list_idx(lst: List[str], item: str): + for i in range(len(lst)): + if item == lst[i]: + return i + return -1 + + idx = find_list_idx(full_command, "open_instruct/finetune1.py") + if idx != -1: + # then try executing the same full_command with + caching_command = "python " + " ".join(full_command[idx:]) + " --cache_dataset_only" + print(f"📦📦📦 Running the caching full_command: {caching_command}") + os.system(caching_command) + print("✅✅✅ Finished running the caching full_command") + + if not args.pure_docker_mode: setup_commands += f"cd {os.getcwd()} && " From c17c9cf83ae674bc882fa2378822daa14430a33c Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 09:58:41 -0800 Subject: [PATCH 16/32] make mason work with the latest change --- mason.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mason.py b/mason.py index 813e4aea9..8e26504a3 100644 --- a/mason.py +++ b/mason.py @@ -440,14 +440,14 @@ def find_list_idx(lst: List[str], item: str): return i return -1 - idx = find_list_idx(full_command, "open_instruct/finetune1.py") - if idx != -1: - # then try executing the same full_command with - caching_command = "python " + " ".join(full_command[idx:]) + " --cache_dataset_only" - print(f"📦📦📦 Running the caching full_command: {caching_command}") - os.system(caching_command) - print("✅✅✅ Finished running the caching full_command") - + for file in ["open_instruct/finetune1.py", "open_instruct/dpo_tune_cache1.py"]: + idx = find_list_idx(full_command, file) + if idx != -1: + # then try executing the same full_command with + caching_command = "python " + " ".join(full_command[idx:]) + " --cache_dataset_only" + print(f"📦📦📦 Running the caching full_command: {caching_command}") + os.system(caching_command) + print("✅✅✅ Finished running the caching full_command") if not args.pure_docker_mode: setup_commands += f"cd {os.getcwd()} && " From 25c636a376f54f69cf624da06217643e2f36becc Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 10:43:34 -0800 Subject: [PATCH 17/32] Use the latest dataset caching logic --- docs/ai2_internal.md | 226 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 198 insertions(+), 28 deletions(-) diff --git a/docs/ai2_internal.md b/docs/ai2_internal.md index 36e4deb09..1081cd188 100644 --- a/docs/ai2_internal.md +++ b/docs/ai2_internal.md @@ -1,30 +1,214 @@ -# Job submissions +# Beaker job submission scripts This document details some best practices when submitting jobs in our cluster. -## Caching on Weka (Ai2-specific) +## First-time setup -Most of our cluster comes with shared file systems (e.g., [WEKA](https://beaker-docs.apps.allenai.org/)). To avoid downloading the same models hundreds and thousands of times, we should cache the models ande datasets in the shared file system. This can be done via +You need to first obtain API key or tokens from the following website: +* `BEAKER_TOKEN`: https://beaker.org/user +* `WANDB_API_KEY`: https://wandb.ai/authorize +* `HF_TOKEN`: https://huggingface.co/settings/tokens + +Then you need to write them in beaker secret as follows (replace the `xxxx` with your own API key or token) ```bash -python mason.py \ - --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale ai2/neptune-cirrascale --image nathanl/open_instruct_auto --pure_docker_mode \ - --workspace ai2/tulu-3-dev \ - --priority normal \ - --preemptible \ - --budget ai2/allennlp \ - --gpus 0 -- python scripts/cache_hf.py \ - --model_name_or_path "allenai/Llama-3.1-Tulu-3-8B-DPO" \ - --model_revision "1208_dpo_13b_tune8e-7__allenai_open_instruct_dev__8__1733807565" \ - --dataset_mixer_list allenai/RLVR-GSM-MATH-IF-Mixed-Constraints 1.0 +beaker_whoami=$(beaker account whoami --format json | jq -r '.[0].name') +beaker secret write -w ai2/tulu-2-improvements "${beaker_whoami}_BEAKER_TOKEN" xxxx +beaker secret write -w ai2/tulu-2-improvements "${beaker_whoami}_WANDB_API_KEY" xxxx +beaker secret write -w ai2/tulu-2-improvements "${beaker_whoami}_HF_TOKEN" xxxx ``` +## Job submission + `mason.py` is our job submission script. It takes in the command after `--` and runs it in the specified clusters. During the job submission, it automatically tries to setup a shared Hugging Face cache with environment variables. For example, it sets * `HF_HOME=/weka/oe-adapt-default/allennlp/.cache/huggingface`. * `HF_DATASETS_CACHE=/weka/oe-adapt-default/allennlp/.cache/huggingface` * `HF_HUB_CACHE=/weka/oe-adapt-default/allennlp/.cache/hub` -As a result, the `allenai/Llama-3.1-Tulu-3-8B-DPO` and `allenai/RLVR-GSM-MATH-IF-Mixed-Constraints` will be cached in the shared file system. + +You can run things like below for a quick spin. + +```bash +python mason.py \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale ai2/neptune-cirrascale \ + --workspace ai2/tulu-3-dev \ + --image nathanl/open_instruct_auto --pure_docker_mode \ + --priority normal \ + --budget ai2/oe-adapt \ + --gpus 1 -- python -c "import torch; print(torch.cuda.is_available())" +``` + +### Supervised Fine-tuning (SFT): + +```bash +python mason.py \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ + --workspace ai2/tulu-3-dev \ + --priority normal \ + --preemptible \ + --num_nodes 4 \ + --budget ai2/oe-adapt \ + --gpus 8 -- accelerate launch \ + --mixed_precision bf16 \ + --num_processes 8 \ + --use_deepspeed \ + --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ + --deepspeed_multinode_launcher standard \ + open_instruct/finetune1.py \ + --model_name_or_path meta-llama/Llama-3.1-8B \ + --tokenizer_name meta-llama/Llama-3.1-8B \ + --use_slow_tokenizer \ + --use_flash_attn \ + --max_seq_length 4096 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --learning_rate 5e-06 \ + --lr_scheduler_type linear \ + --warmup_ratio 0.03 \ + --weight_decay 0.0 \ + --num_train_epochs 2 \ + --output_dir output/sft_8b \ + --with_tracking \ + --report_to wandb \ + --logging_steps 1 \ + --reduce_loss sum \ + --model_revision main \ + --dataset_mixer_list allenai/tulu-3-sft-mixture 100 \ + --dataset_mix_dir output/sft_8b \ + --exp_name tulu-3-8b-sft \ + --seed 123 +``` + +Note that during job submission, we will try to tokenize and cache the dataset so we are not running these CPU-heavy workloads in GPU jobs. Specifically, `mason.py` will parse out `python` command you are running and attempts to run it with `--cache_dataset_only` flag. For example, you will see output like + +```bash +📦📦📦 Running the caching full_command: python open_instruct/dpo_tune_cache1.py --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-SFT --use_flash_attn --tokenizer_name allenai/Llama-3.1-Tulu-3-8B-SFT --max_seq_length 2048 --preprocessing_num_workers 16 --per_device_train_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 5e-07 --lr_scheduler_type linear --warmup_ratio 0.1 --weight_decay 0.0 --num_train_epochs 1 --output_dir output/dpo_8b --with_tracking --report_to wandb --logging_steps 1 --model_revision main --gradient_checkpointing --dataset_mixer_list allenai/llama-3.1-tulu-3-8b-preference-mixture 1.0 --use_slow_tokenizer --use_lora False --dpo_loss_type dpo_norm --dpo_beta 5 --exp_name tulu-3-8b-dpo --cache_dataset_only +[2025-01-21 09:58:09,342] [WARNING] [real_accelerator.py:162:get_accelerator] Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it. +[2025-01-21 09:58:09,354] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cpu (auto detect) +Failed to get Beaker experiment: b'Error: experiment "01JD3WCQYTBPE195GVWPVMDHVV" not found\n\n' + +# .... +Cache not found, transforming datasets... +Map (num_proc=192): 100%|██████████████████| 271409/271409 [00:40<00:00, 6690.00 examples/s] +Filter (num_proc=192): 100%|███████████████| 271409/271409 [00:36<00:00, 7492.34 examples/s] +Creating parquet from Arrow format: 100%|███████████████████| 17/17 [00:02<00:00, 7.26ba/s] +Creating parquet from Arrow format: 100%|███████████████████| 17/17 [00:02<00:00, 7.77ba/s +🚀 Pushed transformed dataset to https://huggingface.co/datasets/vwxyzjn/dataset-mix-cached/tree/992c2b51ba +✅ Found cached dataset at https://huggingface.co/datasets/vwxyzjn/dataset-mix-cached/tree/992c2b51ba + +# ... + +Kicked off Beaker job. https://beaker.org/ex/01JJ50D88M757GZD14W9CNN7NT +``` + +It would be most helpful if you run the `mason.py` command on a vscode session with access to weka, that way, the dataset is also automatically downloaded to `HF_HOME`, etc. + +When you inspect the job, it's going to have the following outputs, meaning the cached dataset is found and used: + +``` +2025-01-21T18:02:04.840723691Z +2025-01-21T18:02:05.948433221Z ✅ Found cached dataset at https://huggingface.co/datasets/vwxyzjn/dataset-mix-cached/tree/992c2b51ba +2025-01-21T18:02:06.120806993Z ✅ Found cached dataset at https://huggingface.co/datasets/vwxyzjn/dataset-mix-cached/tree/992c2b51ba +2025-01-21T18:02:06.190569046Z ✅ Found cached dataset at https://huggingface.co/datasets/vwxyzjn/dataset-mix-cached/tree/992c2b51ba +2025-01-21T18:02:06.197208582Z ✅ Found cached dataset at https://huggingface.co/datasets/vwxyzjn/dataset-mix-cached/tree/992c2b51ba +2025-01-21T18:02:06.333301775Z ✅ Found cached dataset at https://huggingface.co/datasets/vwxyzjn/dataset-mix-cached/tree/992c2b51ba +2025-01-21T18:02:06.338503095Z ✅ Found cached dataset at https://huggingface.co/datasets/vwxyzjn/dataset-mix-cached/tree/992c2b51ba +2025-01-21T18:02:06.385010439Z ✅ Found cached dataset at https://huggingface.co/datasets/vwxyzjn/dataset-mix-cached/tree/992c2b51ba +``` + + +### Direct Preference Optimization (DPO): + + +```bash +python mason.py \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ + --workspace ai2/tulu-3-dev \ + --priority normal \ + --preemptible \ + --num_nodes 4 \ + --budget ai2/oe-adapt \ + --gpus 8 -- accelerate launch \ + --mixed_precision bf16 \ + --num_processes 8 \ + --use_deepspeed \ + --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ + --deepspeed_multinode_launcher standard \ + open_instruct/dpo_tune_cache1.py \ + --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-SFT \ + --use_flash_attn \ + --tokenizer_name allenai/Llama-3.1-Tulu-3-8B-SFT \ + --max_seq_length 2048 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --learning_rate 5e-07 \ + --lr_scheduler_type linear \ + --warmup_ratio 0.1 \ + --weight_decay 0.0 \ + --num_train_epochs 1 \ + --output_dir /output \ + --with_tracking \ + --report_to wandb \ + --logging_steps 1 \ + --model_revision main \ + --gradient_checkpointing \ + --dataset_mixer_list allenai/llama-3.1-tulu-3-8b-preference-mixture 1.0 \ + --use_slow_tokenizer \ + --use_lora False \ + --dpo_loss_type dpo_norm \ + --dpo_beta 5 \ + --exp_name tulu-3-8b-dpo +``` + +## RLVR: + +```bash +python mason.py \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ + --workspace ai2/tulu-3-dev \ + --priority high \ + --preemptible \ + --budget ai2/oe-adapt \ + --gpus 8 -- python open_instruct/ppo_vllm_thread_ray_gtrl.py \ + --exp_name tulu-3-8b-rlvr \ + --dataset_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 1.0}' \ + --dataset_train_splits train \ + --dataset_eval_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 128}' \ + --dataset_eval_splits train \ + --max_token_length 2048 \ + --max_prompt_token_length 2048 \ + --response_length 2048 \ + --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \ + --reward_model_path allenai/Llama-3.1-Tulu-3-8B-RM \ + --non_stop_penalty \ + --stop_token eos \ + --temperature 1.0 \ + --ground_truths_key ground_truth \ + --chat_template tulu \ + --sft_messages_key messages \ + --learning_rate 3e-7 \ + --total_episodes 10000000 \ + --penalty_reward_value -10.0 \ + --deepspeed_stage 3 \ + --per_device_train_batch_size 1 \ + --local_rollout_forward_batch_size 1 \ + --local_mini_batch_size 4 \ + --local_rollout_batch_size 4 \ + --actor_num_gpus_per_node 6 \ + --vllm_tensor_parallel_size 2 \ + --vllm_enforce_eager \ + --beta 0.05 \ + --apply_verifiable_reward true \ + --output_dir /weka/oe-adapt-default/costah/models/tulu-3-8b-rlvr/test \ + --seed 3 \ + --num_evals 3 \ + --save_freq 2 \ + --reward_model_multiplier 0.0 \ + --gradient_checkpointing \ + --try_launch_beaker_eval_jobs_on_weka \ + --with_tracking +``` + ### Ai2 Internal Evaluation @@ -105,20 +289,6 @@ gantry run --workspace {your_workspace} --cluster {cluster} --gpus {n_gpus} \ Because the `--report_output_path` argument is set to `/results/metrics.json`, the output will automatically get logged to Beaker metrics in the experiment page ([example](https://beaker.org/ex/01HW8NKZ458MA1PSB1X4YQTH94/tasks/01HW8NKZ4DTDA8FEFDGWA7Q8XX/job/01HW8NM2QR5AYB53PYP32J2VAA)). -### Gantry One-Time Setup - -Before you can use gantry, there are a couple of things to set up. For the workspace you use, ensure it is owned by the `ai2` organization, or gantry won't be able to create the experiments. - -1. Run `pip install beaker-gantry beaker-py` -2. Create a [GitHub personal access token](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) with "repo" scope -3. Go to https://github.com/settings/tokens and authorize your token to configure SSO access to the allenai organization -4. Run `gantry config set-gh-token` and paste the token created above when prompted -5. Create a [HuggingFace access token](https://huggingface.co/settings/tokens) with "read" scope (this is used to authenticate for using restricted models like Llama series) -6. Run `beaker secret write --workspace {your_workspace} hf_token {your_token}` -7. Obtain an OpenAI API key and run `beaker secret write --workspace {your_workspace} openai_api_key {your_api_key} - -Doing these steps once will set up your workspace to use gantry. - ### Common Gotchas From 2c2c085ed180d4c66d6175f51523dd1c74ed1846 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 10:46:01 -0800 Subject: [PATCH 18/32] restore change --- open_instruct/finetune.py | 233 +++++++++++++++++++++++++++++++++++--- 1 file changed, 217 insertions(+), 16 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index da25c0af9..30ee740da 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -54,7 +54,6 @@ ) from open_instruct.dataset_processor import CHAT_TEMPLATES -from open_instruct.dataset_transformation import TokenizerConfigV1, get_cached_dataset_tulu from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, @@ -383,6 +382,79 @@ def __post_init__(self): raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") +def encode_sft_example(example, tokenizer, max_seq_length): + """ + This function encodes a single example into a format that can be used for sft training. + Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields. + We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors. + """ + messages = example["messages"] + if len(messages) == 0: + raise ValueError("messages field is empty.") + input_ids = tokenizer.apply_chat_template( + conversation=messages, + tokenize=True, + return_tensors="pt", + padding=False, + truncation=True, + max_length=max_seq_length, + add_generation_prompt=False, + ) + labels = input_ids.clone() + # mask the non-assistant part for avoiding loss + for message_idx, message in enumerate(messages): + if message["role"] != "assistant": + # we calculate the start index of this non-assistant message + if message_idx == 0: + message_start_idx = 0 + else: + message_start_idx = tokenizer.apply_chat_template( + conversation=messages[:message_idx], # here marks the end of the previous messages + tokenize=True, + return_tensors="pt", + padding=False, + truncation=True, + max_length=max_seq_length, + add_generation_prompt=False, + ).shape[1] + # next, we calculate the end index of this non-assistant message + if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant": + # for intermediate messages that follow with an assistant message, we need to + # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss + # (e.g., `<|assistant|>`) + message_end_idx = tokenizer.apply_chat_template( + conversation=messages[: message_idx + 1], + tokenize=True, + return_tensors="pt", + padding=False, + truncation=True, + max_length=max_seq_length, + add_generation_prompt=True, + ).shape[1] + else: + # for the last message or the message that doesn't follow with an assistant message, + # we don't need to add the assistant generation prefix + message_end_idx = tokenizer.apply_chat_template( + conversation=messages[: message_idx + 1], + tokenize=True, + return_tensors="pt", + padding=False, + truncation=True, + max_length=max_seq_length, + add_generation_prompt=False, + ).shape[1] + # set the label to -100 for the non-assistant part + labels[:, message_start_idx:message_end_idx] = -100 + if max_seq_length and message_end_idx >= max_seq_length: + break + attention_mask = torch.ones_like(input_ids) + return { + "input_ids": input_ids.flatten(), + "labels": labels.flatten(), + "attention_mask": attention_mask.flatten(), + } + + def main(args: FlatArguments): # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -443,21 +515,40 @@ def main(args: FlatArguments): accelerator.wait_for_everyone() - - tc = TokenizerConfigV1( - model_name_or_path=args.model_name_or_path, - revision=args.model_revision, - use_fast=not args.use_slow_tokenizer, - chat_template_name=args.chat_template_name, - add_bos=args.add_bos, - ) - tokenizer = tc.tokenizer - train_dataset = get_cached_dataset_tulu( - args.dataset_mixer_list, - tc, - args.max_seq_length, - ) - train_dataset.shuffle(seed=args.seed) + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + args.dataset_name, + args.dataset_config_name, + ) + elif args.dataset_mixer is not None: + # mixing datasets via config + raw_datasets = get_datasets( + args.dataset_mixer, + configs=args.dataset_config_name, + splits=["train"], + save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, + columns_to_keep=["messages"], + ) + elif args.dataset_mixer_list is not None: + # mixing datasets via config + raw_datasets = get_datasets( + args.dataset_mixer_list, + configs=args.dataset_config_name, + splits=["train"], + save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, + columns_to_keep=["messages"], + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + raw_datasets = load_dataset( + "json", + data_files=data_files, + **dataset_args, + ) # Load pretrained model and tokenizer if args.config_name: @@ -477,6 +568,34 @@ def main(args: FlatArguments): "You are instantiating a new config instance from scratch. This is not supported by this script." ) + tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision + if tokenizer_revision != args.model_revision: + # Warn user if tokenizer and model use different revisions; this is an unusual + # use case. + warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different + from the model revision `{args.model_revision}`.""" + logger.warning(warning) + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=tokenizer_revision, + trust_remote_code=args.trust_remote_code, + use_fast=not args.use_slow_tokenizer, + ) + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, + revision=tokenizer_revision, + trust_remote_code=args.trust_remote_code, + use_fast=not args.use_slow_tokenizer, + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + if args.model_name_or_path: if args.use_qlora: bnb_config = BitsAndBytesConfig( @@ -513,6 +632,42 @@ def main(args: FlatArguments): logger.info("Training new model from scratch") model = AutoModelForCausalLM.from_config(config) + # no default pad token for llama! + # here we add all special tokens again, because the default ones are not in the special_tokens_map + if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast): + num_added_tokens = tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + "pad_token": "", + } + ) + assert num_added_tokens in [ + 0, + 1, + ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." + elif isinstance(tokenizer, GPTNeoXTokenizerFast): + # OLMo newer models use this tokenizer + if tokenizer.bos_token is None: + tokenizer.bos_token = tokenizer.eos_token + assert ( + args.add_bos + ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." + # else, pythia / other models + else: + num_added_tokens = tokenizer.add_special_tokens( + { + "pad_token": "", + } + ) + assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token." + elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): + num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""}) + elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: + num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) + assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. # gather deepspeed to get "real" embedding size @@ -528,6 +683,27 @@ def main(args: FlatArguments): with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): embedding_size = embeddings.weight.shape[0] + # set the tokenizer chat template to the training format + # this will be used for encoding the training examples + # and saved together with the tokenizer to be used later. + if args.chat_template_name in CHAT_TEMPLATES: + tokenizer.chat_template = CHAT_TEMPLATES[args.chat_template_name] + else: + try: + tokenizer.chat_template = AutoTokenizer.from_pretrained(args.chat_template_name).chat_template + except Exception: + raise ValueError(f"Could not find chat template for {args.chat_template_name}.") + + if args.add_bos: + if tokenizer.chat_template.startswith("{{ bos_token }}") or ( + tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token) + ): + raise ValueError( + "You specified add_bos=True, but the chat template already has a bos_token at the beginning." + ) + # also add bos in the chat template if not already there + tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template + if args.use_lora: if args.use_qlora: model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) @@ -546,6 +722,31 @@ def main(args: FlatArguments): elif args.gradient_checkpointing: model.gradient_checkpointing_enable() + train_dataset = raw_datasets["train"] + # debugging tool for fewer samples + if args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), args.max_train_samples) + logger.info(f"Limiting training samples to {max_train_samples} from {len(train_dataset)}.") + train_dataset = train_dataset.select(range(max_train_samples)) + + with accelerator.main_process_first(): + train_dataset = train_dataset.map( + partial(encode_sft_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length), + batched=False, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + remove_columns=[ + name for name in train_dataset.column_names if name not in ["input_ids", "labels", "attention_mask"] + ], + desc="Tokenizing and reformatting instruction data", + ) + train_dataset.set_format(type="pt") + train_dataset = train_dataset.filter(lambda example: (example["labels"] != -100).any()) + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + # DataLoaders creation: train_dataloader = DataLoader( train_dataset, From b911f0a24183a7eb68838506e0800b8c9ddbd8fe Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 11:23:17 -0800 Subject: [PATCH 19/32] push docs --- docs/ai2_internal.md | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/docs/ai2_internal.md b/docs/ai2_internal.md index 1081cd188..17e70bb05 100644 --- a/docs/ai2_internal.md +++ b/docs/ai2_internal.md @@ -35,7 +35,7 @@ python mason.py \ --image nathanl/open_instruct_auto --pure_docker_mode \ --priority normal \ --budget ai2/oe-adapt \ - --gpus 1 -- python -c "import torch; print(torch.cuda.is_available())" + --gpus 1 -- which python ``` ### Supervised Fine-tuning (SFT): @@ -79,6 +79,47 @@ python mason.py \ --seed 123 ``` + +```bash +python mason.py \ + --image nathanl/open_instruct_auto --pure_docker_mode \ + --cluster ai2/augusta-google-1 \ + --workspace ai2/tulu-3-dev \ + --priority normal \ + --preemptible \ + --num_nodes 4 \ + --budget ai2/oe-adapt \ + --gpus 8 -- accelerate launch \ + --mixed_precision bf16 \ + --num_processes 8 \ + --use_deepspeed \ + --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ + --deepspeed_multinode_launcher standard \ + open_instruct/finetune1.py \ + --model_name_or_path meta-llama/Llama-3.1-8B \ + --tokenizer_name meta-llama/Llama-3.1-8B \ + --use_slow_tokenizer \ + --use_flash_attn \ + --max_seq_length 4096 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --learning_rate 5e-06 \ + --lr_scheduler_type linear \ + --warmup_ratio 0.03 \ + --weight_decay 0.0 \ + --num_train_epochs 2 \ + --output_dir output/sft_8b \ + --with_tracking \ + --report_to wandb \ + --logging_steps 1 \ + --reduce_loss sum \ + --model_revision main \ + --dataset_mixer_list allenai/tulu-3-sft-mixture 100 \ + --dataset_mix_dir output/sft_8b \ + --exp_name tulu-3-8b-sft \ + --seed 123 +``` + Note that during job submission, we will try to tokenize and cache the dataset so we are not running these CPU-heavy workloads in GPU jobs. Specifically, `mason.py` will parse out `python` command you are running and attempts to run it with `--cache_dataset_only` flag. For example, you will see output like ```bash From 952611e3d0c3810d69d36733f85cad4b3a6803f3 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 14:33:08 -0800 Subject: [PATCH 20/32] update docs --- docs/ai2_internal.md | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/docs/ai2_internal.md b/docs/ai2_internal.md index 17e70bb05..a4a93aa98 100644 --- a/docs/ai2_internal.md +++ b/docs/ai2_internal.md @@ -80,46 +80,6 @@ python mason.py \ ``` -```bash -python mason.py \ - --image nathanl/open_instruct_auto --pure_docker_mode \ - --cluster ai2/augusta-google-1 \ - --workspace ai2/tulu-3-dev \ - --priority normal \ - --preemptible \ - --num_nodes 4 \ - --budget ai2/oe-adapt \ - --gpus 8 -- accelerate launch \ - --mixed_precision bf16 \ - --num_processes 8 \ - --use_deepspeed \ - --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ - --deepspeed_multinode_launcher standard \ - open_instruct/finetune1.py \ - --model_name_or_path meta-llama/Llama-3.1-8B \ - --tokenizer_name meta-llama/Llama-3.1-8B \ - --use_slow_tokenizer \ - --use_flash_attn \ - --max_seq_length 4096 \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 4 \ - --learning_rate 5e-06 \ - --lr_scheduler_type linear \ - --warmup_ratio 0.03 \ - --weight_decay 0.0 \ - --num_train_epochs 2 \ - --output_dir output/sft_8b \ - --with_tracking \ - --report_to wandb \ - --logging_steps 1 \ - --reduce_loss sum \ - --model_revision main \ - --dataset_mixer_list allenai/tulu-3-sft-mixture 100 \ - --dataset_mix_dir output/sft_8b \ - --exp_name tulu-3-8b-sft \ - --seed 123 -``` - Note that during job submission, we will try to tokenize and cache the dataset so we are not running these CPU-heavy workloads in GPU jobs. Specifically, `mason.py` will parse out `python` command you are running and attempts to run it with `--cache_dataset_only` flag. For example, you will see output like ```bash From 3e4c0fb49509b7162aecce01513d626527f4b2e1 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 14:34:39 -0800 Subject: [PATCH 21/32] quick update --- mason.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mason.py b/mason.py index 8e26504a3..274d5be29 100644 --- a/mason.py +++ b/mason.py @@ -105,8 +105,9 @@ def get_args(): parser.add_argument( "--resumable", action="store_true", help="If given, make the job resumable" ) - - + parser.add_argument( + "--no_auto_dataset_cache", action="store_true", help="If given, don't cache the dataset automatically" + ) # Split up the mason args from the Python args. mason_args, command_args = parser.parse_known_args() commands = parse_commands(command_args) @@ -433,21 +434,20 @@ def make_task_spec(args, command: List[str], i: int, beaker_secrets: str, whoami # We could in theory submit a cpu only job to beaker to do this, but that requires setting up # dependency jobs somehow. Since tokenization is like ~5 minutes, we can just run it locally. # Once it's cached, we don't need to cache it again. - def find_list_idx(lst: List[str], item: str): for i in range(len(lst)): if item == lst[i]: return i return -1 - - for file in ["open_instruct/finetune1.py", "open_instruct/dpo_tune_cache1.py"]: - idx = find_list_idx(full_command, file) - if idx != -1: - # then try executing the same full_command with - caching_command = "python " + " ".join(full_command[idx:]) + " --cache_dataset_only" - print(f"📦📦📦 Running the caching full_command: {caching_command}") - os.system(caching_command) - print("✅✅✅ Finished running the caching full_command") + if not args.no_auto_dataset_cache: + for file in ["open_instruct/finetune1.py", "open_instruct/dpo_tune_cache1.py"]: + idx = find_list_idx(full_command, file) + if idx != -1: + # then try executing the same full_command with + caching_command = "python " + " ".join(full_command[idx:]) + " --cache_dataset_only" + print(f"📦📦📦 Running the caching full_command: {caching_command}") + os.system(caching_command) + print("✅✅✅ Finished running the caching full_command") if not args.pure_docker_mode: setup_commands += f"cd {os.getcwd()} && " From 01da6ab578fab902291b15b5edd1f24fa57fdc8b Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 14:38:55 -0800 Subject: [PATCH 22/32] quick push --- open_instruct/dpo_tune_cache1.py | 2 ++ open_instruct/finetune1.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/open_instruct/dpo_tune_cache1.py b/open_instruct/dpo_tune_cache1.py index 2305ca394..c478261ea 100644 --- a/open_instruct/dpo_tune_cache1.py +++ b/open_instruct/dpo_tune_cache1.py @@ -538,6 +538,8 @@ def main(args: FlatArguments): add_bos=args.add_bos, ) tokenizer = tc.tokenizer + if args.dataset_mixer is not None: + args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair] with accelerator.main_process_first(): train_dataset = get_cached_dataset_tulu_preference( args.dataset_mixer_list, diff --git a/open_instruct/finetune1.py b/open_instruct/finetune1.py index 367136476..45f222996 100644 --- a/open_instruct/finetune1.py +++ b/open_instruct/finetune1.py @@ -385,6 +385,7 @@ def __post_init__(self): def main(args: FlatArguments): + breakpoint() # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment @@ -460,6 +461,8 @@ def main(args: FlatArguments): add_bos=args.add_bos, ) tokenizer = tc.tokenizer + if args.dataset_mixer is not None: + args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair] with accelerator.main_process_first(): train_dataset = get_cached_dataset_tulu_sft( args.dataset_mixer_list, From 60a7ecbb330c4b2396f1b2372449c5b0e1131f29 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 14:41:27 -0800 Subject: [PATCH 23/32] push --- open_instruct/finetune1.py | 1 - 1 file changed, 1 deletion(-) diff --git a/open_instruct/finetune1.py b/open_instruct/finetune1.py index 45f222996..ada18b103 100644 --- a/open_instruct/finetune1.py +++ b/open_instruct/finetune1.py @@ -385,7 +385,6 @@ def __post_init__(self): def main(args: FlatArguments): - breakpoint() # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment From ad5ec2b385417260cc3c8008f0e40ac4a98129cc Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 14:42:12 -0800 Subject: [PATCH 24/32] Just replace the existing dpo / finetune --- open_instruct/dpo_tune_cache.py | 231 +++++------------------------ open_instruct/finetune.py | 251 ++++---------------------------- 2 files changed, 71 insertions(+), 411 deletions(-) diff --git a/open_instruct/dpo_tune_cache.py b/open_instruct/dpo_tune_cache.py index e95673bae..c478261ea 100644 --- a/open_instruct/dpo_tune_cache.py +++ b/open_instruct/dpo_tune_cache.py @@ -57,7 +57,7 @@ get_scheduler, ) -from open_instruct.dataset_processor import CHAT_TEMPLATES +from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_preference from open_instruct.dpo_utils import ( DataCollatorForSeq2SeqDPO, concatenated_forward, @@ -66,12 +66,10 @@ simpo_loss, wpo_loss, ) -from open_instruct.finetune import encode_sft_example from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, clean_last_n_checkpoints, - get_datasets, get_last_checkpoint_path, get_wandb_tags, is_beaker_job, @@ -374,6 +372,8 @@ class FlatArguments: default=0.001, metadata={"help": "Weight for load balancing loss if applicable."}, ) + cache_dataset_only: bool = False + """Immediately exit after caching the dataset""" concatenated_forward: bool = True """Whether to concatenate chosen and rejected for DPO training; True is good but you can set to False for saving memory.""" try_auto_save_to_beaker: bool = True @@ -421,31 +421,6 @@ def __post_init__(self): raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") -def encode_dpo_example(example, tokenizer, max_seq_length): - """ - Here we assume each example has a rejected and chosen field, both of which are a list of messages. - Each message is a dict with 'role' and 'content' fields. - We assume only the last message is different, and the prompt is contained in the list of messages. - """ - chosen_messages = example["chosen"] - rejected_messages = example["rejected"] - if len(chosen_messages) == 0: - raise ValueError("chosen messages field is empty.") - if len(rejected_messages) == 0: - raise ValueError("rejected messages field is empty.") - - chosen_encoded = encode_sft_example({"messages": chosen_messages}, tokenizer, max_seq_length) - rejected_encoded = encode_sft_example({"messages": rejected_messages}, tokenizer, max_seq_length) - - return { - "chosen_input_ids": chosen_encoded["input_ids"], - "chosen_labels": chosen_encoded["labels"], - "chosen_attention_mask": chosen_encoded["attention_mask"], - "rejected_input_ids": rejected_encoded["input_ids"], - "rejected_labels": rejected_encoded["labels"], - "rejected_attention_mask": rejected_encoded["attention_mask"], - } - def get_cache_ref_logprobs( model: torch.nn.Module, @@ -485,7 +460,9 @@ def get_cache_ref_logprobs( def main(args: FlatArguments): - init_gpu_memory = torch.cuda.mem_get_info()[0] + init_gpu_memory = None + if torch.cuda.is_available(): + init_gpu_memory = torch.cuda.mem_get_info()[0] # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment @@ -545,40 +522,34 @@ def main(args: FlatArguments): accelerator.wait_for_everyone() - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - args.dataset_name, - args.dataset_config_name, - ) - elif args.dataset_mixer is not None: - # mixing datasets via config - raw_datasets = get_datasets( - args.dataset_mixer, - configs=args.dataset_config_name, - splits=["train"], - save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["chosen", "rejected"], - ) - elif args.dataset_mixer_list is not None: - # mixing datasets via config - raw_datasets = get_datasets( + tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision + tokenizer_name = args.tokenizer_name if args.tokenizer_name is not None else args.model_name_or_path + if tokenizer_revision != args.model_revision: + # Warn user if tokenizer and model use different revisions; this is an unusual + # use case. + warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different + from the model revision `{args.model_revision}`.""" + logger.warning(warning) + tc = TokenizerConfig( + model_name_or_path=tokenizer_name, + revision=args.model_revision, + use_fast=not args.use_slow_tokenizer, + chat_template_name=args.chat_template_name, + add_bos=args.add_bos, + ) + tokenizer = tc.tokenizer + if args.dataset_mixer is not None: + args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair] + with accelerator.main_process_first(): + train_dataset = get_cached_dataset_tulu_preference( args.dataset_mixer_list, - configs=args.dataset_config_name, - splits=["train"], - save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["chosen", "rejected"], - ) - else: - data_files = {} - dataset_args = {} - if args.train_file is not None: - data_files["train"] = args.train_file - raw_datasets = load_dataset( - "json", - data_files=data_files, - **dataset_args, + tc, + args.max_seq_length, ) + train_dataset.shuffle(seed=args.seed) + train_dataset.set_format(type="pt") + if args.cache_dataset_only: + return # Load pretrained model and tokenizer if args.config_name: @@ -598,34 +569,6 @@ def main(args: FlatArguments): "You are instantiating a new config instance from scratch. This is not supported by this script." ) - tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision - if tokenizer_revision != args.model_revision: - # Warn user if tokenizer and model use different revisions; this is an unusual - # use case. - warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different - from the model revision `{args.model_revision}`.""" - logger.warning(warning) - - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name, - revision=tokenizer_revision, - trust_remote_code=args.trust_remote_code, - use_fast=not args.use_slow_tokenizer, - ) - elif args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained( - args.model_name_or_path, - revision=tokenizer_revision, - trust_remote_code=args.trust_remote_code, - use_fast=not args.use_slow_tokenizer, - ) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) - def load_model(): if args.model_name_or_path: if args.use_qlora: @@ -666,43 +609,6 @@ def load_model(): model = load_model() print("=============model loaded") print_gpu_stats(init_gpu_memory) - # no default pad token for llama! - # here we add all special tokens again, because the default ones are not in the special_tokens_map - if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast): - num_added_tokens = tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "", - } - ) - assert num_added_tokens in [ - 0, - 1, - ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast): - # OLMo newer models use this tokenizer - if tokenizer.bos_token is None: - tokenizer.bos_token = tokenizer.eos_token - assert ( - args.add_bos - ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." - # else, pythia / other models - else: - num_added_tokens = tokenizer.add_special_tokens( - { - "pad_token": "", - } - ) - assert ( - num_added_tokens <= 1 - ), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens)." - elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): - num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""}) - elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: - num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) - assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. @@ -712,27 +618,6 @@ def load_model(): if len(tokenizer) > embeddings.weight.shape[0]: model.resize_token_embeddings(len(tokenizer)) - # set the tokenizer chat template to the training format - # this will be used for encoding the training examples - # and saved together with the tokenizer to be used later. - if args.chat_template_name in CHAT_TEMPLATES: - tokenizer.chat_template = CHAT_TEMPLATES[args.chat_template_name] - else: - try: - tokenizer.chat_template = AutoTokenizer.from_pretrained(args.chat_template_name).chat_template - except Exception: - raise ValueError(f"Could not find chat template for {args.chat_template_name}.") - - if args.add_bos: - if tokenizer.chat_template.startswith("{{ bos_token }}") or ( - tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token) - ): - raise ValueError( - "You specified add_bos=True, but the chat template already has a bos_token at the beginning." - ) - # add bos in the chat template if not already there - tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template - if args.use_lora: if args.use_qlora: model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) @@ -751,51 +636,12 @@ def load_model(): elif args.gradient_checkpointing: model.gradient_checkpointing_enable() - # Preprocessing the datasets. - if "prompt" in raw_datasets["train"].column_names and "completion" in raw_datasets["train"].column_names: - raise ValueError("Sorry, prompt-completion format is not supported for DPO training.") - elif "chosen" in raw_datasets["train"].column_names and "rejected" in raw_datasets["train"].column_names: - encode_function = partial( - encode_dpo_example, - tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - ) - else: - raise ValueError("You need to have 'chosen' and 'rejected in your column names.") - - train_dataset = raw_datasets["train"] - # debugging tool for fewer samples if args.max_train_samples is not None: max_train_samples = min(len(train_dataset), args.max_train_samples) logger.info(f"Limiting training samples to {max_train_samples} from {len(train_dataset)}.") train_dataset = train_dataset.select(range(max_train_samples)) - with accelerator.main_process_first(): - train_dataset = train_dataset.map( - encode_function, - batched=False, - num_proc=args.preprocessing_num_workers, - remove_columns=[ - name - for name in train_dataset.column_names - if name - not in [ - "chosen_input_ids", - "chosen_labels", - "chosen_attention_mask", - "rejected_input_ids", - "rejected_labels", - "rejected_attention_mask", - ] - ], - desc="Tokenizing and reformatting instruction data", - ) - train_dataset.set_format(type="pt") - # our thresholding mighta meant some examples have no labels, remove. - train_dataset = train_dataset.filter(lambda example: (example["chosen_labels"] != -100).any()) - train_dataset = train_dataset.filter(lambda example: (example["rejected_labels"] != -100).any()) - # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") @@ -1216,12 +1062,13 @@ def load_model(): accelerator.end_training() -def print_gpu_stats(init_gpu_memory): - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - peak_memory = init_gpu_memory - free_gpu_memory - print(f"Peak memory usage: {peak_memory / 1024 ** 3:.2f} GB") - print(f"Total memory usage: {total_gpu_memory / 1024 ** 3:.2f} GB") - print(f"Free memory: {free_gpu_memory / 1024 ** 3:.2f} GB") +def print_gpu_stats(init_gpu_memory: Optional[int]): + if torch.cuda.is_available(): + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + print(f"Peak memory usage: {peak_memory / 1024 ** 3:.2f} GB") + print(f"Total memory usage: {total_gpu_memory / 1024 ** 3:.2f} GB") + print(f"Free memory: {free_gpu_memory / 1024 ** 3:.2f} GB") if __name__ == "__main__": diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 30ee740da..ada18b103 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -53,7 +53,7 @@ get_scheduler, ) -from open_instruct.dataset_processor import CHAT_TEMPLATES +from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_sft from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, @@ -338,6 +338,8 @@ class FlatArguments: default=0.5, metadata={"help": "Weight for load balancing loss if applicable."}, ) + cache_dataset_only: bool = False + """Immediately exit after caching the dataset""" try_auto_save_to_beaker: bool = True """Whether to try to save the model to Beaker dataset `/output` after training""" push_to_hub: bool = True @@ -382,79 +384,6 @@ def __post_init__(self): raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") -def encode_sft_example(example, tokenizer, max_seq_length): - """ - This function encodes a single example into a format that can be used for sft training. - Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields. - We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors. - """ - messages = example["messages"] - if len(messages) == 0: - raise ValueError("messages field is empty.") - input_ids = tokenizer.apply_chat_template( - conversation=messages, - tokenize=True, - return_tensors="pt", - padding=False, - truncation=True, - max_length=max_seq_length, - add_generation_prompt=False, - ) - labels = input_ids.clone() - # mask the non-assistant part for avoiding loss - for message_idx, message in enumerate(messages): - if message["role"] != "assistant": - # we calculate the start index of this non-assistant message - if message_idx == 0: - message_start_idx = 0 - else: - message_start_idx = tokenizer.apply_chat_template( - conversation=messages[:message_idx], # here marks the end of the previous messages - tokenize=True, - return_tensors="pt", - padding=False, - truncation=True, - max_length=max_seq_length, - add_generation_prompt=False, - ).shape[1] - # next, we calculate the end index of this non-assistant message - if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant": - # for intermediate messages that follow with an assistant message, we need to - # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss - # (e.g., `<|assistant|>`) - message_end_idx = tokenizer.apply_chat_template( - conversation=messages[: message_idx + 1], - tokenize=True, - return_tensors="pt", - padding=False, - truncation=True, - max_length=max_seq_length, - add_generation_prompt=True, - ).shape[1] - else: - # for the last message or the message that doesn't follow with an assistant message, - # we don't need to add the assistant generation prefix - message_end_idx = tokenizer.apply_chat_template( - conversation=messages[: message_idx + 1], - tokenize=True, - return_tensors="pt", - padding=False, - truncation=True, - max_length=max_seq_length, - add_generation_prompt=False, - ).shape[1] - # set the label to -100 for the non-assistant part - labels[:, message_start_idx:message_end_idx] = -100 - if max_seq_length and message_end_idx >= max_seq_length: - break - attention_mask = torch.ones_like(input_ids) - return { - "input_ids": input_ids.flatten(), - "labels": labels.flatten(), - "attention_mask": attention_mask.flatten(), - } - - def main(args: FlatArguments): # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -515,41 +444,35 @@ def main(args: FlatArguments): accelerator.wait_for_everyone() - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - args.dataset_name, - args.dataset_config_name, - ) - elif args.dataset_mixer is not None: - # mixing datasets via config - raw_datasets = get_datasets( - args.dataset_mixer, - configs=args.dataset_config_name, - splits=["train"], - save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["messages"], - ) - elif args.dataset_mixer_list is not None: - # mixing datasets via config - raw_datasets = get_datasets( + tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision + tokenizer_name = args.tokenizer_name if args.tokenizer_name is not None else args.model_name_or_path + if tokenizer_revision != args.model_revision: + # Warn user if tokenizer and model use different revisions; this is an unusual + # use case. + warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different + from the model revision `{args.model_revision}`.""" + logger.warning(warning) + tc = TokenizerConfig( + model_name_or_path=tokenizer_name, + revision=args.model_revision, + use_fast=not args.use_slow_tokenizer, + chat_template_name=args.chat_template_name, + add_bos=args.add_bos, + ) + tokenizer = tc.tokenizer + if args.dataset_mixer is not None: + args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair] + with accelerator.main_process_first(): + train_dataset = get_cached_dataset_tulu_sft( args.dataset_mixer_list, - configs=args.dataset_config_name, - splits=["train"], - save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["messages"], - ) - else: - data_files = {} - dataset_args = {} - if args.train_file is not None: - data_files["train"] = args.train_file - raw_datasets = load_dataset( - "json", - data_files=data_files, - **dataset_args, + tc, + args.max_seq_length, ) - + train_dataset.shuffle(seed=args.seed) + train_dataset.set_format(type="pt") + if args.cache_dataset_only: + return + # Load pretrained model and tokenizer if args.config_name: config = AutoConfig.from_pretrained( @@ -568,34 +491,6 @@ def main(args: FlatArguments): "You are instantiating a new config instance from scratch. This is not supported by this script." ) - tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision - if tokenizer_revision != args.model_revision: - # Warn user if tokenizer and model use different revisions; this is an unusual - # use case. - warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different - from the model revision `{args.model_revision}`.""" - logger.warning(warning) - - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name, - revision=tokenizer_revision, - trust_remote_code=args.trust_remote_code, - use_fast=not args.use_slow_tokenizer, - ) - elif args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained( - args.model_name_or_path, - revision=tokenizer_revision, - trust_remote_code=args.trust_remote_code, - use_fast=not args.use_slow_tokenizer, - ) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) - if args.model_name_or_path: if args.use_qlora: bnb_config = BitsAndBytesConfig( @@ -632,42 +527,6 @@ def main(args: FlatArguments): logger.info("Training new model from scratch") model = AutoModelForCausalLM.from_config(config) - # no default pad token for llama! - # here we add all special tokens again, because the default ones are not in the special_tokens_map - if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast): - num_added_tokens = tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "", - } - ) - assert num_added_tokens in [ - 0, - 1, - ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast): - # OLMo newer models use this tokenizer - if tokenizer.bos_token is None: - tokenizer.bos_token = tokenizer.eos_token - assert ( - args.add_bos - ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." - # else, pythia / other models - else: - num_added_tokens = tokenizer.add_special_tokens( - { - "pad_token": "", - } - ) - assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token." - elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): - num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""}) - elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: - num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) - assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." - # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. # gather deepspeed to get "real" embedding size @@ -683,27 +542,6 @@ def main(args: FlatArguments): with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): embedding_size = embeddings.weight.shape[0] - # set the tokenizer chat template to the training format - # this will be used for encoding the training examples - # and saved together with the tokenizer to be used later. - if args.chat_template_name in CHAT_TEMPLATES: - tokenizer.chat_template = CHAT_TEMPLATES[args.chat_template_name] - else: - try: - tokenizer.chat_template = AutoTokenizer.from_pretrained(args.chat_template_name).chat_template - except Exception: - raise ValueError(f"Could not find chat template for {args.chat_template_name}.") - - if args.add_bos: - if tokenizer.chat_template.startswith("{{ bos_token }}") or ( - tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token) - ): - raise ValueError( - "You specified add_bos=True, but the chat template already has a bos_token at the beginning." - ) - # also add bos in the chat template if not already there - tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template - if args.use_lora: if args.use_qlora: model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) @@ -722,31 +560,6 @@ def main(args: FlatArguments): elif args.gradient_checkpointing: model.gradient_checkpointing_enable() - train_dataset = raw_datasets["train"] - # debugging tool for fewer samples - if args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), args.max_train_samples) - logger.info(f"Limiting training samples to {max_train_samples} from {len(train_dataset)}.") - train_dataset = train_dataset.select(range(max_train_samples)) - - with accelerator.main_process_first(): - train_dataset = train_dataset.map( - partial(encode_sft_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length), - batched=False, - num_proc=args.preprocessing_num_workers, - load_from_cache_file=not args.overwrite_cache, - remove_columns=[ - name for name in train_dataset.column_names if name not in ["input_ids", "labels", "attention_mask"] - ], - desc="Tokenizing and reformatting instruction data", - ) - train_dataset.set_format(type="pt") - train_dataset = train_dataset.filter(lambda example: (example["labels"] != -100).any()) - - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - # DataLoaders creation: train_dataloader = DataLoader( train_dataset, @@ -888,7 +701,6 @@ def main(args: FlatArguments): print(f"Starting from epoch {starting_epoch} and step {completed_steps}.") # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) - local_total_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device) total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device) start_time = time.time() @@ -909,6 +721,7 @@ def main(args: FlatArguments): if args.load_balancing_loss: outputs = model(**batch, use_cache=False, output_router_logits=True) else: + # TODO: we have calculated the mean loss here anyway, so doubling the calculation outputs = model(**batch, use_cache=False) if args.reduce_loss == "mean": loss = outputs.loss @@ -1002,7 +815,7 @@ def main(args: FlatArguments): os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w" ) as f: f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. - if accelerator.is_local_main_process: + if accelerator.is_local_main_process: # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem. clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) accelerator.wait_for_everyone() From dacc3f201f740d270fa9cda686a19baf4a0a87ea Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 14:43:43 -0800 Subject: [PATCH 25/32] remove unused files --- open_instruct/dpo_tune.py | 1202 ------------------------------ open_instruct/dpo_tune_cache1.py | 1077 -------------------------- open_instruct/finetune1.py | 929 ----------------------- 3 files changed, 3208 deletions(-) delete mode 100644 open_instruct/dpo_tune.py delete mode 100644 open_instruct/dpo_tune_cache1.py delete mode 100644 open_instruct/finetune1.py diff --git a/open_instruct/dpo_tune.py b/open_instruct/dpo_tune.py deleted file mode 100644 index a4bd874ad..000000000 --- a/open_instruct/dpo_tune.py +++ /dev/null @@ -1,1202 +0,0 @@ -# !/usr/bin/env python -# coding=utf-8 -# Copyright 2024 AllenAI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -DPO tuning script. Adapted from our finetuning script. -""" - -import json -import logging -import math -import os -import random -import shutil -import subprocess -import time -from copy import deepcopy -from dataclasses import dataclass, field -from datetime import timedelta -from functools import partial -from typing import List, Optional, Union - -import datasets -import deepspeed -import torch -import transformers -from accelerate import Accelerator, DataLoaderConfiguration -from accelerate.logging import get_logger -from accelerate.utils import InitProcessGroupKwargs, set_seed -from datasets import load_dataset -from huggingface_hub import HfApi -from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - GPT2Tokenizer, - GPTNeoXTokenizerFast, - LlamaTokenizer, - LlamaTokenizerFast, - OPTForCausalLM, - get_scheduler, -) - -from open_instruct.dataset_processor import CHAT_TEMPLATES -from open_instruct.dpo_utils import ( - DataCollatorForSeq2SeqDPO, - concatenated_forward, - dpo_loss, - simpo_loss, - wpo_loss, -) -from open_instruct.finetune import encode_sft_example -from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate -from open_instruct.utils import ( - ArgumentParserPlus, - clean_last_n_checkpoints, - get_datasets, - get_last_checkpoint_path, - get_wandb_tags, - is_beaker_job, - maybe_get_beaker_config, - maybe_use_ai2_hf_entity, - maybe_use_ai2_wandb_entity, - upload_metadata_to_hf, -) - -logger = get_logger(__name__) - - -@dataclass -class FlatArguments: - """ - Full arguments class for all fine-tuning jobs. - """ - - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """The name of this experiment""" - run_name: Optional[str] = None - """A unique name of this run""" - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." - ) - }, - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - dpo_use_paged_optimizer: bool = field( - default=False, - metadata={ - "help": "Use paged optimizer from bitsandbytes." - " Not compatible with deepspeed (use deepspeed config instead)." - }, - ) - dpo_beta: float = field( - default=0.1, - metadata={"help": "Beta parameter for DPO loss. Default is 0.1."}, - ) - dpo_loss_type: str = field( - default="dpo", - metadata={"help": "Type of DPO loss to use. Options are 'dpo', 'dpo_norm', 'simpo', 'wpo'."}, - ) - dpo_gamma_beta_ratio: float = field( - default=0.3, - metadata={"help": "Gamma to beta ratio for SimPO loss. Default is 0.3. Not used for DPO loss."}, - ) - dpo_label_smoothing: float = field( - default=0.0, - metadata={"help": "Label smoothing for DPO/SimPO loss. Default is 0 (no smoothing)."}, - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - tokenizer_revision: Optional[str] = field( - default=None, - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - chat_template_name: str = field( - default="tulu", - metadata={ - "help": ( - f"The name of the chat template to use. " - f"You can choose one of our pre-defined templates: {', '.join(CHAT_TEMPLATES.keys())}." - f"Or, you can provide a tokenizer name or path here and we will apply its chat template." - ) - }, - ) - use_flash_attn: bool = field( - default=True, - metadata={"help": "Whether to use flash attention in the model training"}, - ) - use_slow_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the slow tokenizer or not (which is then fast tokenizer)."}, - ) - model_revision: Optional[str] = field( - default=None, - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - trust_remote_code: bool = field( - default=False, - metadata={ - "help": ( - "Whether or not to allow for custom models defined on the Hub in their own modeling files. " - "This option should only be set to `True` for repositories you trust and in which you " - "have read the code, as it will execute code present on the Hub on your local machine." - ) - }, - ) - low_cpu_mem_usage: bool = field( - default=False, - metadata={ - "help": ( - "It is an option to create the model as an empty shell, " - "then only materialize its parameters when the pretrained weights are loaded. " - "set True will benefit LLM loading time and RAM consumption." - ) - }, - ) - dataset_name: Optional[str] = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - dataset_mixer: Optional[dict] = field( - default=None, metadata={"help": "A dictionary of datasets (local or HF) to sample from."} - ) - dataset_mixer_list: Optional[list[str]] = field( - default=None, metadata={"help": "A list of datasets (local or HF) to sample from."} - ) - dataset_mix_dir: Optional[str] = field( - default=None, metadata={"help": "The directory to save the mixed dataset to disk."} - ) - dataset_config_name: Optional[str] = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} - ) - train_file: Optional[str] = field( - default=None, metadata={"help": "The input training data file (a json/jsonl file)."} - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - max_seq_length: Optional[int] = field( - default=None, - metadata={ - "help": ( - "The maximum total input sequence length after tokenization. " - "Sequences longer than this will be truncated," - ) - }, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) - add_bos: bool = field( - default=False, - metadata={ - "help": "Forcibly add bos token to the beginning of the input sequence." - " Use only when tokenizer does not add bos token by default." - }, - ) - clip_grad_norm: float = field( - default=-1, - metadata={"help": "Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead)."}, - ) - gradient_accumulation_steps: int = field( - default=1, - metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, - ) - learning_rate: float = field( - default=2e-5, - metadata={"help": "The initial learning rate for AdamW optimizer."}, - ) - logging_steps: Optional[int] = field( - default=None, - metadata={"help": "Log the training loss and learning rate every logging_steps steps."}, - ) - lora_rank: int = field( - default=64, - metadata={"help": "The rank of lora."}, - ) - lora_alpha: float = field( - default=16, - metadata={"help": "The alpha parameter of lora."}, - ) - lora_dropout: float = field( - default=0.1, - metadata={"help": "The dropout rate of lora modules."}, - ) - lr_scheduler_type: str = field( - default="linear", - metadata={ - "help": "The scheduler type to use for learning rate adjustment.", - "choices": ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], - }, - ) - num_train_epochs: int = field( - default=2, - metadata={"help": "Total number of training epochs to perform."}, - ) - output_dir: str = field( - default="output/", - metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, - ) - per_device_train_batch_size: int = field( - default=8, - metadata={"help": "Batch size per GPU/TPU core/CPU for training."}, - ) - use_lora: bool = field( - default=False, - metadata={"help": "If True, will use LORA (low-rank parameter-efficient training) to train the model."}, - ) - use_qlora: bool = field( - default=False, - metadata={"help": "Use qLoRA training - initializes model in quantized form. Not compatible with deepspeed."}, - ) - use_8bit_optimizer: bool = field( - default=False, - metadata={"help": "Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed."}, - ) - warmup_ratio: float = field( - default=0.03, - metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}, - ) - weight_decay: float = field( - default=0.0, - metadata={"help": "Weight decay for AdamW if we apply some."}, - ) - timeout: int = field( - default=1800, - metadata={ - "help": "Timeout for the training process in seconds." - "Useful if tokenization process is long. Default is 1800 seconds (30 minutes)." - }, - ) - reduce_loss: str = field( - default="mean", - metadata={ - "help": "How to reduce loss over tokens. Options are 'mean' or 'sum'." - "Using 'sum' can improve chat model performance." - }, - ) - wandb_entity: Optional[str] = field( - default=None, - metadata={"help": "Entity to use for logging to wandb."}, - ) - resume_from_checkpoint: Optional[str] = field( - default=None, - metadata={"help": "If the training should continue from a checkpoint folder."}, - ) - with_tracking: bool = field( - default=False, - metadata={"help": "Whether to enable experiment trackers for logging."}, - ) - report_to: Union[str, List[str]] = field( - default="all", - metadata={ - "help": "The integration(s) to report results and logs to. " - "Can be a single string or a list of strings. " - "Options are 'tensorboard', 'wandb', 'comet_ml', 'clearml', or 'all'. " - "Specify multiple by listing them: e.g., ['tensorboard', 'wandb']" - }, - ) - save_to_hub: Optional[str] = field( - default=None, - metadata={"help": "Save the model to the Hub under this name. E.g allenai/your-model"}, - ) - gradient_checkpointing: bool = field( - default=False, - metadata={"help": "Turn on gradient checkpointing. Saves memory but slows training."}, - ) - max_train_steps: Optional[int] = field( - default=None, - metadata={"help": "If set, overrides the number of training steps. Otherwise, num_train_epochs is used."}, - ) - seed: int = field(default=42, metadata={"help": "Random seed for initialization and dataset shuffling."}) - checkpointing_steps: Optional[str] = field( - default=None, - metadata={ - "help": "Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." # noqa - }, - ) - overwrite_output_dir: bool = field( - default=False, - metadata={ - "help": "Overwrite the content of the output directory. Means that resumption will always start from scratch." - }, - ) - keep_last_n_checkpoints: int = field( - default=3, - metadata={"help": "How many checkpoints to keep in the output directory. -1 for all."}, - ) - load_balancing_loss: bool = field( - default=False, - metadata={ - "help": "Whether to include a load balancing loss (for OLMoE) or not.", - }, - ) - load_balancing_weight: float = field( - default=0.001, - metadata={"help": "Weight for load balancing loss if applicable."}, - ) - try_auto_save_to_beaker: bool = True - """Whether to try to save the model to Beaker dataset `/output` after training""" - push_to_hub: bool = True - """Whether to upload the saved model to huggingface""" - hf_entity: Optional[str] = None - """The user or org name of the model repository from the Hugging Face Hub""" - hf_repo_id: Optional[str] = None - """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_revision: Optional[str] = None - """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_url: Optional[str] = None - """The url of the saved model in the Hugging Face Hub (will be autoset)""" - try_launch_beaker_eval_jobs: bool = True - """Whether to launch beaker evaluation jobs after training""" - hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" - """What dataset to upload the metadata to. If unset, don't upload metadata""" - - def __post_init__(self): - if self.reduce_loss not in ["mean", "sum"]: - raise ValueError("reduce_loss must be either 'mean' or 'sum'") - if ( - self.dataset_name is None - and self.train_file is None - and self.dataset_mixer is None - and self.dataset_mixer_list is None - ): - raise ValueError("Need either a dataset name, dataset mixer, or a training file.") - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - assert extension in ["json", "jsonl"], "`train_file` should be a json or a jsonl file." - if ( - (self.dataset_name is not None and (self.dataset_mixer is not None or self.dataset_mixer_list is not None)) - or (self.dataset_name is not None and self.train_file is not None) - or ( - (self.dataset_mixer is not None or self.dataset_mixer_list is not None) and self.train_file is not None - ) - or (self.dataset_mixer is not None and self.dataset_mixer_list is not None) - ): - raise ValueError("Cannot provide two dataset selection mechanisms.") - - if self.try_launch_beaker_eval_jobs and not self.push_to_hub: - raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") - - -def encode_dpo_example(example, tokenizer, max_seq_length): - """ - Here we assume each example has a rejected and chosen field, both of which are a list of messages. - Each message is a dict with 'role' and 'content' fields. - We assume only the last message is different, and the prompt is contained in the list of messages. - """ - chosen_messages = example["chosen"] - rejected_messages = example["rejected"] - if len(chosen_messages) == 0: - raise ValueError("chosen messages field is empty.") - if len(rejected_messages) == 0: - raise ValueError("rejected messages field is empty.") - - chosen_encoded = encode_sft_example({"messages": chosen_messages}, tokenizer, max_seq_length) - rejected_encoded = encode_sft_example({"messages": rejected_messages}, tokenizer, max_seq_length) - - return { - "chosen_input_ids": chosen_encoded["input_ids"], - "chosen_labels": chosen_encoded["labels"], - "chosen_attention_mask": chosen_encoded["attention_mask"], - "rejected_input_ids": rejected_encoded["input_ids"], - "rejected_labels": rejected_encoded["labels"], - "rejected_attention_mask": rejected_encoded["attention_mask"], - } - - -# from trl, we have to prep the ref model separately. -def prepare_deepspeed(accelerator, model): - deepspeed_plugin = accelerator.state.deepspeed_plugin - config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) - - if model is not None: - if hasattr(model, "config"): - hidden_size = ( - max(model.config.hidden_sizes) - if getattr(model.config, "hidden_sizes", None) - else getattr(model.config, "hidden_size", None) - ) - if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: - # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: - # `Invalidate trace cache @ step 0: expected module 1, but got module 0` - # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 - config_kwargs.update( - { - "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, - "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, - "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, - } - ) - - # If ZeRO-3 is used, we shard both the active and reference model. - # Otherwise, we assume the reference model fits in memory and - # is initialized on each device with ZeRO disabled (stage 0) - if config_kwargs["zero_optimization"]["stage"] != 3: - config_kwargs["zero_optimization"]["stage"] = 0 - model, *_ = deepspeed.initialize(model=model, config=config_kwargs) - model.eval() - return model - - -def main(args: FlatArguments): - # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. - # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers - # in the environment - args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - if args.push_to_hub: - if args.hf_repo_id is None: # auto-generate one - args.hf_repo_id = "open_instruct_dev" - if args.hf_entity is None: # first try to use AI2 entity - args.hf_entity = maybe_use_ai2_hf_entity() - if args.hf_entity is None: # then try to use the user's entity - args.hf_entity = HfApi().whoami()["name"] - args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" - if args.hf_repo_revision is None: - args.hf_repo_revision = args.run_name - args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" - - if is_beaker_job(): - beaker_config = maybe_get_beaker_config() - - accelerator_log_kwargs = {} - - if args.with_tracking: - accelerator_log_kwargs["log_with"] = args.report_to - accelerator_log_kwargs["project_dir"] = args.output_dir - - # if you get timeouts (e.g. due to long tokenization) increase this. - timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout)) - dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True) - - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - dataloader_config=dataloader_config, - **accelerator_log_kwargs, - kwargs_handlers=[timeout_kwargs], - ) - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - accelerator.wait_for_everyone() - - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - args.dataset_name, - args.dataset_config_name, - ) - elif args.dataset_mixer is not None: - # mixing datasets via config - raw_datasets = get_datasets( - args.dataset_mixer, - configs=args.dataset_config_name, - splits=["train"], - save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["chosen", "rejected"], - ) - elif args.dataset_mixer_list is not None: - # mixing datasets via config - raw_datasets = get_datasets( - args.dataset_mixer_list, - configs=args.dataset_config_name, - splits=["train"], - save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["chosen", "rejected"], - ) - else: - data_files = {} - dataset_args = {} - if args.train_file is not None: - data_files["train"] = args.train_file - raw_datasets = load_dataset( - "json", - data_files=data_files, - **dataset_args, - ) - - # Load pretrained model and tokenizer - if args.config_name: - config = AutoConfig.from_pretrained( - args.config_name, - revision=args.model_revision, - trust_remote_code=args.trust_remote_code, - ) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - trust_remote_code=args.trust_remote_code, - ) - else: - raise ValueError( - "You are instantiating a new config instance from scratch. This is not supported by this script." - ) - - tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision - if tokenizer_revision != args.model_revision: - # Warn user if tokenizer and model use different revisions; this is an unusual - # use case. - warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different - from the model revision `{args.model_revision}`.""" - logger.warning(warning) - - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_name, - revision=tokenizer_revision, - trust_remote_code=args.trust_remote_code, - use_fast=not args.use_slow_tokenizer, - ) - elif args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained( - args.model_name_or_path, - revision=tokenizer_revision, - trust_remote_code=args.trust_remote_code, - use_fast=not args.use_slow_tokenizer, - ) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) - - def load_model(): - if args.model_name_or_path: - if args.use_qlora: - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - ) - device_index = accelerator.local_process_index - device_map = {"": device_index} # force data-parallel training. - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - trust_remote_code=args.trust_remote_code, - quantization_config=bnb_config, - device_map=device_map, - torch_dtype=torch.bfloat16, - use_flash_attention_2=True if args.use_flash_attn else False, - ) - else: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - trust_remote_code=args.trust_remote_code, - low_cpu_mem_usage=args.low_cpu_mem_usage, - use_flash_attention_2=True if args.use_flash_attn else False, - ) - else: - logger.info("Training new model from scratch") - model = AutoModelForCausalLM.from_config(config) - return model - - model = load_model() - # only simpo is reference model free rn - if args.dpo_loss_type != "simpo": - if not args.use_lora: - reference_model = load_model() - else: - reference_model = model - else: - reference_model = None - - # no default pad token for llama! - # here we add all special tokens again, because the default ones are not in the special_tokens_map - if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast): - num_added_tokens = tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "", - } - ) - assert num_added_tokens in [ - 0, - 1, - ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast): - # OLMo newer models use this tokenizer - if tokenizer.bos_token is None: - tokenizer.bos_token = tokenizer.eos_token - assert ( - args.add_bos - ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." - # else, pythia / other models - else: - num_added_tokens = tokenizer.add_special_tokens( - { - "pad_token": "", - } - ) - assert ( - num_added_tokens <= 1 - ), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens)." - elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM): - num_added_tokens = tokenizer.add_special_tokens({"unk_token": ""}) - elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: - num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) - assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." - - # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch - # on a small vocab and want a smaller embedding size, remove this test. - # gather deepspeed to get "real" embedding size - embeddings = model.get_input_embeddings() - with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): - if len(tokenizer) > embeddings.weight.shape[0]: - model.resize_token_embeddings(len(tokenizer)) - if reference_model is not None: - reference_embeddings = reference_model.get_input_embeddings() - with deepspeed.zero.GatheredParameters(reference_embeddings.weight, modifier_rank=None): - if len(tokenizer) > reference_embeddings.weight.shape[0]: - reference_model.resize_token_embeddings(len(tokenizer)) - - # set the tokenizer chat template to the training format - # this will be used for encoding the training examples - # and saved together with the tokenizer to be used later. - if args.chat_template_name in CHAT_TEMPLATES: - tokenizer.chat_template = CHAT_TEMPLATES[args.chat_template_name] - else: - try: - tokenizer.chat_template = AutoTokenizer.from_pretrained(args.chat_template_name).chat_template - except Exception: - raise ValueError(f"Could not find chat template for {args.chat_template_name}.") - - if args.add_bos: - if tokenizer.chat_template.startswith("{{ bos_token }}") or ( - tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token) - ): - raise ValueError( - "You specified add_bos=True, but the chat template already has a bos_token at the beginning." - ) - # add bos in the chat template if not already there - tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template - - if args.use_lora: - if args.use_qlora: - model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) - - logger.info("Initializing LORA model...") - peft_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - target_modules=["q_proj", "o_proj", "v_proj", "k_proj", "gate_proj", "up_proj", "down_proj"], - ) - model = get_peft_model(model, peft_config) - model.print_trainable_parameters() - elif args.gradient_checkpointing: - model.gradient_checkpointing_enable() - - # Preprocessing the datasets. - if "prompt" in raw_datasets["train"].column_names and "completion" in raw_datasets["train"].column_names: - raise ValueError("Sorry, prompt-completion format is not supported for DPO training.") - elif "chosen" in raw_datasets["train"].column_names and "rejected" in raw_datasets["train"].column_names: - encode_function = partial( - encode_dpo_example, - tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - ) - else: - raise ValueError("You need to have 'chosen' and 'rejected in your column names.") - - train_dataset = raw_datasets["train"] - - # debugging tool for fewer samples - if args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), args.max_train_samples) - logger.info(f"Limiting training samples to {max_train_samples} from {len(train_dataset)}.") - train_dataset = train_dataset.select(range(max_train_samples)) - - with accelerator.main_process_first(): - train_dataset = train_dataset.map( - encode_function, - batched=False, - num_proc=args.preprocessing_num_workers, - remove_columns=[ - name - for name in train_dataset.column_names - if name - not in [ - "chosen_input_ids", - "chosen_labels", - "chosen_attention_mask", - "rejected_input_ids", - "rejected_labels", - "rejected_attention_mask", - ] - ], - desc="Tokenizing and reformatting instruction data", - ) - train_dataset.set_format(type="pt") - # our thresholding mighta meant some examples have no labels, remove. - train_dataset = train_dataset.filter(lambda example: (example["chosen_labels"] != -100).any()) - train_dataset = train_dataset.filter(lambda example: (example["rejected_labels"] != -100).any()) - - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - # DataLoaders creation: - train_dataloader = DataLoader( - train_dataset, - shuffle=True, - collate_fn=DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest"), - batch_size=args.per_device_train_batch_size, - ) - - # Optimizer - # Split weights in two groups, one with weight decay and the other not. - no_decay = ["bias", "layer_norm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": args.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - if args.use_qlora or args.dpo_use_paged_optimizer: - from bitsandbytes.optim import AdamW - - optimizer = AdamW( - optimizer_grouped_parameters, - lr=args.learning_rate, - optim_bits=8 if args.use_8bit_optimizer else 32, - is_paged=True, - ) - else: - optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - # Create the learning rate scheduler. - # Note: the current accelerator.step() calls the .step() of the real scheduler - # for the `num_processes` times. This is because they assume - # the user initialize the scheduler with the entire training set. - # In the case of data parallel training, each process only - # sees a subset (1/num_processes) of the training set. - # So each time the process needs to update the lr multiple times so that the total - # number of updates in the end matches the num_training_steps here. - # Here we need to set the num_training_steps to either using the - # entire training set (when epochs is specified) or we need to multiply the - # num_training_steps by num_processes so that the total number of - # updates matches the num_training_steps. - num_training_steps_for_scheduler = ( - args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes - ) - lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_training_steps=num_training_steps_for_scheduler, - num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), - ) - # Prepare everything with `accelerator`. - model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler - ) - # reference model may not be none with e.g. SimPO loss. - if not args.use_lora and reference_model is not None: - reference_model = prepare_deepspeed(accelerator, reference_model) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # Figure out how many steps we should save the Accelerator states - checkpointing_steps = args.checkpointing_steps - if checkpointing_steps is not None and str(checkpointing_steps).lower() != "epoch": - checkpointing_steps = int(checkpointing_steps) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if args.with_tracking: - experiment_config = vars(args) - # TensorBoard cannot log Enums, need the raw value - experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"] - - # (Optional) Ai2 internal tracking - if args.wandb_entity is None: - args.wandb_entity = maybe_use_ai2_wandb_entity() - if is_beaker_job(): - experiment_config.update(vars(beaker_config)) - accelerator.init_trackers( - "open_instruct_internal", - experiment_config, - init_kwargs={ - "wandb": { - "name": args.run_name, - "entity": args.wandb_entity, - "tags": [args.exp_name] + get_wandb_tags(), - } - }, - ) - wandb_tracker = accelerator.get_tracker("wandb") - - # Train! - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - completed_steps = 0 - starting_epoch = 0 - - # Potentially load in the weights and states from a previous save - last_checkpoint_path = get_last_checkpoint_path(args) - if last_checkpoint_path: - accelerator.print(f"Resumed from checkpoint: {last_checkpoint_path}") - accelerator.load_state(last_checkpoint_path) - # Extract `epoch_{i}` or `step_{i}` - last_checkpoint_path = os.path.basename(last_checkpoint_path) - training_difference = os.path.splitext(last_checkpoint_path)[0] - - if "epoch" in training_difference: - starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None - completed_steps = starting_epoch * num_update_steps_per_epoch - else: - # need to multiply `gradient_accumulation_steps` to reflect real steps - resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps - starting_epoch = resume_step // len(train_dataloader) - completed_steps = resume_step // args.gradient_accumulation_steps - resume_step -= starting_epoch * len(train_dataloader) - - print(f"Starting from epoch {starting_epoch} and step {completed_steps}.") - # update the progress_bar if load from checkpoint - progress_bar.update(completed_steps) - - local_metrics = torch.zeros((20), device=accelerator.device) - episode = 0 - for epoch in range(starting_epoch, args.num_train_epochs): - model.train() - train_dataloader.set_epoch(epoch) - if last_checkpoint_path and resume_step is not None: - # We skip the first `n` batches in the dataloader when resuming from a checkpoint - active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) - else: - active_dataloader = train_dataloader - # we need to average the log probs for simpo loss - average_log_prob_loss_types = ["simpo", "dpo_norm"] - average_log_prob = args.dpo_loss_type in average_log_prob_loss_types - for step, batch in enumerate(active_dataloader): - episode += len(batch["chosen_input_ids"]) * accelerator.num_processes - # dpo forward pass & loss - with accelerator.accumulate(model): - policy_chosen_logps, policy_rejected_logps, aux_loss = concatenated_forward( - model, batch, average_log_prob=average_log_prob, output_router_logits=args.load_balancing_loss - ) # `aux_loss` is only used when `args.load_balancing_loss = True` - if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": - with torch.no_grad(): - if args.use_lora: - with accelerator.unwrap_model(model).disable_adapter(): - reference_chosen_logps, reference_rejected_logps, _ = concatenated_forward( - model, batch, average_log_prob=average_log_prob - ) - else: - reference_chosen_logps, reference_rejected_logps, _ = concatenated_forward( - reference_model, batch, average_log_prob=average_log_prob - ) - losses, _, _ = dpo_loss( - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - beta=args.dpo_beta, - label_smoothing=args.dpo_label_smoothing, - ) - elif args.dpo_loss_type == "simpo": - losses, _, _ = simpo_loss( - policy_chosen_logps, - policy_rejected_logps, - beta=args.dpo_beta, - gamma_beta_ratio=args.dpo_gamma_beta_ratio, - label_smoothing=args.dpo_label_smoothing, - ) - elif args.dpo_loss_type == "wpo": - losses, _, _ = wpo_loss( - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - beta=args.dpo_beta, - label_smoothing=args.dpo_label_smoothing, - chosen_loss_mask=batch["chosen_labels"] != -100, - rejected_loss_mask=batch["rejected_labels"] != -100, - ) - else: - raise ValueError(f"Invalid dpo loss type {args.dpo_loss_type}.") - # TODO: metric logging - loss = losses.mean() - if args.load_balancing_loss: - weighted_aux_loss = args.load_balancing_weight * aux_loss - loss += weighted_aux_loss - accelerator.backward(loss) - # clip gradient norm. don't do this with deepspeed - if accelerator.sync_gradients and args.clip_grad_norm > 0: - accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm) - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() - - # We keep track of the loss at each logged step - with torch.no_grad(): - local_metrics[0] += loss - if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": - chosen_rewards = (args.dpo_beta * (policy_chosen_logps - reference_chosen_logps)).mean() - rejected_rewards = (args.dpo_beta * (policy_rejected_logps - reference_rejected_logps)).mean() - average_rewards = (chosen_rewards + rejected_rewards) / 2 - accuracy = (chosen_rewards > rejected_rewards).float().mean() - margin = (chosen_rewards - rejected_rewards).mean() - local_metrics[1] += chosen_rewards - local_metrics[2] += rejected_rewards - local_metrics[3] += average_rewards - local_metrics[4] += accuracy - local_metrics[5] += margin - local_metrics[6] += policy_chosen_logps.mean() - local_metrics[7] += policy_rejected_logps.mean() - if args.load_balancing_loss: - local_metrics[19] += weighted_aux_loss - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - completed_steps += 1 - if args.logging_steps and completed_steps % args.logging_steps == 0: - # single all reduce to save time, avoiding per metric all reduce - global_metrics = accelerator.reduce(local_metrics, reduction="mean") - global_metrics /= args.gradient_accumulation_steps * args.logging_steps - global_metrics = global_metrics.tolist() - metrics_to_log = { - "training_step": completed_steps, - "learning_rate": lr_scheduler.get_last_lr()[0], - "epoch": episode / len(train_dataset), - "train_loss": global_metrics[0], - "logps/chosen": global_metrics[6], - "logps/rejected": global_metrics[7], - } - if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": - metrics_to_log.update( - { - "rewards/chosen": global_metrics[1], - "rewards/rejected": global_metrics[2], - "rewards/average": global_metrics[3], - "rewards/accuracy": global_metrics[4], - "rewards/margin": global_metrics[5], - } - ) - logger_str = ( - f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {global_metrics[0]}" - ) - if args.load_balancing_loss: - logger_str += f" Aux Loss: {global_metrics[19]}" - metrics_to_log["aux_loss"] = global_metrics[19] - logger.info(logger_str) - if args.with_tracking: - accelerator.log( - metrics_to_log, - step=completed_steps, - ) - # Reset the local metrics - local_metrics.zero_() - - if isinstance(checkpointing_steps, int): - if completed_steps % checkpointing_steps == 0: - output_dir = f"step_{completed_steps}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints - with open( - os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w" - ) as f: - f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. - if accelerator.is_local_main_process: - clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) - accelerator.wait_for_everyone() - - if completed_steps >= args.max_train_steps: - break - - if checkpointing_steps == "epoch": - output_dir = f"epoch_{epoch}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints - with open(os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w") as f: - f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. - if accelerator.is_local_main_process: - clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) - accelerator.wait_for_everyone() - - if args.output_dir is not None: - save_with_accelerate( - accelerator, - model, - tokenizer, - args.output_dir, - args.use_lora, - ) - - # remove all checkpoints to save space - if accelerator.is_local_main_process: - clean_last_n_checkpoints(args.output_dir, keep_last_n_checkpoints=0) - - if ( - args.try_auto_save_to_beaker - and accelerator.is_main_process - and len(beaker_config.beaker_dataset_id_urls) > 0 - and args.output_dir.rstrip("/") != "/output" - ): - shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) - - if is_beaker_job() and accelerator.is_main_process: - # dpo script only supports these two options right now for datasets - if args.dataset_mixer: - dataset_list = list(args.dataset_mixer.keys()) - elif args.dataset_mixer_list: - dataset_list = args.dataset_mixer_list[::2] # even indices - elif args.dataset_name: - dataset_list = [args.dataset_name] - else: - dataset_list = [args.train_file] - # mainly just focussing here on what would be useful for the leaderboard. - # wandb will have even more useful information. - metadata_blob = { - "model_name": args.exp_name, - "model_type": "dpo", - "datasets": dataset_list, - "base_model": args.model_name_or_path, - "wandb_path": wandb_tracker.run.get_url(), - "beaker_experiment": beaker_config.beaker_experiment_url, - "beaker_datasets": beaker_config.beaker_dataset_id_urls, - } - # save metadata to the output directory. then it should also get pushed to HF. - with open(os.path.join(args.output_dir, "metadata.json"), "w") as f: - json.dump(metadata_blob, f) - - # upload metadata to the dataset if set - if args.hf_metadata_dataset: - upload_metadata_to_hf( - metadata_blob, - "metadata.json", - args.hf_metadata_dataset, - "results/" + args.run_name, # to match what the auto-evals name as. - ) - - if args.try_launch_beaker_eval_jobs: - command = f"""\ - python mason.py \ - --cluster ai2/ganymede-cirrascale ai2/ceres-cirrascale ai2/neptune-cirrascale ai2/saturn-cirrascale ai2/jupiter-cirrascale-2 \ - --priority low \ - --preemptible \ - --budget ai2/allennlp \ - --workspace ai2/tulu-2-improvements \ - --image nathanl/open_instruct_auto \ - --pure_docker_mode \ - --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ - --beaker_workload_id {beaker_config.beaker_workload_id} \ - --upload_to_hf {args.hf_metadata_dataset} \ - --model_name {args.run_name} \ - --run_id {wandb_tracker.run.get_url()} - """ - process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") - print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") - print(f"Submit jobs after model training is finished - process return code: {process.returncode}") - - if args.push_to_hub: - push_folder_to_hub( - accelerator, - args.output_dir, - args.hf_repo_id, - args.hf_repo_revision, - ) - accelerator.wait_for_everyone() - if args.with_tracking: - accelerator.end_training() - - -if __name__ == "__main__": - parser = ArgumentParserPlus((FlatArguments)) - args = parser.parse() - main(args) diff --git a/open_instruct/dpo_tune_cache1.py b/open_instruct/dpo_tune_cache1.py deleted file mode 100644 index c478261ea..000000000 --- a/open_instruct/dpo_tune_cache1.py +++ /dev/null @@ -1,1077 +0,0 @@ -# !/usr/bin/env python -# coding=utf-8 -# Copyright 2024 AllenAI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -DPO tuning script. Adapted from our finetuning script. -""" - -import json -import logging -import math -import os -import random -import shutil -import subprocess -import time -from dataclasses import dataclass, field -from datetime import timedelta -from functools import partial -from typing import Callable, List, Optional, Union - -import datasets -import deepspeed -import torch -import torch.utils -import torch.utils.data -import transformers -from accelerate import Accelerator, DataLoaderConfiguration -from accelerate.logging import get_logger -from accelerate.utils import InitProcessGroupKwargs, set_seed -from datasets import load_dataset -from huggingface_hub import HfApi -from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - GPT2Tokenizer, - GPTNeoXTokenizerFast, - LlamaTokenizer, - LlamaTokenizerFast, - OPTForCausalLM, - get_scheduler, -) - -from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_preference -from open_instruct.dpo_utils import ( - DataCollatorForSeq2SeqDPO, - concatenated_forward, - dpo_loss, - separate_forward, - simpo_loss, - wpo_loss, -) -from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate -from open_instruct.utils import ( - ArgumentParserPlus, - clean_last_n_checkpoints, - get_last_checkpoint_path, - get_wandb_tags, - is_beaker_job, - maybe_get_beaker_config, - maybe_use_ai2_hf_entity, - maybe_use_ai2_wandb_entity, - upload_metadata_to_hf, -) - -logger = get_logger(__name__) - - -@dataclass -class FlatArguments: - """ - Full arguments class for all fine-tuning jobs. - """ - - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """The name of this experiment""" - run_name: Optional[str] = None - """A unique name of this run""" - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." - ) - }, - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - dpo_use_paged_optimizer: bool = field( - default=False, - metadata={ - "help": "Use paged optimizer from bitsandbytes." - " Not compatible with deepspeed (use deepspeed config instead)." - }, - ) - dpo_beta: float = field( - default=0.1, - metadata={"help": "Beta parameter for DPO loss. Default is 0.1."}, - ) - dpo_loss_type: str = field( - default="dpo", - metadata={"help": "Type of DPO loss to use. Options are 'dpo', 'dpo_norm', 'simpo', 'wpo'."}, - ) - dpo_gamma_beta_ratio: float = field( - default=0.3, - metadata={"help": "Gamma to beta ratio for SimPO loss. Default is 0.3. Not used for DPO loss."}, - ) - dpo_label_smoothing: float = field( - default=0.0, - metadata={"help": "Label smoothing for DPO/SimPO loss. Default is 0 (no smoothing)."}, - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - tokenizer_revision: Optional[str] = field( - default=None, - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - chat_template_name: str = field( - default="tulu", - metadata={ - "help": ( - f"The name of the chat template to use. " - f"You can choose one of our pre-defined templates: {', '.join(CHAT_TEMPLATES.keys())}." - f"Or, you can provide a tokenizer name or path here and we will apply its chat template." - ) - }, - ) - use_flash_attn: bool = field( - default=True, - metadata={"help": "Whether to use flash attention in the model training"}, - ) - use_slow_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the slow tokenizer or not (which is then fast tokenizer)."}, - ) - model_revision: Optional[str] = field( - default=None, - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - trust_remote_code: bool = field( - default=False, - metadata={ - "help": ( - "Whether or not to allow for custom models defined on the Hub in their own modeling files. " - "This option should only be set to `True` for repositories you trust and in which you " - "have read the code, as it will execute code present on the Hub on your local machine." - ) - }, - ) - low_cpu_mem_usage: bool = field( - default=False, - metadata={ - "help": ( - "It is an option to create the model as an empty shell, " - "then only materialize its parameters when the pretrained weights are loaded. " - "set True will benefit LLM loading time and RAM consumption." - ) - }, - ) - dataset_name: Optional[str] = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - dataset_mixer: Optional[dict] = field( - default=None, metadata={"help": "A dictionary of datasets (local or HF) to sample from."} - ) - dataset_mixer_list: Optional[list[str]] = field( - default=None, metadata={"help": "A list of datasets (local or HF) to sample from."} - ) - dataset_mix_dir: Optional[str] = field( - default=None, metadata={"help": "The directory to save the mixed dataset to disk."} - ) - dataset_config_name: Optional[str] = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} - ) - train_file: Optional[str] = field( - default=None, metadata={"help": "The input training data file (a json/jsonl file)."} - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - max_seq_length: Optional[int] = field( - default=None, - metadata={ - "help": ( - "The maximum total input sequence length after tokenization. " - "Sequences longer than this will be truncated," - ) - }, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) - add_bos: bool = field( - default=False, - metadata={ - "help": "Forcibly add bos token to the beginning of the input sequence." - " Use only when tokenizer does not add bos token by default." - }, - ) - clip_grad_norm: float = field( - default=-1, - metadata={"help": "Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead)."}, - ) - gradient_accumulation_steps: int = field( - default=1, - metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, - ) - learning_rate: float = field( - default=2e-5, - metadata={"help": "The initial learning rate for AdamW optimizer."}, - ) - logging_steps: Optional[int] = field( - default=None, - metadata={"help": "Log the training loss and learning rate every logging_steps steps."}, - ) - lora_rank: int = field( - default=64, - metadata={"help": "The rank of lora."}, - ) - lora_alpha: float = field( - default=16, - metadata={"help": "The alpha parameter of lora."}, - ) - lora_dropout: float = field( - default=0.1, - metadata={"help": "The dropout rate of lora modules."}, - ) - lr_scheduler_type: str = field( - default="linear", - metadata={ - "help": "The scheduler type to use for learning rate adjustment.", - "choices": ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], - }, - ) - num_train_epochs: int = field( - default=2, - metadata={"help": "Total number of training epochs to perform."}, - ) - output_dir: str = field( - default="output/", - metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, - ) - per_device_train_batch_size: int = field( - default=8, - metadata={"help": "Batch size per GPU/TPU core/CPU for training."}, - ) - use_lora: bool = field( - default=False, - metadata={"help": "If True, will use LORA (low-rank parameter-efficient training) to train the model."}, - ) - use_qlora: bool = field( - default=False, - metadata={"help": "Use qLoRA training - initializes model in quantized form. Not compatible with deepspeed."}, - ) - use_8bit_optimizer: bool = field( - default=False, - metadata={"help": "Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed."}, - ) - warmup_ratio: float = field( - default=0.03, - metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}, - ) - weight_decay: float = field( - default=0.0, - metadata={"help": "Weight decay for AdamW if we apply some."}, - ) - timeout: int = field( - default=1800, - metadata={ - "help": "Timeout for the training process in seconds." - "Useful if tokenization process is long. Default is 1800 seconds (30 minutes)." - }, - ) - reduce_loss: str = field( - default="mean", - metadata={ - "help": "How to reduce loss over tokens. Options are 'mean' or 'sum'." - "Using 'sum' can improve chat model performance." - }, - ) - wandb_entity: Optional[str] = field( - default=None, - metadata={"help": "Entity to use for logging to wandb."}, - ) - resume_from_checkpoint: Optional[str] = field( - default=None, - metadata={"help": "If the training should continue from a checkpoint folder."}, - ) - with_tracking: bool = field( - default=False, - metadata={"help": "Whether to enable experiment trackers for logging."}, - ) - report_to: Union[str, List[str]] = field( - default="all", - metadata={ - "help": "The integration(s) to report results and logs to. " - "Can be a single string or a list of strings. " - "Options are 'tensorboard', 'wandb', 'comet_ml', 'clearml', or 'all'. " - "Specify multiple by listing them: e.g., ['tensorboard', 'wandb']" - }, - ) - save_to_hub: Optional[str] = field( - default=None, - metadata={"help": "Save the model to the Hub under this name. E.g allenai/your-model"}, - ) - gradient_checkpointing: bool = field( - default=False, - metadata={"help": "Turn on gradient checkpointing. Saves memory but slows training."}, - ) - max_train_steps: Optional[int] = field( - default=None, - metadata={"help": "If set, overrides the number of training steps. Otherwise, num_train_epochs is used."}, - ) - seed: int = field(default=42, metadata={"help": "Random seed for initialization and dataset shuffling."}) - checkpointing_steps: Optional[str] = field( - default=None, - metadata={ - "help": "Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." # noqa - }, - ) - overwrite_output_dir: bool = field( - default=False, - metadata={ - "help": "Overwrite the content of the output directory. Means that resumption will always start from scratch." - }, - ) - keep_last_n_checkpoints: int = field( - default=3, - metadata={"help": "How many checkpoints to keep in the output directory. -1 for all."}, - ) - fused_optimizer: bool = field( - default=True, - metadata={ - "help": "Whether to use fused AdamW or not.", - }, - ) - load_balancing_loss: bool = field( - default=False, - metadata={ - "help": "Whether to include a load balancing loss (for OLMoE) or not.", - }, - ) - load_balancing_weight: float = field( - default=0.001, - metadata={"help": "Weight for load balancing loss if applicable."}, - ) - cache_dataset_only: bool = False - """Immediately exit after caching the dataset""" - concatenated_forward: bool = True - """Whether to concatenate chosen and rejected for DPO training; True is good but you can set to False for saving memory.""" - try_auto_save_to_beaker: bool = True - """Whether to try to save the model to Beaker dataset `/output` after training""" - push_to_hub: bool = True - """Whether to upload the saved model to huggingface""" - hf_entity: Optional[str] = None - """The user or org name of the model repository from the Hugging Face Hub""" - hf_repo_id: Optional[str] = None - """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_revision: Optional[str] = None - """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_url: Optional[str] = None - """The url of the saved model in the Hugging Face Hub (will be autoset)""" - try_launch_beaker_eval_jobs: bool = True - """Whether to launch beaker evaluation jobs after training""" - hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" - """What dataset to upload the metadata to. If unset, don't upload metadata""" - - def __post_init__(self): - if self.reduce_loss not in ["mean", "sum"]: - raise ValueError("reduce_loss must be either 'mean' or 'sum'") - if ( - self.dataset_name is None - and self.train_file is None - and self.dataset_mixer is None - and self.dataset_mixer_list is None - ): - raise ValueError("Need either a dataset name, dataset mixer, or a training file.") - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - assert extension in ["json", "jsonl"], "`train_file` should be a json or a jsonl file." - if ( - (self.dataset_name is not None and (self.dataset_mixer is not None or self.dataset_mixer_list is not None)) - or (self.dataset_name is not None and self.train_file is not None) - or ( - (self.dataset_mixer is not None or self.dataset_mixer_list is not None) and self.train_file is not None - ) - or (self.dataset_mixer is not None and self.dataset_mixer_list is not None) - ): - raise ValueError("Cannot provide two dataset selection mechanisms.") - - if self.try_launch_beaker_eval_jobs and not self.push_to_hub: - raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") - - - -def get_cache_ref_logprobs( - model: torch.nn.Module, - active_dataloader: torch.utils.data.DataLoader, - accelerator: Accelerator, - average_log_prob: bool, - last_checkpoint_path: Optional[str], - resume_step: int, - epoch_range: range, - forward_fn: Callable, -): - epoch_cached_reference_chosen_logps = [] - epoch_cached_reference_rejected_logps = [] - for epoch in epoch_range: - active_dataloader.set_epoch(epoch) - if last_checkpoint_path and resume_step is not None: - # We skip the first `n` batches in the dataloader when resuming from a checkpoint - active_dataloader = accelerator.skip_first_batches(active_dataloader, resume_step) - cached_reference_chosen_logps = [] - cached_reference_rejected_logps = [] - with torch.no_grad(): - for step, batch in tqdm(enumerate(active_dataloader), disable=not accelerator.is_local_main_process): - if args.use_lora: - with accelerator.unwrap_model(model).disable_adapter(): - reference_chosen_logps, reference_rejected_logps, _ = forward_fn( - model, batch, average_log_prob=average_log_prob - ) - else: - reference_chosen_logps, reference_rejected_logps, _ = forward_fn( - model, batch, average_log_prob=average_log_prob - ) - cached_reference_chosen_logps.append(reference_chosen_logps.cpu()) - cached_reference_rejected_logps.append(reference_rejected_logps.cpu()) - epoch_cached_reference_chosen_logps.append(cached_reference_chosen_logps) - epoch_cached_reference_rejected_logps.append(cached_reference_rejected_logps) - return epoch_cached_reference_chosen_logps, epoch_cached_reference_rejected_logps - - -def main(args: FlatArguments): - init_gpu_memory = None - if torch.cuda.is_available(): - init_gpu_memory = torch.cuda.mem_get_info()[0] - # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. - # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers - # in the environment - args.run_name = f"{args.exp_name}__{args.model_name_or_path.replace('/', '_')}__{args.seed}__{int(time.time())}" - if args.push_to_hub: - if args.hf_repo_id is None: # auto-generate one - args.hf_repo_id = "open_instruct_dev" - if args.hf_entity is None: # first try to use AI2 entity - args.hf_entity = maybe_use_ai2_hf_entity() - if args.hf_entity is None: # then try to use the user's entity - args.hf_entity = HfApi().whoami()["name"] - args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" - if args.hf_repo_revision is None: - args.hf_repo_revision = args.run_name - args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" - - if is_beaker_job(): - beaker_config = maybe_get_beaker_config() - - accelerator_log_kwargs = {} - - if args.with_tracking: - accelerator_log_kwargs["log_with"] = args.report_to - accelerator_log_kwargs["project_dir"] = args.output_dir - - # if you get timeouts (e.g. due to long tokenization) increase this. - timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout)) - dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True) - - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - dataloader_config=dataloader_config, - **accelerator_log_kwargs, - kwargs_handlers=[timeout_kwargs], - ) - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - accelerator.wait_for_everyone() - - tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision - tokenizer_name = args.tokenizer_name if args.tokenizer_name is not None else args.model_name_or_path - if tokenizer_revision != args.model_revision: - # Warn user if tokenizer and model use different revisions; this is an unusual - # use case. - warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different - from the model revision `{args.model_revision}`.""" - logger.warning(warning) - tc = TokenizerConfig( - model_name_or_path=tokenizer_name, - revision=args.model_revision, - use_fast=not args.use_slow_tokenizer, - chat_template_name=args.chat_template_name, - add_bos=args.add_bos, - ) - tokenizer = tc.tokenizer - if args.dataset_mixer is not None: - args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair] - with accelerator.main_process_first(): - train_dataset = get_cached_dataset_tulu_preference( - args.dataset_mixer_list, - tc, - args.max_seq_length, - ) - train_dataset.shuffle(seed=args.seed) - train_dataset.set_format(type="pt") - if args.cache_dataset_only: - return - - # Load pretrained model and tokenizer - if args.config_name: - config = AutoConfig.from_pretrained( - args.config_name, - revision=args.model_revision, - trust_remote_code=args.trust_remote_code, - ) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - trust_remote_code=args.trust_remote_code, - ) - else: - raise ValueError( - "You are instantiating a new config instance from scratch. This is not supported by this script." - ) - - def load_model(): - if args.model_name_or_path: - if args.use_qlora: - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - ) - device_index = accelerator.local_process_index - device_map = {"": device_index} # force data-parallel training. - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - trust_remote_code=args.trust_remote_code, - quantization_config=bnb_config, - device_map=device_map, - torch_dtype=torch.bfloat16, - use_flash_attention_2=True if args.use_flash_attn else False, - ) - else: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - trust_remote_code=args.trust_remote_code, - low_cpu_mem_usage=args.low_cpu_mem_usage, - use_flash_attention_2=True if args.use_flash_attn else False, - ) - else: - logger.info("Training new model from scratch") - model = AutoModelForCausalLM.from_config(config) - return model - - model = load_model() - print("=============model loaded") - print_gpu_stats(init_gpu_memory) - - # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch - # on a small vocab and want a smaller embedding size, remove this test. - # gather deepspeed to get "real" embedding size - embeddings = model.get_input_embeddings() - with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): - if len(tokenizer) > embeddings.weight.shape[0]: - model.resize_token_embeddings(len(tokenizer)) - - if args.use_lora: - if args.use_qlora: - model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) - - logger.info("Initializing LORA model...") - peft_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - target_modules=["q_proj", "o_proj", "v_proj", "k_proj", "gate_proj", "up_proj", "down_proj"], - ) - model = get_peft_model(model, peft_config) - model.print_trainable_parameters() - elif args.gradient_checkpointing: - model.gradient_checkpointing_enable() - - # debugging tool for fewer samples - if args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), args.max_train_samples) - logger.info(f"Limiting training samples to {max_train_samples} from {len(train_dataset)}.") - train_dataset = train_dataset.select(range(max_train_samples)) - - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - # DataLoaders creation: - train_dataloader = DataLoader( - train_dataset, - shuffle=True, - collate_fn=DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest"), - batch_size=args.per_device_train_batch_size, - ) - - # Optimizer - # Split weights in two groups, one with weight decay and the other not. - no_decay = ["bias", "layer_norm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": args.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - if args.use_qlora or args.dpo_use_paged_optimizer: - from bitsandbytes.optim import AdamW - - optimizer = AdamW( - optimizer_grouped_parameters, - lr=args.learning_rate, - optim_bits=8 if args.use_8bit_optimizer else 32, - is_paged=True, - ) - else: - optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, fused=args.fused_optimizer) - print("=============optimizer loaded") - print_gpu_stats(init_gpu_memory) - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - # Create the learning rate scheduler. - # Note: the current accelerator.step() calls the .step() of the real scheduler - # for the `num_processes` times. This is because they assume - # the user initialize the scheduler with the entire training set. - # In the case of data parallel training, each process only - # sees a subset (1/num_processes) of the training set. - # So each time the process needs to update the lr multiple times so that the total - # number of updates in the end matches the num_training_steps here. - # Here we need to set the num_training_steps to either using the - # entire training set (when epochs is specified) or we need to multiply the - # num_training_steps by num_processes so that the total number of - # updates matches the num_training_steps. - num_training_steps_for_scheduler = ( - args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes - ) - lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_training_steps=num_training_steps_for_scheduler, - num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), - ) - # Prepare everything with `accelerator`. - model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler - ) - print("=============accelerate prepared") - print_gpu_stats(init_gpu_memory) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # Figure out how many steps we should save the Accelerator states - checkpointing_steps = args.checkpointing_steps - if checkpointing_steps is not None and str(checkpointing_steps).lower() != "epoch": - checkpointing_steps = int(checkpointing_steps) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if args.with_tracking: - experiment_config = vars(args) - # TensorBoard cannot log Enums, need the raw value - experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"] - - # (Optional) Ai2 internal tracking - if args.wandb_entity is None: - args.wandb_entity = maybe_use_ai2_wandb_entity() - if is_beaker_job(): - experiment_config.update(vars(beaker_config)) - accelerator.init_trackers( - "open_instruct_internal", - experiment_config, - init_kwargs={ - "wandb": { - "name": args.run_name, - "entity": args.wandb_entity, - "tags": [args.exp_name] + get_wandb_tags(), - } - }, - ) - wandb_tracker = accelerator.get_tracker("wandb") - - # Train! - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - - completed_steps = 0 - starting_epoch = 0 - - # Potentially load in the weights and states from a previous save - last_checkpoint_path = get_last_checkpoint_path(args) - resume_step = None - if last_checkpoint_path: - accelerator.print(f"Resumed from checkpoint: {last_checkpoint_path}") - accelerator.load_state(last_checkpoint_path) - # Extract `epoch_{i}` or `step_{i}` - last_checkpoint_path = os.path.basename(last_checkpoint_path) - training_difference = os.path.splitext(last_checkpoint_path)[0] - - if "epoch" in training_difference: - starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None - completed_steps = starting_epoch * num_update_steps_per_epoch - else: - # need to multiply `gradient_accumulation_steps` to reflect real steps - resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps - starting_epoch = resume_step // len(train_dataloader) - completed_steps = resume_step // args.gradient_accumulation_steps - resume_step -= starting_epoch * len(train_dataloader) - - print(f"Starting from epoch {starting_epoch} and step {completed_steps}.") - - print("=============before cache logprobs") - print_gpu_stats(init_gpu_memory) - - # Cache the logprobs - average_log_prob_loss_types = ["simpo", "dpo_norm"] - average_log_prob = args.dpo_loss_type in average_log_prob_loss_types - forward_fn = concatenated_forward if args.concatenated_forward else separate_forward - if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": - epoch_cached_reference_chosen_logps, epoch_cached_reference_rejected_logps = get_cache_ref_logprobs( - model, - train_dataloader, - accelerator, - average_log_prob, - last_checkpoint_path, - resume_step, - range(starting_epoch, args.num_train_epochs), - forward_fn, - ) - print("=============after cache logprobs") - print_gpu_stats(init_gpu_memory) - torch.cuda.empty_cache() # clear cache - - print("=============after cache logprobs; clear cache") - print_gpu_stats(init_gpu_memory) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - # update the progress_bar if load from checkpoint - progress_bar.update(completed_steps) - - local_metrics = torch.zeros((20), device=accelerator.device) - episode = 0 - for epoch in range(starting_epoch, args.num_train_epochs): - model.train() - train_dataloader.set_epoch(epoch) - if last_checkpoint_path and resume_step is not None: - # We skip the first `n` batches in the dataloader when resuming from a checkpoint - active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) - else: - active_dataloader = train_dataloader - # we need to average the log probs for simpo loss - for step, batch in enumerate(active_dataloader): - episode += len(batch["chosen_input_ids"]) * accelerator.num_processes - # dpo forward pass & loss - with accelerator.accumulate(model): - policy_chosen_logps, policy_rejected_logps, aux_loss = forward_fn( - model, batch, average_log_prob=average_log_prob, output_router_logits=args.load_balancing_loss - ) # `aux_loss` is only used when `args.load_balancing_loss = True` - if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": - p_device = policy_chosen_logps.device - reference_chosen_logps = epoch_cached_reference_chosen_logps[epoch][step].to(p_device) - reference_rejected_logps = epoch_cached_reference_rejected_logps[epoch][step].to(p_device) - losses, _, _ = dpo_loss( - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - beta=args.dpo_beta, - label_smoothing=args.dpo_label_smoothing, - ) - elif args.dpo_loss_type == "simpo": - losses, _, _ = simpo_loss( - policy_chosen_logps, - policy_rejected_logps, - beta=args.dpo_beta, - gamma_beta_ratio=args.dpo_gamma_beta_ratio, - label_smoothing=args.dpo_label_smoothing, - ) - elif args.dpo_loss_type == "wpo": - losses, _, _ = wpo_loss( - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - beta=args.dpo_beta, - label_smoothing=args.dpo_label_smoothing, - chosen_loss_mask=batch["chosen_labels"] != -100, - rejected_loss_mask=batch["rejected_labels"] != -100, - ) - else: - raise ValueError(f"Invalid dpo loss type {args.dpo_loss_type}.") - # TODO: metric logging - loss = losses.mean() - if args.load_balancing_loss: - weighted_aux_loss = args.load_balancing_weight * aux_loss - loss += weighted_aux_loss - accelerator.backward(loss) - # clip gradient norm. don't do this with deepspeed - if accelerator.sync_gradients and args.clip_grad_norm > 0: - accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm) - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() - - # We keep track of the loss at each logged step - with torch.no_grad(): - local_metrics[0] += loss - if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": - chosen_rewards = (args.dpo_beta * (policy_chosen_logps - reference_chosen_logps)).mean() - rejected_rewards = (args.dpo_beta * (policy_rejected_logps - reference_rejected_logps)).mean() - average_rewards = (chosen_rewards + rejected_rewards) / 2 - accuracy = (chosen_rewards > rejected_rewards).float().mean() - margin = (chosen_rewards - rejected_rewards).mean() - local_metrics[1] += chosen_rewards - local_metrics[2] += rejected_rewards - local_metrics[3] += average_rewards - local_metrics[4] += accuracy - local_metrics[5] += margin - local_metrics[6] += policy_chosen_logps.mean() - local_metrics[7] += policy_rejected_logps.mean() - if args.load_balancing_loss: - local_metrics[19] += weighted_aux_loss - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - completed_steps += 1 - if args.logging_steps and completed_steps % args.logging_steps == 0: - # single all reduce to save time, avoiding per metric all reduce - global_metrics = accelerator.reduce(local_metrics, reduction="mean") - global_metrics /= args.gradient_accumulation_steps * args.logging_steps - global_metrics = global_metrics.tolist() - metrics_to_log = { - "training_step": completed_steps, - "learning_rate": lr_scheduler.get_last_lr()[0], - "epoch": episode / len(train_dataset), - "train_loss": global_metrics[0], - "logps/chosen": global_metrics[6], - "logps/rejected": global_metrics[7], - } - if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": - metrics_to_log.update( - { - "rewards/chosen": global_metrics[1], - "rewards/rejected": global_metrics[2], - "rewards/average": global_metrics[3], - "rewards/accuracy": global_metrics[4], - "rewards/margin": global_metrics[5], - } - ) - logger_str = ( - f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {global_metrics[0]}" - ) - if args.load_balancing_loss: - logger_str += f" Aux Loss: {global_metrics[19]}" - metrics_to_log["aux_loss"] = global_metrics[19] - logger.info(logger_str) - if args.with_tracking: - accelerator.log( - metrics_to_log, - step=completed_steps, - ) - # Reset the local metrics - local_metrics.zero_() - - if isinstance(checkpointing_steps, int): - if completed_steps % checkpointing_steps == 0: - output_dir = f"step_{completed_steps}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints - with open( - os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w" - ) as f: - f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. - if accelerator.is_local_main_process: - clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) - accelerator.wait_for_everyone() - - if completed_steps >= args.max_train_steps: - break - - if checkpointing_steps == "epoch": - output_dir = f"epoch_{epoch}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints - with open(os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w") as f: - f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. - if accelerator.is_local_main_process: - clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) - accelerator.wait_for_everyone() - - if args.output_dir is not None: - save_with_accelerate( - accelerator, - model, - tokenizer, - args.output_dir, - args.use_lora, - ) - - # remove all checkpoints to save space - if accelerator.is_local_main_process: - clean_last_n_checkpoints(args.output_dir, keep_last_n_checkpoints=0) - - if ( - args.try_auto_save_to_beaker - and accelerator.is_main_process - and len(beaker_config.beaker_dataset_id_urls) > 0 - and args.output_dir.rstrip("/") != "/output" - ): - shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) - - if is_beaker_job() and accelerator.is_main_process: - # dpo script only supports these two options right now for datasets - if args.dataset_mixer: - dataset_list = list(args.dataset_mixer.keys()) - elif args.dataset_mixer_list: - dataset_list = args.dataset_mixer_list[::2] # even indices - elif args.dataset_name: - dataset_list = [args.dataset_name] - else: - dataset_list = [args.train_file] - # mainly just focussing here on what would be useful for the leaderboard. - # wandb will have even more useful information. - metadata_blob = { - "model_name": args.exp_name, - "model_type": "dpo", - "datasets": dataset_list, - "base_model": args.model_name_or_path, - "wandb_path": wandb_tracker.run.get_url(), - "beaker_experiment": beaker_config.beaker_experiment_url, - "beaker_datasets": beaker_config.beaker_dataset_id_urls, - } - # save metadata to the output directory. then it should also get pushed to HF. - with open(os.path.join(args.output_dir, "metadata.json"), "w") as f: - json.dump(metadata_blob, f) - - # upload metadata to the dataset if set - if args.hf_metadata_dataset: - upload_metadata_to_hf( - metadata_blob, - "metadata.json", - args.hf_metadata_dataset, - "results/" + args.run_name, # to match what the auto-evals name as. - ) - - if args.try_launch_beaker_eval_jobs: - command = f"""\ - python mason.py \ - --cluster ai2/ganymede-cirrascale ai2/ceres-cirrascale ai2/neptune-cirrascale ai2/saturn-cirrascale ai2/jupiter-cirrascale-2 \ - --priority low \ - --preemptible \ - --budget ai2/allennlp \ - --workspace ai2/tulu-2-improvements \ - --image nathanl/open_instruct_auto \ - --pure_docker_mode \ - --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ - --beaker_workload_id {beaker_config.beaker_workload_id} \ - --upload_to_hf {args.hf_metadata_dataset} \ - --model_name {args.run_name} \ - --run_id {wandb_tracker.run.get_url()} - """ - process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") - print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") - print(f"Submit jobs after model training is finished - process return code: {process.returncode}") - - if args.push_to_hub: - push_folder_to_hub( - accelerator, - args.output_dir, - args.hf_repo_id, - args.hf_repo_revision, - ) - accelerator.wait_for_everyone() - if args.with_tracking: - accelerator.end_training() - - -def print_gpu_stats(init_gpu_memory: Optional[int]): - if torch.cuda.is_available(): - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - peak_memory = init_gpu_memory - free_gpu_memory - print(f"Peak memory usage: {peak_memory / 1024 ** 3:.2f} GB") - print(f"Total memory usage: {total_gpu_memory / 1024 ** 3:.2f} GB") - print(f"Free memory: {free_gpu_memory / 1024 ** 3:.2f} GB") - - -if __name__ == "__main__": - parser = ArgumentParserPlus((FlatArguments)) - args = parser.parse() - main(args) diff --git a/open_instruct/finetune1.py b/open_instruct/finetune1.py deleted file mode 100644 index ada18b103..000000000 --- a/open_instruct/finetune1.py +++ /dev/null @@ -1,929 +0,0 @@ -# !/usr/bin/env python -# coding=utf-8 -# Copyright 2024 AllenAI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import math -import os -import random -import shutil -import subprocess -import time -from dataclasses import dataclass, field -from datetime import timedelta -from functools import partial -from typing import List, Optional, Union - -import datasets -import deepspeed -import torch -import transformers -from accelerate import Accelerator, DataLoaderConfiguration -from accelerate.logging import get_logger -from accelerate.utils import InitProcessGroupKwargs, set_seed -from datasets import load_dataset -from huggingface_hub import HfApi -from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - DataCollatorForSeq2Seq, - GPT2Tokenizer, - GPTNeoXTokenizerFast, - LlamaTokenizer, - LlamaTokenizerFast, - OPTForCausalLM, - get_scheduler, -) - -from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_sft -from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate -from open_instruct.utils import ( - ArgumentParserPlus, - clean_last_n_checkpoints, - get_datasets, - get_last_checkpoint_path, - get_wandb_tags, - is_beaker_job, - maybe_get_beaker_config, - maybe_use_ai2_hf_entity, - maybe_use_ai2_wandb_entity, - upload_metadata_to_hf, -) - -logger = get_logger(__name__) - - -@dataclass -class FlatArguments: - """ - Full arguments class for all fine-tuning jobs. - """ - - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """The name of this experiment""" - run_name: Optional[str] = None - """A unique name of this run""" - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." - ) - }, - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - tokenizer_revision: Optional[str] = field( - default=None, - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - chat_template_name: str = field( - default="tulu", - metadata={ - "help": ( - f"The name of the chat template to use. " - f"You can choose one of our pre-defined templates: {', '.join(CHAT_TEMPLATES.keys())}." - f"Or, you can provide a tokenizer name or path here and we will apply its chat template." - ) - }, - ) - use_flash_attn: bool = field( - default=True, - metadata={"help": "Whether to use flash attention in the model training"}, - ) - use_slow_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the slow tokenizer or not (which is then fast tokenizer)."}, - ) - model_revision: Optional[str] = field( - default=None, - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - trust_remote_code: bool = field( - default=False, - metadata={ - "help": ( - "Whether or not to allow for custom models defined on the Hub in their own modeling files. " - "This option should only be set to `True` for repositories you trust and in which you " - "have read the code, as it will execute code present on the Hub on your local machine." - ) - }, - ) - low_cpu_mem_usage: bool = field( - default=False, - metadata={ - "help": ( - "It is an option to create the model as an empty shell, " - "then only materialize its parameters when the pretrained weights are loaded. " - "set True will benefit LLM loading time and RAM consumption." - ) - }, - ) - dataset_name: Optional[str] = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - dataset_mixer: Optional[dict] = field( - default=None, metadata={"help": "A dictionary of datasets (local or HF) to sample from."} - ) - dataset_mixer_list: Optional[list[str]] = field( - default=None, metadata={"help": "A list of datasets (local or HF) to sample from."} - ) - dataset_mix_dir: Optional[str] = field( - default=None, metadata={"help": "The directory to save the mixed dataset to disk."} - ) - dataset_config_name: Optional[str] = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} - ) - train_file: Optional[str] = field( - default=None, metadata={"help": "The input training data file (a json/jsonl file)."} - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - max_seq_length: Optional[int] = field( - default=None, - metadata={ - "help": ( - "The maximum total input sequence length after tokenization. " - "Sequences longer than this will be truncated," - ) - }, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) - add_bos: bool = field( - default=False, - metadata={ - "help": "Forcibly add bos token to the beginning of the input sequence." - " Use only when tokenizer does not add bos token by default." - }, - ) - clip_grad_norm: float = field( - default=-1, - metadata={"help": "Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead)."}, - ) - gradient_accumulation_steps: int = field( - default=1, - metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, - ) - learning_rate: float = field( - default=2e-5, - metadata={"help": "The initial learning rate for AdamW optimizer."}, - ) - logging_steps: Optional[int] = field( - default=None, - metadata={"help": "Log the training loss and learning rate every logging_steps steps."}, - ) - lora_rank: int = field( - default=64, - metadata={"help": "The rank of lora."}, - ) - lora_alpha: float = field( - default=16, - metadata={"help": "The alpha parameter of lora."}, - ) - lora_dropout: float = field( - default=0.1, - metadata={"help": "The dropout rate of lora modules."}, - ) - lr_scheduler_type: str = field( - default="linear", - metadata={ - "help": "The scheduler type to use for learning rate adjustment.", - "choices": ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], - }, - ) - num_train_epochs: int = field( - default=2, - metadata={"help": "Total number of training epochs to perform."}, - ) - output_dir: str = field( - default="output/", - metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, - ) - per_device_train_batch_size: int = field( - default=8, - metadata={"help": "Batch size per GPU/TPU core/CPU for training."}, - ) - use_lora: bool = field( - default=False, - metadata={"help": "If True, will use LORA (low-rank parameter-efficient training) to train the model."}, - ) - use_qlora: bool = field( - default=False, - metadata={"help": "Use qLoRA training - initializes model in quantized form. Not compatible with deepspeed."}, - ) - use_8bit_optimizer: bool = field( - default=False, - metadata={"help": "Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed."}, - ) - warmup_ratio: float = field( - default=0.03, - metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}, - ) - weight_decay: float = field( - default=0.0, - metadata={"help": "Weight decay for AdamW if we apply some."}, - ) - timeout: int = field( - default=1800, - metadata={ - "help": "Timeout for the training process in seconds." - "Useful if tokenization process is long. Default is 1800 seconds (30 minutes)." - }, - ) - reduce_loss: str = field( - default="mean", - metadata={ - "help": "How to reduce loss over tokens. Options are 'mean' or 'sum'." - "Using 'sum' can improve chat model performance." - }, - ) - wandb_entity: Optional[str] = field( - default=None, - metadata={"help": "Entity to use for logging to wandb."}, - ) - resume_from_checkpoint: Optional[str] = field( - default=None, - metadata={"help": "If the training should continue from a checkpoint folder."}, - ) - with_tracking: bool = field( - default=False, - metadata={"help": "Whether to enable experiment trackers for logging."}, - ) - report_to: Union[str, List[str]] = field( - default="all", - metadata={ - "help": "The integration(s) to report results and logs to. " - "Can be a single string or a list of strings. " - "Options are 'tensorboard', 'wandb', 'comet_ml', 'clearml', or 'all'. " - "Specify multiple by listing them: e.g., ['tensorboard', 'wandb']" - }, - ) - save_to_hub: Optional[str] = field( - default=None, - metadata={"help": "Save the model to the Hub under this name. E.g allenai/your-model"}, - ) - gradient_checkpointing: bool = field( - default=False, - metadata={"help": "Turn on gradient checkpointing. Saves memory but slows training."}, - ) - max_train_steps: Optional[int] = field( - default=None, - metadata={"help": "If set, overrides the number of training steps. Otherwise, num_train_epochs is used."}, - ) - seed: int = field(default=42, metadata={"help": "Random seed for initialization and dataset shuffling."}) - checkpointing_steps: Optional[str] = field( - default=None, - metadata={ - "help": "Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." # noqa - }, - ) - overwrite_output_dir: bool = field( - default=False, - metadata={ - "help": "Overwrite the content of the output directory. Means that resumption will always start from scratch." - }, - ) - keep_last_n_checkpoints: int = field( - default=3, - metadata={"help": "How many checkpoints to keep in the output directory. -1 for all."}, - ) - fused_optimizer: bool = field( - default=True, - metadata={ - "help": "Whether to use fused AdamW or not.", - }, - ) - load_balancing_loss: bool = field( - default=False, - metadata={ - "help": "Whether to include a load balancing loss (for OLMoE) or not.", - }, - ) - load_balancing_weight: float = field( - default=0.5, - metadata={"help": "Weight for load balancing loss if applicable."}, - ) - cache_dataset_only: bool = False - """Immediately exit after caching the dataset""" - try_auto_save_to_beaker: bool = True - """Whether to try to save the model to Beaker dataset `/output` after training""" - push_to_hub: bool = True - """Whether to upload the saved model to huggingface""" - hf_entity: Optional[str] = None - """The user or org name of the model repository from the Hugging Face Hub""" - hf_repo_id: Optional[str] = None - """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_revision: Optional[str] = None - """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_url: Optional[str] = None - """The url of the saved model in the Hugging Face Hub (will be autoset)""" - try_launch_beaker_eval_jobs: bool = True - """Whether to launch beaker evaluation jobs after training""" - hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" - """What dataset to upload the metadata to. If unset, don't upload metadata""" - - def __post_init__(self): - if self.reduce_loss not in ["mean", "sum"]: - raise ValueError("reduce_loss must be either 'mean' or 'sum'") - if ( - self.dataset_name is None - and self.train_file is None - and self.dataset_mixer is None - and self.dataset_mixer_list is None - ): - raise ValueError("Need either a dataset name, dataset mixer, or a training file.") - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - assert extension in ["json", "jsonl"], "`train_file` should be a json or a jsonl file." - if ( - (self.dataset_name is not None and (self.dataset_mixer is not None or self.dataset_mixer_list is not None)) - or (self.dataset_name is not None and self.train_file is not None) - or ( - (self.dataset_mixer is not None or self.dataset_mixer_list is not None) and self.train_file is not None - ) - or (self.dataset_mixer is not None and self.dataset_mixer_list is not None) - ): - raise ValueError("Cannot provide two dataset selection mechanisms.") - if self.try_launch_beaker_eval_jobs and not self.push_to_hub: - raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") - - -def main(args: FlatArguments): - # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. - # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers - # in the environment - args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - if args.push_to_hub: - if args.hf_repo_id is None: # auto-generate one - args.hf_repo_id = "open_instruct_dev" - if args.hf_entity is None: # first try to use AI2 entity - args.hf_entity = maybe_use_ai2_hf_entity() - if args.hf_entity is None: # then try to use the user's entity - args.hf_entity = HfApi().whoami()["name"] - args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" - if args.hf_repo_revision is None: - args.hf_repo_revision = args.run_name - args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" - - if is_beaker_job(): - beaker_config = maybe_get_beaker_config() - - accelerator_log_kwargs = {} - - if args.with_tracking: - accelerator_log_kwargs["log_with"] = args.report_to - accelerator_log_kwargs["project_dir"] = args.output_dir - - # if you get timeouts (e.g. due to long tokenization) increase this. - timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout)) - dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True) - - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - dataloader_config=dataloader_config, - **accelerator_log_kwargs, - kwargs_handlers=[timeout_kwargs], - ) - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - accelerator.wait_for_everyone() - - tokenizer_revision = args.model_revision if args.tokenizer_revision is None else args.tokenizer_revision - tokenizer_name = args.tokenizer_name if args.tokenizer_name is not None else args.model_name_or_path - if tokenizer_revision != args.model_revision: - # Warn user if tokenizer and model use different revisions; this is an unusual - # use case. - warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different - from the model revision `{args.model_revision}`.""" - logger.warning(warning) - tc = TokenizerConfig( - model_name_or_path=tokenizer_name, - revision=args.model_revision, - use_fast=not args.use_slow_tokenizer, - chat_template_name=args.chat_template_name, - add_bos=args.add_bos, - ) - tokenizer = tc.tokenizer - if args.dataset_mixer is not None: - args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair] - with accelerator.main_process_first(): - train_dataset = get_cached_dataset_tulu_sft( - args.dataset_mixer_list, - tc, - args.max_seq_length, - ) - train_dataset.shuffle(seed=args.seed) - train_dataset.set_format(type="pt") - if args.cache_dataset_only: - return - - # Load pretrained model and tokenizer - if args.config_name: - config = AutoConfig.from_pretrained( - args.config_name, - revision=args.model_revision, - trust_remote_code=args.trust_remote_code, - ) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - trust_remote_code=args.trust_remote_code, - ) - else: - raise ValueError( - "You are instantiating a new config instance from scratch. This is not supported by this script." - ) - - if args.model_name_or_path: - if args.use_qlora: - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - ) - device_index = accelerator.local_process_index - device_map = {"": device_index} # force data-parallel training. - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - trust_remote_code=args.trust_remote_code, - quantization_config=bnb_config, - device_map=device_map, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", - ) - else: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - trust_remote_code=args.trust_remote_code, - low_cpu_mem_usage=args.low_cpu_mem_usage, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", - ) - else: - logger.info("Training new model from scratch") - model = AutoModelForCausalLM.from_config(config) - - # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch - # on a small vocab and want a smaller embedding size, remove this test. - # gather deepspeed to get "real" embedding size - embeddings = model.get_input_embeddings() - with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): - embedding_size = embeddings.weight.shape[0] - # resize does its own gather - if len(tokenizer) > embedding_size: - # pad to multiple for tensor cores. - model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) - # update embedding size after resizing for sum loss - embeddings = model.get_input_embeddings() - with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): - embedding_size = embeddings.weight.shape[0] - - if args.use_lora: - if args.use_qlora: - model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) - - logger.info("Initializing LORA model...") - peft_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - target_modules=["q_proj", "o_proj", "v_proj", "k_proj", "gate_proj", "up_proj", "down_proj"], - ) - model = get_peft_model(model, peft_config) - model.print_trainable_parameters() - elif args.gradient_checkpointing: - model.gradient_checkpointing_enable() - - # DataLoaders creation: - train_dataloader = DataLoader( - train_dataset, - shuffle=True, - collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"), - batch_size=args.per_device_train_batch_size, - ) - - # Optimizer - # Split weights in two groups, one with weight decay and the other not. - no_decay = ["bias", "layer_norm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": args.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - if args.use_qlora: - from bitsandbytes.optim import AdamW - - optimizer = AdamW( - optimizer_grouped_parameters, - lr=args.learning_rate, - optim_bits=8 if args.use_8bit_optimizer else 32, - is_paged=True, - ) - else: - optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, fused=args.fused_optimizer) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - # Create the learning rate scheduler. - # Note: the current accelerator.step() calls the .step() of the real scheduler - # for the `num_processes` times. This is because they assume - # the user initialize the scheduler with the entire training set. - # In the case of data parallel training, each process only - # sees a subset (1/num_processes) of the training set. - # So each time the process needs to update the lr multiple times so that the total - # number of updates in the end matches the num_training_steps here. - # Here we need to set the num_training_steps to either using the - # entire training set (when epochs is specified) or we need to multiply the - # num_training_steps by num_processes so that the total number of - # updates matches the num_training_steps. - num_training_steps_for_scheduler = ( - args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes - ) - lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_training_steps=num_training_steps_for_scheduler, - num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), - ) - # Prepare everything with `accelerator`. - model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # Figure out how many steps we should save the Accelerator states - checkpointing_steps = args.checkpointing_steps - if checkpointing_steps is not None and str(checkpointing_steps).lower() != "epoch": - checkpointing_steps = int(checkpointing_steps) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if args.with_tracking: - experiment_config = vars(args) - # TensorBoard cannot log Enums, need the raw value - experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"] - - # (Optional) Ai2 internal tracking - if args.wandb_entity is None: - args.wandb_entity = maybe_use_ai2_wandb_entity() - if is_beaker_job(): - experiment_config.update(vars(beaker_config)) - accelerator.init_trackers( - "open_instruct_internal", - experiment_config, - init_kwargs={ - "wandb": { - "name": args.run_name, - "entity": args.wandb_entity, - "tags": [args.exp_name] + get_wandb_tags(), - } - }, - ) - wandb_tracker = accelerator.get_tracker("wandb") - - # Train! - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - completed_steps = 0 - starting_epoch = 0 - - # Potentially load in the weights and states from a previous save - last_checkpoint_path = get_last_checkpoint_path(args) - if last_checkpoint_path: - accelerator.print(f"Resumed from checkpoint: {last_checkpoint_path}") - accelerator.load_state(last_checkpoint_path) - # Extract `epoch_{i}` or `step_{i}` - last_checkpoint_path = os.path.basename(last_checkpoint_path) - training_difference = os.path.splitext(last_checkpoint_path)[0] - - if "epoch" in training_difference: - starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None - completed_steps = starting_epoch * num_update_steps_per_epoch - else: - # need to multiply `gradient_accumulation_steps` to reflect real steps - resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps - starting_epoch = resume_step // len(train_dataloader) - completed_steps = resume_step // args.gradient_accumulation_steps - resume_step -= starting_epoch * len(train_dataloader) - - print(f"Starting from epoch {starting_epoch} and step {completed_steps}.") - # update the progress_bar if load from checkpoint - progress_bar.update(completed_steps) - local_total_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device) - total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device) - start_time = time.time() - for epoch in range(starting_epoch, args.num_train_epochs): - model.train() - train_dataloader.set_epoch(epoch) - total_loss = 0 - total_aux_loss = 0 - if last_checkpoint_path and resume_step is not None: - # We skip the first `n` batches in the dataloader when resuming from a checkpoint - active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) - else: - active_dataloader = train_dataloader - for step, batch in enumerate(active_dataloader): - local_total_tokens += batch["attention_mask"].sum() - total_token_including_padding += batch["attention_mask"].numel() - with accelerator.accumulate(model): - if args.load_balancing_loss: - outputs = model(**batch, use_cache=False, output_router_logits=True) - else: - # TODO: we have calculated the mean loss here anyway, so doubling the calculation - outputs = model(**batch, use_cache=False) - if args.reduce_loss == "mean": - loss = outputs.loss - else: - # reduce loss is sum - # this ensures that we weight all tokens in the dataset equally, - # rather than weighting each overall example equally when - # using high amounts of gradient accumulation. - # this can result in > 5 point improvements in AlpacaEval - # see https://github.com/huggingface/transformers/issues/24725 for - # more discussion and details. - logits = outputs.logits - labels = batch["labels"] - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = torch.nn.CrossEntropyLoss(reduction="sum") - shift_logits = shift_logits.view(-1, embedding_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - if args.load_balancing_loss: - aux_loss = args.load_balancing_weight * outputs.aux_loss - loss += aux_loss - # We keep track of the loss at each logged step - total_loss += loss.detach().float() - accelerator.backward(loss) - if args.load_balancing_loss: - total_aux_loss += aux_loss.detach().float() - # clip gradient norm. don't do this with deepspeed - if accelerator.sync_gradients and args.clip_grad_norm > 0: - accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm) - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - completed_steps += 1 - if args.logging_steps and completed_steps % args.logging_steps == 0: - avg_loss = ( - accelerator.gather(total_loss).mean().item() - / args.gradient_accumulation_steps - / args.logging_steps - ) - total_tokens = accelerator.gather(local_total_tokens).sum().item() - total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item() - metrics_to_log = { - "learning_rate": lr_scheduler.get_last_lr()[0], - "train_loss": avg_loss, - "total_tokens": total_tokens, - "per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time), - "total_tokens_including_padding": total_tokens_including_padding, - "per_device_tps_including_padding": total_tokens_including_padding - / accelerator.num_processes - / (time.time() - start_time), - } - if args.load_balancing_loss: - avg_aux_loss = ( - accelerator.gather(total_aux_loss).mean().item() - / args.gradient_accumulation_steps - / args.logging_steps - ) - logger.info( - f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, Aux Loss: {avg_aux_loss}, TPS: {total_tokens / (time.time() - start_time)}" - ) - metrics_to_log["aux_loss"] = avg_aux_loss - else: - logger.info( - f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}" - ) - if args.with_tracking: - accelerator.log( - metrics_to_log, - step=completed_steps, - ) - total_loss = 0 - total_aux_loss = 0 - - if isinstance(checkpointing_steps, int): - if completed_steps % checkpointing_steps == 0: - output_dir = f"step_{completed_steps}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints - with open( - os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w" - ) as f: - f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. - if accelerator.is_local_main_process: # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem. - clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) - accelerator.wait_for_everyone() - - if completed_steps >= args.max_train_steps: - break - - if checkpointing_steps == "epoch": - output_dir = f"epoch_{epoch}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints - with open(os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w") as f: - f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. - if accelerator.is_local_main_process: - clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) - accelerator.wait_for_everyone() - - if args.output_dir is not None: - save_with_accelerate( - accelerator, - model, - tokenizer, - args.output_dir, - args.use_lora, - ) - - # remove all checkpoints to save space - if accelerator.is_local_main_process: - clean_last_n_checkpoints(args.output_dir, keep_last_n_checkpoints=0) - - if ( - args.try_auto_save_to_beaker - and accelerator.is_main_process - and len(beaker_config.beaker_dataset_id_urls) > 0 - and args.output_dir.rstrip("/") != "/output" - ): - shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) - - if is_beaker_job() and accelerator.is_main_process: - # dpo script only supports these two options right now for datasets - if args.dataset_mixer: - dataset_list = list(args.dataset_mixer.keys()) - elif args.dataset_mixer_list: - dataset_list = args.dataset_mixer_list[::2] # even indices - elif args.dataset_name: - dataset_list = [args.dataset_name] - else: - dataset_list = [args.train_file] - # mainly just focussing here on what would be useful for the leaderboard. - # wandb will have even more useful information. - metadata_blob = { - "model_name": args.exp_name, - "model_type": "sft", - "datasets": dataset_list, - "base_model": args.model_name_or_path, - "wandb_path": wandb_tracker.run.get_url(), - "beaker_experiment": beaker_config.beaker_experiment_url, - "beaker_datasets": beaker_config.beaker_dataset_id_urls, - } - # save metadata to the output directory. then it should also get pushed to HF. - with open(os.path.join(args.output_dir, "metadata.json"), "w") as f: - json.dump(metadata_blob, f) - - # upload metadata to the dataset if set - if args.hf_metadata_dataset: - upload_metadata_to_hf( - metadata_blob, - "metadata.json", - args.hf_metadata_dataset, - "results/" + args.run_name, # to match what the auto-evals name as. - ) - - if args.try_launch_beaker_eval_jobs: - command = f"""\ - python mason.py \ - --cluster ai2/ganymede-cirrascale ai2/ceres-cirrascale ai2/neptune-cirrascale ai2/saturn-cirrascale ai2/jupiter-cirrascale-2 \ - --priority low \ - --preemptible \ - --budget ai2/allennlp \ - --workspace ai2/tulu-2-improvements \ - --image nathanl/open_instruct_auto \ - --pure_docker_mode \ - --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ - --beaker_workload_id {beaker_config.beaker_workload_id} \ - --upload_to_hf {args.hf_metadata_dataset} \ - --model_name {args.run_name} \ - --run_id {wandb_tracker.run.get_url()} - """ - process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") - print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") - print(f"Submit jobs after model training is finished - process return code: {process.returncode}") - - if args.push_to_hub: - push_folder_to_hub( - accelerator, - args.output_dir, - args.hf_repo_id, - args.hf_repo_revision, - ) - accelerator.wait_for_everyone() - if args.with_tracking: - accelerator.end_training() - - -if __name__ == "__main__": - parser = ArgumentParserPlus((FlatArguments)) - args = parser.parse() - main(args) From 4dd45981a841f17b004251b3309dc30eca8d5eaa Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 21 Jan 2025 17:44:14 -0500 Subject: [PATCH 26/32] Apply suggestions from code review Co-authored-by: Nathan Lambert --- docs/ai2_internal.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/ai2_internal.md b/docs/ai2_internal.md index a4a93aa98..5bf00a9ca 100644 --- a/docs/ai2_internal.md +++ b/docs/ai2_internal.md @@ -26,7 +26,7 @@ beaker secret write -w ai2/tulu-2-improvements "${beaker_whoami}_HF_TOKEN" xxxx * `HF_HUB_CACHE=/weka/oe-adapt-default/allennlp/.cache/hub` -You can run things like below for a quick spin. +You can run things like below for a quick spin. This example just starts a beaker job to print the Python version in your beaker image. ```bash python mason.py \ @@ -40,6 +40,7 @@ python mason.py \ ### Supervised Fine-tuning (SFT): +Otherwise, the `mason.py` command can be used to launch all of our other training jobs. ```bash python mason.py \ --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ From a041044a9a21cce5973da2d606febd49f14e70d7 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 22 Jan 2025 17:06:27 -0800 Subject: [PATCH 27/32] quick change --- docs/ai2_internal.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/ai2_internal.md b/docs/ai2_internal.md index 5bf00a9ca..7a5904b32 100644 --- a/docs/ai2_internal.md +++ b/docs/ai2_internal.md @@ -43,7 +43,7 @@ python mason.py \ Otherwise, the `mason.py` command can be used to launch all of our other training jobs. ```bash python mason.py \ - --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ + --cluster ai2/jupiter-cirrascale-2 \ --workspace ai2/tulu-3-dev \ --priority normal \ --preemptible \ @@ -55,7 +55,7 @@ python mason.py \ --use_deepspeed \ --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ --deepspeed_multinode_launcher standard \ - open_instruct/finetune1.py \ + open_instruct/finetune.py \ --model_name_or_path meta-llama/Llama-3.1-8B \ --tokenizer_name meta-llama/Llama-3.1-8B \ --use_slow_tokenizer \ @@ -124,7 +124,7 @@ When you inspect the job, it's going to have the following outputs, meaning the ```bash python mason.py \ - --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ + --cluster ai2/jupiter-cirrascale-2 \ --workspace ai2/tulu-3-dev \ --priority normal \ --preemptible \ From 8e2b0f9f88906311007321604d492af0780225fa Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 24 Jan 2025 12:24:41 -0800 Subject: [PATCH 28/32] update docs and fix mason --- docs/ai2_internal.md | 85 +++++++++++++++++++++++++++++++++++++++----- mason.py | 2 +- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/docs/ai2_internal.md b/docs/ai2_internal.md index 7a5904b32..721868494 100644 --- a/docs/ai2_internal.md +++ b/docs/ai2_internal.md @@ -165,14 +165,26 @@ python mason.py \ ## RLVR: ```bash +# make sure to match up the GPUs. E.g., +# `--actor_num_gpus_per_node 6 8 8 8` +# `--vllm_tensor_parallel_size 2` +# translates to 6 + 2 + 8 + 8 + 8 = 32 GPUs +# which matches up with `--num_nodes 4 --gpus 8` +for beta in 0.05; do +exp_name="0112_ppo_rlvr_${beta}_${RANDOM}" python mason.py \ - --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ + --cluster ai2/jupiter-cirrascale-2 \ --workspace ai2/tulu-3-dev \ --priority high \ --preemptible \ --budget ai2/oe-adapt \ - --gpus 8 -- python open_instruct/ppo_vllm_thread_ray_gtrl.py \ - --exp_name tulu-3-8b-rlvr \ + --num_nodes 4 \ + --gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& python open_instruct/ppo_vllm_thread_ray_gtrl.py \ + --exp_name $exp_name \ + --beta $beta \ + --output_dir /weka/oe-adapt-default/allennlp/deletable_checkpoint/$exp_name \ + --try_launch_beaker_eval_jobs_on_weka \ + --try_launch_beaker_eval_jobs False \ --dataset_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 1.0}' \ --dataset_train_splits train \ --dataset_eval_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 128}' \ @@ -189,29 +201,84 @@ python mason.py \ --chat_template tulu \ --sft_messages_key messages \ --learning_rate 3e-7 \ - --total_episodes 10000000 \ + --total_episodes 200000 \ --penalty_reward_value -10.0 \ --deepspeed_stage 3 \ --per_device_train_batch_size 1 \ --local_rollout_forward_batch_size 1 \ --local_mini_batch_size 4 \ --local_rollout_batch_size 4 \ - --actor_num_gpus_per_node 6 \ + --actor_num_gpus_per_node 6 8 8 8 \ --vllm_tensor_parallel_size 2 \ --vllm_enforce_eager \ - --beta 0.05 \ --apply_verifiable_reward true \ - --output_dir /weka/oe-adapt-default/costah/models/tulu-3-8b-rlvr/test \ --seed 3 \ - --num_evals 3 \ - --save_freq 2 \ + --num_evals 1000 \ + --save_freq 40 \ --reward_model_multiplier 0.0 \ --gradient_checkpointing \ + --with_tracking +done +``` + +GRPO: +```bash +for beta in 0.0 0.05 0.03; do +for nspp in 4 8 16; do +exp_name="0112_grpo_math_zs_${beta}_${nspp}_${RANDOM}" +python mason.py \ + --cluster ai2/jupiter-cirrascale-2 \ + --workspace ai2/tulu-3-dev \ + --priority high \ + --preemptible \ + --num_nodes 4 \ + --budget ai2/oe-adapt \ + --gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& uv run python open_instruct/grpo_vllm_thread_ray_gtrl.py \ + --exp_name $exp_name \ + --beta $beta \ + --local_mini_batch_size 8 \ + --number_samples_per_prompt $nspp \ + --output_dir /weka/oe-adapt-default/costah/$exp_name \ + --dataset_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 1.0}" \ + --dataset_train_splits train \ + --dataset_eval_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 32}" \ + --dataset_eval_splits train \ + --max_token_length 2048 \ + --max_prompt_token_length 2048 \ + --response_length 4096 \ + --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \ + --non_stop_penalty \ + --stop_token eos \ + --temperature 1.0 \ + --ground_truths_key ground_truth \ + --chat_template tulu \ + --sft_messages_key messages \ + --learning_rate 5e-7 \ + --total_episodes 1000000 \ + --penalty_reward_value 0.0 \ + --deepspeed_stage 3 \ + --per_device_train_batch_size 2 \ + --local_rollout_forward_batch_size 2 \ + --local_rollout_batch_size 8 \ + --actor_num_gpus_per_node 7 8 8 8 \ + --num_epochs 1 \ + --vllm_tensor_parallel_size 1 \ + --lr_scheduler_type constant \ + --apply_verifiable_reward true \ + --seed 1 \ + --num_evals 1000 \ + --save_freq 40 \ + --reward_model_multiplier 0.0 \ + --no_try_launch_beaker_eval_jobs \ --try_launch_beaker_eval_jobs_on_weka \ + --gradient_checkpointing \ --with_tracking +done +done ``` + ### Ai2 Internal Evaluation We provide a script integrated with beaker for use internally at Ai2. For example, to run all the tulu 3 evals with easy uploading: diff --git a/mason.py b/mason.py index 274d5be29..262234349 100644 --- a/mason.py +++ b/mason.py @@ -440,7 +440,7 @@ def find_list_idx(lst: List[str], item: str): return i return -1 if not args.no_auto_dataset_cache: - for file in ["open_instruct/finetune1.py", "open_instruct/dpo_tune_cache1.py"]: + for file in ["open_instruct/finetune.py", "open_instruct/dpo_tune_cache.py"]: idx = find_list_idx(full_command, file) if idx != -1: # then try executing the same full_command with From 4bdc5a163b07035f236d5906b5e44af57f20d8f8 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 27 Jan 2025 13:57:46 -0800 Subject: [PATCH 29/32] quick update on docs --- docs/ai2_internal.md | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/docs/ai2_internal.md b/docs/ai2_internal.md index 721868494..9b94f4772 100644 --- a/docs/ai2_internal.md +++ b/docs/ai2_internal.md @@ -38,6 +38,32 @@ python mason.py \ --gpus 1 -- which python ``` + +### Caching model (Weka-only) + +You can run the following command to cache models to the share Hugging Face cache. We recommend this for weka (shared filesystem) users because otherwise your jobs would 1) waste time downloading the model while GPU is idle and 2) risk potential download failures. + +```bash +python mason.py \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale ai2/neptune-cirrascale --image nathanl/open_instruct_auto --pure_docker_mode \ + --workspace ai2/tulu-3-dev \ + --priority normal \ + --preemptible \ + --budget ai2/allennlp \ + --gpus 0 -- python scripts/cache_hf.py \ + --model_name_or_path "allenai/open_instruct_dev" \ + --model_revision "reward_modeling__1__1737836233" +``` + +If you have the weka environment setup you can also just run + +```bash +python scripts/cache_hf.py \ + --model_name_or_path "allenai/open_instruct_dev" \ + --model_revision "reward_modeling__1__1737836233" +``` + + ### Supervised Fine-tuning (SFT): Otherwise, the `mason.py` command can be used to launch all of our other training jobs. @@ -84,7 +110,7 @@ python mason.py \ Note that during job submission, we will try to tokenize and cache the dataset so we are not running these CPU-heavy workloads in GPU jobs. Specifically, `mason.py` will parse out `python` command you are running and attempts to run it with `--cache_dataset_only` flag. For example, you will see output like ```bash -📦📦📦 Running the caching full_command: python open_instruct/dpo_tune_cache1.py --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-SFT --use_flash_attn --tokenizer_name allenai/Llama-3.1-Tulu-3-8B-SFT --max_seq_length 2048 --preprocessing_num_workers 16 --per_device_train_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 5e-07 --lr_scheduler_type linear --warmup_ratio 0.1 --weight_decay 0.0 --num_train_epochs 1 --output_dir output/dpo_8b --with_tracking --report_to wandb --logging_steps 1 --model_revision main --gradient_checkpointing --dataset_mixer_list allenai/llama-3.1-tulu-3-8b-preference-mixture 1.0 --use_slow_tokenizer --use_lora False --dpo_loss_type dpo_norm --dpo_beta 5 --exp_name tulu-3-8b-dpo --cache_dataset_only +📦📦📦 Running the caching full_command: python open_instruct/dpo_tune_cache.py --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-SFT --use_flash_attn --tokenizer_name allenai/Llama-3.1-Tulu-3-8B-SFT --max_seq_length 2048 --preprocessing_num_workers 16 --per_device_train_batch_size 1 --gradient_accumulation_steps 16 --learning_rate 5e-07 --lr_scheduler_type linear --warmup_ratio 0.1 --weight_decay 0.0 --num_train_epochs 1 --output_dir output/dpo_8b --with_tracking --report_to wandb --logging_steps 1 --model_revision main --gradient_checkpointing --dataset_mixer_list allenai/llama-3.1-tulu-3-8b-preference-mixture 1.0 --use_slow_tokenizer --use_lora False --dpo_loss_type dpo_norm --dpo_beta 5 --exp_name tulu-3-8b-dpo --cache_dataset_only [2025-01-21 09:58:09,342] [WARNING] [real_accelerator.py:162:get_accelerator] Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it. [2025-01-21 09:58:09,354] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cpu (auto detect) Failed to get Beaker experiment: b'Error: experiment "01JD3WCQYTBPE195GVWPVMDHVV" not found\n\n' @@ -103,7 +129,7 @@ Creating parquet from Arrow format: 100%|█████████████ Kicked off Beaker job. https://beaker.org/ex/01JJ50D88M757GZD14W9CNN7NT ``` -It would be most helpful if you run the `mason.py` command on a vscode session with access to weka, that way, the dataset is also automatically downloaded to `HF_HOME`, etc. +It would be most helpful if you run the `mason.py` command on a vscode session with access to weka, that way, the dataset is also automatically downloaded to the shared `HF_HOME` on weka, etc. When you inspect the job, it's going to have the following outputs, meaning the cached dataset is found and used: @@ -136,7 +162,7 @@ python mason.py \ --use_deepspeed \ --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \ --deepspeed_multinode_launcher standard \ - open_instruct/dpo_tune_cache1.py \ + open_instruct/dpo_tune_cache.py \ --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-SFT \ --use_flash_attn \ --tokenizer_name allenai/Llama-3.1-Tulu-3-8B-SFT \ From e5b992320e1f36627b377e400bffda32676d9559 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 28 Jan 2025 05:57:27 -0800 Subject: [PATCH 30/32] use the default entity --- open_instruct/dpo_tune_cache.py | 1 + open_instruct/finetune.py | 1 + scripts/cache_hf.py | 5 ++--- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/open_instruct/dpo_tune_cache.py b/open_instruct/dpo_tune_cache.py index c478261ea..ee3a37868 100644 --- a/open_instruct/dpo_tune_cache.py +++ b/open_instruct/dpo_tune_cache.py @@ -545,6 +545,7 @@ def main(args: FlatArguments): args.dataset_mixer_list, tc, args.max_seq_length, + args.hf_entity, ) train_dataset.shuffle(seed=args.seed) train_dataset.set_format(type="pt") diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index ada18b103..235037290 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -467,6 +467,7 @@ def main(args: FlatArguments): args.dataset_mixer_list, tc, args.max_seq_length, + args.hf_entity, ) train_dataset.shuffle(seed=args.seed) train_dataset.set_format(type="pt") diff --git a/scripts/cache_hf.py b/scripts/cache_hf.py index e7455826d..88e6043e8 100644 --- a/scripts/cache_hf.py +++ b/scripts/cache_hf.py @@ -22,9 +22,8 @@ --preemptible \ --budget ai2/allennlp \ --gpus 0 -- python scripts/cache_hf.py \ - --model_name_or_path "allenai/Llama-3.1-Tulu-3-8B-DPO" \ - --model_revision "1208_dpo_13b_tune8e-7__allenai_open_instruct_dev__8__1733807565" \ - --dataset_mixer_list allenai/RLVR-GSM-MATH-IF-Mixed-Constraints 1.0 + --model_name_or_path "allenai/open_instruct_dev" \ + --model_revision "reward_modeling__1__1737836233" \ """ From 8e2356f2974d4227975e6dcc025612d8c7a7d860 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 28 Jan 2025 06:00:44 -0800 Subject: [PATCH 31/32] format --- open_instruct/dpo_tune_cache.py | 7 +++++-- open_instruct/finetune.py | 12 +++++++++--- open_instruct/olmo_adapter/__init__.py | 18 +++++++++++++----- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/open_instruct/dpo_tune_cache.py b/open_instruct/dpo_tune_cache.py index ee3a37868..6f2975915 100644 --- a/open_instruct/dpo_tune_cache.py +++ b/open_instruct/dpo_tune_cache.py @@ -57,7 +57,11 @@ get_scheduler, ) -from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_preference +from open_instruct.dataset_transformation import ( + CHAT_TEMPLATES, + TokenizerConfig, + get_cached_dataset_tulu_preference, +) from open_instruct.dpo_utils import ( DataCollatorForSeq2SeqDPO, concatenated_forward, @@ -421,7 +425,6 @@ def __post_init__(self): raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") - def get_cache_ref_logprobs( model: torch.nn.Module, active_dataloader: torch.utils.data.DataLoader, diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 235037290..db4df123a 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -53,7 +53,11 @@ get_scheduler, ) -from open_instruct.dataset_transformation import CHAT_TEMPLATES, TokenizerConfig, get_cached_dataset_tulu_sft +from open_instruct.dataset_transformation import ( + CHAT_TEMPLATES, + TokenizerConfig, + get_cached_dataset_tulu_sft, +) from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, @@ -473,7 +477,7 @@ def main(args: FlatArguments): train_dataset.set_format(type="pt") if args.cache_dataset_only: return - + # Load pretrained model and tokenizer if args.config_name: config = AutoConfig.from_pretrained( @@ -816,7 +820,9 @@ def main(args: FlatArguments): os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"), "w" ) as f: f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker. - if accelerator.is_local_main_process: # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem. + if ( + accelerator.is_local_main_process + ): # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem. clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) accelerator.wait_for_everyone() diff --git a/open_instruct/olmo_adapter/__init__.py b/open_instruct/olmo_adapter/__init__.py index f39796b11..71776201d 100644 --- a/open_instruct/olmo_adapter/__init__.py +++ b/open_instruct/olmo_adapter/__init__.py @@ -2,10 +2,18 @@ import torch from torch import nn -from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import SequenceClassifierOutputWithPast -from transformers.models.olmo2.modeling_olmo2 import Olmo2Config, Olmo2Model, Olmo2PreTrainedModel -from transformers.models.olmoe.modeling_olmoe import OlmoeConfig, OlmoeModel, OlmoePreTrainedModel +from transformers.models.olmo2.modeling_olmo2 import ( + Olmo2Config, + Olmo2Model, + Olmo2PreTrainedModel, +) +from transformers.models.olmoe.modeling_olmoe import ( + OlmoeConfig, + OlmoeModel, + OlmoePreTrainedModel, +) class Olmo2ForSequenceClassification(Olmo2PreTrainedModel): @@ -36,7 +44,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: + ) -> Union[Tuple, "SequenceClassifierOutputWithPast"]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -116,7 +124,7 @@ def forward( class OlmoeForSequenceClassification(OlmoePreTrainedModel): - def __init__(self, config: Olmo2Config): + def __init__(self, config: OlmoeConfig): super().__init__(config) self.num_labels = config.num_labels self.model = OlmoeModel(config) From e0625af5cc64ebaca81134b3a841171c97c3290d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 28 Jan 2025 06:11:17 -0800 Subject: [PATCH 32/32] Format --- open_instruct/dataset_transformation.py | 287 ++++++++++++------------ 1 file changed, 146 insertions(+), 141 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index dfcdb4809..6d6b2447f 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -43,25 +43,25 @@ import copy import hashlib import json -from dataclasses import dataclass, field, asdict import multiprocessing import os +from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Optional + import torch +import transformers +from datasets import Dataset, concatenate_datasets, load_dataset +from huggingface_hub import HfApi, ModelCard, revision_exists from rich.console import Console from rich.text import Text from transformers import ( AutoTokenizer, - PreTrainedTokenizer, GPTNeoXTokenizerFast, LlamaTokenizer, LlamaTokenizerFast, + PreTrainedTokenizer, ) -from huggingface_hub import ModelCard -import transformers from transformers.utils.hub import cached_file, extract_commit_hash -from datasets import Dataset, load_dataset, concatenate_datasets -from huggingface_hub import HfApi, revision_exists # ---------------------------------------------------------------------------- @@ -71,16 +71,20 @@ def get_commit_hash(model_name_or_path: str, revision: str, filename: str = "con commit_hash = extract_commit_hash(file, None) return commit_hash + # Performance tuning. Some rough numbers: APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU = 400 FILTER_EXAMPLE_PER_SECOND_PER_CPU = 1130 + def get_num_proc(dataset_len: int, num_available_cpus: int, example_per_second_per_cpu) -> int: num_required_cpus = max(1, dataset_len // example_per_second_per_cpu) return min(num_required_cpus, num_available_cpus) + COLORS = ["on red", "on green", "on blue", "on yellow", "on magenta"] + def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): i = 0 console = Console() @@ -91,6 +95,7 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): rich_text.append(f"{decoded_token}", style=color) console.print(rich_text) + # ---------------------------------------------------------------------------- # Tokenization # Chat templates @@ -162,7 +167,7 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): # flake8: noqa -def get_tokenizer_simple_v1(tc: 'TokenizerConfig'): +def get_tokenizer_simple_v1(tc: "TokenizerConfig"): tokenizer = AutoTokenizer.from_pretrained( tc.model_name_or_path, revision=tc.revision, @@ -171,7 +176,8 @@ def get_tokenizer_simple_v1(tc: 'TokenizerConfig'): ) return tokenizer -def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): + +def get_tokenizer_tulu_v1(tc: "TokenizerConfig"): tokenizer = AutoTokenizer.from_pretrained( tc.model_name_or_path, revision=tc.revision, @@ -197,9 +203,7 @@ def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token - assert ( - tc.add_bos - ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." + assert tc.add_bos, "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." # else, pythia / other models else: num_added_tokens = tokenizer.add_special_tokens( @@ -216,7 +220,7 @@ def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." - + # set the tokenizer chat template to the training format # this will be used for encoding the training examples # and saved together with the tokenizer to be used later. @@ -237,7 +241,7 @@ def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): ) # also add bos in the chat template if not already there tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template - + return tokenizer @@ -247,24 +251,26 @@ def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): } - @dataclass class TokenizerConfig: model_name_or_path: str revision: str trust_remote_code: bool = True use_fast: bool = True - chat_template_name: Optional[str] = None # TODO: should I give an option to force override? + chat_template_name: Optional[str] = None # TODO: should I give an option to force override? add_bos: bool = False get_tokenizer_fn: str = "get_tokenizer_tulu_v1" - + # for tracking purposes tokenizer_commit_hash: Optional[str] = None def __post_init__(self): - self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision, filename="tokenizer_config.json") + self.tokenizer_commit_hash = get_commit_hash( + self.model_name_or_path, self.revision, filename="tokenizer_config.json" + ) self.tokenizer = GET_TOKENIZER_FN[self.get_tokenizer_fn](self) + # TODO: for testing, we should load the tokenizer from the sft / dpo / rl and make sure they are all the same. @@ -311,14 +317,15 @@ def __post_init__(self): ] - # TODO: allow passing in sft_message key, so we can train on "chosen" of pref dataset. -def sft_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY): +def sft_tokenize_v1( + row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY +): if len(row[sft_message_key]) == 1: prompt = row[sft_message_key] else: prompt = row[sft_message_key][:-1] - + row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, @@ -328,15 +335,17 @@ def sft_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_mes labels = copy.deepcopy(row[INPUT_IDS_KEY]) row[LABELS_KEY] = labels return row - -def sft_tokenize_mask_out_prompt_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY): + +def sft_tokenize_mask_out_prompt_v1( + row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY +): """mask out the prompt tokens by manipulating labels""" if len(row[sft_message_key]) == 1: prompt = row[sft_message_key] else: prompt = row[sft_message_key][:-1] - + row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, @@ -349,7 +358,13 @@ def sft_tokenize_mask_out_prompt_v1(row: Dict[str, Any], tokenizer: PreTrainedTo return row -def sft_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_prompt_token_length: Optional[int] = None, max_token_length: Optional[int] = None, need_contain_labels: bool = True): +def sft_filter_v1( + row: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + max_prompt_token_length: Optional[int] = None, + max_token_length: Optional[int] = None, + need_contain_labels: bool = True, +): max_prompt_token_length_ok = True if max_prompt_token_length is not None: max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= max_prompt_token_length @@ -359,10 +374,9 @@ def sft_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_promp max_token_length_ok = len(row[INPUT_IDS_KEY]) <= max_token_length contain_some_labels = any(x != -100 for x in row[LABELS_KEY]) - return ( - max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) - ) - + return max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) + + def sft_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int): """taken directly from https://github.com/allenai/open-instruct/blob/ba11286e5b9eb00d4ce5b40ef4cac1389888416a/open_instruct/finetune.py#L385""" messages = row["messages"] @@ -433,45 +447,58 @@ def sft_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrained def sft_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): return any(x != -100 for x in row[LABELS_KEY]) - + + def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): # Extract prompt (all messages except the last one) prompt = row["chosen"][:-1] - + # Tokenize prompt row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, ) row[ATTENTION_MASK_PROMPT_KEY] = [1] * len(row[INPUT_IDS_PROMPT_KEY]) - + # Tokenize chosen completion row[CHOSEN_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["chosen"]) row[CHOSEN_ATTENTION_MASK_KEY] = [1] * len(row[CHOSEN_INPUT_IDS_KEY]) - + # Tokenize rejected completion row[REJECTED_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["rejected"]) row[REJECTED_ATTENTION_MASK_KEY] = [1] * len(row[REJECTED_INPUT_IDS_KEY]) - + return row -def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_prompt_token_length: Optional[int] = None, max_token_length: Optional[int] = None): + +def preference_filter_v1( + row: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + max_prompt_token_length: Optional[int] = None, + max_token_length: Optional[int] = None, +): # Check prompt length if specified if max_prompt_token_length is not None: if len(row[INPUT_IDS_PROMPT_KEY]) > max_prompt_token_length: return False - + # Check total sequence lengths if specified if max_token_length is not None: if len(row[CHOSEN_INPUT_IDS_KEY]) > max_token_length: return False if len(row[REJECTED_INPUT_IDS_KEY]) > max_token_length: return False - + return True -def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int, chosen_key: str = DEFAULT_CHOSEN_KEY, rejected_key: str = DEFAULT_REJECTED_KEY): +def preference_tulu_tokenize_and_truncate_v1( + row: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + max_seq_length: int, + chosen_key: str = DEFAULT_CHOSEN_KEY, + rejected_key: str = DEFAULT_REJECTED_KEY, +): """ Here we assume each example has a rejected and chosen field, both of which are a list of messages. Each message is a dict with 'role' and 'content' fields. @@ -484,8 +511,12 @@ def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: Pre if len(rejected_messages) == 0: raise ValueError("rejected messages field is empty.") - chosen_encoded = sft_tulu_tokenize_and_truncate_v1({DEFAULT_SFT_MESSAGE_KEY: chosen_messages}, tokenizer, max_seq_length) - rejected_encoded = sft_tulu_tokenize_and_truncate_v1({DEFAULT_SFT_MESSAGE_KEY: rejected_messages}, tokenizer, max_seq_length) + chosen_encoded = sft_tulu_tokenize_and_truncate_v1( + {DEFAULT_SFT_MESSAGE_KEY: chosen_messages}, tokenizer, max_seq_length + ) + rejected_encoded = sft_tulu_tokenize_and_truncate_v1( + {DEFAULT_SFT_MESSAGE_KEY: rejected_messages}, tokenizer, max_seq_length + ) return { CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"], @@ -513,6 +544,7 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize "preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"), } + # ---------------------------------------------------------------------------- # Dataset Configuration and Caching @dataclass @@ -523,10 +555,10 @@ class DatasetConfig: dataset_range: Optional[int] = None transform_fn: List[str] = field(default_factory=list) transform_fn_args: Dict[str, Dict[str, Any]] = field(default_factory=dict) - + # for tracking purposes dataset_commit_hash: Optional[str] = None - + def __post_init__(self): self.dataset_commit_hash = get_commit_hash(self.dataset_name, self.dataset_revision, "README.md", "dataset") self.dataset = load_dataset( @@ -537,20 +569,21 @@ def __post_init__(self): if self.dataset_range is None: dataset_range = len(self.dataset) self.update_range(dataset_range) - + def update_range(self, dataset_range: int): self.dataset_range = dataset_range if self.dataset_range > len(self.dataset): raise ValueError("Dataset range exceeds dataset length") self.dataset = self.dataset.select(range(self.dataset_range)) + def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): # beaker specific logic; we may get assigned 15.5 CPU, so we convert it to float then int num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count()))) - + tokenizer = tc.tokenizer dataset = dc.dataset - + for fn_name in dc.transform_fn: fn, fn_type = TRANSFORM_FNS[fn_name] # always pass in tokenizer and other args if needed @@ -570,7 +603,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): ) elif fn_type == "filter": dataset = dataset.filter( - fn, + fn, fn_kwargs=fn_kwargs, num_proc=get_num_proc(len(dataset), num_proc, FILTER_EXAMPLE_PER_SECOND_PER_CPU), ) @@ -582,21 +615,16 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): raise ValueError("No examples left after transformation") return dataset + class DatasetTransformationCache: def __init__(self, hf_entity: Optional[str] = None): self.hf_entity = hf_entity or HfApi().whoami()["name"] - + def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> str: """Compute a deterministic hash of both configs for caching.""" - dc_dicts = [ - {k: v for k, v in asdict(dc).items() if v is not None} - for dc in dcs - ] + dc_dicts = [{k: v for k, v in asdict(dc).items() if v is not None} for dc in dcs] tc_dict = {k: v for k, v in asdict(tc).items() if v is not None} - combined_dict = { - "dataset_configs": dc_dicts, - "tokenizer_config": tc_dict - } + combined_dict = {"dataset_configs": dc_dicts, "tokenizer_config": tc_dict} config_str = json.dumps(combined_dict, sort_keys=True) return hashlib.sha256(config_str.encode()).hexdigest()[:10] @@ -604,39 +632,35 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfi """Load dataset from cache if it exists, otherwise transform and cache it.""" config_hash = self.compute_config_hash(dcs, tc) repo_name = f"{self.hf_entity}/dataset-mix-cached" - + # Check if the revision exists if revision_exists(repo_name, config_hash, repo_type="dataset"): print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") # Use the split from the first dataset config as default - return load_dataset( - repo_name, - split=dcs[0].dataset_split, - revision=config_hash - ) - + return load_dataset(repo_name, split=dcs[0].dataset_split, revision=config_hash) + print(f"Cache not found, transforming datasets...") - + # Transform each dataset transformed_datasets = [] for dc in dcs: dataset = get_dataset_v1(dc, tc) transformed_datasets.append(dataset) - + # Combine datasets combined_dataset = concatenate_datasets(transformed_datasets) - + # Push to hub with config hash as revision combined_dataset.push_to_hub( repo_name, private=True, revision=config_hash, - commit_message=f"Cache combined dataset with configs hash: {config_hash}" + commit_message=f"Cache combined dataset with configs hash: {config_hash}", ) print(f"🚀 Pushed transformed dataset to https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") - - model_card = ModelCard(f"""\ + model_card = ModelCard( + f"""\ --- tags: [open-instruct] --- @@ -658,23 +682,23 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfi ```json {json.dumps([asdict(dc) for dc in dcs], indent=2)} ``` -""") +""" + ) model_card.push_to_hub(repo_name, repo_type="dataset", revision=config_hash) # NOTE: Load the dataset again to make sure it's downloaded to the HF cache print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") - return load_dataset( - repo_name, - split=dc.dataset_split, - revision=config_hash - ) + return load_dataset(repo_name, split=dc.dataset_split, revision=config_hash) def get_cached_dataset(dcs: List[DatasetConfig], tc: TokenizerConfig, hf_entity: Optional[str] = None) -> Dataset: cache = DatasetTransformationCache(hf_entity=hf_entity) return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: + +def get_cached_dataset_tulu_sft( + dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None +) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -684,7 +708,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf frac_or_num_samples = float(frac_or_num_samples) else: frac_or_num_samples = int(frac_or_num_samples) - + dataset_config = DatasetConfig( dataset_name=dataset_name, dataset_split="train", @@ -695,7 +719,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf "max_seq_length": max_seq_length, "target_columns": TOKENIZED_SFT_DATASET_KEYS, } - } + }, ) if frac_or_num_samples > 1.0: new_range = int(frac_or_num_samples) @@ -707,7 +731,9 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset_tulu_preference( + dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None +) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -717,7 +743,7 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni frac_or_num_samples = float(frac_or_num_samples) else: frac_or_num_samples = int(frac_or_num_samples) - + dataset_config = DatasetConfig( dataset_name=dataset_name, dataset_split="train", @@ -728,7 +754,7 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni "max_seq_length": max_seq_length, "target_columns": TOKENIZED_PREFERENCE_DATASET_KEYS, } - } + }, ) if frac_or_num_samples > 1.0: new_range = int(frac_or_num_samples) @@ -742,21 +768,15 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni def test_sft_dpo_same_tokenizer(): base_to_sft_tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" + model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) sft_to_dpo_tc = TokenizerConfig( - model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-SFT", - revision="main", - chat_template_name="tulu" + model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-SFT", revision="main", chat_template_name="tulu" ) dpo_to_rl_tc = TokenizerConfig( - model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-DPO", - revision="main", - chat_template_name="tulu" + model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-DPO", revision="main", chat_template_name="tulu" ) - + def equal_tokenizer(tc1, tc2): tok1 = tc1.tokenizer tok2 = tc2.tokenizer @@ -765,8 +785,11 @@ def equal_tokenizer(tc1, tc2): assert tok1.is_fast == tok2.is_fast, "is_fast should be the same" assert tok1.padding_side == tok2.padding_side, "padding_side should be the same" assert tok1.truncation_side == tok2.truncation_side, "truncation_side should be the same" - assert tok1.clean_up_tokenization_spaces == tok2.clean_up_tokenization_spaces, "clean_up_tokenization_spaces should be the same" + assert ( + tok1.clean_up_tokenization_spaces == tok2.clean_up_tokenization_spaces + ), "clean_up_tokenization_spaces should be the same" assert tok1.added_tokens_decoder == tok2.added_tokens_decoder, "added_tokens_decoder should be the same" + equal_tokenizer(base_to_sft_tc, sft_to_dpo_tc) equal_tokenizer(sft_to_dpo_tc, dpo_to_rl_tc) equal_tokenizer(base_to_sft_tc, dpo_to_rl_tc) @@ -774,11 +797,7 @@ def equal_tokenizer(tc1, tc2): def test_config_hash_different(): """Test that different configurations produce different hashes.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") dcs1 = [ DatasetConfig( @@ -786,96 +805,86 @@ def test_config_hash_different(): dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], - transform_fn_args={} + transform_fn_args={}, ) ] - + dcs2 = [ DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], - transform_fn_args={} + transform_fn_args={}, ) ] - + cache = DatasetTransformationCache() hash1 = cache.compute_config_hash(dcs1, tc) hash2 = cache.compute_config_hash(dcs2, tc) assert hash1 != hash2, "Different configs should have different hashes" + def test_sft_dataset_caching(): """Test caching functionality for SFT datasets.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) - + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") + dcs = [ DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], - transform_fn_args={} + transform_fn_args={}, ), DatasetConfig( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], - transform_fn_args={} - ) + transform_fn_args={}, + ), ] - + # First transformation should cache dataset1 = get_cached_dataset(dcs, tc) - + # Second load should use cache dataset1_cached = get_cached_dataset(dcs, tc) - + # Verify the datasets are the same assert len(dataset1) == len(dataset1_cached), "Cached dataset should have same length" + def test_sft_different_transform(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) - + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") + dcs = [ DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], - transform_fn_args={} + transform_fn_args={}, ), DatasetConfig( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], - transform_fn_args={} - ) + transform_fn_args={}, + ), ] - + dataset = get_cached_dataset(dcs, tc) assert dataset is not None, "Should successfully create dataset with different transform" def test_sft_filter(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) - + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") + ARBITRARY_MAX_LENGTH = 1000 dcs = [ DatasetConfig( @@ -887,15 +896,15 @@ def test_sft_filter(): "sft_filter_v1": { "max_token_length": ARBITRARY_MAX_LENGTH # Filter to sequences <= ARBITRARY_MAX_LENGTH tokens } - } + }, ) ] - + filtered_dataset = get_cached_dataset(dcs, tc) # Verify that all sequences are <= ARBITRARY_MAX_LENGTH tokens max_length = max(len(example[INPUT_IDS_KEY]) for example in filtered_dataset) assert max_length <= ARBITRARY_MAX_LENGTH, f"Found sequence with length {max_length} > {ARBITRARY_MAX_LENGTH}" - + print("Filter test passed! Max sequence length:", max_length) print("All tests passed!") assert filtered_dataset is not None, "Should successfully create dataset with different transform" @@ -903,29 +912,25 @@ def test_sft_filter(): def test_preference_dataset(): """Test caching functionality for preference datasets.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) - + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") + dcs_pref = [ DatasetConfig( dataset_name="allenai/tulu-3-pref-personas-instruction-following", dataset_split="train", dataset_revision="main", transform_fn=["preference_tokenize_v1"], - transform_fn_args={} + transform_fn_args={}, ), DatasetConfig( dataset_name="allenai/tulu-3-wildchat-reused-on-policy-70b", dataset_split="train", dataset_revision="main", transform_fn=["preference_tokenize_v1"], - transform_fn_args={} - ) + transform_fn_args={}, + ), ] - + dataset_pref = get_cached_dataset(dcs_pref, tc) assert dataset_pref is not None, "Should successfully create preference dataset" @@ -937,4 +942,4 @@ def test_preference_dataset(): test_sft_different_transform() test_preference_dataset() test_sft_filter() - print("All tests passed!") \ No newline at end of file + print("All tests passed!")