Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use TextEncoder in coca encode_image #321

Merged
merged 25 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 20 additions & 54 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,56 +76,45 @@ def __init__(
):
super().__init__()


norm_layer = (
LayerNormFp32
if cast_dtype in (torch.float16, torch.bfloat16)
else LayerNorm
)

text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.context_length = self.positional_embedding.shape[0] - 1

self.cls_token = nn.Parameter(torch.randn(embed_dim))
self.text = text
self.context_length = self.text.positional_embedding.shape[0]
gpucce marked this conversation as resolved.
Show resolved Hide resolved
self.heads = text_cfg["heads"]

self.visual = _build_vision_tower(
embed_dim, vision_cfg, quick_gelu, cast_dtype
)
self.heads = text_cfg["heads"]

self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower(
embed_dim, multimodal_cfg, quick_gelu, cast_dtype
)

self.img_attn_pool = AttentionalPooler(
multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1
multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1 # extra query for contrastive_loss
)

self.img_attn_pool_norm = norm_layer(embed_dim)

self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width
self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False)

self.to_logits = nn.Sequential(
norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False)
norm_layer(embed_dim), nn.Linear(embed_dim, self.text.vocab_size, bias=False)
gpucce marked this conversation as resolved.
Show resolved Hide resolved
)

# tie embedding weights and projection
self.to_logits[-1].weight = self.token_embedding.weight
self.to_logits[-1].weight = self.text.token_embedding.weight
gpucce marked this conversation as resolved.
Show resolved Hide resolved

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = 0

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
self.text.set_grad_checkpointing(enable)
self.multimodal_decoder.grad_checkpointing = enable

def encode_image(self, images, normalize=True, return_tokens=False):
Expand All @@ -148,56 +137,33 @@ def encode_image(self, images, normalize=True, return_tokens=False):
def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def _build_cls_mask(self, text, cast_dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask

def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat(
[
x + self.positional_embedding[:seq_len, :].to(cast_dtype),
self._repeat(self.cls_token + self.positional_embedding[-1, :], x.shape[0])
],
dim=1
)
seq_len += 1 # seq is 1 longer as we added CLS
attn_mask = self.attn_mask[None, :seq_len, :seq_len].expand(
text.shape[0] * self.heads, seq_len, seq_len
)
cls_mask = self._build_cls_mask(text, cast_dtype)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask + cls_mask)
x = x.permute(1, 0, 2) # LND -> NLD

x = x[torch.arange(x.shape[0]), :] @ self.text_projection
cls_emb, token_emb = self.text(text, output_tokens=True)
gpucce marked this conversation as resolved.
Show resolved Hide resolved

cls_emb = x[torch.arange(x.shape[0]), -1]
token_emb = x[torch.arange(x.shape[0]), :-1]
if hasattr(self.text, "text_projection") and self.text.text_projection is not None:
text_latent = cls_emb @ self.text.text_projection

cls_emb = self.ln_final(cls_emb)
text_latent = self.to_text_latent(cls_emb)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent

return (text_latent, token_emb) if return_tokens else text_latent

def forward(self, image, text):
def forward(self, image, text, output_dict=False):
labels = text[:, 1:]

text_latents, text_tokens = self.encode_text(text, return_tokens=True)
image_latents, image_tokens = self.encode_image(image, return_tokens=True)

text_tokens = self.multimodal_decoder(image_tokens, text_tokens)
logits = self.to_logits(text_tokens)
if output_dict:
return {
"image_features":image_latents,
"text_features":text_latents,
"logits":logits,
"labels":labels,
"logit_scale":self.logit_scale.exp()
}

return image_latents, text_latents, logits, labels, self.logit_scale.exp()

Expand Down
20 changes: 18 additions & 2 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class CLIPTextCfg:
hf_model_pretrained: bool = True
proj: str = 'mlp'
pooler_type: str = 'mean_pooler'
embed_cls: bool = False
pad_id: int = 0


def get_cast_dtype(precision: str):
Expand Down Expand Up @@ -146,6 +148,8 @@ def _build_text_tower(
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
pad_id=text_cfg.pad_id,
act_layer=act_layer,
norm_layer=norm_layer,
)
Expand Down Expand Up @@ -202,9 +206,15 @@ def encode_text(self, text, normalize: bool = False):
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x

def forward(self, image, text):
def forward(self, image, text, output_dict=False):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
if output_dict:
return {
"image_features":image_features,
"text_features":text_features,
"logit_scale":self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()


Expand Down Expand Up @@ -242,9 +252,15 @@ def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features

def forward(self, image, text):
def forward(self, image, text, output_dict=False):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
if output_dict:
return {
"image_features":image_features,
"text_features":text_features,
"logit_scale":self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()


Expand Down
5 changes: 3 additions & 2 deletions src/open_clip/model_configs/coca_ViT-B-32.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
"layers": 12,
"embed_cls": true
},
"multimodal_cfg": {
"context_length": 76,
Expand Down
96 changes: 46 additions & 50 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,12 @@ def __init__(
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)

def forward(self,
q_x,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None
):

L, N, C = q_x.shape
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x

w_q, w_k, w_v = self.in_proj_weight.split(3, dim=0)

q = F.linear(q_x, w_q, self.in_proj_bias)
k = F.linear(k_x, w_k, self.in_proj_bias)
v = F.linear(v_x, w_v, self.in_proj_bias)

q = q.view(L, N * self.num_heads, -1).transpose(0, 1)
k = k.view(L, N * self.num_heads, -1).transpose(0, 1)
v = v.view(L, N * self.num_heads, -1).transpose(0, 1)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)

if self.logit_scale is not None:
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
Expand Down Expand Up @@ -266,21 +252,17 @@ def __init__(
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
is_cross_attention: bool = False,
):
super().__init__()

self.ln_1 = norm_layer(d_model)
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)

self.attn = Attention(
d_model, n_head,
scaled_cosine=scale_cosine_attn,
scale_heads=scale_heads,
)
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity()
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()

self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
Expand All @@ -290,22 +272,10 @@ def __init__(
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity()

def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None
):
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()

k_x = self.ln_1_kv(k_x) if k_x is not None else None
v_x = self.ln_1_kv(v_x) if v_x is not None else None

x = q_x + self.ls_1(
self.ln_attn(self.attn(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x

Expand Down Expand Up @@ -492,13 +462,22 @@ def __init__(
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
embed_cls: bool = False,
pad_id: int = 0,
):
super().__init__()
self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim


if embed_cls:
self.embed_cls = embed_cls
self.cls_emb = nn.Parameter(torch.empty(width))
self.heads = heads
self.pad_id = pad_id
self.context_length += 1

self.token_embedding = nn.Embedding(vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
self.transformer = Transformer(
Expand All @@ -511,14 +490,15 @@ def __init__(
)
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))

self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)

self.init_parameters()

def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if hasattr(self, "embed_cls") and self.embed_cls:
nn.init.normal_(self.cls_emb, std=0.01)

proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
Expand All @@ -543,20 +523,40 @@ def build_attention_mask(self):
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask

def forward(self, text):

def build_cls_mask(self, text, cast_dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask

def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def forward(self, text, output_tokens: bool = False):
cast_dtype = self.transformer.get_cast_dtype()

x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if hasattr(self, "embed_cls") and self.embed_cls:
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
cls_mask = self.build_cls_mask(text, cast_dtype)
attn_mask = attn_mask.unsqueeze(0) + cls_mask

x = x + self.positional_embedding.to(cast_dtype)
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.transformer(x, attn_mask=attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
if output_tokens:
return x[:, -1], x[:, :-1]

x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

return x
Expand Down Expand Up @@ -614,10 +614,6 @@ def init_parameters(self):
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable

def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
Expand Down
Loading