Skip to content

Commit

Permalink
Merge pull request #199 from atticusg/atticusDev
Browse files Browse the repository at this point in the history
Add Qwen model
  • Loading branch information
aryamanarora authored Jan 28, 2025
2 parents 0d3f1f8 + 129fbc8 commit 7a7a96a
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pyvene/models/intervenable_modelcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .blip.modelings_intervenable_blip_itm import *
from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import *
from .llava.modelings_intervenable_llava import *
from .qwen2.modelings_intervenable_qwen2 import *
from .olmo.modelings_intervenable_olmo import *

#########################################################################
Expand Down Expand Up @@ -71,7 +72,9 @@
GRULMHeadModel: gru_lm_type_to_module_mapping,
GRUForClassification: gru_classifier_type_to_module_mapping,
BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_module_mapping,
# new model type goes here after defining the model files
hf_models.qwen2.modeling_qwen2.Qwen2Model: qwen2_type_to_module_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2ForCausalLM: qwen2_lm_type_to_module_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification: qwen2_classifier_type_to_module_mapping,
}
if enable_blip:
type_to_module_mapping[BlipWrapper] = blip_wrapper_type_to_module_mapping
Expand Down Expand Up @@ -106,8 +109,11 @@
GRULMHeadModel: gru_lm_type_to_dimension_mapping,
GRUForClassification: gru_classifier_type_to_dimension_mapping,
BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_dimension_mapping,
# new model type goes here after defining the model files
hf_models.qwen2.modeling_qwen2.Qwen2Model: qwen2_type_to_dimension_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2ForCausalLM: qwen2_lm_type_to_dimension_mapping,
hf_models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification: qwen2_classifier_type_to_dimension_mapping,
}

if enable_blip:
type_to_dimension_mapping[BlipWrapper] = blip_wrapper_type_to_dimension_mapping
type_to_dimension_mapping[BlipITMWrapper] = blip_itm_wrapper_type_to_dimension_mapping
Expand Down
77 changes: 77 additions & 0 deletions pyvene/models/qwen2/modelings_intervenable_qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Each modeling file in this library is a mapping between
abstract naming of intervention anchor points and actual
model module defined in the huggingface library.
We also want to let the intervention library know how to
config the dimensions of intervention based on model config
defined in the huggingface library.
"""
import torch
from ..constants import *

qwen2_type_to_module_mapping = {
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")),
}

qwen2_type_to_dimension_mapping = {
"n_head": ("num_attention_heads",),
"n_kv_head": ("num_key_value_heads",),
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": ("intermediate_size",),
"mlp_output": ("hidden_size",),
"mlp_input": ("hidden_size",),
"attention_value_output": ("hidden_size",),
"head_attention_value_output": ("head_dim",),
"attention_output": ("hidden_size",),
"attention_input": ("hidden_size",),
"query_output": ("hidden_size",),
"key_output": ("hidden_size",),
"value_output": ("hidden_size",),
"head_query_output": ("head_dim",),
"head_key_output": ("head_dim",),
"head_value_output": ("head_dim",),
}

"""qwen2 model with LM head"""
qwen2_lm_type_to_module_mapping = {}
for k, v in qwen2_type_to_module_mapping.items():
qwen2_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]
qwen2_lm_type_to_dimension_mapping = qwen2_type_to_dimension_mapping

"""qwen2 model with classifier head"""
qwen2_classifier_type_to_module_mapping = {}
for k, v in qwen2_type_to_module_mapping.items():
qwen2_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:]
qwen2_classifier_type_to_dimension_mapping = qwen2_type_to_dimension_mapping

def create_qwen2(
name="Qwen/Qwen2-7B-beta", cache_dir=None, dtype=torch.bfloat16
):
"""Creates a Causal LM model, config, and tokenizer from the given name and revision"""
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=dtype,
)
print("loaded model")
return config, tokenizer, model

0 comments on commit 7a7a96a

Please sign in to comment.