Skip to content

Commit

Permalink
[P0] Fix test failures due to transformers version change
Browse files Browse the repository at this point in the history
  • Loading branch information
PinetreePantry committed Jan 29, 2025
1 parent 22d4d60 commit bd33ec9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
11 changes: 10 additions & 1 deletion pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,16 @@ def load(
binary_filename = f"intkey_{k}.bin"
intervention.is_source_constant = \
saving_config.intervention_constant_sources[i]
intervention.set_interchange_dim(saving_config.intervention_dimensions[i])
dim = saving_config.intervention_dimensions[i]
if dim is None:
# Infer interchange dimension from component name to be compatible with old versions
component_name = saving_config.representations[i].component
if component_name.startswith("head_"):
dim = model.config.hidden_size // model.config.num_attention_heads
else:
dim = model.config.hidden_size

intervention.set_interchange_dim(dim)
if saving_config.intervention_constant_sources[i] and \
not isinstance(intervention, ZeroIntervention) and \
not isinstance(intervention, SourcelessIntervention):
Expand Down
24 changes: 20 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ def is_package_installed(package_name):
forward calls to fetch activations or run with cached activations
"""

def split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)


def DO_INTERVENTION(name, orig_hidden_states, INTERVENTION_ACTIVATIONS):
if name in INTERVENTION_ACTIVATIONS:
Expand All @@ -100,9 +116,9 @@ def GPT2_SELF_ATTENTION_RUN(
value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS)
CACHE_ACTIVATIONS[f"{i}.value_output"] = value

head_query = self_attn._split_heads(query, self_attn.num_heads, self_attn.head_dim)
head_key = self_attn._split_heads(key, self_attn.num_heads, self_attn.head_dim)
head_value = self_attn._split_heads(value, self_attn.num_heads, self_attn.head_dim)
head_query = split_heads(query, self_attn.num_heads, self_attn.head_dim)
head_key = split_heads(key, self_attn.num_heads, self_attn.head_dim)
head_value = split_heads(value, self_attn.num_heads, self_attn.head_dim)

head_query = DO_INTERVENTION(
f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS
Expand All @@ -128,7 +144,7 @@ def GPT2_SELF_ATTENTION_RUN(
)
CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output

attn_value_output = self_attn._merge_heads(
attn_value_output = merge_heads(
head_attention_value_output, self_attn.num_heads, self_attn.head_dim
)
attn_value_output = DO_INTERVENTION(
Expand Down

0 comments on commit bd33ec9

Please sign in to comment.