From bd33ec9259669b4755b0d5c3d1bdb1294d22ae67 Mon Sep 17 00:00:00 2001 From: PinetreePantry Date: Wed, 29 Jan 2025 14:00:43 -0800 Subject: [PATCH 1/5] [P0] Fix test failures due to transformers version change --- pyvene/models/intervenable_base.py | 11 ++++++++++- tests/utils.py | 24 ++++++++++++++++++++---- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index db84b57d..61e51e96 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1292,7 +1292,16 @@ def load( binary_filename = f"intkey_{k}.bin" intervention.is_source_constant = \ saving_config.intervention_constant_sources[i] - intervention.set_interchange_dim(saving_config.intervention_dimensions[i]) + dim = saving_config.intervention_dimensions[i] + if dim is None: + # Infer interchange dimension from component name to be compatible with old versions + component_name = saving_config.representations[i].component + if component_name.startswith("head_"): + dim = model.config.hidden_size // model.config.num_attention_heads + else: + dim = model.config.hidden_size + + intervention.set_interchange_dim(dim) if saving_config.intervention_constant_sources[i] and \ not isinstance(intervention, ZeroIntervention) and \ not isinstance(intervention, SourcelessIntervention): diff --git a/tests/utils.py b/tests/utils.py index 4d6a9076..e3a9bfb5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -79,6 +79,22 @@ def is_package_installed(package_name): forward calls to fetch activations or run with cached activations """ +def split_heads(tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + +def merge_heads(tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + def DO_INTERVENTION(name, orig_hidden_states, INTERVENTION_ACTIVATIONS): if name in INTERVENTION_ACTIVATIONS: @@ -100,9 +116,9 @@ def GPT2_SELF_ATTENTION_RUN( value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS) CACHE_ACTIVATIONS[f"{i}.value_output"] = value - head_query = self_attn._split_heads(query, self_attn.num_heads, self_attn.head_dim) - head_key = self_attn._split_heads(key, self_attn.num_heads, self_attn.head_dim) - head_value = self_attn._split_heads(value, self_attn.num_heads, self_attn.head_dim) + head_query = split_heads(query, self_attn.num_heads, self_attn.head_dim) + head_key = split_heads(key, self_attn.num_heads, self_attn.head_dim) + head_value = split_heads(value, self_attn.num_heads, self_attn.head_dim) head_query = DO_INTERVENTION( f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS @@ -128,7 +144,7 @@ def GPT2_SELF_ATTENTION_RUN( ) CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output - attn_value_output = self_attn._merge_heads( + attn_value_output = merge_heads( head_attention_value_output, self_attn.num_heads, self_attn.head_dim ) attn_value_output = DO_INTERVENTION( From 3f307939bb441ce52ece10c08d3607b0cb08d6ef Mon Sep 17 00:00:00 2001 From: PinetreePantry Date: Wed, 29 Jan 2025 14:24:37 -0800 Subject: [PATCH 2/5] Additional test fixes --- tests/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index e3a9bfb5..62927c85 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,6 +7,7 @@ import pandas as pd import numpy as np from transformers import GPT2Config, LlamaConfig +from transformers.models.gpt2.modeling_gpt2 import eager_attention_forward import math from torch import nn @@ -133,10 +134,14 @@ def GPT2_SELF_ATTENTION_RUN( ) CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value - head_attention_value_output, attn_weights = self_attn._attn( - head_query, head_key, head_value + head_attention_value_output, attn_weights = eager_attention_forward( + module=self_attn, + query=head_query, + key=head_key, + value=head_value, ) + head_attention_value_output = DO_INTERVENTION( f"{i}.head_attention_value_output", head_attention_value_output, From 529446283247d6f5911ffec804adac97f0b551ea Mon Sep 17 00:00:00 2001 From: PinetreePantry Date: Wed, 29 Jan 2025 14:29:19 -0800 Subject: [PATCH 3/5] Additional test fixes --- tests/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 62927c85..94dce25e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -134,11 +134,12 @@ def GPT2_SELF_ATTENTION_RUN( ) CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value - head_attention_value_output, attn_weights = eager_attention_forward( + head_attention_value_output = eager_attention_forward( module=self_attn, query=head_query, key=head_key, value=head_value, + attention_mask=None, ) From 3c1aabe067e580d7e79a95e3a0eadadb241fa992 Mon Sep 17 00:00:00 2001 From: PinetreePantry Date: Wed, 29 Jan 2025 15:58:27 -0800 Subject: [PATCH 4/5] Fix test cases --- tests/utils.py | 52 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 94dce25e..69f27cca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -134,7 +134,7 @@ def GPT2_SELF_ATTENTION_RUN( ) CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value - head_attention_value_output = eager_attention_forward( + head_attention_value_output, _ = eager_attention_forward( module=self_attn, query=head_query, key=head_key, @@ -142,6 +142,7 @@ def GPT2_SELF_ATTENTION_RUN( attention_mask=None, ) + head_attention_value_output = head_attention_value_output.permute(0, 2, 1, 3) head_attention_value_output = DO_INTERVENTION( f"{i}.head_attention_value_output", @@ -309,7 +310,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed def Llama_SELF_ATTENTION_RUN( - self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS + self_attn, + hidden_states, + i, + CACHE_ACTIVATIONS, + INTERVENTION_ACTIVATIONS, + num_heads, + num_key_value_heads, + rotary_emb ): bsz, q_len, _ = hidden_states.size() @@ -324,9 +332,9 @@ def Llama_SELF_ATTENTION_RUN( value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS) CACHE_ACTIVATIONS[f"{i}.value_output"] = value - head_query = query.view(bsz, q_len, self_attn.num_heads, self_attn.head_dim).transpose(1, 2) - head_key = key.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2) - head_value = value.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2) + head_query = query.view(bsz, q_len, num_heads, self_attn.head_dim).transpose(1, 2) + head_key = key.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2) + head_value = value.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2) head_query = DO_INTERVENTION( f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS @@ -342,7 +350,7 @@ def Llama_SELF_ATTENTION_RUN( CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value position_ids = torch.arange(q_len, device=hidden_states.device).repeat(bsz, 1) - cos, sin = self_attn.rotary_emb(head_value, position_ids) + cos, sin = rotary_emb(head_value, position_ids) head_query, head_key = apply_rotary_pos_emb(head_query, head_key, cos, sin) head_key = repeat_kv(head_key, self_attn.num_key_value_groups) @@ -362,7 +370,7 @@ def Llama_SELF_ATTENTION_RUN( INTERVENTION_ACTIVATIONS, ) CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output - attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self_attn.hidden_size) + attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, num_heads * self_attn.head_dim) attn_value_output = DO_INTERVENTION( f"{i}.attention_value_output", attn_value_output, INTERVENTION_ACTIVATIONS ) @@ -386,7 +394,14 @@ def Llama_MLP_RUN(mlp, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVAT return hidden_states_down_proj def Llama_BLOCK_RUN( - block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS + block, + hidden_states, + i, + CACHE_ACTIVATIONS, + INTERVENTION_ACTIVATIONS, + num_heads, + num_key_value_heads, + rotary_emb ): # self attention + residual residual = hidden_states @@ -398,7 +413,14 @@ def Llama_BLOCK_RUN( CACHE_ACTIVATIONS[f"{i}.attention_input"] = hidden_states attn_outputs = Llama_SELF_ATTENTION_RUN( - block.self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS + block.self_attn, + hidden_states, + i, + CACHE_ACTIVATIONS, + INTERVENTION_ACTIVATIONS, + num_heads, + num_key_value_heads, + rotary_emb ) attn_outputs = DO_INTERVENTION( @@ -439,6 +461,9 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS): """ # embed inputs_embeds = llama.model.embed_tokens(input_ids) + num_heads = llama.model.config.num_attention_heads + num_key_value_heads = llama.model.config.num_key_value_heads + rotary_emb = llama.model.rotary_emb hidden_states = inputs_embeds for i, block in enumerate(llama.model.layers): hidden_states = DO_INTERVENTION( @@ -447,7 +472,14 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS): CACHE_ACTIVATIONS[f"{i}.block_input"] = hidden_states hidden_states = Llama_BLOCK_RUN( - block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS + block, + hidden_states, + i, + CACHE_ACTIVATIONS, + INTERVENTION_ACTIVATIONS, + num_heads, + num_key_value_heads, + rotary_emb ) hidden_states = DO_INTERVENTION( From 8695a0964d55224add3996e9d2a2876e37c8934f Mon Sep 17 00:00:00 2001 From: PinetreePantry Date: Wed, 29 Jan 2025 16:08:04 -0800 Subject: [PATCH 5/5] Fix pyvene 101 notebook --- pyvene_101.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyvene_101.ipynb b/pyvene_101.ipynb index e2b20765..18817bd6 100644 --- a/pyvene_101.ipynb +++ b/pyvene_101.ipynb @@ -1877,7 +1877,7 @@ " tokenizer(\"The capital of Italy is\", return_tensors=\"pt\"),\n", "]\n", "base_outputs, counterfactual_outputs = pv_gpt2(\n", - " base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}\n", + " base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}, output_original_output=True\n", ")\n", "print(counterfactual_outputs.last_hidden_state - base_outputs.last_hidden_state)\n", "# call backward will put gradients on model's weights\n", @@ -2785,7 +2785,7 @@ " model=resnet\n", ")\n", "intervened_outputs = pv_resnet(\n", - " base_inputs, [source_inputs], return_dict=True\n", + " base_inputs, [source_inputs], return_dict=True, output_original_output=True\n", ")\n", "(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()" ] @@ -2842,7 +2842,7 @@ ")\n", "\n", "intervened_outputs = pv_resnet(\n", - " base_inputs, [source_inputs], return_dict=True\n", + " base_inputs, [source_inputs], return_dict=True, output_original_output=True\n", ")\n", "(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()" ]