Skip to content

Commit

Permalink
support preference tuning as well
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Jan 21, 2025
1 parent c8b6a62 commit 15ba999
Show file tree
Hide file tree
Showing 3 changed files with 1,218 additions and 54 deletions.
167 changes: 118 additions & 49 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# 3. ✅ RM / DPO (chosen and rejected)
# 4. ✅ Visualization of length distributions?
# 5. ✅ Filter?
# * Smart truncation?
# 6. ✅ dataset_num_proc
# 7. ✅ check EOS token
# 8. dataset mixer?
Expand Down Expand Up @@ -49,6 +48,8 @@
import os
from typing import Any, Dict, List, Optional
import torch
from rich.console import Console
from rich.text import Text
from transformers import (
AutoTokenizer,
PreTrainedTokenizer,
Expand Down Expand Up @@ -79,6 +80,17 @@ def get_num_proc(dataset_len: int, num_available_cpus: int, example_per_second_p
num_required_cpus = max(1, dataset_len // example_per_second_per_cpu)
return min(num_required_cpus, num_available_cpus)

COLORS = ["on red", "on green", "on blue", "on yellow", "on magenta"]

def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer):
i = 0
console = Console()
rich_text = Text()
for i, token in enumerate(tokens):
color = COLORS[i % len(COLORS)]
decoded_token = tokenizer.decode(token)
rich_text.append(f"{decoded_token}", style=color)
console.print(rich_text)

# ----------------------------------------------------------------------------
# Tokenization
Expand All @@ -95,7 +107,7 @@ class TokenizerConfigV1:
tokenizer_commit_hash: Optional[str] = None

def __post_init__(self):
self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision)
self.tokenizer_commit_hash = get_commit_hash(self.model_name_or_path, self.revision, filename="tokenizer_config.json")

self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path,
Expand Down Expand Up @@ -182,31 +194,6 @@ def __post_init__(self):

# ----------------------------------------------------------------------------
# Dataset Transformation
COLORS = ["on red", "on green", "on blue", "on yellow", "on magenta"]
# Preference dataset
INPUT_IDS_CHOSEN_KEY = "input_ids_chosen"
ATTENTION_MASK_CHOSEN_KEY = "attention_mask_chosen"
INPUT_IDS_REJECTED_KEY = "input_ids_rejected"
ATTENTION_MASK_REJECTED_KEY = "attention_mask_rejected"
INPUT_IDS_PROMPT_KEY = "input_ids_prompt"
ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt"
GROUND_TRUTHS_KEY = "ground_truth"
DATASET_SOURCE_KEY = "dataset"

# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only
# also we don't really need `ATTENTION_MASK_CHOSEN_KEY` and `ATTENTION_MASK_REJECTED_KEY`
# since we are always padding from the right with a collator; however they might become
# more useful if we want to do some sort of packing in the future. The nice thing is
# that the tokenization logic would work for both DPO and RM training.
TOKENIZED_PREFERENCE_DATASET_KEYS = [
INPUT_IDS_CHOSEN_KEY,
INPUT_IDS_REJECTED_KEY,
# ATTENTION_MASK_CHOSEN_KEY,
# ATTENTION_MASK_REJECTED_KEY,
# INPUT_IDS_PROMPT_KEY,
# ATTENTION_MASK_PROMPT_KEY,
]

# SFT dataset
SFT_MESSAGE_KEY = "messages"
INPUT_IDS_KEY = "input_ids"
Expand All @@ -218,15 +205,38 @@ def __post_init__(self):
LABELS_KEY,
]

# Binary dataset
BINARY_LABEL_KEY = "binary_labels"
BINARY_DATASET_KEYS = [
INPUT_IDS_KEY,
LABELS_KEY,
BINARY_LABEL_KEY,
# Preference dataset
# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only
# also we don't really need `CHOSEN_ATTENTION_MASK_KEY` and `REJECTED_ATTENTION_MASK_KEY`
# since we are always padding from the right with a collator; however they might become
# more useful if we want to do some sort of packing in the future. The nice thing is
# that the tokenization logic would work for both DPO and RM training.
CHOSEN_KEY = "chosen"
REJECTED_KEY = "rejected"
CHOSEN_INPUT_IDS_KEY = "chosen_input_ids"
CHOSEN_ATTENTION_MASK_KEY = "chosen_attention_mask"
CHOSEN_LABELS_KEY = "chosen_labels"
REJECTED_INPUT_IDS_KEY = "rejected_input_ids"
REJECTED_ATTENTION_MASK_KEY = "rejected_attention_mask"
REJECTED_LABELS_KEY = "rejected_labels"

INPUT_IDS_PROMPT_KEY = "input_ids_prompt"
ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt"
GROUND_TRUTHS_KEY = "ground_truth"
DATASET_SOURCE_KEY = "dataset"

TOKENIZED_PREFERENCE_DATASET_KEYS = [
CHOSEN_INPUT_IDS_KEY,
CHOSEN_LABELS_KEY,
CHOSEN_ATTENTION_MASK_KEY,
REJECTED_INPUT_IDS_KEY,
REJECTED_LABELS_KEY,
REJECTED_ATTENTION_MASK_KEY,
]



# TODO: allow passing in sft_message key, so we can train on "chosen" of pref dataset.
def sft_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
if len(row[SFT_MESSAGE_KEY]) == 1:
prompt = row[SFT_MESSAGE_KEY]
Expand Down Expand Up @@ -346,7 +356,7 @@ def sft_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrained


def sft_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
return row["labels"] != -100
return any(x != -100 for x in row[LABELS_KEY])

def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
# Extract prompt (all messages except the last one)
Expand All @@ -360,12 +370,12 @@ def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
row[ATTENTION_MASK_PROMPT_KEY] = [1] * len(row[INPUT_IDS_PROMPT_KEY])

# Tokenize chosen completion
row[INPUT_IDS_CHOSEN_KEY] = tokenizer.apply_chat_template(row["chosen"])
row[ATTENTION_MASK_CHOSEN_KEY] = [1] * len(row[INPUT_IDS_CHOSEN_KEY])
row[CHOSEN_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["chosen"])
row[CHOSEN_ATTENTION_MASK_KEY] = [1] * len(row[CHOSEN_INPUT_IDS_KEY])

# Tokenize rejected completion
row[INPUT_IDS_REJECTED_KEY] = tokenizer.apply_chat_template(row["rejected"])
row[ATTENTION_MASK_REJECTED_KEY] = [1] * len(row[INPUT_IDS_REJECTED_KEY])
row[REJECTED_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["rejected"])
row[REJECTED_ATTENTION_MASK_KEY] = [1] * len(row[REJECTED_INPUT_IDS_KEY])

return row

Expand All @@ -377,13 +387,44 @@ def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, ma

# Check total sequence lengths if specified
if max_token_length is not None:
if len(row[INPUT_IDS_CHOSEN_KEY]) > max_token_length:
if len(row[CHOSEN_INPUT_IDS_KEY]) > max_token_length:
return False
if len(row[INPUT_IDS_REJECTED_KEY]) > max_token_length:
if len(row[REJECTED_INPUT_IDS_KEY]) > max_token_length:
return False

return True


def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int):
"""
Here we assume each example has a rejected and chosen field, both of which are a list of messages.
Each message is a dict with 'role' and 'content' fields.
We assume only the last message is different, and the prompt is contained in the list of messages.
"""
chosen_messages = row[CHOSEN_KEY]
rejected_messages = row[REJECTED_KEY]
if len(chosen_messages) == 0:
raise ValueError("chosen messages field is empty.")
if len(rejected_messages) == 0:
raise ValueError("rejected messages field is empty.")

chosen_encoded = sft_tulu_tokenize_and_truncate_v1({SFT_MESSAGE_KEY: chosen_messages}, tokenizer, max_seq_length)
rejected_encoded = sft_tulu_tokenize_and_truncate_v1({SFT_MESSAGE_KEY: rejected_messages}, tokenizer, max_seq_length)

return {
CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"],
CHOSEN_LABELS_KEY: chosen_encoded["labels"],
CHOSEN_ATTENTION_MASK_KEY: chosen_encoded["attention_mask"],
REJECTED_INPUT_IDS_KEY: rejected_encoded["input_ids"],
REJECTED_LABELS_KEY: rejected_encoded["labels"],
REJECTED_ATTENTION_MASK_KEY: rejected_encoded["attention_mask"],
}


def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
return any(x != -100 for x in row[CHOSEN_LABELS_KEY]) and any(x != -100 for x in row[REJECTED_LABELS_KEY])


TRANSFORM_FNS = {
"sft_tokenize_v1": (sft_tokenize_v1, "map"),
"sft_tokenize_mask_out_prompt_v1": (sft_tokenize_mask_out_prompt_v1, "map"),
Expand All @@ -392,6 +433,8 @@ def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, ma
"sft_tulu_filter_v1": (sft_tulu_filter_v1, "filter"),
"preference_tokenize_v1": (preference_tokenize_v1, "map", TOKENIZED_PREFERENCE_DATASET_KEYS),
"preference_filter_v1": (preference_filter_v1, "filter"),
"preference_tulu_tokenize_and_truncate_v1": (preference_tulu_tokenize_and_truncate_v1, "map"),
"preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"),
}

# ----------------------------------------------------------------------------
Expand Down Expand Up @@ -503,11 +546,6 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerCon
transformed_datasets = []
for dc in dcs:
dataset = get_dataset_v1(dc, tc)
# # Add id column if not present
# if "id" not in dataset.column_names:
# base_name = dc.dataset_name.split("/")[-1]
# id_col = [f"{base_name}-{i}" for i in range(len(dataset))]
# dataset = dataset.add_column("id", id_col)
transformed_datasets.append(dataset)

# Combine datasets
Expand All @@ -532,13 +570,10 @@ def load_or_transform_dataset(self, dcs: List[DatasetConfigV1], tc: TokenizerCon


def get_cached_dataset(dcs: List[DatasetConfigV1], tc: TokenizerConfigV1, hf_entity: Optional[str] = None) -> Dataset:
"""Get transformed and cached dataset from multiple dataset configs."""
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(dcs, tc)

# TODO: need to support length / range of the dataset
def get_cached_dataset_tulu(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset:
"""Get transformed and cached dataset from multiple dataset configs."""
def get_cached_dataset_tulu_sft(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset:
dcs = []
assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}"
for i in range(0, len(dataset_mixer_list), 2):
Expand Down Expand Up @@ -570,6 +605,40 @@ def get_cached_dataset_tulu(dataset_mixer_list: List[str], tc: TokenizerConfigV1
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(dcs, tc)


def get_cached_dataset_tulu_preference(dataset_mixer_list: List[str], tc: TokenizerConfigV1, max_seq_length: int, hf_entity: Optional[str] = None) -> Dataset:
dcs = []
assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}"
for i in range(0, len(dataset_mixer_list), 2):
dataset_name = dataset_mixer_list[i]
frac_or_num_samples = dataset_mixer_list[i + 1]
if "." in frac_or_num_samples:
frac_or_num_samples = float(frac_or_num_samples)
else:
frac_or_num_samples = int(frac_or_num_samples)

dataset_config = DatasetConfigV1(
dataset_name=dataset_name,
dataset_split="train",
dataset_revision="main",
transform_fn=["preference_tulu_tokenize_and_truncate_v1", "preference_tulu_filter_v1"],
transform_fn_args={
"preference_tulu_tokenize_and_truncate_v1": {
"max_seq_length": max_seq_length,
"target_columns": TOKENIZED_PREFERENCE_DATASET_KEYS,
}
}
)
if frac_or_num_samples > 1.0:
new_range = int(frac_or_num_samples)
else:
new_range = int(frac_or_num_samples * len(dataset_config.dataset))
dataset_config.update_range(new_range)
dcs.append(dataset_config)
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(dcs, tc)


def test_config_hash_different():
"""Test that different configurations produce different hashes."""
tc = TokenizerConfigV1(
Expand Down
Loading

0 comments on commit 15ba999

Please sign in to comment.