Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use TextEncoder in coca encode_image #321

Merged
merged 25 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 5 additions & 36 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ def __init__(
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.context_length = self.positional_embedding.shape[0] - 1

self.context_length = self.positional_embedding.shape[0]
self.text = text


self.cls_token = nn.Parameter(torch.randn(embed_dim))
self.visual = _build_vision_tower(
embed_dim, vision_cfg, quick_gelu, cast_dtype
Expand Down Expand Up @@ -148,43 +150,10 @@ def encode_image(self, images, normalize=True, return_tokens=False):
def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def _build_cls_mask(self, text, cast_dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask

def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat(
[
x + self.positional_embedding[:seq_len, :].to(cast_dtype),
self._repeat(self.cls_token + self.positional_embedding[-1, :], x.shape[0])
],
dim=1
)
seq_len += 1 # seq is 1 longer as we added CLS
attn_mask = self.attn_mask[None, :seq_len, :seq_len].expand(
text.shape[0] * self.heads, seq_len, seq_len
)
cls_mask = self._build_cls_mask(text, cast_dtype)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask + cls_mask)
x = x.permute(1, 0, 2) # LND -> NLD

x = x[torch.arange(x.shape[0]), :] @ self.text_projection

cls_emb = x[torch.arange(x.shape[0]), -1]
token_emb = x[torch.arange(x.shape[0]), :-1]
cls_emb, token_emb = self.text(text, output_tokens=True)
gpucce marked this conversation as resolved.
Show resolved Hide resolved

cls_emb = self.ln_final(cls_emb)
text_latent = self.to_text_latent(cls_emb)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent

Expand Down
4 changes: 4 additions & 0 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class CLIPTextCfg:
hf_model_pretrained: bool = True
proj: str = 'mlp'
pooler_type: str = 'mean_pooler'
embed_cls: bool = False
pad_id: int = 0


def get_cast_dtype(precision: str):
Expand Down Expand Up @@ -146,6 +148,8 @@ def _build_text_tower(
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
pad_id=text_cfg.pad_id,
act_layer=act_layer,
norm_layer=norm_layer,
)
Expand Down
5 changes: 3 additions & 2 deletions src/open_clip/model_configs/coca_ViT-B-32.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
"layers": 12,
"embed_cls": true
},
"multimodal_cfg": {
"context_length": 76,
Expand Down
43 changes: 37 additions & 6 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,22 @@ def __init__(
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
embed_cls: bool = False,
pad_id: int = 0,
):
super().__init__()
self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim


if embed_cls:
self.embed_cls = embed_cls
self.cls_emb = nn.Parameter(torch.empty(width))
self.heads = heads
self.pad_id = pad_id
self.context_length += 1

self.token_embedding = nn.Embedding(vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
self.transformer = Transformer(
Expand All @@ -511,14 +520,15 @@ def __init__(
)
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))

self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)

self.init_parameters()

def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if hasattr(self, "embed_cls") and self.embed_cls:
nn.init.normal_(self.cls_emb, std=0.01)

proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
Expand All @@ -543,20 +553,41 @@ def build_attention_mask(self):
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask

def forward(self, text):

def build_cls_mask(self, text, cast_dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask

def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def forward(self, text, output_tokens: bool = False):
seq_len = text.shape[1]
cast_dtype = self.transformer.get_cast_dtype()

x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if hasattr(self, "embed_cls") and self.embed_cls:
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
cls_mask = self.build_cls_mask(text, cast_dtype)
attn_mask = attn_mask.unsqueeze(0) + cls_mask

x = x + self.positional_embedding.to(cast_dtype)
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.transformer(x, attn_mask=attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
if output_tokens:
return x[:, -1], x[:, :-1]

x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

return x
Expand Down