-
Notifications
You must be signed in to change notification settings - Fork 279
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Support olmo2 and olmoe * push * push changes * add note
- Loading branch information
Showing
4 changed files
with
241 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss | ||
from transformers.modeling_outputs import SequenceClassifierOutputWithPast | ||
from transformers.models.olmo2.modeling_olmo2 import Olmo2Config, Olmo2Model, Olmo2PreTrainedModel | ||
from transformers.models.olmoe.modeling_olmoe import OlmoeConfig, OlmoeModel, OlmoePreTrainedModel | ||
|
||
|
||
class Olmo2ForSequenceClassification(Olmo2PreTrainedModel): | ||
def __init__(self, config: Olmo2Config): | ||
super().__init__(config) | ||
self.num_labels = config.num_labels | ||
self.model = Olmo2Model(config) | ||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
def get_input_embeddings(self): | ||
return self.model.embed_tokens | ||
|
||
def set_input_embeddings(self, value): | ||
self.model.embed_tokens = value | ||
|
||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: | ||
r""" | ||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | ||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | ||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
""" | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
transformer_outputs = self.model( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
hidden_states = transformer_outputs[0] | ||
logits = self.score(hidden_states) | ||
|
||
if input_ids is not None: | ||
batch_size = input_ids.shape[0] | ||
else: | ||
batch_size = inputs_embeds.shape[0] | ||
|
||
if self.config.pad_token_id is None and batch_size != 1: | ||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | ||
if self.config.pad_token_id is None: | ||
sequence_lengths = -1 | ||
else: | ||
if input_ids is not None: | ||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility | ||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 | ||
sequence_lengths = sequence_lengths % input_ids.shape[-1] | ||
sequence_lengths = sequence_lengths.to(logits.device) | ||
else: | ||
sequence_lengths = -1 | ||
|
||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | ||
|
||
loss = None | ||
if labels is not None: | ||
labels = labels.to(logits.device) | ||
if self.config.problem_type is None: | ||
if self.num_labels == 1: | ||
self.config.problem_type = "regression" | ||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | ||
self.config.problem_type = "single_label_classification" | ||
else: | ||
self.config.problem_type = "multi_label_classification" | ||
|
||
if self.config.problem_type == "regression": | ||
loss_fct = MSELoss() | ||
if self.num_labels == 1: | ||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) | ||
else: | ||
loss = loss_fct(pooled_logits, labels) | ||
elif self.config.problem_type == "single_label_classification": | ||
loss_fct = CrossEntropyLoss() | ||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) | ||
elif self.config.problem_type == "multi_label_classification": | ||
loss_fct = BCEWithLogitsLoss() | ||
loss = loss_fct(pooled_logits, labels) | ||
if not return_dict: | ||
output = (pooled_logits,) + transformer_outputs[1:] | ||
return ((loss,) + output) if loss is not None else output | ||
|
||
return SequenceClassifierOutputWithPast( | ||
loss=loss, | ||
logits=pooled_logits, | ||
past_key_values=transformer_outputs.past_key_values, | ||
hidden_states=transformer_outputs.hidden_states, | ||
attentions=transformer_outputs.attentions, | ||
) | ||
|
||
|
||
class OlmoeForSequenceClassification(OlmoePreTrainedModel): | ||
def __init__(self, config: Olmo2Config): | ||
super().__init__(config) | ||
self.num_labels = config.num_labels | ||
self.model = OlmoeModel(config) | ||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
def get_input_embeddings(self): | ||
return self.model.embed_tokens | ||
|
||
def set_input_embeddings(self, value): | ||
self.model.embed_tokens = value | ||
|
||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
output_router_logits: Optional[bool] = None, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
) -> Union[Tuple, SequenceClassifierOutputWithPast]: | ||
r""" | ||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | ||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | ||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
""" | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
transformer_outputs = self.model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
output_router_logits=output_router_logits, | ||
return_dict=return_dict, | ||
cache_position=cache_position, | ||
) | ||
hidden_states = transformer_outputs[0] | ||
logits = self.score(hidden_states) | ||
|
||
if input_ids is not None: | ||
batch_size = input_ids.shape[0] | ||
else: | ||
batch_size = inputs_embeds.shape[0] | ||
|
||
if self.config.pad_token_id is None and batch_size != 1: | ||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | ||
if self.config.pad_token_id is None: | ||
sequence_lengths = -1 | ||
else: | ||
if input_ids is not None: | ||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility | ||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 | ||
sequence_lengths = sequence_lengths % input_ids.shape[-1] | ||
sequence_lengths = sequence_lengths.to(logits.device) | ||
else: | ||
sequence_lengths = -1 | ||
|
||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | ||
|
||
loss = None | ||
if labels is not None: | ||
labels = labels.to(logits.device) | ||
if self.config.problem_type is None: | ||
if self.num_labels == 1: | ||
self.config.problem_type = "regression" | ||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | ||
self.config.problem_type = "single_label_classification" | ||
else: | ||
self.config.problem_type = "multi_label_classification" | ||
|
||
if self.config.problem_type == "regression": | ||
loss_fct = MSELoss() | ||
if self.num_labels == 1: | ||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) | ||
else: | ||
loss = loss_fct(pooled_logits, labels) | ||
elif self.config.problem_type == "single_label_classification": | ||
loss_fct = CrossEntropyLoss() | ||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) | ||
elif self.config.problem_type == "multi_label_classification": | ||
loss_fct = BCEWithLogitsLoss() | ||
loss = loss_fct(pooled_logits, labels) | ||
if not return_dict: | ||
output = (pooled_logits,) + transformer_outputs[1:] | ||
return ((loss,) + output) if loss is not None else output | ||
|
||
return SequenceClassifierOutputWithPast( | ||
loss=loss, | ||
logits=pooled_logits, | ||
past_key_values=transformer_outputs.past_key_values, | ||
hidden_states=transformer_outputs.hidden_states, | ||
attentions=transformer_outputs.attentions, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters