Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Llava integration to Pyvene #151

Merged
merged 9 commits into from
May 1, 2024
1 change: 1 addition & 0 deletions pyvene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .models.gru.modelings_intervenable_gru import create_gru
from .models.gru.modelings_intervenable_gru import create_gru_lm
from .models.gru.modelings_intervenable_gru import create_gru_classifier
from .models.llava.modelings_intervenable_llava import create_llava
from .models.gru.modelings_gru import GRUConfig
from .models.llama.modelings_intervenable_llama import create_llama
from .models.mlp.modelings_intervenable_mlp import create_mlp_classifier
Expand Down
3 changes: 3 additions & 0 deletions pyvene/models/intervenable_modelcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .blip.modelings_intervenable_blip import *
from .blip.modelings_intervenable_blip_itm import *
from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import *
from .llava.modelings_intervenable_llava import *


#########################################################################
Expand Down Expand Up @@ -41,6 +42,7 @@
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_module_mapping,
hf_models.llava.modeling_llava.LlavaForConditionalGeneration: llava_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_module_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_module_mapping,
Expand Down Expand Up @@ -71,6 +73,7 @@
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_dimension_mapping,
hf_models.llava.modeling_llava.LlavaForConditionalGeneration: llava_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_dimension_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_dimension_mapping,
Expand Down
Empty file added pyvene/models/llava/__init__.py
Empty file.
92 changes: 92 additions & 0 deletions pyvene/models/llava/modelings_intervenable_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Each modeling file in this library is a mapping between
abstract naming of intervention anchor points and actual
model module defined in the huggingface library.

We also want to let the intervention library know how to
config the dimensions of intervention based on model config
defined in the huggingface library.
"""


import torch
from ..constants import *

llava_type_to_module_mapping = {
"block_input": ("language_model.model.layers[%s]", CONST_INPUT_HOOK),
"block_output": ("language_model.model.layers[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("language_model.model.layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
"mlp_output": ("language_model.model.layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("language_model.model.layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("language_model.model.layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("language_model.model.layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"attention_output": ("language_model.model.layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("language_model.model.layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("language_model.model.layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("language_model.model.layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("language_model.model.layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("language_model.model.layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"head_key_output": ("language_model.model.layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"head_value_output": ("language_model.model.layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
}


llava_type_to_dimension_mapping = {
"block_input": ("text_config.hidden_size",),
"block_output": ("text_config.hidden_size",),
"mlp_activation": ("text_config.intermediate_size",),
"mlp_output": ("text_config.hidden_size",),
"mlp_input": ("text_config.hidden_size",),
"attention_value_output": ("text_config.hidden_size",),
"head_attention_value_output": ("text_config.hidden_size/text_config.num_attention_heads",),
"attention_output": ("text_config.hidden_size",),
"attention_input": ("text_config.hidden_size",),
"query_output": ("text_config.hidden_size",),
"key_output": ("text_config.hidden_size",),
"value_output": ("text_config.hidden_size",),
"head_query_output": ("text_config.hidden_size/text_config.num_attention_heads",),
"head_key_output": ("text_config.hidden_size/text_config.num_attention_heads",),
"head_value_output": ("text_config.hidden_size/text_config.num_attention_heads",),
}


"""llava model with LM head"""
llava_lm_type_to_module_mapping = {}
for k, v in llava_type_to_module_mapping.items():
llava_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])


llava_lm_type_to_dimension_mapping = llava_type_to_dimension_mapping


"""llava model with classifier head"""
llava_classifier_type_to_module_mapping = {}
for k, v in llava_type_to_module_mapping.items():
llava_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])


llava_classifier_type_to_dimension_mapping = llava_type_to_dimension_mapping




def create_llava(
name="llava-hf/llava-1.5-7b-hf", cache_dir=None, dtype=torch.bfloat16
):
"""Creates a llava Causal LM model, config, and tokenizer from the given name and revision"""
from transformers import LlavaForConditionalGeneration, LlavaConfig, AutoTokenizer, AutoProcessor

config = LlavaConfig.from_pretrained(name, cache_dir=cache_dir)
tokenizer = AutoTokenizer.from_pretrained(name, use_fast=False)
llava = LlavaForConditionalGeneration.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=dtype,
)

image_processor = AutoProcessor.from_pretrained(name)

print("loaded model")
return config, tokenizer, llava, image_processor

Loading