Skip to content

Commit

Permalink
add ability to use rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 1, 2022
1 parent ad829e0 commit 1f6d8dd
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 11 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,14 @@ loss.backward()
year = {2021},
}
```

```bibtex
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'x-clip',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.2.4',
version = '0.3.0',
license='MIT',
description = 'X-CLIP',
author = 'Phil Wang',
Expand Down
70 changes: 60 additions & 10 deletions x_clip/x_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,31 @@ def __init__(self, dim, fn):
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

# rotary positional embedding

class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

def forward(self, seq_len, device):
inv_freq = self.inv_freq
t = torch.arange(seq_len, device = device).type_as(inv_freq)
freqs = torch.einsum('i , j -> i j', t, inv_freq)
return torch.cat((freqs, freqs), dim = -1)

def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)

def apply_rotary_pos_emb(freqs, t):
rot_dim = freqs.shape[-1]
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
return torch.cat((t, t_pass), dim = -1)

# transformer

class FeedForward(nn.Module):
Expand Down Expand Up @@ -120,11 +145,15 @@ def __init__(self, dim, dim_head = 64, heads = 8, dropout = 0.):
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask = None):
def forward(self, x, mask = None, rotary_pos_emb = None):
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

if exists(rotary_pos_emb):
apply_rotary = partial(apply_rotary_pos_emb, rotary_pos_emb)
q, k, v = map(apply_rotary, (q, k, v))

sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

if exists(mask):
Expand Down Expand Up @@ -161,9 +190,14 @@ def __init__(

self.norm_out = nn.LayerNorm(dim)

def forward(self, x, mask = None):
def forward(
self,
x,
rotary_pos_emb = None,
mask = None
):
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x
x = ff(x) + x

return self.norm_out(x)
Expand All @@ -177,30 +211,40 @@ def __init__(
*,
num_tokens,
max_seq_len,
dim_head,
rotary_pos_emb = None,
**kwargs
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)

self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if not rotary_pos_emb else None
self.rotary_pos_emb = RotaryEmbedding(min(dim_head, 32)) if rotary_pos_emb else None

self.cls_token = nn.Parameter(torch.randn(dim))

self.transformer = Transformer(dim, **kwargs)
self.transformer = Transformer(dim, dim_head = dim_head, **kwargs)

def forward(self, x, mask = None):
b, n, device = *x.shape, x.device

x = self.token_emb(x)

pos_emb = self.pos_emb(torch.arange(n, device = device))
x = x + rearrange(pos_emb, 'n d -> 1 n d')
if exists(self.abs_pos_emb):
pos_emb = self.abs_pos_emb(torch.arange(n, device = device))
x = x + rearrange(pos_emb, 'n d -> 1 n d')

rotary_pos_emb = None
if exists(self.rotary_pos_emb):
rotary_pos_emb = self.rotary_pos_emb(n + 1, device = device)

cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)

if exists(mask):
mask = F.pad(mask, (1, 0), value = True)

out = self.transformer(x, mask = mask)
out = self.transformer(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
return out

class VisionTransformer(nn.Module):
Expand Down Expand Up @@ -276,10 +320,13 @@ def __init__(
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
text_dim_head = 64,
text_has_cls_token = True,
text_pad_id = 0,
text_rotary_pos_emb = False,
visual_enc_depth = 6,
visual_heads = 8,
visual_dim_head = 64,
visual_image_size = 256,
visual_patch_size = 32,
visual_has_cls_token = True,
Expand Down Expand Up @@ -315,7 +362,9 @@ def __init__(
num_tokens = num_text_tokens + (1 if use_mlm else 0),
max_seq_len = text_seq_len,
depth = text_enc_depth,
heads = text_heads
heads = text_heads,
dim_head = text_dim_head,
rotary_pos_emb = text_rotary_pos_emb
)

# instantiate image transformer
Expand All @@ -331,7 +380,8 @@ def __init__(
patch_size = visual_patch_size,
channels = channels,
depth = visual_enc_depth,
heads = visual_heads
heads = visual_heads,
dim_head = visual_dim_head
)

# text ssl
Expand Down

0 comments on commit 1f6d8dd

Please sign in to comment.