Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Olmo2 and olmoe support #525

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/beaker_configs/ray_node_setup.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export CURRENT_DATETIME=$(python -c "import datetime; import pytz; print(datetime.datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%m%d%y_%H%M%S'))")
export PYTHONPATH=$REPO_PATH
export PATH="/root/.local/bin:$PATH"
export NCCL_CUMEM_ENABLE=0


echo CURRENT_DATETIME=$CURRENT_DATETIME
Expand Down
226 changes: 226 additions & 0 deletions open_instruct/olmo_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.models.olmo2.modeling_olmo2 import Olmo2Config, Olmo2Model, Olmo2PreTrainedModel
from transformers.models.olmoe.modeling_olmoe import OlmoeConfig, OlmoeModel, OlmoePreTrainedModel


class Olmo2ForSequenceClassification(Olmo2PreTrainedModel):
def __init__(self, config: Olmo2Config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Olmo2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


class OlmoeForSequenceClassification(OlmoePreTrainedModel):
def __init__(self, config: Olmo2Config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = OlmoeModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
7 changes: 5 additions & 2 deletions open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,9 @@ def from_pretrained(
self.model.train()

# value model
from open_instruct.olmo_adapter import Olmo2Config, Olmo2ForSequenceClassification, OlmoeConfig, OlmoeForSequenceClassification
AutoModelForSequenceClassification.register(Olmo2Config, Olmo2ForSequenceClassification)
AutoModelForSequenceClassification.register(OlmoeConfig, OlmoeForSequenceClassification)
self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
args.reward_model_path,
revision=args.reward_model_revision,
Expand Down Expand Up @@ -670,7 +673,7 @@ def from_pretrained(
# reference model
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
Expand Down Expand Up @@ -709,7 +712,7 @@ def from_pretrained(
disable_dropout_in_model(self.reward_model)
ds_config = get_eval_ds_config(
offload=False,
stage=args.deepspeed_stage,
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
Expand Down
6 changes: 5 additions & 1 deletion open_instruct/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ def layer_init(layer: nn.Module, std: float):


def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
from open_instruct.olmo_adapter import Olmo2Config, Olmo2ForSequenceClassification, OlmoeConfig, OlmoeForSequenceClassification
AutoModelForSequenceClassification.register(Olmo2Config, Olmo2ForSequenceClassification)
AutoModelForSequenceClassification.register(OlmoeConfig, OlmoeForSequenceClassification)
accelerator = calculate_runtime_args_and_accelerator(args, model_config)
local_seed = args.seed + accelerator.process_index

Expand Down Expand Up @@ -241,7 +244,8 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|>
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding
tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]
if dataset_config.chat_template is not None:
tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]

# create the dataset
dataset_dict = DatasetDict()
Expand Down
Loading