Skip to content

Commit

Permalink
Fix test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
PinetreePantry committed Jan 29, 2025
1 parent 5294462 commit 3c1aabe
Showing 1 changed file with 42 additions and 10 deletions.
52 changes: 42 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,15 @@ def GPT2_SELF_ATTENTION_RUN(
)
CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value

head_attention_value_output = eager_attention_forward(
head_attention_value_output, _ = eager_attention_forward(
module=self_attn,
query=head_query,
key=head_key,
value=head_value,
attention_mask=None,
)

head_attention_value_output = head_attention_value_output.permute(0, 2, 1, 3)

head_attention_value_output = DO_INTERVENTION(
f"{i}.head_attention_value_output",
Expand Down Expand Up @@ -309,7 +310,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed

def Llama_SELF_ATTENTION_RUN(
self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
self_attn,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
):
bsz, q_len, _ = hidden_states.size()

Expand All @@ -324,9 +332,9 @@ def Llama_SELF_ATTENTION_RUN(
value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS)
CACHE_ACTIVATIONS[f"{i}.value_output"] = value

head_query = query.view(bsz, q_len, self_attn.num_heads, self_attn.head_dim).transpose(1, 2)
head_key = key.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2)
head_value = value.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2)
head_query = query.view(bsz, q_len, num_heads, self_attn.head_dim).transpose(1, 2)
head_key = key.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2)
head_value = value.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2)

head_query = DO_INTERVENTION(
f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS
Expand All @@ -342,7 +350,7 @@ def Llama_SELF_ATTENTION_RUN(
CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value

position_ids = torch.arange(q_len, device=hidden_states.device).repeat(bsz, 1)
cos, sin = self_attn.rotary_emb(head_value, position_ids)
cos, sin = rotary_emb(head_value, position_ids)
head_query, head_key = apply_rotary_pos_emb(head_query, head_key, cos, sin)

head_key = repeat_kv(head_key, self_attn.num_key_value_groups)
Expand All @@ -362,7 +370,7 @@ def Llama_SELF_ATTENTION_RUN(
INTERVENTION_ACTIVATIONS,
)
CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output
attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self_attn.hidden_size)
attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, num_heads * self_attn.head_dim)
attn_value_output = DO_INTERVENTION(
f"{i}.attention_value_output", attn_value_output, INTERVENTION_ACTIVATIONS
)
Expand All @@ -386,7 +394,14 @@ def Llama_MLP_RUN(mlp, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVAT
return hidden_states_down_proj

def Llama_BLOCK_RUN(
block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
block,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
):
# self attention + residual
residual = hidden_states
Expand All @@ -398,7 +413,14 @@ def Llama_BLOCK_RUN(
CACHE_ACTIVATIONS[f"{i}.attention_input"] = hidden_states

attn_outputs = Llama_SELF_ATTENTION_RUN(
block.self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
block.self_attn,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
)

attn_outputs = DO_INTERVENTION(
Expand Down Expand Up @@ -439,6 +461,9 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
"""
# embed
inputs_embeds = llama.model.embed_tokens(input_ids)
num_heads = llama.model.config.num_attention_heads
num_key_value_heads = llama.model.config.num_key_value_heads
rotary_emb = llama.model.rotary_emb
hidden_states = inputs_embeds
for i, block in enumerate(llama.model.layers):
hidden_states = DO_INTERVENTION(
Expand All @@ -447,7 +472,14 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
CACHE_ACTIVATIONS[f"{i}.block_input"] = hidden_states

hidden_states = Llama_BLOCK_RUN(
block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
block,
hidden_states,
i,
CACHE_ACTIVATIONS,
INTERVENTION_ACTIVATIONS,
num_heads,
num_key_value_heads,
rotary_emb
)

hidden_states = DO_INTERVENTION(
Expand Down

0 comments on commit 3c1aabe

Please sign in to comment.