Skip to content

Commit

Permalink
Merge branch 'main' of github.com:stanfordnlp/pyvene into main
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jun 3, 2024
2 parents 3342628 + d29f959 commit cc865bf
Show file tree
Hide file tree
Showing 18 changed files with 257 additions and 26 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,5 @@ If you would like to read more works on this area, here is a list of papers that

## Star History

[![Star History Chart](https://api.star-history.com/svg?repos=stanfordnlp/pyvene&type=Date)](https://star-history.com/#stanfordnlp/pyvene&Date)
[![Star History Chart](https://api.star-history.com/svg?repos=stanfordnlp/pyvene,stanfordnlp/pyreft&type=Date)](https://star-history.com/#stanfordnlp/pyvene&stanfordnlp/pyreft&Date)

1 change: 1 addition & 0 deletions pyvene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .models.gru.modelings_intervenable_gru import create_gru
from .models.gru.modelings_intervenable_gru import create_gru_lm
from .models.gru.modelings_intervenable_gru import create_gru_classifier
from .models.llava.modelings_intervenable_llava import create_llava
from .models.gru.modelings_gru import GRUConfig
from .models.llama.modelings_intervenable_llama import create_llama
from .models.mlp.modelings_intervenable_mlp import create_mlp_classifier
Expand Down
28 changes: 28 additions & 0 deletions pyvene/analyses/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import seaborn
import torch

def rotation_token_heatmap(rotate_layer,
tokens,
token_size,
variables,
intervention_size):

W = rotate_layer.weight.data
in_dim, out_dim = W.shape

assert in_dim % token_size == 0
assert in_dim / token_size >= len(tokens)

assert out_dim % intervention_size == 0
assert out_dim / intervention_size >= len(variables)

heatmap = []
for j in range(len(variables)):
row = []
for i in range(len(tokens)):
row.append(torch.norm(W[i*token_size:(i+1)*token_size, j*intervention_size:(j+1)*intervention_size]))
mean = sum(row)
heatmap.append([x/mean for x in row])
return seaborn.heatmap(heatmap,
xticklabels=tokens,
yticklabels=variables)
21 changes: 10 additions & 11 deletions pyvene/data_generators/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def __init__(
assert variable in self.values
assert variable in self.children
assert variable in self.functions
assert len(inspect.getfullargspec(self.functions[variable])[0]) == len(
self.parents[variable]
)
if timesteps is not None:
assert variable in timesteps
for variable2 in copy.copy(self.variables):
Expand Down Expand Up @@ -79,6 +76,8 @@ def __init__(
self.equiv_classes = equiv_classes
else:
self.equiv_classes = {}

def generate_equiv_classes(self):
for var in self.variables:
if var in self.inputs or var in self.equiv_classes:
continue
Expand Down Expand Up @@ -113,7 +112,7 @@ def generate_timesteps(self):
def marginalize(self, target):
pass

def print_structure(self, pos=None):
def print_structure(self, pos=None, font=12, node_size=1000):
G = nx.DiGraph()
G.add_edges_from(
[
Expand All @@ -123,7 +122,7 @@ def print_structure(self, pos=None):
]
)
plt.figure(figsize=(10, 10))
nx.draw_networkx(G, with_labels=True, node_color="green", pos=self.pos)
nx.draw_networkx(G, with_labels=True, node_color="green", pos=self.pos, font_size=font, node_size=node_size)
plt.show()

def find_live_paths(self, intervention):
Expand All @@ -149,12 +148,9 @@ def find_live_paths(self, intervention):
del paths[1]
return paths

def print_setting(self, total_setting, display=None):
labeler = lambda var: var + ": " + str(total_setting[var]) \
if display is None or display[var] \
else var
def print_setting(self, total_setting, font=12, node_size=1000):
relabeler = {
var: labeler(var) for var in self.variables
var: var + ": " + str(total_setting[var]) for var in self.variables
}
G = nx.DiGraph()
G.add_edges_from(
Expand All @@ -170,7 +166,7 @@ def print_setting(self, total_setting, display=None):
if self.pos is not None:
for var in self.pos:
newpos[relabeler[var]] = self.pos[var]
nx.draw_networkx(G, with_labels=True, node_color="green", pos=newpos)
nx.draw_networkx(G, with_labels=True, node_color="green", pos=newpos, font_size=font, node_size=node_size)
plt.show()

def run_forward(self, intervention=None):
Expand Down Expand Up @@ -233,11 +229,14 @@ def sample_input(self, mandatory=None):

def sample_input_tree_balanced(self, output_var=None, output_var_value=None):
assert output_var is not None or len(self.outputs) == 1
self.generate_equiv_classes()

if output_var is None:
output_var = self.outputs[0]
if output_var_value is None:
output_var_value = random.choice(self.values[output_var])


def create_input(var, value, input={}):
parent_values = random.choice(self.equiv_classes[var][value])
for parent in parent_values:
Expand Down
Empty file added pyvene/models/gemma/__init__.py
Empty file.
87 changes: 87 additions & 0 deletions pyvene/models/gemma/modelings_intervenable_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Each modeling file in this library is a mapping between
abstract naming of intervention anchor points and actual
model module defined in the huggingface library.
We also want to let the intervention library know how to
config the dimensions of intervention based on model config
defined in the huggingface library.
"""


import torch
from ..constants import *


gemma_type_to_module_mapping = {
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
}


gemma_type_to_dimension_mapping = {
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": ("intermediate_size",),
"mlp_output": ("hidden_size",),
"mlp_input": ("hidden_size",),
"attention_value_output": ("hidden_size",),
"head_attention_value_output": ("hidden_size/num_attention_heads",),
"attention_output": ("hidden_size",),
"attention_input": ("hidden_size",),
"query_output": ("hidden_size",),
"key_output": ("hidden_size",),
"value_output": ("hidden_size",),
"head_query_output": ("hidden_size/num_attention_heads",),
"head_key_output": ("hidden_size/num_attention_heads",),
"head_value_output": ("hidden_size/num_attention_heads",),
}


"""gemma model with LM head"""
gemma_lm_type_to_module_mapping = {}
for k, v in gemma_type_to_module_mapping.items():
gemma_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])


gemma_lm_type_to_dimension_mapping = gemma_type_to_dimension_mapping


"""gemma model with classifier head"""
gemma_classifier_type_to_module_mapping = {}
for k, v in gemma_type_to_module_mapping.items():
gemma_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])


gemma_classifier_type_to_dimension_mapping = gemma_type_to_dimension_mapping


def create_gemma(
name="google/gemma-2b-it", cache_dir=None, dtype=torch.bfloat16
):
"""Creates a Gemma Causal LM model, config, and tokenizer from the given name and revision"""
from transformers import GemmaForCausalLM, GemmaTokenizer, GemmaConfig

config = GemmaConfig.from_pretrained(name, cache_dir=cache_dir)
tokenizer = GemmaTokenizer.from_pretrained(name, cache_dir=cache_dir)
gemma = GemmaForCausalLM.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=dtype, # save memory
)
print("loaded model")
return config, tokenizer, gemma
15 changes: 10 additions & 5 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, config, model, **kwargs):
self.mode = config.mode
intervention_type = config.intervention_types
self.is_model_stateless = is_stateless(model)
self.config.model_type = type(model) # backfill
self.config.model_type = str(type(model)) # backfill
self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False

self.model_has_grad = False
Expand Down Expand Up @@ -1320,6 +1320,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_original_output: Optional[bool] = False,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = True,
):
"""
Main forward function that serves a wrapper to
Expand Down Expand Up @@ -1439,10 +1440,14 @@ def forward(
)

# run intervened forward
if labels is not None:
counterfactual_outputs = self.model(**base, labels=labels)
else:
counterfactual_outputs = self.model(**base)
model_kwargs = {}
if labels is not None: # for training
model_kwargs["labels"] = labels
if 'use_cache' in self.model.config.to_dict(): # for transformer models
model_kwargs["use_cache"] = use_cache

counterfactual_outputs = self.model(**base, **model_kwargs)

set_handlers_to_remove.remove()

self._output_validation()
Expand Down
12 changes: 12 additions & 0 deletions pyvene/models/intervenable_modelcard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .constants import *
from .llama.modelings_intervenable_llama import *
from .mistral.modellings_intervenable_mistral import *
from .gemma.modelings_intervenable_gemma import *
from .gpt2.modelings_intervenable_gpt2 import *
from .gpt_neo.modelings_intervenable_gpt_neo import *
from .gpt_neox.modelings_intervenable_gpt_neox import *
Expand All @@ -9,6 +10,7 @@
from .blip.modelings_intervenable_blip import *
from .blip.modelings_intervenable_blip_itm import *
from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import *
from .llava.modelings_intervenable_llava import *


#########################################################################
Expand Down Expand Up @@ -39,12 +41,17 @@
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_module_mapping,
hf_models.llava.modeling_llava.LlavaForConditionalGeneration: llava_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_module_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_module_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_module_mapping,
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_module_mapping,
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_module_mapping,
BlipWrapper: blip_wrapper_type_to_module_mapping,
Expand All @@ -65,12 +72,17 @@
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_dimension_mapping,
hf_models.llava.modeling_llava.LlavaForConditionalGeneration: llava_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_dimension_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_dimension_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_dimension_mapping,
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_dimension_mapping,
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_dimension_mapping,
BlipWrapper: blip_wrapper_type_to_dimension_mapping,
Expand Down
9 changes: 9 additions & 0 deletions pyvene/models/llama/modelings_intervenable_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@
llama_lm_type_to_dimension_mapping = llama_type_to_dimension_mapping


"""llama model with classifier head"""
llama_classifier_type_to_module_mapping = {}
for k, v in llama_type_to_module_mapping.items():
llama_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])


llama_classifier_type_to_dimension_mapping = llama_type_to_dimension_mapping


def create_llama(
name="sharpbai/alpaca-7b-merged", cache_dir=None, dtype=torch.bfloat16
):
Expand Down
Empty file added pyvene/models/llava/__init__.py
Empty file.
Loading

0 comments on commit cc865bf

Please sign in to comment.