Skip to content

Commit

Permalink
attentions
Browse files Browse the repository at this point in the history
  • Loading branch information
Esmail-ibraheem authored May 14, 2024
1 parent 60d49e7 commit 35bf970
Show file tree
Hide file tree
Showing 3 changed files with 599 additions and 0 deletions.
262 changes: 262 additions & 0 deletions models/attentions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
from asyncio.log import logger
from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch.nn import functional as F

import math

from transformers.cache_utils import Cache

from model import (
LlamaRotaryEmbeddings,
LlamaLinearScalingRotaryEmbeddings,
LlamaDynamicNTKScalingRotaryEmbeddings,
repeat_kv,
apply_rotary_embeddings
)
from config import LlamaConfig

class LlamaScalableGroupedQueryAttention(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_heads_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)

self.query_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.key_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.value_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope()

def _init_rope(self):
if self.config.rope_scaling is None:
self.emb_rotary = LlamaRotaryEmbeddings(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta
)
else :
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.emb_rotary = LlamaLinearScalingRotaryEmbeddings(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta
)
elif scaling_type == "dynamic":
self.emb_rotary = LlamaDynamicNTKScalingRotaryEmbeddings(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta
)
else:
raise ValueError(f"Unkown scaling type of RoPE {scaling_type}")

def forward(
self,
hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None ,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
attention_output: bool = False,
position_cache: Optional[torch.LongTensor] = None
)->Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[Optional[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.query_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)

key_slices = self.key_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.value_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.query_proj(hidden_states)
key_states = self.key_proj(hidden_states)
value_states = self.value_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1,2)

past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.emb_rotary(value_states, position_ids)
query_states, key_states = apply_rotary_embeddings(query_states, key_states, cos, sin)

key_states = repeat_kv(key_states, self.num_key_value_heads_groups)
value_states = repeat_kv(value_states, self.num_key_value_heads_groups)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "position_cache": position_cache}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_weights = torch.matmul(query_states, key_states.transpose(3,2)) / math.sqrt(self.head_dim)

if attention_mask is not None:
causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
attention_weights = attention_weights + causal_mask

attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attention_weights = nn.functional.dropout(attention_weights, p=self.config.attention_dropout, training=self.training)
output_attention = torch.matmul(attention_weights, value_states)

if output_attention.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attention_output.size()}"
)
output_attention = output_attention.transpose(1,2).contiguous()
output_attention = output_attention.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 2:
output_attention = output_attention.split(self.hidden_size // self.config.pretraining_tp, dim=2)
output_slices = self.output_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
output_attention = sum([F.linear(output_attention[i], output_slices[i]) for i in range(self.config.pretraining_tp)])
else:
output_attention = self.output_proj(output_attention)

if not attention_output :
attention_weights = None

return attention_weights, output_attention, past_key_value

class LlamaFixedGroupedQueryAttention(nn.Module):
def __init__(self, args: LlamaConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_heads_q = args.n_heads
self.n_rep = self.n_heads_q // self.n_kv_heads
self.head_dim = args.dim // args.n_heads

self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))

def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_complex: torch.Tensor
):
batch_size, seq_len, _ = x.shape # (B, 1, Dim)

xq = self.wq(x)
xk = self.wk(x)
xv = self.wv(x)

xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device)
xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device)

self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv

keys = self.cache_k[:batch_size, : start_pos + seq_len]
values = self.cache_v[:batch_size, : start_pos + seq_len]

keys = repeat_kv(keys, self.n_rep)
values = repeat_kv(values, self.n_rep)

xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)

scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values)
output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
return self.wo(output) # (B, 1, Dim) -> (B, 1, Dim)

def MultiHeadAttention():
d_model, batch, heads, key, value = 512, 32, 8, (512 // 8), (512 // 8)

m = 5 # suppose we have already cached "m" tokens
previous_key = torch.rand(batch, heads, m, key)
previous_value = torch.rand(batch, heads, m, value)

X = torch.rand(batch, d_model) # query
M = torch.rand(batch, d_model) # key and value

P_q = torch.rand(heads, d_model, key)
P_k = torch.rand(heads, d_model, key)
P_v = torch.rand(heads, d_model, value)
P_o = torch.rand(heads, d_model, value)

q = torch.einsum("bd,hdk->bhk", X, P_q)
new_k = torch.concat([previous_key, torch.einsum("bd,hdk->bhk", M, P_k).unsqueeze(2)], axis=2)
new_v = torch.concat([previous_value, torch.einsum("bd,hdk->bhk", M, P_v).unsqueeze(2)], axis=2)

logits = torch.einsum("bhk,bhmk->bhm", q, new_k)
weights = torch.softmax(logits, dim=-1)
output = torch.einsum("bhm,bhmv->bhv", weights, new_v)
y = torch.einsum("bhv,hdv->bd", output, P_o)

return y, new_k, new_v

def MultiQueryAttention():
d_model, batch, heads, key, value = 512, 32, 8, (512 // 8), (512 // 8)

m = 5 # suppose we have already cached "m" tokens
previous_key = torch.rand(batch, m, key)
previous_value = torch.rand(batch, m, value)

X = torch.rand(batch, d_model) # query
M = torch.rand(batch, d_model) # key and value

P_q = torch.rand(heads, d_model, key)
P_k = torch.rand(d_model, key)
P_v = torch.rand(d_model, value)
P_o = torch.rand(heads, d_model, value)

q = torch.einsum("bd,hdk->bhk", X, P_q)
k = torch.concat([previous_key, torch.einsum("bd,dk->bk", M, P_k).unsqueeze(1)], axis=1)
v = torch.concat([previous_value, torch.einsum("bd,dv->bv", M, P_v).unsqueeze(1)], axis=1)

logits = torch.einsum("bhk,bmk->bhm", q, k)
weights = torch.softmax(logits, dim=-1)
output = torch.einsum("bhm,bhmv->bhv", weights, v)
y = torch.einsum("bhv,hdv->bd", output, P_o)

return y, k, v
98 changes: 98 additions & 0 deletions models/rotary_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
from asyncio.log import logger

class LlamaRotaryEmbeddings(nn.Module):
def __init__(self, dim, max_position_embddings=2048, base=10000, device=None, scaling_factor=1.0) -> None:
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embddings
self.base = base
self.scaling_factor = scaling_factor
self.inv_frequency = 1.0 / (self.base * (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_frequency", self.inv_frequency, persistent=False)
self.max_seq_len_cached = max_position_embddings
t = torch.arange(0, self.max_seq_len_cached, device=device)
t = t / scaling_factor
freqs = torch.outer(t, self.inv_frequency)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

@property
def sin_cached(self):
logger.warning_once(
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self.sin_cached

@property
def cos_cached(self):
logger.warning_once(
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self.cos_cached

@torch.no_grad
def forward(self, x, position_ids):
inv_frequency_expanded = self.inv_frequency[None, : , None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[None, :, None].float()
device_dtype = x.device.dtype
device_dtype = device_dtype if isinstance(device_dtype, str) and device_dtype == "mps" else "cpu"
with torch.autocast(device_type=device_dtype, enabled=False):
freqs = (inv_frequency_expanded.float() @ position_ids_expanded.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

class LlamaLinearScalingRotaryEmbeddings(LlamaRotaryEmbeddings):
def forward(self, x, position_ids):
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin

class LlamaDynamicNTKScalingRotaryEmbeddings(LlamaRotaryEmbeddings):
def forward(self, x, position_ids):
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (self.scaling_factor * seq_len / self.max_position_embeddings - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
inv_frequency = 1.0 / (base * (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_frequency", inv_frequency, persistent=False)
cos, sin = super().forward(x, position_ids)
return cos, sin

LLAMA_ROTARY_EMBEDDINGS_CLASSES = {
"rotary": LlamaRotaryEmbeddings,
"linear": LlamaLinearScalingRotaryEmbeddings,
"dynamic": LlamaDynamicNTKScalingRotaryEmbeddings,
}

def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float=10000.0):
assert head_dim % 2 == 0 , "dimension must be divisable by 2"
theta_numerator = torch.arange(0, head_dim, 2).float()
theta = 1.0 / (theta ** (theta_numerator / head_dim))
m = torch.arange(seq_len, device=device)
freqs = torch.outer(m, theta).float()
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: int, device:str):
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
x_rotated = x_complex * freqs_complex
x_out = torch.view_as_real(x_rotated)
x_out = torch.reshape(*x.shape)
return x_out.type_as(x).to(device)

def repeat_kv(x:torch.Tensor, n_rep: int):
batch_size, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
.reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
)
Loading

0 comments on commit 35bf970

Please sign in to comment.