Skip to content

Commit

Permalink
allow customizing chosen / rejected key
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Jan 21, 2025
1 parent 4159255 commit 367c9e5
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def __post_init__(self):
# since we are always padding from the right with a collator; however they might become
# more useful if we want to do some sort of packing in the future. The nice thing is
# that the tokenization logic would work for both DPO and RM training.
CHOSEN_KEY = "chosen"
REJECTED_KEY = "rejected"
DEFAULT_CHOSEN_KEY = "chosen"
DEFAULT_REJECTED_KEY = "rejected"
CHOSEN_INPUT_IDS_KEY = "chosen_input_ids"
CHOSEN_ATTENTION_MASK_KEY = "chosen_attention_mask"
CHOSEN_LABELS_KEY = "chosen_labels"
Expand Down Expand Up @@ -396,14 +396,14 @@ def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, ma
return True


def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int):
def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int, chosen_key: str = DEFAULT_CHOSEN_KEY, rejected_key: str = DEFAULT_REJECTED_KEY):
"""
Here we assume each example has a rejected and chosen field, both of which are a list of messages.
Each message is a dict with 'role' and 'content' fields.
We assume only the last message is different, and the prompt is contained in the list of messages.
"""
chosen_messages = row[CHOSEN_KEY]
rejected_messages = row[REJECTED_KEY]
chosen_messages = row[chosen_key]
rejected_messages = row[rejected_key]
if len(chosen_messages) == 0:
raise ValueError("chosen messages field is empty.")
if len(rejected_messages) == 0:
Expand Down

0 comments on commit 367c9e5

Please sign in to comment.