Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P0] Fix test failures due to transformers version change #203

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,16 @@ def load(
binary_filename = f"intkey_{k}.bin"
intervention.is_source_constant = \
saving_config.intervention_constant_sources[i]
intervention.set_interchange_dim(saving_config.intervention_dimensions[i])
dim = saving_config.intervention_dimensions[i]
if dim is None:
# Infer interchange dimension from component name to be compatible with old versions
component_name = saving_config.representations[i].component
if component_name.startswith("head_"):
dim = model.config.hidden_size // model.config.num_attention_heads
else:
dim = model.config.hidden_size

intervention.set_interchange_dim(dim)
if saving_config.intervention_constant_sources[i] and \
not isinstance(intervention, ZeroIntervention) and \
not isinstance(intervention, SourcelessIntervention):
Expand Down
6 changes: 3 additions & 3 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,7 @@
" tokenizer(\"The capital of Italy is\", return_tensors=\"pt\"),\n",
"]\n",
"base_outputs, counterfactual_outputs = pv_gpt2(\n",
" base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}\n",
" base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}, output_original_output=True\n",
")\n",
"print(counterfactual_outputs.last_hidden_state - base_outputs.last_hidden_state)\n",
"# call backward will put gradients on model's weights\n",
Expand Down Expand Up @@ -2785,7 +2785,7 @@
" model=resnet\n",
")\n",
"intervened_outputs = pv_resnet(\n",
" base_inputs, [source_inputs], return_dict=True\n",
" base_inputs, [source_inputs], return_dict=True, output_original_output=True\n",
")\n",
"(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()"
]
Expand Down Expand Up @@ -2842,7 +2842,7 @@
")\n",
"\n",
"intervened_outputs = pv_resnet(\n",
" base_inputs, [source_inputs], return_dict=True\n",
" base_inputs, [source_inputs], return_dict=True, output_original_output=True\n",
")\n",
"(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()"
]
Expand Down
84 changes: 69 additions & 15 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import numpy as np
from transformers import GPT2Config, LlamaConfig
from transformers.models.gpt2.modeling_gpt2 import eager_attention_forward
import math
from torch import nn

Expand Down Expand Up @@ -79,6 +80,22 @@ def is_package_installed(package_name):
forward calls to fetch activations or run with cached activations
"""

def split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)


def DO_INTERVENTION(name, orig_hidden_states, INTERVENTION_ACTIVATIONS):
if name in INTERVENTION_ACTIVATIONS:
Expand All @@ -100,9 +117,9 @@ def GPT2_SELF_ATTENTION_RUN(
value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS)
CACHE_ACTIVATIONS[f"{i}.value_output"] = value

head_query = self_attn._split_heads(query, self_attn.num_heads, self_attn.head_dim)
head_key = self_attn._split_heads(key, self_attn.num_heads, self_attn.head_dim)
head_value = self_attn._split_heads(value, self_attn.num_heads, self_attn.head_dim)
head_query = split_heads(query, self_attn.num_heads, self_attn.head_dim)
head_key = split_heads(key, self_attn.num_heads, self_attn.head_dim)
head_value = split_heads(value, self_attn.num_heads, self_attn.head_dim)

head_query = DO_INTERVENTION(
f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS
Expand All @@ -117,18 +134,24 @@ def GPT2_SELF_ATTENTION_RUN(
)
CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value

head_attention_value_output, attn_weights = self_attn._attn(
head_query, head_key, head_value
head_attention_value_output, _ = eager_attention_forward(
module=self_attn,
query=head_query,
key=head_key,
value=head_value,
attention_mask=None,
)

head_attention_value_output = head_attention_value_output.permute(0, 2, 1, 3)

head_attention_value_output = DO_INTERVENTION(
f"{i}.head_attention_value_output",
head_attention_value_output,
INTERVENTION_ACTIVATIONS,
)
CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output

attn_value_output = self_attn._merge_heads(
attn_value_output = merge_heads(
head_attention_value_output, self_attn.num_heads, self_attn.head_dim
)
attn_value_output = DO_INTERVENTION(
Expand Down Expand Up @@ -287,7 +310,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed

def Llama_SELF_ATTENTION_RUN(
self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
self_attn,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
):
bsz, q_len, _ = hidden_states.size()

Expand All @@ -302,9 +332,9 @@ def Llama_SELF_ATTENTION_RUN(
value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS)
CACHE_ACTIVATIONS[f"{i}.value_output"] = value

head_query = query.view(bsz, q_len, self_attn.num_heads, self_attn.head_dim).transpose(1, 2)
head_key = key.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2)
head_value = value.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2)
head_query = query.view(bsz, q_len, num_heads, self_attn.head_dim).transpose(1, 2)
head_key = key.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2)
head_value = value.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2)

head_query = DO_INTERVENTION(
f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS
Expand All @@ -320,7 +350,7 @@ def Llama_SELF_ATTENTION_RUN(
CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value

position_ids = torch.arange(q_len, device=hidden_states.device).repeat(bsz, 1)
cos, sin = self_attn.rotary_emb(head_value, position_ids)
cos, sin = rotary_emb(head_value, position_ids)
head_query, head_key = apply_rotary_pos_emb(head_query, head_key, cos, sin)

head_key = repeat_kv(head_key, self_attn.num_key_value_groups)
Expand All @@ -340,7 +370,7 @@ def Llama_SELF_ATTENTION_RUN(
INTERVENTION_ACTIVATIONS,
)
CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output
attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self_attn.hidden_size)
attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, num_heads * self_attn.head_dim)
attn_value_output = DO_INTERVENTION(
f"{i}.attention_value_output", attn_value_output, INTERVENTION_ACTIVATIONS
)
Expand All @@ -364,7 +394,14 @@ def Llama_MLP_RUN(mlp, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVAT
return hidden_states_down_proj

def Llama_BLOCK_RUN(
block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
block,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
):
# self attention + residual
residual = hidden_states
Expand All @@ -376,7 +413,14 @@ def Llama_BLOCK_RUN(
CACHE_ACTIVATIONS[f"{i}.attention_input"] = hidden_states

attn_outputs = Llama_SELF_ATTENTION_RUN(
block.self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
block.self_attn,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
)

attn_outputs = DO_INTERVENTION(
Expand Down Expand Up @@ -417,6 +461,9 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
"""
# embed
inputs_embeds = llama.model.embed_tokens(input_ids)
num_heads = llama.model.config.num_attention_heads
num_key_value_heads = llama.model.config.num_key_value_heads
rotary_emb = llama.model.rotary_emb
hidden_states = inputs_embeds
for i, block in enumerate(llama.model.layers):
hidden_states = DO_INTERVENTION(
Expand All @@ -425,7 +472,14 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
CACHE_ACTIVATIONS[f"{i}.block_input"] = hidden_states

hidden_states = Llama_BLOCK_RUN(
block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
block,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
)

hidden_states = DO_INTERVENTION(
Expand Down
Loading