diff --git a/pyvene/models/gemma2/__init__.py b/pyvene/models/gemma2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyvene/models/gemma2/modelings_intervenable_gemma2.py b/pyvene/models/gemma2/modelings_intervenable_gemma2.py new file mode 100644 index 00000000..ed704d83 --- /dev/null +++ b/pyvene/models/gemma2/modelings_intervenable_gemma2.py @@ -0,0 +1,89 @@ +""" +Each modeling file in this library is a mapping between +abstract naming of intervention anchor points and actual +model module defined in the huggingface library. + +We also want to let the intervention library know how to +config the dimensions of intervention based on model config +defined in the huggingface library. +""" + + +import torch +from ..constants import * + + +gemma2_type_to_module_mapping = { + "block_input": ("layers[%s]", CONST_INPUT_HOOK), + "block_output": ("layers[%s]", CONST_OUTPUT_HOOK), + "mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK), + "mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK), + "mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK), + "attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK), + "head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")), + "attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK), + "attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK), + "query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK), + "key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK), + "value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK), + "head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), + "head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")), + "head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")), +} + + +gemma2_type_to_dimension_mapping = { + "n_head": ("num_attention_heads",), + "n_kv_head": ("num_key_value_heads",), + "block_input": ("hidden_size",), + "block_output": ("hidden_size",), + "mlp_activation": ("intermediate_size",), + "mlp_output": ("hidden_size",), + "mlp_input": ("hidden_size",), + "attention_value_output": ("hidden_size",), + "head_attention_value_output": ("head_dim",), + "attention_output": ("hidden_size",), + "attention_input": ("hidden_size",), + "query_output": ("hidden_size",), + "key_output": ("hidden_size",), + "value_output": ("hidden_size",), + "head_query_output": ("head_dim",), + "head_key_output": ("head_dim",), + "head_value_output": ("hhead_dim",), +} + + +"""gemma2 model with LM head""" +gemma2_lm_type_to_module_mapping = {} +for k, v in gemma2_type_to_module_mapping.items(): + gemma2_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:] + + +gemma2_lm_type_to_dimension_mapping = gemma2_type_to_dimension_mapping + + +"""gemma2 model with classifier head""" +gemma2_classifier_type_to_module_mapping = {} +for k, v in gemma2_type_to_module_mapping.items(): + gemma2_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:] + + +gemma2_classifier_type_to_dimension_mapping = gemma2_type_to_dimension_mapping + + +def create_gemma2( + name="google/gemma2-2b", cache_dir=None, dtype=torch.bfloat16 +): + """Creates a Causal LM model, config, and tokenizer from the given name and revision""" + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + config = AutoConfig.from_pretrained(name, cache_dir=cache_dir) + tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir) + gemma = AutoModelForCausalLM.from_pretrained( + name, + config=config, + cache_dir=cache_dir, + torch_dtype=dtype, + ) + print("loaded model") + return config, tokenizer, gemma diff --git a/pyvene/models/intervenable_modelcard.py b/pyvene/models/intervenable_modelcard.py index a0dba200..1f7329f4 100644 --- a/pyvene/models/intervenable_modelcard.py +++ b/pyvene/models/intervenable_modelcard.py @@ -2,6 +2,7 @@ from .llama.modelings_intervenable_llama import * from .mistral.modellings_intervenable_mistral import * from .gemma.modelings_intervenable_gemma import * +from .gemma2.modelings_intervenable_gemma2 import * from .gpt2.modelings_intervenable_gpt2 import * from .gpt_neo.modelings_intervenable_gpt_neo import * from .gpt_neox.modelings_intervenable_gpt_neox import * @@ -58,6 +59,8 @@ hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_module_mapping, hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_module_mapping, hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_module_mapping, + hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_module_mapping, + hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_module_mapping, hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_module_mapping, hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_module_mapping, hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping, @@ -91,6 +94,8 @@ hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_dimension_mapping, hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_dimension_mapping, hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_dimension_mapping, + hf_models.gemma2.modeling_gemma2.Gemma2Model: gemma2_type_to_dimension_mapping, + hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_dimension_mapping, hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_dimension_mapping, hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_dimension_mapping, hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,