Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P2] Add OLMo models. #182

Merged
merged 1 commit into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyvene/models/intervenable_modelcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .blip.modelings_intervenable_blip_itm import *
from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import *
from .llava.modelings_intervenable_llava import *

from .olmo.modelings_intervenable_olmo import *

#########################################################################
"""
Expand Down Expand Up @@ -52,6 +52,8 @@
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_module_mapping,
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_module_mapping,
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_module_mapping,
BlipWrapper: blip_wrapper_type_to_module_mapping,
Expand Down Expand Up @@ -83,6 +85,8 @@
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_dimension_mapping,
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_dimension_mapping,
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_dimension_mapping,
BlipWrapper: blip_wrapper_type_to_dimension_mapping,
Expand Down
Empty file added pyvene/models/olmo/__init__.py
Empty file.
93 changes: 93 additions & 0 deletions pyvene/models/olmo/modelings_intervenable_olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Each modeling file in this library is a mapping between
abstract naming of intervention anchor points and actual
model module defined in the huggingface library.

We also want to let the intervention library know how to
config the dimensions of intervention based on model config
defined in the huggingface library.
"""


import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from ..constants import *


olmo_type_to_module_mapping = {
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
}


olmo_type_to_dimension_mapping = {
"n_head": ("num_attention_heads",),
"n_kv_head": ("num_key_value_heads",),
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": ("intermediate_size",),
"mlp_output": ("hidden_size",),
"mlp_input": ("hidden_size",),
"attention_value_output": ("hidden_size",),
"head_attention_value_output": ("hidden_size/num_attention_heads",),
"attention_output": ("hidden_size",),
"attention_input": ("hidden_size",),
"query_output": ("hidden_size",),
"key_output": ("hidden_size",),
"value_output": ("hidden_size",),
"head_query_output": ("hidden_size/num_attention_heads",),
"head_key_output": ("hidden_size/num_attention_heads",),
"head_value_output": ("hidden_size/num_attention_heads",),
}


"""olmo model with LM head"""
olmo_lm_type_to_module_mapping = {}
for k, v in olmo_type_to_module_mapping.items():
olmo_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


olmo_lm_type_to_dimension_mapping = olmo_type_to_dimension_mapping


"""olmo model with classifier head"""
olmo_classifier_type_to_module_mapping = {}
for k, v in olmo_type_to_module_mapping.items():
olmo_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


olmo_classifier_type_to_dimension_mapping = olmo_type_to_dimension_mapping


def create_olmo(
name="allenai/OLMo-7B-0424-hf", cache_dir=None, dtype=torch.bfloat16, config=None,
revision='main'
):
"""Creates a OLMo Causal LM model, config, and tokenizer from the given name and revision"""
if config is None:
config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
olmo = AutoModelForCausalLM.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=dtype,
)
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
else:
olmo = AutoModelForCausalLM(config, cache_dir=cache_dir, revision=revision)
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
print("loaded model")
return config, tokenizer, olmo
Loading