Skip to content

Commit

Permalink
Allow using tokenizer chat template (#513)
Browse files Browse the repository at this point in the history
* Allow using tokenizer chat template

* add suggestion
  • Loading branch information
hamishivi authored Jan 14, 2025
1 parent bcb991d commit ec76d40
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
6 changes: 3 additions & 3 deletions open_instruct/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
@dataclass
class DatasetConfig:
# dataset specs
chat_template: str = "simple_chat"
chat_template: Optional[str] = None

# columns names for preference dataset
preference_chosen_key: str = "chosen"
Expand Down Expand Up @@ -206,8 +206,8 @@ def __post_init__(self):
self.num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count())))
self.load_from_cache_file = True

if self.chat_template not in CHAT_TEMPLATES:
raise ValueError(f"chat_template must be one of {list(CHAT_TEMPLATES.keys())}")
if self.chat_template is not None and self.chat_template not in CHAT_TEMPLATES:
raise ValueError(f"chat_template must None or one of {list(CHAT_TEMPLATES.keys())}")


def get_num_proc(dataset_len: int, num_available_cpus: int, example_per_second_per_cpu) -> int:
Expand Down
3 changes: 2 additions & 1 deletion open_instruct/ppo_vllm_thread_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,8 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|>
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding
tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]
if dataset_config.chat_template is not None:
tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]

# create the dataset
dataset_dict = DatasetDict()
Expand Down
3 changes: 2 additions & 1 deletion open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,7 +1606,8 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|>
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding
tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]
if dataset_config.chat_template is not None:
tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]

# create the dataset
dataset_dict = DatasetDict()
Expand Down

0 comments on commit ec76d40

Please sign in to comment.