Skip to content

Commit

Permalink
Merge pull request #203 from stanfordnlp/peterwz-versions
Browse files Browse the repository at this point in the history
[P0] Fix test failures due to transformers version change
  • Loading branch information
frankaging authored Jan 30, 2025
2 parents f5119c1 + 8695a09 commit 4be6f6e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 19 deletions.
11 changes: 10 additions & 1 deletion pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,16 @@ def load(
binary_filename = f"intkey_{k}.bin"
intervention.is_source_constant = \
saving_config.intervention_constant_sources[i]
intervention.set_interchange_dim(saving_config.intervention_dimensions[i])
dim = saving_config.intervention_dimensions[i]
if dim is None:
# Infer interchange dimension from component name to be compatible with old versions
component_name = saving_config.representations[i].component
if component_name.startswith("head_"):
dim = model.config.hidden_size // model.config.num_attention_heads
else:
dim = model.config.hidden_size

intervention.set_interchange_dim(dim)
if saving_config.intervention_constant_sources[i] and \
not isinstance(intervention, ZeroIntervention) and \
not isinstance(intervention, SourcelessIntervention):
Expand Down
6 changes: 3 additions & 3 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,7 @@
" tokenizer(\"The capital of Italy is\", return_tensors=\"pt\"),\n",
"]\n",
"base_outputs, counterfactual_outputs = pv_gpt2(\n",
" base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}\n",
" base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}, output_original_output=True\n",
")\n",
"print(counterfactual_outputs.last_hidden_state - base_outputs.last_hidden_state)\n",
"# call backward will put gradients on model's weights\n",
Expand Down Expand Up @@ -2785,7 +2785,7 @@
" model=resnet\n",
")\n",
"intervened_outputs = pv_resnet(\n",
" base_inputs, [source_inputs], return_dict=True\n",
" base_inputs, [source_inputs], return_dict=True, output_original_output=True\n",
")\n",
"(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()"
]
Expand Down Expand Up @@ -2842,7 +2842,7 @@
")\n",
"\n",
"intervened_outputs = pv_resnet(\n",
" base_inputs, [source_inputs], return_dict=True\n",
" base_inputs, [source_inputs], return_dict=True, output_original_output=True\n",
")\n",
"(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()"
]
Expand Down
84 changes: 69 additions & 15 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import numpy as np
from transformers import GPT2Config, LlamaConfig
from transformers.models.gpt2.modeling_gpt2 import eager_attention_forward
import math
from torch import nn

Expand Down Expand Up @@ -79,6 +80,22 @@ def is_package_installed(package_name):
forward calls to fetch activations or run with cached activations
"""

def split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)


def DO_INTERVENTION(name, orig_hidden_states, INTERVENTION_ACTIVATIONS):
if name in INTERVENTION_ACTIVATIONS:
Expand All @@ -100,9 +117,9 @@ def GPT2_SELF_ATTENTION_RUN(
value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS)
CACHE_ACTIVATIONS[f"{i}.value_output"] = value

head_query = self_attn._split_heads(query, self_attn.num_heads, self_attn.head_dim)
head_key = self_attn._split_heads(key, self_attn.num_heads, self_attn.head_dim)
head_value = self_attn._split_heads(value, self_attn.num_heads, self_attn.head_dim)
head_query = split_heads(query, self_attn.num_heads, self_attn.head_dim)
head_key = split_heads(key, self_attn.num_heads, self_attn.head_dim)
head_value = split_heads(value, self_attn.num_heads, self_attn.head_dim)

head_query = DO_INTERVENTION(
f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS
Expand All @@ -117,18 +134,24 @@ def GPT2_SELF_ATTENTION_RUN(
)
CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value

head_attention_value_output, attn_weights = self_attn._attn(
head_query, head_key, head_value
head_attention_value_output, _ = eager_attention_forward(
module=self_attn,
query=head_query,
key=head_key,
value=head_value,
attention_mask=None,
)

head_attention_value_output = head_attention_value_output.permute(0, 2, 1, 3)

head_attention_value_output = DO_INTERVENTION(
f"{i}.head_attention_value_output",
head_attention_value_output,
INTERVENTION_ACTIVATIONS,
)
CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output

attn_value_output = self_attn._merge_heads(
attn_value_output = merge_heads(
head_attention_value_output, self_attn.num_heads, self_attn.head_dim
)
attn_value_output = DO_INTERVENTION(
Expand Down Expand Up @@ -287,7 +310,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed

def Llama_SELF_ATTENTION_RUN(
self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
self_attn,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
):
bsz, q_len, _ = hidden_states.size()

Expand All @@ -302,9 +332,9 @@ def Llama_SELF_ATTENTION_RUN(
value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS)
CACHE_ACTIVATIONS[f"{i}.value_output"] = value

head_query = query.view(bsz, q_len, self_attn.num_heads, self_attn.head_dim).transpose(1, 2)
head_key = key.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2)
head_value = value.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2)
head_query = query.view(bsz, q_len, num_heads, self_attn.head_dim).transpose(1, 2)
head_key = key.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2)
head_value = value.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2)

head_query = DO_INTERVENTION(
f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS
Expand All @@ -320,7 +350,7 @@ def Llama_SELF_ATTENTION_RUN(
CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value

position_ids = torch.arange(q_len, device=hidden_states.device).repeat(bsz, 1)
cos, sin = self_attn.rotary_emb(head_value, position_ids)
cos, sin = rotary_emb(head_value, position_ids)
head_query, head_key = apply_rotary_pos_emb(head_query, head_key, cos, sin)

head_key = repeat_kv(head_key, self_attn.num_key_value_groups)
Expand All @@ -340,7 +370,7 @@ def Llama_SELF_ATTENTION_RUN(
INTERVENTION_ACTIVATIONS,
)
CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output
attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self_attn.hidden_size)
attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, num_heads * self_attn.head_dim)
attn_value_output = DO_INTERVENTION(
f"{i}.attention_value_output", attn_value_output, INTERVENTION_ACTIVATIONS
)
Expand All @@ -364,7 +394,14 @@ def Llama_MLP_RUN(mlp, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVAT
return hidden_states_down_proj

def Llama_BLOCK_RUN(
block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
block,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
):
# self attention + residual
residual = hidden_states
Expand All @@ -376,7 +413,14 @@ def Llama_BLOCK_RUN(
CACHE_ACTIVATIONS[f"{i}.attention_input"] = hidden_states

attn_outputs = Llama_SELF_ATTENTION_RUN(
block.self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
block.self_attn,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
)

attn_outputs = DO_INTERVENTION(
Expand Down Expand Up @@ -417,6 +461,9 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
"""
# embed
inputs_embeds = llama.model.embed_tokens(input_ids)
num_heads = llama.model.config.num_attention_heads
num_key_value_heads = llama.model.config.num_key_value_heads
rotary_emb = llama.model.rotary_emb
hidden_states = inputs_embeds
for i, block in enumerate(llama.model.layers):
hidden_states = DO_INTERVENTION(
Expand All @@ -425,7 +472,14 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
CACHE_ACTIVATIONS[f"{i}.block_input"] = hidden_states

hidden_states = Llama_BLOCK_RUN(
block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
block,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
)

hidden_states = DO_INTERVENTION(
Expand Down

0 comments on commit 4be6f6e

Please sign in to comment.