Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add] Gemma Example #25

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/
.ruff_cache
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ To execute a demonstration of SelfExtend on the Passkey Retrivale task, you can
python llama_example.py # llama

python mistral_example.py # mistra

python gemma_example.py # gemma
```


Expand Down Expand Up @@ -170,6 +172,39 @@ SelfExtend: [What is the pass key? The pass key is 58328.]
-----------------------------------
```

For Gemma
```bash
-----------------------------------
#Tokens of Prompt: 5142 Passkey target: 89427
Gemma: [What is the pass key? The pass key is 89427.]
SelfExtend: [What is the pass key? The pass key is 89427.]
-----------------------------------

-----------------------------------
#Tokens of Prompt: 5142 Passkey target: 51906
Gemma: [What is the pass key? The pass key is 519. Here.]
SelfExtend: [What is the pass key? The pass key is 51906.]
-----------------------------------

-----------------------------------
#Tokens of Prompt: 5142 Passkey target: 38117
Gemma: [What is the pass key? The pass key is 38117.]
SelfExtend: [What is the pass key? The pass key is 38117.]
-----------------------------------

-----------------------------------
#Tokens of Prompt: 5142 Passkey target: 60151
Gemma: [What is the pass key? The pass key is 60151.]
SelfExtend: [What is the pass key? The pass key is 60151.]
-----------------------------------

-----------------------------------
#Tokens of Prompt: 5142 Passkey target: 23789
Gemma: [What is the pass key? The pass key is 2378. The]
SelfExtend: [What is the pass key? The pass key is 23789.]
-----------------------------------
```



## 4.How to choose the group_size and neighbor_window
Expand Down
61 changes: 61 additions & 0 deletions gemma_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# transfromers version 4.32.0
import warnings

warnings.filterwarnings("ignore")

import gemma_self_extend_patch as GemmaSE
from modify_utils import modify_method_of_instance
from functools import partial
import json
from transformers.models.gemma.modeling_gemma import GemmaAttention
from transformers import AutoTokenizer, AutoModelForCausalLM

original_gemma_forward = GemmaAttention.forward
self_extend_forward = partial(
GemmaSE.self_extend_forward, group_size_1=8, group_size_2=1024
)

device = "cpu"
model_path = "google/gemma-2b-it"
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.eval()


for line in open("passkey_examples_5k.jsonl", "r"):
example = json.loads(line)
prompt_postfix = "What is the pass key? The pass key is "
prompt = example["input"] + prompt_postfix
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
print("-----------------------------------")
print(f"#Tokens of Prompt:", input_ids.shape[1], end=" ")
print("Passkey target:", example["target"])

modify_method_of_instance(
model, "GemmaAttention", "forward", original_gemma_forward
)
tokens = model.generate(input_ids, max_new_tokens=6)
answer = (
"Gemma: ["
+ prompt_postfix
+ tokenizer.decode(
tokens[0].tolist()[input_ids.shape[1] :], skip_special_tokens=True
)
+ "]"
)
answer = answer.replace("\n", "\\n")
print(answer)

modify_method_of_instance(model, "GemmaAttention", "forward", self_extend_forward)
tokens = model.generate(input_ids, max_new_tokens=6)
answer = (
"SelfExtend: ["
+ prompt_postfix
+ tokenizer.decode(
tokens[0].tolist()[input_ids.shape[1] :], skip_special_tokens=True
)
+ "]"
)
answer = answer.replace("\n", "\\n")
print(answer)
print("-----------------------------------\n")
147 changes: 100 additions & 47 deletions gemma_self_extend_patch.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
# transformers version: 4.38.1

import torch
from transformers.models.llama.modeling_llama import *
from transformers.models.gpt_neox.modeling_gpt_neox import *
import numpy as np
import torch.nn as nn
import math
import warnings
from typing import Optional, Tuple
import math
import torch
import transformers


if transformers.__version__ >= '4.36':
if transformers.__version__ >= "4.36":
from transformers.cache_utils import Cache


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
Expand All @@ -21,9 +19,12 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -83,52 +84,88 @@ def self_extend_forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)

if log_scale_base > 0:
scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(log_scale_base)).clip(1).to(query_states.dtype)
scaled_query = query_states * (
(position_ids + 1)[:, None, :, None].log() / np.log(log_scale_base)
).clip(1).to(query_states.dtype)
else:
scaled_query = query_states

past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
kv_seq_len = key_states.shape[-2]

query_position = position_ids
key_position = position_ids if q_len != 1 else torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(bsz, kv_seq_len)


neighbor_q_cos, neighbor_q_sin = self.rotary_emb(value_states, query_position, seq_len=None)
neighbor_k_cos, neighbor_k_sin = self.rotary_emb(value_states, key_position, seq_len=None)


_re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
group_query_position = query_position // group_size_1 + _re_group_size_2 - _re_group_size_2 / group_size_1
key_position = (
position_ids
if q_len != 1
else torch.arange(kv_seq_len, dtype=position_ids.dtype)
.to(query_position.device)
.view(bsz, kv_seq_len)
)

neighbor_q_cos, neighbor_q_sin = self.rotary_emb(
value_states, query_position, seq_len=None
)
neighbor_k_cos, neighbor_k_sin = self.rotary_emb(
value_states, key_position, seq_len=None
)

_re_group_size_2 = (
0 if query_position.max() < group_size_2 else group_size_2
) # in case that, the smallest q position, g2-g2//g1 exceed the max position
group_query_position = (
query_position // group_size_1
+ _re_group_size_2
- _re_group_size_2 / group_size_1
)
group_key_position = key_position // group_size_1

group_q_cos, group_q_sin = self.rotary_emb(value_states, group_query_position, seq_len=None)
group_k_cos, group_k_sin = self.rotary_emb(value_states, group_key_position, seq_len=None)



neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, neighbor_q_cos, neighbor_q_sin, None)
_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, neighbor_k_cos, neighbor_k_sin, None)
group_query_states, _ = apply_rotary_pos_emb(scaled_query, None, group_q_cos, group_q_sin, None)
_, group_key_states = apply_rotary_pos_emb(None, key_states, group_k_cos, group_k_sin, None)

group_q_cos, group_q_sin = self.rotary_emb(
value_states, group_query_position, seq_len=None
)
group_k_cos, group_k_sin = self.rotary_emb(
value_states, group_key_position, seq_len=None
)

neighbor_query_states, _ = apply_rotary_pos_emb(
scaled_query, None, neighbor_q_cos, neighbor_q_sin, None
)
_, neighbor_key_states = apply_rotary_pos_emb(
None, key_states, neighbor_k_cos, neighbor_k_sin, None
)
group_query_states, _ = apply_rotary_pos_emb(
scaled_query, None, group_q_cos, group_q_sin, None
)
_, group_key_states = apply_rotary_pos_emb(
None, key_states, group_k_cos, group_k_sin, None
)

neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

neighbor_attn_weights = torch.matmul(
neighbor_query_states, neighbor_key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
group_attn_weights = torch.matmul(
group_query_states, group_key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
if cache_position is not None:
Expand All @@ -139,23 +176,40 @@ def self_extend_forward(
neighbor_attn_weights = neighbor_attn_weights + causal_mask

if q_len == 1:
neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
neighbor_attention_mask = torch.zeros(
(q_len, kv_seq_len), device=neighbor_attn_weights.device
)
neighbor_attention_mask[:, -group_size_2:] = 1
elif q_len == kv_seq_len:
neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
neighbor_attention_mask = torch.ones(
(q_len, kv_seq_len), device=neighbor_attn_weights.device
)
neighbor_attention_mask = torch.tril(neighbor_attention_mask)
if q_len-group_size_2 > 0:
group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
if q_len - group_size_2 > 0:
group_attention_mask = torch.tril(
torch.ones(
(q_len - group_size_2, kv_seq_len - group_size_2),
device=group_attn_weights.device,
)
)
neighbor_attention_mask[group_size_2:, :-group_size_2] -= (
group_attention_mask
)
else:
raise ValueError("q_len should be 1 or seq_len.")

neighbor_attention_mask = neighbor_attention_mask.bool()
attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)

attn_weights = torch.where(
neighbor_attention_mask, neighbor_attn_weights, group_attn_weights
)

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_weights = torch.nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = torch.nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
Expand All @@ -169,7 +223,6 @@ def self_extend_forward(
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)


if not output_attentions:
attn_weights = None

Expand Down