Skip to content

Commit

Permalink
Merge pull request #93 from stanfordnlp/zen/stringaccess
Browse files Browse the repository at this point in the history
[Major] Update with string access and code refactory (#83)
  • Loading branch information
frankaging authored Jan 26, 2024
2 parents cdeab77 + 4758ebc commit d1f7c58
Show file tree
Hide file tree
Showing 18 changed files with 433 additions and 1,203 deletions.
4 changes: 3 additions & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
tutorials/* linguist-vendored
tutorials/basic_tutorials/* linguist-vendored
tutorials/advanced_tutorials/* linguist-vendored
pyvene_101.ipynb
tests/qa_runbook.ipynb linguist-vendored
Original file line number Diff line number Diff line change
Expand Up @@ -9,64 +9,17 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *


"""gpt2 base model"""
backpack_gpt2_lm_type_to_module_mapping = {
"block_input": ("backpack.gpt2_model.h[%s]", CONST_INPUT_HOOK),
"block_output": ("backpack.gpt2_model.h[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("backpack.gpt2_model.h[%s].mlp.act", CONST_OUTPUT_HOOK),
"mlp_output": ("backpack.gpt2_model.h[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("backpack.gpt2_model.h[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("backpack.gpt2_model.h[%s].attn.c_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("backpack.gpt2_model.h[%s].attn.c_proj", CONST_INPUT_HOOK),
"attention_output": ("backpack.gpt2_model.h[%s].attn", CONST_OUTPUT_HOOK),
"attention_input": ("backpack.gpt2_model.h[%s].attn", CONST_INPUT_HOOK),
"query_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"key_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"value_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_query_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_key_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_value_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"sense_output": ("backpack.sense_network", CONST_OUTPUT_HOOK),
"sense_block_output": ("backpack.sense_network.block", CONST_OUTPUT_HOOK),
"sense_mlp_input": ("backpack.sense_network.final_mlp", CONST_INPUT_HOOK),
"sense_mlp_output": ("backpack.sense_network.final_mlp", CONST_OUTPUT_HOOK),
"sense_mlp_activation": ("backpack.sense_network.final_mlp.act", CONST_OUTPUT_HOOK),
"sense_weight_input": ("backpack.sense_weight_net", CONST_INPUT_HOOK),
"sense_weight_output": ("backpack.sense_weight_net", CONST_OUTPUT_HOOK),
"sense_network_output": ("backpack.sense_network", CONST_OUTPUT_HOOK),
}


backpack_gpt2_lm_type_to_dimension_mapping = {
"block_input": ("n_embd",),
"block_output": ("n_embd",),
"mlp_activation": (
"n_inner",
"n_embd*4",
),
"mlp_output": ("n_embd",),
"mlp_input": ("n_embd",),
"attention_value_output": ("n_embd",),
"head_attention_value_output": ("n_embd/n_head",),
"attention_output": ("n_embd",),
"attention_input": ("n_embd",),
"query_output": ("n_embd",),
"key_output": ("n_embd",),
"value_output": ("n_embd",),
"head_query_output": ("n_embd/n_head",),
"head_key_output": ("n_embd/n_head",),
"head_value_output": ("n_embd/n_head",),
"sense_output": ("n_embd",),
"sense_block_output": ("n_embd",),
"sense_mlp_input": ("n_embd",),
"sense_mlp_output": ("n_embd",),
"num_senses": ("num_senses",),
"sense_mlp_activation": (
"n_inner",
"n_embd*4",
),
"sense_network_output": ("n_embd",),
}

def create_backpack_gpt2(name="stanfordnlp/backpack-gpt2", cache_dir=None):
Expand Down
2 changes: 1 addition & 1 deletion pyvene/models/blip/modelings_intervenable_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *

"""blip base model"""
blip_type_to_module_mapping = {
Expand Down
91 changes: 10 additions & 81 deletions pyvene/models/constants.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,17 @@
CONST_VALID_INTERVENABLE_UNIT = ["pos", "h", "h.pos", "t"]

import torch

CONST_INPUT_HOOK = "register_forward_pre_hook"
CONST_OUTPUT_HOOK = "register_forward_hook"
CONST_GRAD_HOOK = "register_hook"


CONST_TRANSFORMER_TOPOLOGICAL_ORDER = [
"block_input",
"query_output",
"head_query_output",
"key_output",
"head_key_output",
"value_output",
"head_value_output",
"attention_input",
"attention_weight",
"head_attention_value_output",
"attention_value_output",
"attention_output",
"cross_attention_input",
"head_cross_attention_value_output",
"cross_attention_value_output",
"cross_attention_output",
"mlp_input",
"mlp_activation",
"mlp_output",
"block_output",
# special keys for backpack model
"sense_block_output",
"sense_mlp_input",
"sense_mlp_activation",
"sense_mlp_output",
"sense_output",
]


CONST_MLP_TOPOLOGICAL_ORDER = [
"block_input",
"mlp_activation",
"block_output",
]


CONST_GRU_TOPOLOGICAL_ORDER = [
"cell_input",
"x2h_output",
"h2h_output",
"reset_x2h_output",
"update_x2h_output",
"new_x2h_output",
"reset_h2h_output",
"update_h2h_output",
"new_h2h_output",
"reset_gate_input",
"update_gate_input",
"new_gate_input",
"reset_gate_output",
"update_gate_output",
"new_gate_output",
"cell_output",
]


CONST_QKV_INDICES = {
"query_output": 0,
"key_output": 1,
"value_output": 2,
"head_query_output": 0,
"head_key_output": 1,
"head_value_output": 2,
"reset_x2h_output": 0,
"update_x2h_output": 1,
"new_x2h_output": 2,
"reset_h2h_output": 0,
"update_h2h_output": 1,
"new_h2h_output": 2,
}
split_and_select = lambda x, num_slice, selct_index: torch.chunk(x, num_slice, dim=-1)[selct_index]
def split_heads(tensor, num_heads, attn_head_size):
"""Splits hidden_size dim into attn_head_size and num_heads."""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

CONST_RUN_INDICES = {
"reset_x2h_output": 0,
"update_x2h_output": 1,
"new_x2h_output": 2,
"reset_h2h_output": 0,
"update_h2h_output": 1,
"new_h2h_output": 2,
}
split_half = lambda x, selct_index: torch.chunk(x, 2, dim=-1)[selct_index]
split_three = lambda x, selct_index: torch.chunk(x, 3, dim=-1)[selct_index]
split_head_and_permute = lambda x, num_head: split_heads(x, num_head, x.shape[-1]//num_head)
21 changes: 10 additions & 11 deletions pyvene/models/gpt2/modelings_intervenable_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *


"""gpt2 base model"""
Expand All @@ -20,20 +20,21 @@
"mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_weight": ("h[%s].attn.attn_dropout", CONST_INPUT_HOOK),
"attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK),
"attention_input": ("h[%s].attn", CONST_INPUT_HOOK),
"query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"value_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_value_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 0)),
"key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 1)),
"value_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 2)),
"head_query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 0), (split_head_and_permute, "n_head")),
"head_key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 1), (split_head_and_permute, "n_head")),
"head_value_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 2), (split_head_and_permute, "n_head")),
}


gpt2_type_to_dimension_mapping = {
"n_head": ("n_head", ),
"block_input": ("n_embd",),
"block_output": ("n_embd",),
"mlp_activation": (
Expand All @@ -44,7 +45,6 @@
"mlp_input": ("n_embd",),
"attention_value_output": ("n_embd",),
"head_attention_value_output": ("n_embd/n_head",),
# attention weight dimension does not really matter
"attention_weight": ("max_position_embeddings", ),
"attention_output": ("n_embd",),
"attention_input": ("n_embd",),
Expand All @@ -60,8 +60,7 @@
"""gpt2 model with LM head"""
gpt2_lm_type_to_module_mapping = {}
for k, v in gpt2_type_to_module_mapping.items():
gpt2_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", v[1])

gpt2_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:]

gpt2_lm_type_to_dimension_mapping = gpt2_type_to_dimension_mapping

Expand Down
11 changes: 6 additions & 5 deletions pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *


"""gpt_neo base model"""
Expand All @@ -20,19 +20,20 @@
"mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK),
"attention_input": ("h[%s].attn", CONST_INPUT_HOOK),
"query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK),
"head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK),
"head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
}


gpt_neo_type_to_dimension_mapping = {
"n_head": "num_heads",
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": (
Expand Down
5 changes: 3 additions & 2 deletions pyvene/models/gpt_neox/modelings_intervenable_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *


"""gpt_neox base model"""
Expand All @@ -20,7 +20,7 @@
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].attention.dense", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].attention.dense", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].attention.dense", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("layers[%s].attention", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].attention", CONST_INPUT_HOOK),
# 'query_output': ("layers[%s].attention.query_key_value", CONST_OUTPUT_HOOK),
Expand All @@ -33,6 +33,7 @@


gpt_neox_type_to_dimension_mapping = {
"n_head": "num_attention_heads",
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": (
Expand Down
16 changes: 8 additions & 8 deletions pyvene/models/gru/modelings_intervenable_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK
from ..constants import *


"""gru base model"""
Expand All @@ -23,12 +23,12 @@
"new_gate_output": ("cells[%s].new_act", CONST_OUTPUT_HOOK),
"x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK),
"h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK),
"reset_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK),
"update_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK),
"new_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK),
"reset_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK),
"update_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK),
"new_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK),
"reset_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK, (split_three, 0)),
"update_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK, (split_three, 1)),
"new_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK, (split_three, 2)),
"reset_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK, (split_three, 0)),
"update_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK, (split_three, 1)),
"new_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK, (split_three, 2)),
"cell_output": ("cells[%s]", CONST_OUTPUT_HOOK),
}

Expand Down Expand Up @@ -56,7 +56,7 @@
"""mlp model with classification head"""
gru_classifier_type_to_module_mapping = {}
for k, v in gru_type_to_module_mapping.items():
gru_classifier_type_to_module_mapping[k] = (f"gru.{v[0]}", v[1])
gru_classifier_type_to_module_mapping[k] = (f"gru.{v[0]}", ) + v[1:]

gru_classifier_type_to_dimension_mapping = gru_type_to_dimension_mapping

Expand Down
Loading

0 comments on commit d1f7c58

Please sign in to comment.