From eab1f79221b1f41d63461a901aea35e8c51bb843 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 22 Jan 2025 17:04:27 -0800 Subject: [PATCH 1/4] Support olmo2 and olmoe --- open_instruct/ppo_vllm_thread_ray_gtrl.py | 7 +++++-- open_instruct/reward_modeling.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index 78979ce99..a747f5784 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -632,6 +632,9 @@ def from_pretrained( self.model.train() # value model + from open_instruct.olmo_adapter import Olmo2Config, Olmo2ForSequenceClassification, OlmoeConfig, OlmoeForSequenceClassification + AutoModelForSequenceClassification.register(Olmo2Config, Olmo2ForSequenceClassification) + AutoModelForSequenceClassification.register(OlmoeConfig, OlmoeForSequenceClassification) self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( args.reward_model_path, revision=args.reward_model_revision, @@ -668,7 +671,7 @@ def from_pretrained( # reference model ds_config = get_eval_ds_config( offload=False, - stage=args.deepspeed_stage, + stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0, bf16=True, ) ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size @@ -707,7 +710,7 @@ def from_pretrained( disable_dropout_in_model(self.reward_model) ds_config = get_eval_ds_config( offload=False, - stage=args.deepspeed_stage, + stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0, bf16=True, ) ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size diff --git a/open_instruct/reward_modeling.py b/open_instruct/reward_modeling.py index c59140840..dc32222be 100644 --- a/open_instruct/reward_modeling.py +++ b/open_instruct/reward_modeling.py @@ -241,7 +241,8 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding - tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + if dataset_config.chat_template is not None: + tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] # create the dataset dataset_dict = DatasetDict() From 226aedbaf31d6448794f015e623824dd8b1be996 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 22 Jan 2025 17:04:45 -0800 Subject: [PATCH 2/4] push --- open_instruct/reward_modeling.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/open_instruct/reward_modeling.py b/open_instruct/reward_modeling.py index dc32222be..dd63d1644 100644 --- a/open_instruct/reward_modeling.py +++ b/open_instruct/reward_modeling.py @@ -195,6 +195,9 @@ def layer_init(layer: nn.Module, std: float): def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): + from open_instruct.olmo_adapter import Olmo2Config, Olmo2ForSequenceClassification, OlmoeConfig, OlmoeForSequenceClassification + AutoModelForSequenceClassification.register(Olmo2Config, Olmo2ForSequenceClassification) + AutoModelForSequenceClassification.register(OlmoeConfig, OlmoeForSequenceClassification) accelerator = calculate_runtime_args_and_accelerator(args, model_config) local_seed = args.seed + accelerator.process_index From e4aac3ba01300ccda36549a66925999f6ccdd657 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 23 Jan 2025 15:17:54 -0800 Subject: [PATCH 3/4] push changes --- configs/beaker_configs/ray_node_setup.sh | 1 + open_instruct/olmo_adapter/__init__.py | 226 +++++++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 open_instruct/olmo_adapter/__init__.py diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index fe075749c..d6375f32c 100755 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -1,6 +1,7 @@ export CURRENT_DATETIME=$(python -c "import datetime; import pytz; print(datetime.datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%m%d%y_%H%M%S'))") export PYTHONPATH=$REPO_PATH export PATH="/root/.local/bin:$PATH" +export NCCL_CUMEM_ENABLE=0 echo CURRENT_DATETIME=$CURRENT_DATETIME diff --git a/open_instruct/olmo_adapter/__init__.py b/open_instruct/olmo_adapter/__init__.py new file mode 100644 index 000000000..f39796b11 --- /dev/null +++ b/open_instruct/olmo_adapter/__init__.py @@ -0,0 +1,226 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss +from transformers.modeling_outputs import SequenceClassifierOutputWithPast +from transformers.models.olmo2.modeling_olmo2 import Olmo2Config, Olmo2Model, Olmo2PreTrainedModel +from transformers.models.olmoe.modeling_olmoe import OlmoeConfig, OlmoeModel, OlmoePreTrainedModel + + +class Olmo2ForSequenceClassification(Olmo2PreTrainedModel): + def __init__(self, config: Olmo2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Olmo2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +class OlmoeForSequenceClassification(OlmoePreTrainedModel): + def __init__(self, config: Olmo2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OlmoeModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) From 86663077c35b29f9235a19bf2af194198dccfc15 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 23 Jan 2025 15:33:49 -0800 Subject: [PATCH 4/4] add note --- open_instruct/ppo_vllm_thread_ray_gtrl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index 79d61cc4c..1c9c3e991 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -673,6 +673,8 @@ def from_pretrained( # reference model ds_config = get_eval_ds_config( offload=False, + # inference model only has stage 3 (sharding) or stage 0 (no sharding) + # stage 2 is optimizer sharding which doesn't apply to inference stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0, bf16=True, ) @@ -712,6 +714,8 @@ def from_pretrained( disable_dropout_in_model(self.reward_model) ds_config = get_eval_ds_config( offload=False, + # inference model only has stage 3 (sharding) or stage 0 (no sharding) + # stage 2 is optimizer sharding which doesn't apply to inference stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0, bf16=True, )