Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P2] Add Gemma 2 model #190

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
89 changes: 89 additions & 0 deletions pyvene/models/gemma2/modelings_intervenable_gemma2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Each modeling file in this library is a mapping between
abstract naming of intervention anchor points and actual
model module defined in the huggingface library.

We also want to let the intervention library know how to
config the dimensions of intervention based on model config
defined in the huggingface library.
"""


import torch
from ..constants import *


gemma2_type_to_module_mapping = {
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
}


gemma2_type_to_dimension_mapping = {
"n_head": ("num_attention_heads",),
"n_kv_head": ("num_key_value_heads",),
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": ("intermediate_size",),
"mlp_output": ("hidden_size",),
"mlp_input": ("hidden_size",),
"attention_value_output": ("hidden_size",),
"head_attention_value_output": ("head_dim",),
"attention_output": ("hidden_size",),
"attention_input": ("hidden_size",),
"query_output": ("hidden_size",),
"key_output": ("hidden_size",),
"value_output": ("hidden_size",),
"head_query_output": ("head_dim",),
"head_key_output": ("head_dim",),
"head_value_output": ("hhead_dim",),
}


"""gemma2 model with LM head"""
gemma2_lm_type_to_module_mapping = {}
for k, v in gemma2_type_to_module_mapping.items():
gemma2_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


gemma2_lm_type_to_dimension_mapping = gemma2_type_to_dimension_mapping


"""gemma2 model with classifier head"""
gemma2_classifier_type_to_module_mapping = {}
for k, v in gemma2_type_to_module_mapping.items():
gemma2_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]


gemma2_classifier_type_to_dimension_mapping = gemma2_type_to_dimension_mapping


def create_gemma2(
name="google/gemma2-2b", cache_dir=None, dtype=torch.bfloat16
):
"""Creates a Causal LM model, config, and tokenizer from the given name and revision"""
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
gemma = AutoModelForCausalLM.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=dtype,
)
print("loaded model")
return config, tokenizer, gemma
5 changes: 5 additions & 0 deletions pyvene/models/intervenable_modelcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .llama.modelings_intervenable_llama import *
from .mistral.modellings_intervenable_mistral import *
from .gemma.modelings_intervenable_gemma import *
from .gemma2.modelings_intervenable_gemma2 import *
from .gpt2.modelings_intervenable_gpt2 import *
from .gpt_neo.modelings_intervenable_gpt_neo import *
from .gpt_neox.modelings_intervenable_gpt_neox import *
Expand Down Expand Up @@ -58,6 +59,8 @@
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_module_mapping,
hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_module_mapping,
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_module_mapping,
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_module_mapping,
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
Expand Down Expand Up @@ -91,6 +94,8 @@
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_dimension_mapping,
hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_dimension_mapping,
hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_dimension_mapping,
hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_dimension_mapping,
hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,
Expand Down
Loading