From 2253664c6538a8c9373ab648e45167c94c4e20e6 Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Tue, 16 Jul 2024 07:15:19 +0000 Subject: [PATCH 01/13] add hf --- open_lm/hf/__init__.py | 3 + open_lm/hf/configuration_openlm.py | 24 ++++ open_lm/hf/modeling_openlm.py | 190 +++++++++++++++++++++++++++++ open_lm/hf/tokenization_openlm.py | 18 +++ 4 files changed, 235 insertions(+) create mode 100644 open_lm/hf/__init__.py create mode 100644 open_lm/hf/configuration_openlm.py create mode 100644 open_lm/hf/modeling_openlm.py create mode 100644 open_lm/hf/tokenization_openlm.py diff --git a/open_lm/hf/__init__.py b/open_lm/hf/__init__.py new file mode 100644 index 00000000..b33c5506 --- /dev/null +++ b/open_lm/hf/__init__.py @@ -0,0 +1,3 @@ +from .configuration_openlm import OpenLMConfig +from .modeling_openlm import OpenLMForCausalLM +from .tokenization_openlm import OpenLMTokenizerFast \ No newline at end of file diff --git a/open_lm/hf/configuration_openlm.py b/open_lm/hf/configuration_openlm.py new file mode 100644 index 00000000..2300fa79 --- /dev/null +++ b/open_lm/hf/configuration_openlm.py @@ -0,0 +1,24 @@ +# Follows OLMo's HF template + +""" +OpenLM configuration +""" + +from transformers import AutoConfig, PretrainedConfig +from transformers.utils import logging + +from open_lm.model import Params + +logger = logging.get_logger(__name__) + + +class OpenLMConfig(PretrainedConfig): + model_type = "openlm" + + def __init__(self, **kwargs): + kwargs["architectures"] = ["OpenLMForCausalLM"] + super().__init__(**kwargs) + + +# Register the config class so that it is available for transformer pipelines, auto-loading etc. +AutoConfig.register("openlm", OpenLMConfig) \ No newline at end of file diff --git a/open_lm/hf/modeling_openlm.py b/open_lm/hf/modeling_openlm.py new file mode 100644 index 00000000..548b9637 --- /dev/null +++ b/open_lm/hf/modeling_openlm.py @@ -0,0 +1,190 @@ +# Follows OLMo's HF template + +import logging +from dataclasses import fields +from typing import List, Optional, Tuple, Union + +import torch +from transformers import PreTrainedModel +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.auto import AutoModelForCausalLM + +from open_lm.model import Params, Transformer +from open_lm.norms import get_norm_class + +from .configuration_openlm import OpenLMConfig + +log = logging.getLogger(__name__) + + +def create_model_config_from_pretrained_config(config: OpenLMConfig): + """ + Utility function + """ + + kwargs = {} + for field in fields(Params): + if hasattr(config, field.name): + kwargs[field.name] = getattr(config, field.name) + + model_config = Params(**kwargs) + + if hasattr(config, "norm_type"): + model_config.norm_type = get_norm_class(config.norm_type) + + return model_config + + +class OpenLMForCausalLM(PreTrainedModel): + """ + Extremely barebones HF model wrapper. + """ + + config_class = OpenLMConfig + base_model_prefix = "model" + + def __init__(self, config: OpenLMConfig, model: Optional[Transformer] = None): + super().__init__(config) + + if not model: + self.model_config = create_model_config_from_pretrained_config(config) + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + self.model_config.init_device = "cpu" + self.model = Transformer(self.model_config) + + else: + self.model = model + + def forward( + self, + input_ids: torch.LongTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[ + Cache + ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is not None: + log.warning("inputs_embeds is set but OpenLM does not support it yet") + if attention_bias is not None: + log.warning("attention_bias is et but OpenLM does not support it yet") + if use_cache is None: + use_cache = True + if output_attentions: + raise ValueError("output_attentions is not yet supported in OpenLM") + if output_hidden_states: + raise ValueError("output_hidden_states is not yet supported in OpenLM") + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + # print("outer past_key_values: ", type(past_key_values)) + # if past_key_values is not None: + # print(len(past_key_values), type(past_key_values[0])) + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + ) + + logits = outputs[0] + past_key_values = outputs[2] + hidden_states = None + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.model_config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=past_key_values, + hidden_states=hidden_states, + ) + + def can_generate(self) -> bool: + return True + + def prepare_inputs_for_generation( + self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values[0][1], int): + # This assumes that the second item of past key values is the length of the past (this is the case for linear attention) + past_length = past_key_values[0][1] + else: + # This assumes that the first item of past key values is a list of all the past keys, thus the shape 1 is the length of the past (this is the case for attention without window) + past_length = past_key_values[0][0].shape[1] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + model_inputs = { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.pop("use_cache", True), + } + return model_inputs + + def get_input_embeddings(self) -> torch.nn.Module: + return self.model.tok_embeddings + + def set_input_embeddings(self, value: torch.nn.Module): + self.model.tok_embeddings = value + + def get_output_embeddings(self): + if self.model_config.weight_tying: + return self.model.tok_embeddings + else: + return self.model.output + + def set_output_embeddings(self, value: torch.nn.Module): + if self.model_config.weight_tying: + self.model.tok_embeddings = value + else: + self.model.output = value + + def tie_weights(self): + """ + Copied from OLMo (description below). I removed it and the results just became garbage, so this pass is needed. + This function is intentionally left as a no-op. + Weight tying is handled as follows: + - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration. + See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`. + - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled. + See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method. + Therefore, there is no need to explicitly tie the weights in this function. + """ + pass + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> torch.nn.Embedding: + raise NotImplementedError + + +# Register the model so that it is available for transformer pipelines, auto-loading, etc. +AutoModelForCausalLM.register(OpenLMConfig, OpenLMForCausalLM) \ No newline at end of file diff --git a/open_lm/hf/tokenization_openlm.py b/open_lm/hf/tokenization_openlm.py new file mode 100644 index 00000000..ae6bab2d --- /dev/null +++ b/open_lm/hf/tokenization_openlm.py @@ -0,0 +1,18 @@ +# Follows OLMo's HF template + +from transformers import AutoTokenizer, PreTrainedTokenizerFast + +from open_lm.open_lm_hf.configuration_openlm import OpenLMConfig + + +class OpenLMTokenizerFast(PreTrainedTokenizerFast): + # Note: OpenLM's tokenizer is already a wrapper around huggingface. This is potentially unnecessary. + pass + + # def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + # # This is required to make the implementation complete. + # pass + + +# Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc. +AutoTokenizer.register(OpenLMConfig, fast_tokenizer_class=OpenLMTokenizerFast) \ No newline at end of file From f77205207b8f27817648ca58ef014da941bb8fc9 Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Tue, 16 Jul 2024 07:18:27 +0000 Subject: [PATCH 02/13] import fix --- open_lm/hf/tokenization_openlm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/hf/tokenization_openlm.py b/open_lm/hf/tokenization_openlm.py index ae6bab2d..d31fef8d 100644 --- a/open_lm/hf/tokenization_openlm.py +++ b/open_lm/hf/tokenization_openlm.py @@ -2,7 +2,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerFast -from open_lm.open_lm_hf.configuration_openlm import OpenLMConfig +from open_lm.hf.configuration_openlm import OpenLMConfig class OpenLMTokenizerFast(PreTrainedTokenizerFast): From 410bb9f584b3c4055a2533ccccd87a578d4ca20e Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Tue, 16 Jul 2024 07:31:31 +0000 Subject: [PATCH 03/13] black reformat --- open_lm/hf/__init__.py | 2 +- open_lm/hf/configuration_openlm.py | 2 +- open_lm/hf/modeling_openlm.py | 2 +- open_lm/hf/tokenization_openlm.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/open_lm/hf/__init__.py b/open_lm/hf/__init__.py index b33c5506..84931689 100644 --- a/open_lm/hf/__init__.py +++ b/open_lm/hf/__init__.py @@ -1,3 +1,3 @@ from .configuration_openlm import OpenLMConfig from .modeling_openlm import OpenLMForCausalLM -from .tokenization_openlm import OpenLMTokenizerFast \ No newline at end of file +from .tokenization_openlm import OpenLMTokenizerFast diff --git a/open_lm/hf/configuration_openlm.py b/open_lm/hf/configuration_openlm.py index 2300fa79..75663962 100644 --- a/open_lm/hf/configuration_openlm.py +++ b/open_lm/hf/configuration_openlm.py @@ -21,4 +21,4 @@ def __init__(self, **kwargs): # Register the config class so that it is available for transformer pipelines, auto-loading etc. -AutoConfig.register("openlm", OpenLMConfig) \ No newline at end of file +AutoConfig.register("openlm", OpenLMConfig) diff --git a/open_lm/hf/modeling_openlm.py b/open_lm/hf/modeling_openlm.py index 548b9637..24ffafcc 100644 --- a/open_lm/hf/modeling_openlm.py +++ b/open_lm/hf/modeling_openlm.py @@ -187,4 +187,4 @@ def resize_token_embeddings( # Register the model so that it is available for transformer pipelines, auto-loading, etc. -AutoModelForCausalLM.register(OpenLMConfig, OpenLMForCausalLM) \ No newline at end of file +AutoModelForCausalLM.register(OpenLMConfig, OpenLMForCausalLM) diff --git a/open_lm/hf/tokenization_openlm.py b/open_lm/hf/tokenization_openlm.py index d31fef8d..e8abdd69 100644 --- a/open_lm/hf/tokenization_openlm.py +++ b/open_lm/hf/tokenization_openlm.py @@ -15,4 +15,4 @@ class OpenLMTokenizerFast(PreTrainedTokenizerFast): # Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc. -AutoTokenizer.register(OpenLMConfig, fast_tokenizer_class=OpenLMTokenizerFast) \ No newline at end of file +AutoTokenizer.register(OpenLMConfig, fast_tokenizer_class=OpenLMTokenizerFast) From 194ba283df9b2e304fd8c25e047cc4c0e57b3f2a Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Tue, 16 Jul 2024 16:27:59 +0000 Subject: [PATCH 04/13] get_attn_func in hf loader --- open_lm/hf/modeling_openlm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/open_lm/hf/modeling_openlm.py b/open_lm/hf/modeling_openlm.py index 24ffafcc..12ae0cc0 100644 --- a/open_lm/hf/modeling_openlm.py +++ b/open_lm/hf/modeling_openlm.py @@ -12,6 +12,7 @@ from open_lm.model import Params, Transformer from open_lm.norms import get_norm_class +from open_lm.attention import get_attn_func from .configuration_openlm import OpenLMConfig @@ -32,6 +33,9 @@ def create_model_config_from_pretrained_config(config: OpenLMConfig): if hasattr(config, "norm_type"): model_config.norm_type = get_norm_class(config.norm_type) + + if hasattr(config, "attn_name"): + model_config.attn_name = get_attn_func(config.attn_name) return model_config From b6c9da755436ae54c8305f6b34515f96bf4364ef Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Tue, 16 Jul 2024 16:37:56 +0000 Subject: [PATCH 05/13] typo --- open_lm/hf/modeling_openlm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/hf/modeling_openlm.py b/open_lm/hf/modeling_openlm.py index 12ae0cc0..120782b0 100644 --- a/open_lm/hf/modeling_openlm.py +++ b/open_lm/hf/modeling_openlm.py @@ -35,7 +35,7 @@ def create_model_config_from_pretrained_config(config: OpenLMConfig): model_config.norm_type = get_norm_class(config.norm_type) if hasattr(config, "attn_name"): - model_config.attn_name = get_attn_func(config.attn_name) + model_config.attn_func = get_attn_func(config.attn_name) return model_config From e4c94a0e74689777871c4667f0e6d738b885c606 Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Wed, 17 Jul 2024 00:18:07 +0000 Subject: [PATCH 06/13] black reformat --- open_lm/hf/modeling_openlm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/open_lm/hf/modeling_openlm.py b/open_lm/hf/modeling_openlm.py index 120782b0..c7429815 100644 --- a/open_lm/hf/modeling_openlm.py +++ b/open_lm/hf/modeling_openlm.py @@ -134,7 +134,8 @@ def prepare_inputs_for_generation( # This assumes that the second item of past key values is the length of the past (this is the case for linear attention) past_length = past_key_values[0][1] else: - # This assumes that the first item of past key values is a list of all the past keys, thus the shape 1 is the length of the past (this is the case for attention without window) + # This assumes that the first item of past key values is a list of all the past keys, thus the + # shape 1 is the length of the past (this is the case for attention without window) past_length = past_key_values[0][0].shape[1] # Some generation methods already pass only the last input ID From c12c987dfc10bc65fb414f363c9ee95708826b8e Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Wed, 17 Jul 2024 00:28:49 +0000 Subject: [PATCH 07/13] black reformat --- open_lm/hf/modeling_openlm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_lm/hf/modeling_openlm.py b/open_lm/hf/modeling_openlm.py index c7429815..687bd054 100644 --- a/open_lm/hf/modeling_openlm.py +++ b/open_lm/hf/modeling_openlm.py @@ -33,7 +33,7 @@ def create_model_config_from_pretrained_config(config: OpenLMConfig): if hasattr(config, "norm_type"): model_config.norm_type = get_norm_class(config.norm_type) - + if hasattr(config, "attn_name"): model_config.attn_func = get_attn_func(config.attn_name) @@ -134,7 +134,7 @@ def prepare_inputs_for_generation( # This assumes that the second item of past key values is the length of the past (this is the case for linear attention) past_length = past_key_values[0][1] else: - # This assumes that the first item of past key values is a list of all the past keys, thus the + # This assumes that the first item of past key values is a list of all the past keys, thus the # shape 1 is the length of the past (this is the case for attention without window) past_length = past_key_values[0][0].shape[1] From 539ca60d0360d8364f3ef1c35ee69f4dd9a2930b Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Wed, 17 Jul 2024 00:38:46 +0000 Subject: [PATCH 08/13] black reformat --- open_lm/hf/modeling_openlm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/open_lm/hf/modeling_openlm.py b/open_lm/hf/modeling_openlm.py index 687bd054..67ee1e4f 100644 --- a/open_lm/hf/modeling_openlm.py +++ b/open_lm/hf/modeling_openlm.py @@ -76,7 +76,6 @@ def forward( Cache ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 ) -> Union[Tuple, CausalLMOutputWithPast]: - if inputs_embeds is not None: log.warning("inputs_embeds is set but OpenLM does not support it yet") if attention_bias is not None: From ce8df74de651c732a771aa9be40c70c0ed84fdb8 Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Wed, 17 Jul 2024 07:00:51 +0000 Subject: [PATCH 09/13] reqs --- requirements_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_test.txt b/requirements_test.txt index 61f15ce7..517aa276 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -3,4 +3,4 @@ pytest-cov==3.0.0 pytest-xdist==2.5.0 pytest==7.0.1 tensorboard==2.14.1 -llm-foundry>=0.4.0 +llm-foundry==0.4.0 From 6a2357335de437ea98288414dd7d7ffc0955d463 Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Wed, 17 Jul 2024 07:11:21 +0000 Subject: [PATCH 10/13] reqs --- requirements_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_test.txt b/requirements_test.txt index 517aa276..bbcdb356 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -3,4 +3,4 @@ pytest-cov==3.0.0 pytest-xdist==2.5.0 pytest==7.0.1 tensorboard==2.14.1 -llm-foundry==0.4.0 +llm-foundry==0.7.0 From ed1de8ec93f6109bd28e572c446534a39f988861 Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Wed, 17 Jul 2024 07:44:12 +0000 Subject: [PATCH 11/13] reqs --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d8787cfb..898d184f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ pandas==2.1.4 fsspec tqdm jsonlines -boto3 +boto3==1.26.90 Pillow zstandard pysimdjson From 7438ac419baa85279d3c3fa2b8851cb6b025425c Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Wed, 17 Jul 2024 07:53:02 +0000 Subject: [PATCH 12/13] reqs --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 898d184f..d8787cfb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ pandas==2.1.4 fsspec tqdm jsonlines -boto3==1.26.90 +boto3 Pillow zstandard pysimdjson From e439ab6ebecb02a0acb31f13cfdf4cfbbab6e1ef Mon Sep 17 00:00:00 2001 From: sedrick-keh-tri Date: Wed, 17 Jul 2024 22:03:23 +0000 Subject: [PATCH 13/13] indexing bug --- open_lm/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index e0e8aba5..7f2e2f4c 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -111,7 +111,7 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None): if attention_mask is None: bias = None # If we only have one query, assume we don't need to be in causal mode (can attend to all keys). - if queries.shape == 1: + if queries.shape[1] == 1: is_causal = False else: if not is_causal: