Skip to content

Commit

Permalink
Additional test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
PinetreePantry committed Jan 29, 2025
1 parent bd33ec9 commit 3f30793
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import numpy as np
from transformers import GPT2Config, LlamaConfig
from transformers.models.gpt2.modeling_gpt2 import eager_attention_forward
import math
from torch import nn

Expand Down Expand Up @@ -133,10 +134,14 @@ def GPT2_SELF_ATTENTION_RUN(
)
CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value

head_attention_value_output, attn_weights = self_attn._attn(
head_query, head_key, head_value
head_attention_value_output, attn_weights = eager_attention_forward(
module=self_attn,
query=head_query,
key=head_key,
value=head_value,
)


head_attention_value_output = DO_INTERVENTION(
f"{i}.head_attention_value_output",
head_attention_value_output,
Expand Down

0 comments on commit 3f30793

Please sign in to comment.