Skip to content

Commit

Permalink
Olmo2 and olmoe support (#525)
Browse files Browse the repository at this point in the history
* Support olmo2 and olmoe

* push

* push changes

* add note
  • Loading branch information
vwxyzjn authored Jan 23, 2025
1 parent 6485393 commit 1d17da6
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 3 deletions.
1 change: 1 addition & 0 deletions configs/beaker_configs/ray_node_setup.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export CURRENT_DATETIME=$(python -c "import datetime; import pytz; print(datetime.datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%m%d%y_%H%M%S'))")
export PYTHONPATH=$REPO_PATH
export PATH="/root/.local/bin:$PATH"
export NCCL_CUMEM_ENABLE=0


echo CURRENT_DATETIME=$CURRENT_DATETIME
Expand Down
226 changes: 226 additions & 0 deletions open_instruct/olmo_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.models.olmo2.modeling_olmo2 import Olmo2Config, Olmo2Model, Olmo2PreTrainedModel
from transformers.models.olmoe.modeling_olmoe import OlmoeConfig, OlmoeModel, OlmoePreTrainedModel


class Olmo2ForSequenceClassification(Olmo2PreTrainedModel):
def __init__(self, config: Olmo2Config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Olmo2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


class OlmoeForSequenceClassification(OlmoePreTrainedModel):
def __init__(self, config: Olmo2Config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = OlmoeModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
11 changes: 9 additions & 2 deletions open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,9 @@ def from_pretrained(
self.model.train()

# value model
from open_instruct.olmo_adapter import Olmo2Config, Olmo2ForSequenceClassification, OlmoeConfig, OlmoeForSequenceClassification
AutoModelForSequenceClassification.register(Olmo2Config, Olmo2ForSequenceClassification)
AutoModelForSequenceClassification.register(OlmoeConfig, OlmoeForSequenceClassification)
self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
args.reward_model_path,
revision=args.reward_model_revision,
Expand Down Expand Up @@ -670,7 +673,9 @@ def from_pretrained(
# reference model
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
# inference model only has stage 3 (sharding) or stage 0 (no sharding)
# stage 2 is optimizer sharding which doesn't apply to inference
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
Expand Down Expand Up @@ -709,7 +714,9 @@ def from_pretrained(
disable_dropout_in_model(self.reward_model)
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
# inference model only has stage 3 (sharding) or stage 0 (no sharding)
# stage 2 is optimizer sharding which doesn't apply to inference
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
Expand Down
6 changes: 5 additions & 1 deletion open_instruct/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ def layer_init(layer: nn.Module, std: float):


def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
from open_instruct.olmo_adapter import Olmo2Config, Olmo2ForSequenceClassification, OlmoeConfig, OlmoeForSequenceClassification
AutoModelForSequenceClassification.register(Olmo2Config, Olmo2ForSequenceClassification)
AutoModelForSequenceClassification.register(OlmoeConfig, OlmoeForSequenceClassification)
accelerator = calculate_runtime_args_and_accelerator(args, model_config)
local_seed = args.seed + accelerator.process_index

Expand Down Expand Up @@ -241,7 +244,8 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|>
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding
tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]
if dataset_config.chat_template is not None:
tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]

# create the dataset
dataset_dict = DatasetDict()
Expand Down

0 comments on commit 1d17da6

Please sign in to comment.