Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF integration #291

Merged
merged 14 commits into from
Jul 17, 2024
2 changes: 1 addition & 1 deletion open_lm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None):
if attention_mask is None:
bias = None
# If we only have one query, assume we don't need to be in causal mode (can attend to all keys).
if queries.shape == 1:
if queries.shape[1] == 1:
is_causal = False
else:
if not is_causal:
Expand Down
3 changes: 3 additions & 0 deletions open_lm/hf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .configuration_openlm import OpenLMConfig
from .modeling_openlm import OpenLMForCausalLM
from .tokenization_openlm import OpenLMTokenizerFast
24 changes: 24 additions & 0 deletions open_lm/hf/configuration_openlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Follows OLMo's HF template

"""
OpenLM configuration
"""

from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging

from open_lm.model import Params

logger = logging.get_logger(__name__)


class OpenLMConfig(PretrainedConfig):
model_type = "openlm"

def __init__(self, **kwargs):
kwargs["architectures"] = ["OpenLMForCausalLM"]
super().__init__(**kwargs)


# Register the config class so that it is available for transformer pipelines, auto-loading etc.
AutoConfig.register("openlm", OpenLMConfig)
194 changes: 194 additions & 0 deletions open_lm/hf/modeling_openlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Follows OLMo's HF template

import logging
from dataclasses import fields
from typing import List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

from open_lm.model import Params, Transformer
from open_lm.norms import get_norm_class
from open_lm.attention import get_attn_func

from .configuration_openlm import OpenLMConfig

log = logging.getLogger(__name__)


def create_model_config_from_pretrained_config(config: OpenLMConfig):
"""
Utility function
"""

kwargs = {}
for field in fields(Params):
if hasattr(config, field.name):
kwargs[field.name] = getattr(config, field.name)

model_config = Params(**kwargs)

if hasattr(config, "norm_type"):
model_config.norm_type = get_norm_class(config.norm_type)

if hasattr(config, "attn_name"):
model_config.attn_func = get_attn_func(config.attn_name)

return model_config


class OpenLMForCausalLM(PreTrainedModel):
"""
Extremely barebones HF model wrapper.
"""

config_class = OpenLMConfig
base_model_prefix = "model"

def __init__(self, config: OpenLMConfig, model: Optional[Transformer] = None):
super().__init__(config)

if not model:
self.model_config = create_model_config_from_pretrained_config(config)
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
self.model_config.init_device = "cpu"
self.model = Transformer(self.model_config)

else:
self.model = model

def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[
Cache
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is not None:
log.warning("inputs_embeds is set but OpenLM does not support it yet")
if attention_bias is not None:
log.warning("attention_bias is et but OpenLM does not support it yet")
if use_cache is None:
use_cache = True
if output_attentions:
raise ValueError("output_attentions is not yet supported in OpenLM")
if output_hidden_states:
raise ValueError("output_hidden_states is not yet supported in OpenLM")

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# print("outer past_key_values: ", type(past_key_values))
# if past_key_values is not None:
# print(len(past_key_values), type(past_key_values[0]))
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
)

logits = outputs[0]
past_key_values = outputs[2]
hidden_states = None

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.model_config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=hidden_states,
)

def can_generate(self) -> bool:
return True

def prepare_inputs_for_generation(
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
):
if past_key_values is not None:
if isinstance(past_key_values[0][1], int):
# This assumes that the second item of past key values is the length of the past (this is the case for linear attention)
past_length = past_key_values[0][1]
else:
# This assumes that the first item of past key values is a list of all the past keys, thus the
# shape 1 is the length of the past (this is the case for attention without window)
past_length = past_key_values[0][0].shape[1]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.pop("use_cache", True),
}
return model_inputs

def get_input_embeddings(self) -> torch.nn.Module:
return self.model.tok_embeddings

def set_input_embeddings(self, value: torch.nn.Module):
self.model.tok_embeddings = value

def get_output_embeddings(self):
if self.model_config.weight_tying:
return self.model.tok_embeddings
else:
return self.model.output

def set_output_embeddings(self, value: torch.nn.Module):
if self.model_config.weight_tying:
self.model.tok_embeddings = value
else:
self.model.output = value

def tie_weights(self):
"""
Copied from OLMo (description below). I removed it and the results just became garbage, so this pass is needed.
This function is intentionally left as a no-op.
Weight tying is handled as follows:
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
Therefore, there is no need to explicitly tie the weights in this function.
"""
pass

def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> torch.nn.Embedding:
raise NotImplementedError


# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoModelForCausalLM.register(OpenLMConfig, OpenLMForCausalLM)
18 changes: 18 additions & 0 deletions open_lm/hf/tokenization_openlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Follows OLMo's HF template

from transformers import AutoTokenizer, PreTrainedTokenizerFast

from open_lm.hf.configuration_openlm import OpenLMConfig


class OpenLMTokenizerFast(PreTrainedTokenizerFast):
# Note: OpenLM's tokenizer is already a wrapper around huggingface. This is potentially unnecessary.
pass

# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# # This is required to make the implementation complete.
# pass


# Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc.
AutoTokenizer.register(OpenLMConfig, fast_tokenizer_class=OpenLMTokenizerFast)
Loading