Skip to content

Commit

Permalink
Merge pull request #129 from stanfordnlp/zen/sharedwxpos
Browse files Browse the repository at this point in the history
[Minor] Allow sharing interventions across multiple positions
  • Loading branch information
frankaging authored Mar 14, 2024
2 parents 96db4e9 + 271fa09 commit 7504fb6
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 49 deletions.
3 changes: 2 additions & 1 deletion pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(self, **kwargs):
super().__init__()
self.trainable = False
self.is_source_constant = False


self.keep_last_dim = kwargs["keep_last_dim"] if "keep_last_dim" in kwargs else False
self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False
self.subspace_partition = (
kwargs["subspace_partition"] if "subspace_partition" in kwargs else None
Expand Down
8 changes: 4 additions & 4 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ def do_intervention(
original_base_shape = base_representation.shape
if len(original_base_shape) == 2 or (
isinstance(intervention, LocalistRepresentationIntervention)
):
# no pos dimension, e.g., gru
) or intervention.keep_last_dim:
# no pos dimension, e.g., gru, or opt-out concate last two dims
base_representation_f = base_representation
source_representation_f = source_representation
elif len(original_base_shape) == 3:
Expand All @@ -459,8 +459,8 @@ def do_intervention(
# unflatten
if len(original_base_shape) == 2 or isinstance(
intervention, LocalistRepresentationIntervention
):
# no pos dimension, e.g., gru
) or intervention.keep_last_dim:
# no pos dimension, e.g., gru or opt-out concate last two dims
pass
elif len(original_base_shape) == 3:
intervened_representation = b_sd_to_bsd(intervened_representation, num_unit)
Expand Down
171 changes: 127 additions & 44 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,34 +126,36 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "17c7f2f6-b0d3-4fe2-8e4f-c044b93f3ef0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded model\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"import pyvene as pv\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"_, tokenizer, gpt2 = pv.create_gpt2()\n",
"model_name = \"gpt2\"\n",
"model = AutoModelForCausalLM.from_pretrained(model_name)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"pv_gpt2 = pv.IntervenableModel({\n",
" \"layer\": 10,\n",
" \"component\": \"attention_weight\",\n",
" \"intervention_type\": pv.CollectIntervention}, model=gpt2)\n",
"# create a dict-based intervention config\n",
"pv_config = pv.IntervenableConfig({\n",
" \"component\": \"transformer.h[0].mlp.output\"},\n",
" intervention_types=pv.VanillaIntervention\n",
")\n",
"# wrap your model with the config\n",
"pv_gpt2 = pv.IntervenableModel(pv_config, model=model)\n",
"\n",
"base = \"When John and Mary went to the shops, Mary gave the bag to\"\n",
"collected_attn_w = pv_gpt2(\n",
" base = tokenizer(base, return_tensors=\"pt\"\n",
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
")[0][-1][0]"
"# run an interchange intervention (activation swap between two examples)\n",
"intervened_outputs = pv_gpt2(\n",
" # the base input\n",
" base=tokenizer(\"The capital of Spain is\", return_tensors = \"pt\"), \n",
" # the source input\n",
" sources=tokenizer(\"The capital of Italy is\", return_tensors = \"pt\"), \n",
" # the location to intervene at (3rd token)\n",
" unit_locations={\"sources->base\": 3},\n",
" output_original_output=True # False then the first element in the tuple is None\n",
")"
]
},
{
Expand All @@ -166,46 +168,49 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 12,
"id": "1ef4a1db-5187-4457-9878-f1dc03e9859b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GPT2Model(\n",
" (wte): Embedding(50257, 768)\n",
" (wpe): Embedding(1024, 768)\n",
" (drop): Dropout(p=0.1, inplace=False)\n",
" (h): ModuleList(\n",
" (0-11): 12 x GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
"GPT2LMHeadModel(\n",
" (transformer): GPT2Model(\n",
" (wte): Embedding(50257, 768)\n",
" (wpe): Embedding(1024, 768)\n",
" (drop): Dropout(p=0.1, inplace=False)\n",
" (h): ModuleList(\n",
" (0-11): 12 x GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
")"
]
},
"execution_count": 15,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gpt2"
"model"
]
},
{
Expand Down Expand Up @@ -1978,6 +1983,84 @@
"print(torch.equal(pv_out3.last_hidden_state, pv_out4.last_hidden_state))"
]
},
{
"cell_type": "markdown",
"id": "243f146f-1b9a-4574-ba2c-ebf455a96c16",
"metadata": {},
"source": [
"Other than intervention linking, you can also share interventions at the same component across multiple positions via setting a flag in the intervention object. It will have the same effect as creating one intervention per location and linking them all together."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "7c647943-c7e1-4024-8c07-b51062e668ba",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded model\n"
]
},
{
"data": {
"text/plain": [
"tensor([[[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import pyvene as pv\n",
"\n",
"_, tokenizer, gpt2 = pv.create_gpt2()\n",
"\n",
"config = pv.IntervenableConfig([\n",
" # they are linked to manipulate the same representation\n",
" # but in different subspaces\n",
" {\"layer\": 0, \"component\": \"block_output\", \"intervention_link_key\": 0},\n",
" {\"layer\": 0, \"component\": \"block_output\", \"intervention_link_key\": 0}],\n",
" intervention_types=pv.VanillaIntervention,\n",
")\n",
"pv_gpt2 = pv.IntervenableModel(config, model=gpt2)\n",
"\n",
"base = tokenizer(\"The capital of Spain is\", return_tensors=\"pt\")\n",
"source = tokenizer(\"The capital of Italy is\", return_tensors=\"pt\")\n",
"\n",
"_, pv_out = pv_gpt2(\n",
" base,\n",
" [source, source],\n",
" # swap 3rd and 4th token reprs from the same source to the base\n",
" {\"sources->base\": ([[[4]], [[3]]], [[[4]], [[3]]])},\n",
")\n",
"\n",
"keep_last_dim_config = pv.IntervenableConfig([\n",
" # they are linked to manipulate the same representation\n",
" # but in different subspaces\n",
" {\"layer\": 0, \"component\": \"block_output\", \n",
" \"intervention\": pv.VanillaIntervention(keep_last_dim=True)}]\n",
")\n",
"keep_last_dim_pv_gpt2 = pv.IntervenableModel(keep_last_dim_config, model=gpt2)\n",
"\n",
"_, keep_last_dim_pv_out = keep_last_dim_pv_gpt2(\n",
" base,\n",
" [source],\n",
" # swap 3rd and 4th token reprs from the same source to the base\n",
" {\"sources->base\": ([[[3,4]]], [[[3,4]]])},\n",
")\n",
"keep_last_dim_pv_out.last_hidden_state - pv_out.last_hidden_state"
]
},
{
"cell_type": "markdown",
"id": "ef5b7a3e",
Expand Down

0 comments on commit 7504fb6

Please sign in to comment.