diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 827721ac1..2f44e7fca 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -212,8 +212,8 @@ def __post_init__(self): # since we are always padding from the right with a collator; however they might become # more useful if we want to do some sort of packing in the future. The nice thing is # that the tokenization logic would work for both DPO and RM training. -CHOSEN_KEY = "chosen" -REJECTED_KEY = "rejected" +DEFAULT_CHOSEN_KEY = "chosen" +DEFAULT_REJECTED_KEY = "rejected" CHOSEN_INPUT_IDS_KEY = "chosen_input_ids" CHOSEN_ATTENTION_MASK_KEY = "chosen_attention_mask" CHOSEN_LABELS_KEY = "chosen_labels" @@ -396,14 +396,14 @@ def preference_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, ma return True -def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int): +def preference_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int, chosen_key: str = DEFAULT_CHOSEN_KEY, rejected_key: str = DEFAULT_REJECTED_KEY): """ Here we assume each example has a rejected and chosen field, both of which are a list of messages. Each message is a dict with 'role' and 'content' fields. We assume only the last message is different, and the prompt is contained in the list of messages. """ - chosen_messages = row[CHOSEN_KEY] - rejected_messages = row[REJECTED_KEY] + chosen_messages = row[chosen_key] + rejected_messages = row[rejected_key] if len(chosen_messages) == 0: raise ValueError("chosen messages field is empty.") if len(rejected_messages) == 0: