diff --git a/README.md b/README.md index a39f508..990756d 100644 --- a/README.md +++ b/README.md @@ -377,3 +377,14 @@ loss.backward() year = {2021}, } ``` + +```bibtex +@misc{su2021roformer, + title = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, + author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, + year = {2021}, + eprint = {2104.09864}, + archivePrefix = {arXiv}, + primaryClass = {cs.CL} +} +``` diff --git a/setup.py b/setup.py index 797044c..f227a2c 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'x-clip', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.2.4', + version = '0.3.0', license='MIT', description = 'X-CLIP', author = 'Phil Wang', diff --git a/x_clip/x_clip.py b/x_clip/x_clip.py index b213c03..675775c 100644 --- a/x_clip/x_clip.py +++ b/x_clip/x_clip.py @@ -92,6 +92,31 @@ def __init__(self, dim, fn): def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) +# rotary positional embedding + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, seq_len, device): + inv_freq = self.inv_freq + t = torch.arange(seq_len, device = device).type_as(inv_freq) + freqs = torch.einsum('i , j -> i j', t, inv_freq) + return torch.cat((freqs, freqs), dim = -1) + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +def apply_rotary_pos_emb(freqs, t): + rot_dim = freqs.shape[-1] + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + return torch.cat((t, t_pass), dim = -1) + # transformer class FeedForward(nn.Module): @@ -120,11 +145,15 @@ def __init__(self, dim, dim_head = 64, heads = 8, dropout = 0.): self.to_out = nn.Linear(inner_dim, dim) self.dropout = nn.Dropout(dropout) - def forward(self, x, mask = None): + def forward(self, x, mask = None, rotary_pos_emb = None): h = self.heads q, k, v = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + if exists(rotary_pos_emb): + apply_rotary = partial(apply_rotary_pos_emb, rotary_pos_emb) + q, k, v = map(apply_rotary, (q, k, v)) + sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale if exists(mask): @@ -161,9 +190,14 @@ def __init__( self.norm_out = nn.LayerNorm(dim) - def forward(self, x, mask = None): + def forward( + self, + x, + rotary_pos_emb = None, + mask = None + ): for attn, ff in self.layers: - x = attn(x, mask = mask) + x + x = attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x x = ff(x) + x return self.norm_out(x) @@ -177,22 +211,32 @@ def __init__( *, num_tokens, max_seq_len, + dim_head, + rotary_pos_emb = None, **kwargs ): super().__init__() self.token_emb = nn.Embedding(num_tokens, dim) - self.pos_emb = nn.Embedding(max_seq_len, dim) + + self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if not rotary_pos_emb else None + self.rotary_pos_emb = RotaryEmbedding(min(dim_head, 32)) if rotary_pos_emb else None + self.cls_token = nn.Parameter(torch.randn(dim)) - self.transformer = Transformer(dim, **kwargs) + self.transformer = Transformer(dim, dim_head = dim_head, **kwargs) def forward(self, x, mask = None): b, n, device = *x.shape, x.device x = self.token_emb(x) - pos_emb = self.pos_emb(torch.arange(n, device = device)) - x = x + rearrange(pos_emb, 'n d -> 1 n d') + if exists(self.abs_pos_emb): + pos_emb = self.abs_pos_emb(torch.arange(n, device = device)) + x = x + rearrange(pos_emb, 'n d -> 1 n d') + + rotary_pos_emb = None + if exists(self.rotary_pos_emb): + rotary_pos_emb = self.rotary_pos_emb(n + 1, device = device) cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b) x = torch.cat((cls_tokens, x), dim = 1) @@ -200,7 +244,7 @@ def forward(self, x, mask = None): if exists(mask): mask = F.pad(mask, (1, 0), value = True) - out = self.transformer(x, mask = mask) + out = self.transformer(x, mask = mask, rotary_pos_emb = rotary_pos_emb) return out class VisionTransformer(nn.Module): @@ -276,10 +320,13 @@ def __init__( text_enc_depth = 6, text_seq_len = 256, text_heads = 8, + text_dim_head = 64, text_has_cls_token = True, text_pad_id = 0, + text_rotary_pos_emb = False, visual_enc_depth = 6, visual_heads = 8, + visual_dim_head = 64, visual_image_size = 256, visual_patch_size = 32, visual_has_cls_token = True, @@ -315,7 +362,9 @@ def __init__( num_tokens = num_text_tokens + (1 if use_mlm else 0), max_seq_len = text_seq_len, depth = text_enc_depth, - heads = text_heads + heads = text_heads, + dim_head = text_dim_head, + rotary_pos_emb = text_rotary_pos_emb ) # instantiate image transformer @@ -331,7 +380,8 @@ def __init__( patch_size = visual_patch_size, channels = channels, depth = visual_enc_depth, - heads = visual_heads + heads = visual_heads, + dim_head = visual_dim_head ) # text ssl