Skip to content

Commit

Permalink
Support MistralModel and MistralForCausalLM
Browse files Browse the repository at this point in the history
  • Loading branch information
jiudingsun01 committed Mar 25, 2024
1 parent b57b660 commit bf09440
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 8 deletions.
13 changes: 5 additions & 8 deletions pyvene/models/intervenable_modelcard.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .constants import *
from .llama.modelings_intervenable_llama import *
from .mistral.modellings_intervenable_mistral import *
from .gpt2.modelings_intervenable_gpt2 import *
from .gpt_neo.modelings_intervenable_gpt_neo import *
from .gpt_neox.modelings_intervenable_gpt_neox import *
from .mlp.modelings_intervenable_mlp import *
from .gru.modelings_intervenable_gru import *
from .blip.modelings_intervenable_blip import *
from .blip.modelings_intervenable_blip_itm import *
from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import *


Expand All @@ -21,7 +21,6 @@

import transformers.models as hf_models
from .blip.modelings_blip import BlipWrapper
from .blip.modelings_blip_itm import BlipITMWrapper
from .mlp.modelings_mlp import MLPModel, MLPForClassification
from .gru.modelings_gru import GRUModel, GRULMHeadModel, GRUForClassification
from .backpack_gpt2.modelings_backpack_gpt2 import BackpackGPT2LMHeadModel
Expand All @@ -35,17 +34,16 @@
type_to_module_mapping = {
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_module_mapping,
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_module_mapping,
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_module_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_module_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_module_mapping,
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_module_mapping,
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_module_mapping,
BlipWrapper: blip_wrapper_type_to_module_mapping,
BlipITMWrapper: blip_itm_wrapper_type_to_module_mapping,
MLPModel: mlp_type_to_module_mapping,
MLPForClassification: mlp_classifier_type_to_module_mapping,
GRUModel: gru_type_to_module_mapping,
Expand All @@ -59,17 +57,16 @@
type_to_dimension_mapping = {
hf_models.gpt2.modeling_gpt2.GPT2Model: gpt2_type_to_dimension_mapping,
hf_models.gpt2.modeling_gpt2.GPT2LMHeadModel: gpt2_lm_type_to_dimension_mapping,
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_dimension_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_dimension_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_dimension_mapping,
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_dimension_mapping,
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_dimension_mapping,
BlipWrapper: blip_wrapper_type_to_dimension_mapping,
BlipITMWrapper: blip_itm_wrapper_type_to_dimension_mapping,
MLPModel: mlp_type_to_dimension_mapping,
MLPForClassification: mlp_classifier_type_to_dimension_mapping,
GRUModel: gru_type_to_dimension_mapping,
Expand Down
Empty file.
78 changes: 78 additions & 0 deletions pyvene/models/mistral/modellings_intervenable_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Each modeling file in this library is a mapping between
abstract naming of intervention anchor points and actual
model module defined in the huggingface library.
We also want to let the intervention library know how to
config the dimensions of intervention based on model config
defined in the huggingface library.
"""


import torch
from ..constants import *


mistral_type_to_module_mapping = {
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
}


mistral_type_to_dimension_mapping = {
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": ("intermediate_size",),
"mlp_output": ("hidden_size",),
"mlp_input": ("hidden_size",),
"attention_value_output": ("hidden_size",),
"head_attention_value_output": ("hidden_size/num_attention_heads",),
"attention_output": ("hidden_size",),
"attention_input": ("hidden_size",),
"query_output": ("hidden_size",),
"key_output": ("hidden_size",),
"value_output": ("hidden_size",),
"head_query_output": ("hidden_size/num_attention_heads",),
"head_key_output": ("hidden_size/num_attention_heads",),
"head_value_output": ("hidden_size/num_attention_heads",),
}


"""llama model with LM head"""
mistral_lm_type_to_module_mapping = {}
for k, v in mistral_type_to_module_mapping.items():
mistral_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])


mistral_lm_type_to_dimension_mapping = mistral_type_to_dimension_mapping


def create_mistral(
name="mistralai/Mistral-7B-v0.1", cache_dir=None
):
"""Creates a Mistral Causal LM model, config, and tokenizer from the given name and revision"""
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
llama = AutoModelForCausalLM.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=torch.bfloat16, # save memory
)
print("loaded model")
return config, tokenizer, llama

3 comments on commit bf09440

@maraPislar
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add back the support for GPT2ForSequenceClassification?

@frankaging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add back the support for GPT2ForSequenceClassification?

Thanks for raising this. Ticket created: #138; will resolve in 30 mins.

@frankaging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add back the support for GPT2ForSequenceClassification?

resolved! please pull.

Please sign in to comment.