Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Olmo rl #417

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions open_instruct/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class DatasetConfig:

# filter config
max_token_length: Optional[int] = None
max_prompt_token_lenth: Optional[int] = None
max_prompt_token_length: Optional[int] = None

# dataset.map config
sanity_check: bool = False
Expand Down Expand Up @@ -314,8 +314,8 @@ def tokenize_fn(row):
def filter(self, dataset: Union[Dataset, DatasetDict]):
def filter_fn(row):
return (
len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_lenth
if self.config.max_prompt_token_lenth is not None
len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_length
if self.config.max_prompt_token_length is not None
else (
True and len(row[INPUT_IDS_CHOSEN_KEY]) <= self.config.max_token_length
if self.config.max_token_length is not None
Expand Down Expand Up @@ -388,8 +388,8 @@ def tokenize_fn(row):
def filter(self, dataset: Dataset):
def filter_fn(row):
max_prompt_token_length_ok = True
if self.config.max_prompt_token_lenth is not None:
max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_lenth
if self.config.max_prompt_token_length is not None:
max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_length

max_token_length_ok = True
if self.config.max_token_length is not None:
Expand Down
76 changes: 72 additions & 4 deletions open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_reward(

# Calculate position IDs for each token, considering the cumulative sum of the attention mask (to exclude padding)
# Shape: (batch_size, sequence_length)
position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum
# position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum

# Access the LM backbone from the reward model using its base model prefix
lm_backbone = getattr(model, model.base_model_prefix)
Expand All @@ -176,7 +176,7 @@ def get_reward(
output = lm_backbone(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
# # position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
use_cache=False, # otherwise mistral-based RM would error out
Expand Down Expand Up @@ -206,6 +206,74 @@ def get_reward(
)


def get_reward_olmo(
model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This function computes reward scores for a batch of query responses based on a pre-trained reward model.

Args:
model (torch.nn.Module): The pre-trained reward model.
query_responses (torch.Tensor): Tensor containing the tokenized responses for which to compute rewards.
Shape: (batch_size, sequence_length)
pad_token_id (int): The ID used for padding tokens in the tokenized sequences.
context_length (int): The length of the prompt or context preceding the completions.

Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
- reward_logits: The logits output from the model for all tokens in the sequences.
Shape: (batch_size, sequence_length)
- final_scores: The final reward scores, one for each sequence, after adjusting for sequence lengths.
Shape: (batch_size,)
- sequence_lengths: The lengths of each sequence (excluding padding).
Shape: (batch_size,)
"""

# Create an attention mask where tokens that are not padding have a value of 1, and padding tokens have a value of 0
# Shape: (batch_size, sequence_length)
attention_mask = query_responses != pad_token_id

# Calculate position IDs for each token, considering the cumulative sum of the attention mask (to exclude padding)
# Shape: (batch_size, sequence_length)
# position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum

# Access the LM backbone from the reward model using its base model prefix
lm_backbone = getattr(model, model.base_model_prefix)

# Replace padding tokens with zeros in the input IDs (so padding tokens won't affect the model's processing)
# Shape: (batch_size, sequence_length)
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
output = lm_backbone(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False, # otherwise mistral-based RM would error out
)
reward_logits = model.score(output.hidden_states[-1]) # (batch_size, sequence_length)

# Calculate the length of each sequence by finding the first occurrence of a padding token after the context
# sequence_lengths shape: (batch_size,)
sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
assert (
reward_logits.shape[-1] == 1
), "Reward model should output a single scalar per token. Check if you added `num_labels=1` when doing `AutoModelForSequenceClassification.from_pretrained(...)`."
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454

# Return the reward logits for all tokens, the final reward scores for each sequence, and the sequence lengths
return (
# reward_logits shape: (batch_size, sequence_length)
reward_logits,
# final_scores shape: (batch_size,)
reward_logits[
torch.arange(reward_logits.size(0), device=reward_logits.device),
sequence_lengths,
].squeeze(
-1
), # Shape: (batch_size,)
sequence_lengths,
)


def forward(
model: torch.nn.Module,
query_responses: torch.Tensor,
Expand All @@ -225,12 +293,12 @@ def forward(
The output of the model, including hidden states.
"""
attention_mask = query_responses != pad_token_id
position_ids = attention_mask.cumsum(1) - attention_mask.long()
# position_ids = attention_mask.cumsum(1) - attention_mask.long()
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
return model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
# # position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
)
Expand Down
Empty file.
226 changes: 226 additions & 0 deletions open_instruct/olmo_adapter/modeling_olmo2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from typing import Callable, Optional, Union, List, Tuple
import torch
from torch import nn
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from hf_olmo import OLMoTokenizerFast, OLMoConfig, OLMoForCausalLM
from hf_olmo.modeling_olmo import OLMo, create_model_config_from_pretrained_config, ActivationCheckpointingStrategy

class OLMoForSequenceClassification(PreTrainedModel):

config_class = OLMoConfig
base_model_prefix = "model"
_no_split_modules = ["OLMoBlock"]
_supports_flash_attn_2 = True
_supports_sdpa = True
supports_gradient_checkpointing = True

def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
super().__init__(config)
self._gradient_checkpointing_func: Optional[Callable] = None
self._gradient_checkpointing = False

self.num_labels = config.num_labels
if not model:
model_config = create_model_config_from_pretrained_config(config)
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
model_config.init_device = "cpu"
self.model = OLMo(model_config, init_params=init_params)
else:
self.model = model

self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)


@property
def gradient_checkpointing(self) -> bool:
return self._gradient_checkpointing

@gradient_checkpointing.setter
def gradient_checkpointing(self, enabled: bool):
if self._gradient_checkpointing == enabled:
return

# HF does not specify a way to pass checkpointing strategies, so we pick
# whole layer as our strategy. We can make this configurable later if needed.
checkpointing_strategy = ActivationCheckpointingStrategy.whole_layer if enabled else None
self.model.set_activation_checkpointing(
checkpointing_strategy, checkpoint_func=self._gradient_checkpointing_func
)
self._gradient_checkpointing = enabled

def get_input_embeddings(self) -> torch.nn.Module:
return self.model.transformer.wte

def set_input_embeddings(self, value: torch.nn.Module):
self.model.transformer.wte = value

def get_output_embeddings(self):
if self.config.weight_tying:
return self.model.transformer.wte
else:
return self.model.transformer.ff_out

def set_output_embeddings(self, value: torch.nn.Module):
if self.config.weight_tying:
self.model.transformer.wte = value
else:
self.model.transformer.ff_out = value

def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> torch.nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.

Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.

Arguments:
new_num_tokens (`int`, *optional*):
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.

This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc

Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.

Note:
This method differs from the base class implementation by resizing the `embedding_size` attribute of the
model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

# Update base model and current model config
self.config.embedding_size = model_embeds.weight.shape[0]
self.model.config.embedding_size = model_embeds.weight.shape[0]

# Check if the embedding size is less than the vocab size
if self.config.embedding_size < self.config.vocab_size:
warning_message = (
f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
"size is less than or equal to the new token embedding size."
)
# log.warning(warning_message)
print(warning_message)

# Tie weights again if needed
self.tie_weights()

return model_embeds


def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
"""
Forward pass for sequence classification with OLMo.

Args:
input_ids: Input token IDs
attention_mask: Attention mask
position_ids: Position IDs for positional encoding
past_key_values: Past key values for incremental decoding
inputs_embeds: Pre-computed input embeddings
labels: Labels for computing loss
use_cache: Whether to use cached key/values
output_attentions: Whether to output attention weights
output_hidden_states: Whether to output hidden states
return_dict: Whether to return a ModelOutput object

Returns:
SequenceClassifierOutputWithPast or tuple: Classification outputs
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")

if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)

if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
Loading
Loading