diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index dfcdb4809..c587af116 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -43,25 +43,25 @@ import copy import hashlib import json -from dataclasses import dataclass, field, asdict import multiprocessing import os +from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Optional + import torch +import transformers +from datasets import Dataset, concatenate_datasets, load_dataset +from huggingface_hub import HfApi, ModelCard, revision_exists from rich.console import Console from rich.text import Text from transformers import ( AutoTokenizer, - PreTrainedTokenizer, GPTNeoXTokenizerFast, LlamaTokenizer, LlamaTokenizerFast, + PreTrainedTokenizer, ) -from huggingface_hub import ModelCard -import transformers from transformers.utils.hub import cached_file, extract_commit_hash -from datasets import Dataset, load_dataset, concatenate_datasets -from huggingface_hub import HfApi, revision_exists # ---------------------------------------------------------------------------- @@ -71,16 +71,20 @@ def get_commit_hash(model_name_or_path: str, revision: str, filename: str = "con commit_hash = extract_commit_hash(file, None) return commit_hash + # Performance tuning. Some rough numbers: APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU = 400 FILTER_EXAMPLE_PER_SECOND_PER_CPU = 1130 + def get_num_proc(dataset_len: int, num_available_cpus: int, example_per_second_per_cpu) -> int: num_required_cpus = max(1, dataset_len // example_per_second_per_cpu) return min(num_required_cpus, num_available_cpus) + COLORS = ["on red", "on green", "on blue", "on yellow", "on magenta"] + def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): i = 0 console = Console() @@ -91,6 +95,7 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): rich_text.append(f"{decoded_token}", style=color) console.print(rich_text) + # ---------------------------------------------------------------------------- # Tokenization # Chat templates @@ -162,7 +167,7 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer): # flake8: noqa -def get_tokenizer_simple_v1(tc: 'TokenizerConfig'): +def get_tokenizer_simple_v1(tc: "TokenizerConfig"): tokenizer = AutoTokenizer.from_pretrained( tc.model_name_or_path, revision=tc.revision, @@ -171,7 +176,8 @@ def get_tokenizer_simple_v1(tc: 'TokenizerConfig'): ) return tokenizer -def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): + +def get_tokenizer_tulu_v1(tc: "TokenizerConfig"): tokenizer = AutoTokenizer.from_pretrained( tc.model_name_or_path, revision=tc.revision, @@ -197,9 +203,7 @@ def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token - assert ( - tc.add_bos - ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." + assert tc.add_bos, "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." # else, pythia / other models else: num_added_tokens = tokenizer.add_special_tokens( @@ -216,7 +220,7 @@ def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None: num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one." - + # set the tokenizer chat template to the training format # this will be used for encoding the training examples # and saved together with the tokenizer to be used later. @@ -237,7 +241,7 @@ def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): ) # also add bos in the chat template if not already there tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template - + return tokenizer @@ -247,24 +251,26 @@ def get_tokenizer_tulu_v1(tc: 'TokenizerConfig'): } - @dataclass class TokenizerConfig: model_name_or_path: str revision: str trust_remote_code: bool = True use_fast: bool = True - chat_template_name: Optional[str] = None # TODO: should I give an option to force override? + chat_template_name: Optional[str] = None # TODO: should I give an option to force override? add_bos: bool = False get_tokenizer_fn: str = "get_tokenizer_tulu_v1" - + # for tracking purposes tokenizer_commit_hash: Optional[str] = None def __post_init__(self): - self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision, filename="tokenizer_config.json") + self.tokenizer_commit_hash = get_commit_hash( + self.model_name_or_path, self.revision, filename="tokenizer_config.json" + ) self.tokenizer = GET_TOKENIZER_FN[self.get_tokenizer_fn](self) + # TODO: for testing, we should load the tokenizer from the sft / dpo / rl and make sure they are all the same. @@ -311,14 +317,15 @@ def __post_init__(self): ] - # TODO: allow passing in sft_message key, so we can train on "chosen" of pref dataset. -def sft_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY): +def sft_tokenize_v1( + row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY +): if len(row[sft_message_key]) == 1: prompt = row[sft_message_key] else: prompt = row[sft_message_key][:-1] - + row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, @@ -328,15 +335,17 @@ def sft_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_mes labels = copy.deepcopy(row[INPUT_IDS_KEY]) row[LABELS_KEY] = labels return row - -def sft_tokenize_mask_out_prompt_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY): + +def sft_tokenize_mask_out_prompt_v1( + row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_message_key: str = DEFAULT_SFT_MESSAGE_KEY +): """mask out the prompt tokens by manipulating labels""" if len(row[sft_message_key]) == 1: prompt = row[sft_message_key] else: prompt = row[sft_message_key][:-1] - + row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, @@ -349,7 +358,13 @@ def sft_tokenize_mask_out_prompt_v1(row: Dict[str, Any], tokenizer: PreTrainedTo return row -def sft_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_prompt_token_length: Optional[int] = None, max_token_length: Optional[int] = None, need_contain_labels: bool = True): +def sft_filter_v1( + row: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + max_prompt_token_length: Optional[int] = None, + max_token_length: Optional[int] = None, + need_contain_labels: bool = True, +): max_prompt_token_length_ok = True if max_prompt_token_length is not None: max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= max_prompt_token_length @@ -359,10 +374,9 @@ def sft_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_promp max_token_length_ok = len(row[INPUT_IDS_KEY]) <= max_token_length contain_some_labels = any(x != -100 for x in row[LABELS_KEY]) - return ( - max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) - ) - + return max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) + + def sft_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int): """taken directly from https://github.com/allenai/open-instruct/blob/ba11286e5b9eb00d4ce5b40ef4cac1389888416a/open_instruct/finetune.py#L385""" messages = row["messages"] @@ -433,45 +447,58 @@ def sft_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrained def sft_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): return any(x != -100 for x in row[LABELS_KEY]) - + + def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): # Extract prompt (all messages except the last one) prompt = row["chosen"][:-1] - + # Tokenize prompt row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template( prompt, add_generation_prompt=True, ) row[ATTENTION_MASK_PROMPT_KEY] = [1] * len(row[INPUT_IDS_PROMPT_KEY]) - + # Tokenize chosen completion row[CHOSEN_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["chosen"]) row[CHOSEN_ATTENTION_MASK_KEY] = [1] * len(row[CHOSEN_INPUT_IDS_KEY]) - + # Tokenize rejected completion row[REJECTED_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["rejected"]) row[REJECTED_ATTENTION_MASK_KEY] = [1] * len(row[REJECTED_INPUT_IDS_KEY]) - + return row -def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_prompt_token_length: Optional[int] = None, max_token_length: Optional[int] = None): + +def preference_filter_v1( + row: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + max_prompt_token_length: Optional[int] = None, + max_token_length: Optional[int] = None, +): # Check prompt length if specified if max_prompt_token_length is not None: if len(row[INPUT_IDS_PROMPT_KEY]) > max_prompt_token_length: return False - + # Check total sequence lengths if specified if max_token_length is not None: if len(row[CHOSEN_INPUT_IDS_KEY]) > max_token_length: return False if len(row[REJECTED_INPUT_IDS_KEY]) > max_token_length: return False - + return True -def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int, chosen_key: str = DEFAULT_CHOSEN_KEY, rejected_key: str = DEFAULT_REJECTED_KEY): +def preference_tulu_tokenize_and_truncate_v1( + row: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + max_seq_length: int, + chosen_key: str = DEFAULT_CHOSEN_KEY, + rejected_key: str = DEFAULT_REJECTED_KEY, +): """ Here we assume each example has a rejected and chosen field, both of which are a list of messages. Each message is a dict with 'role' and 'content' fields. @@ -484,8 +511,12 @@ def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: Pre if len(rejected_messages) == 0: raise ValueError("rejected messages field is empty.") - chosen_encoded = sft_tulu_tokenize_and_truncate_v1({DEFAULT_SFT_MESSAGE_KEY: chosen_messages}, tokenizer, max_seq_length) - rejected_encoded = sft_tulu_tokenize_and_truncate_v1({DEFAULT_SFT_MESSAGE_KEY: rejected_messages}, tokenizer, max_seq_length) + chosen_encoded = sft_tulu_tokenize_and_truncate_v1( + {DEFAULT_SFT_MESSAGE_KEY: chosen_messages}, tokenizer, max_seq_length + ) + rejected_encoded = sft_tulu_tokenize_and_truncate_v1( + {DEFAULT_SFT_MESSAGE_KEY: rejected_messages}, tokenizer, max_seq_length + ) return { CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"], @@ -513,6 +544,7 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize "preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"), } + # ---------------------------------------------------------------------------- # Dataset Configuration and Caching @dataclass @@ -523,10 +555,10 @@ class DatasetConfig: dataset_range: Optional[int] = None transform_fn: List[str] = field(default_factory=list) transform_fn_args: Dict[str, Dict[str, Any]] = field(default_factory=dict) - + # for tracking purposes dataset_commit_hash: Optional[str] = None - + def __post_init__(self): self.dataset_commit_hash = get_commit_hash(self.dataset_name, self.dataset_revision, "README.md", "dataset") self.dataset = load_dataset( @@ -537,20 +569,21 @@ def __post_init__(self): if self.dataset_range is None: dataset_range = len(self.dataset) self.update_range(dataset_range) - + def update_range(self, dataset_range: int): self.dataset_range = dataset_range if self.dataset_range > len(self.dataset): raise ValueError("Dataset range exceeds dataset length") self.dataset = self.dataset.select(range(self.dataset_range)) + def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): # beaker specific logic; we may get assigned 15.5 CPU, so we convert it to float then int num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count()))) - + tokenizer = tc.tokenizer dataset = dc.dataset - + for fn_name in dc.transform_fn: fn, fn_type = TRANSFORM_FNS[fn_name] # always pass in tokenizer and other args if needed @@ -570,7 +603,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): ) elif fn_type == "filter": dataset = dataset.filter( - fn, + fn, fn_kwargs=fn_kwargs, num_proc=get_num_proc(len(dataset), num_proc, FILTER_EXAMPLE_PER_SECOND_PER_CPU), ) @@ -582,21 +615,16 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): raise ValueError("No examples left after transformation") return dataset + class DatasetTransformationCache: def __init__(self, hf_entity: Optional[str] = None): self.hf_entity = hf_entity or HfApi().whoami()["name"] - + def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> str: """Compute a deterministic hash of both configs for caching.""" - dc_dicts = [ - {k: v for k, v in asdict(dc).items() if v is not None} - for dc in dcs - ] + dc_dicts = [{k: v for k, v in asdict(dc).items() if v is not None} for dc in dcs] tc_dict = {k: v for k, v in asdict(tc).items() if v is not None} - combined_dict = { - "dataset_configs": dc_dicts, - "tokenizer_config": tc_dict - } + combined_dict = {"dataset_configs": dc_dicts, "tokenizer_config": tc_dict} config_str = json.dumps(combined_dict, sort_keys=True) return hashlib.sha256(config_str.encode()).hexdigest()[:10] @@ -604,39 +632,35 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfi """Load dataset from cache if it exists, otherwise transform and cache it.""" config_hash = self.compute_config_hash(dcs, tc) repo_name = f"{self.hf_entity}/dataset-mix-cached" - + # Check if the revision exists if revision_exists(repo_name, config_hash, repo_type="dataset"): print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") # Use the split from the first dataset config as default - return load_dataset( - repo_name, - split=dcs[0].dataset_split, - revision=config_hash - ) - + return load_dataset(repo_name, split=dcs[0].dataset_split, revision=config_hash) + print(f"Cache not found, transforming datasets...") - + # Transform each dataset transformed_datasets = [] for dc in dcs: dataset = get_dataset_v1(dc, tc) transformed_datasets.append(dataset) - + # Combine datasets combined_dataset = concatenate_datasets(transformed_datasets) - + # Push to hub with config hash as revision combined_dataset.push_to_hub( repo_name, private=True, revision=config_hash, - commit_message=f"Cache combined dataset with configs hash: {config_hash}" + commit_message=f"Cache combined dataset with configs hash: {config_hash}", ) print(f"🚀 Pushed transformed dataset to https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") - - model_card = ModelCard(f"""\ + model_card = ModelCard( + f"""\ --- tags: [open-instruct] --- @@ -658,23 +682,23 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfi ```json {json.dumps([asdict(dc) for dc in dcs], indent=2)} ``` -""") +""" + ) model_card.push_to_hub(repo_name, repo_type="dataset", revision=config_hash) # NOTE: Load the dataset again to make sure it's downloaded to the HF cache print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{config_hash}") - return load_dataset( - repo_name, - split=dc.dataset_split, - revision=config_hash - ) + return load_dataset(repo_name, split=dc.dataset_split, revision=config_hash) def get_cached_dataset(dcs: List[DatasetConfig], tc: TokenizerConfig, hf_entity: Optional[str] = None) -> Dataset: cache = DatasetTransformationCache(hf_entity=hf_entity) return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: + +def get_cached_dataset_tulu_sft( + dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None +) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -684,7 +708,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf frac_or_num_samples = float(frac_or_num_samples) else: frac_or_num_samples = int(frac_or_num_samples) - + dataset_config = DatasetConfig( dataset_name=dataset_name, dataset_split="train", @@ -695,7 +719,7 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf "max_seq_length": max_seq_length, "target_columns": TOKENIZED_SFT_DATASET_KEYS, } - } + }, ) if frac_or_num_samples > 1.0: new_range = int(frac_or_num_samples) @@ -707,7 +731,9 @@ def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConf return cache.load_or_transform_dataset(dcs, tc) -def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset: +def get_cached_dataset_tulu_preference( + dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None +) -> Dataset: dcs = [] assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}" for i in range(0, len(dataset_mixer_list), 2): @@ -717,7 +743,7 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni frac_or_num_samples = float(frac_or_num_samples) else: frac_or_num_samples = int(frac_or_num_samples) - + dataset_config = DatasetConfig( dataset_name=dataset_name, dataset_split="train", @@ -728,7 +754,7 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni "max_seq_length": max_seq_length, "target_columns": TOKENIZED_PREFERENCE_DATASET_KEYS, } - } + }, ) if frac_or_num_samples > 1.0: new_range = int(frac_or_num_samples) @@ -742,21 +768,15 @@ def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: Tokeni def test_sft_dpo_same_tokenizer(): base_to_sft_tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" + model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu" ) sft_to_dpo_tc = TokenizerConfig( - model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-SFT", - revision="main", - chat_template_name="tulu" + model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-SFT", revision="main", chat_template_name="tulu" ) dpo_to_rl_tc = TokenizerConfig( - model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-DPO", - revision="main", - chat_template_name="tulu" + model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-DPO", revision="main", chat_template_name="tulu" ) - + def equal_tokenizer(tc1, tc2): tok1 = tc1.tokenizer tok2 = tc2.tokenizer @@ -765,8 +785,11 @@ def equal_tokenizer(tc1, tc2): assert tok1.is_fast == tok2.is_fast, "is_fast should be the same" assert tok1.padding_side == tok2.padding_side, "padding_side should be the same" assert tok1.truncation_side == tok2.truncation_side, "truncation_side should be the same" - assert tok1.clean_up_tokenization_spaces == tok2.clean_up_tokenization_spaces, "clean_up_tokenization_spaces should be the same" + assert ( + tok1.clean_up_tokenization_spaces == tok2.clean_up_tokenization_spaces + ), "clean_up_tokenization_spaces should be the same" assert tok1.added_tokens_decoder == tok2.added_tokens_decoder, "added_tokens_decoder should be the same" + equal_tokenizer(base_to_sft_tc, sft_to_dpo_tc) equal_tokenizer(sft_to_dpo_tc, dpo_to_rl_tc) equal_tokenizer(base_to_sft_tc, dpo_to_rl_tc) @@ -774,11 +797,7 @@ def equal_tokenizer(tc1, tc2): def test_config_hash_different(): """Test that different configurations produce different hashes.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") dcs1 = [ DatasetConfig( @@ -786,96 +805,86 @@ def test_config_hash_different(): dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], - transform_fn_args={} + transform_fn_args={}, ) ] - + dcs2 = [ DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], - transform_fn_args={} + transform_fn_args={}, ) ] - + cache = DatasetTransformationCache() hash1 = cache.compute_config_hash(dcs1, tc) hash2 = cache.compute_config_hash(dcs2, tc) assert hash1 != hash2, "Different configs should have different hashes" + def test_sft_dataset_caching(): """Test caching functionality for SFT datasets.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) - + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") + dcs = [ DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], - transform_fn_args={} + transform_fn_args={}, ), DatasetConfig( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_v1"], - transform_fn_args={} - ) + transform_fn_args={}, + ), ] - + # First transformation should cache dataset1 = get_cached_dataset(dcs, tc) - + # Second load should use cache dataset1_cached = get_cached_dataset(dcs, tc) - + # Verify the datasets are the same assert len(dataset1) == len(dataset1_cached), "Cached dataset should have same length" + def test_sft_different_transform(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) - + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") + dcs = [ DatasetConfig( dataset_name="allenai/tulu-3-sft-personas-algebra", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], - transform_fn_args={} + transform_fn_args={}, ), DatasetConfig( dataset_name="allenai/tulu-3-hard-coded-10x", dataset_split="train", dataset_revision="main", transform_fn=["sft_tokenize_mask_out_prompt_v1"], - transform_fn_args={} - ) + transform_fn_args={}, + ), ] - + dataset = get_cached_dataset(dcs, tc) assert dataset is not None, "Should successfully create dataset with different transform" def test_sft_filter(): """Test different transform functions produce different cached datasets.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) - + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") + ARBITRARY_MAX_LENGTH = 1000 dcs = [ DatasetConfig( @@ -887,15 +896,15 @@ def test_sft_filter(): "sft_filter_v1": { "max_token_length": ARBITRARY_MAX_LENGTH # Filter to sequences <= ARBITRARY_MAX_LENGTH tokens } - } + }, ) ] - + filtered_dataset = get_cached_dataset(dcs, tc) # Verify that all sequences are <= ARBITRARY_MAX_LENGTH tokens max_length = max(len(example[INPUT_IDS_KEY]) for example in filtered_dataset) assert max_length <= ARBITRARY_MAX_LENGTH, f"Found sequence with length {max_length} > {ARBITRARY_MAX_LENGTH}" - + print("Filter test passed! Max sequence length:", max_length) print("All tests passed!") assert filtered_dataset is not None, "Should successfully create dataset with different transform" @@ -903,29 +912,25 @@ def test_sft_filter(): def test_preference_dataset(): """Test caching functionality for preference datasets.""" - tc = TokenizerConfig( - model_name_or_path="meta-llama/Llama-3.1-8B", - revision="main", - chat_template_name="tulu" - ) - + tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu") + dcs_pref = [ DatasetConfig( dataset_name="allenai/tulu-3-pref-personas-instruction-following", dataset_split="train", dataset_revision="main", transform_fn=["preference_tokenize_v1"], - transform_fn_args={} + transform_fn_args={}, ), DatasetConfig( dataset_name="allenai/tulu-3-wildchat-reused-on-policy-70b", dataset_split="train", dataset_revision="main", transform_fn=["preference_tokenize_v1"], - transform_fn_args={} - ) + transform_fn_args={}, + ), ] - + dataset_pref = get_cached_dataset(dcs_pref, tc) assert dataset_pref is not None, "Should successfully create preference dataset"