From 1189487ef23e08d29ff87f44babb9859d42440ad Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 25 Nov 2022 16:20:55 +0100 Subject: [PATCH 001/113] initial setup --- src/open_clip/coca_layers.py | 228 +++++++++++++++++++++++++++++++++++ src/open_clip/coca_model.py | 227 ++++++++++++++++++++++++++++++++++ src/open_clip/transformer.py | 2 +- 3 files changed, 456 insertions(+), 1 deletion(-) create mode 100644 src/open_clip/coca_layers.py create mode 100644 src/open_clip/coca_model.py diff --git a/src/open_clip/coca_layers.py b/src/open_clip/coca_layers.py new file mode 100644 index 000000000..873737fb8 --- /dev/null +++ b/src/open_clip/coca_layers.py @@ -0,0 +1,228 @@ +import torch +import torch.nn.functional as F +from torch import nn, einsum + +from einops import rearrange, repeat + +from .transformer import LayerNorm + + +class SwiGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, *, device): + seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = einsum("i , j -> i j", seq, self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(pos, t): + return (t * pos.cos()) + (rotate_half(t) * pos.sin()) + + +class ParallelTransformerBlock(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): + super().__init__() + self.norm = LayerNorm(dim) + + attn_inner_dim = dim_head * heads + ff_inner_dim = dim * ff_mult + self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) + + self.heads = heads + self.scale = dim_head**-0.5 + self.rotary_emb = RotaryEmbedding(dim_head) + + self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) + self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) + + self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)) + + # for caching causal mask and rotary embeddings + + self.register_buffer("mask", None, persistent=False) + self.register_buffer("pos_emb", None, persistent=False) + + def get_mask(self, n, device): + if self.mask is not None and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def get_rotary_embedding(self, n, device): + if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: + return self.pos_emb[:n] + + pos_emb = self.rotary_emb(n, device=device) + self.register_buffer("pos_emb", pos_emb, persistent=False) + return pos_emb + + def forward(self, x, attn_mask=None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device, h = x.shape[1], x.device, self.heads + + # pre layernorm + x = self.norm(x) + + # attention queries, keys, values, and feedforward inner + q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) + + # split heads + # they use multi-query single-key-value attention, yet another Noam Shazeer paper + # they found no performance loss past a certain scale, and more efficient decoding obviously + # https://arxiv.org/abs/1911.02150 + + q = rearrange(q, "b n (h d) -> b h n d", h=h) + + # rotary embeddings + positions = self.get_rotary_embedding(n, device) + q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) + + # scale + q = q * self.scale + + # similarity + sim = einsum("b h i d, b j d -> b h i j", q, k) + + # causal mask + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # extra attention mask - for masking out attention from text CLS token to padding + + if attn_mask is not None: + attn_mask = rearrange(attn_mask, "b i j -> b 1 i j") + sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) + + # attention + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + out = einsum("b h i j, b j d -> b h i d", attn, v) + + # merge heads + out = rearrange(out, "b h n d -> b n (h d)") + return (self.attn_out(out) + self.ff_out(ff)) + x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + *, + context_dim=None, + dim_head=64, + heads=8, + parallel_ff=False, + ff_mult=4, + norm_context=False, + residual=False + ): + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + self.residual = residual + inner_dim = heads * dim_head + context_dim = context_dim if context_dim is not None else dim + + self.norm = LayerNorm(dim) + self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether to have parallel feedforward + + ff_inner_dim = ff_mult * dim + + self.ff = ( + nn.Sequential( + nn.Linear(dim, ff_inner_dim * 2, bias=False), + SwiGLU(), + nn.Linear(ff_inner_dim, dim, bias=False), + ) + if parallel_ff + else None + ) + + def forward(self, x, context): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + # pre-layernorm, for queries and context + + x = self.norm(x) + context = self.context_norm(context) + + # get queries + + q = self.to_q(x) + q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) + + # scale + + q = q * self.scale + + # get key / values + + k, v = self.to_kv(context).chunk(2, dim=-1) + + # query / key similarity + + sim = einsum("b h i d, b j d -> b h i j", q, k) + + # attention + + sim = sim - sim.amax(dim=-1, keepdim=True) + attn = sim.softmax(dim=-1) + + # aggregate + + out = einsum("b h i j, b j d -> b h i d", attn, v) + + # merge and combine heads + + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + + # add parallel feedforward (for multimodal layers) + + if self.ff is not None: + out = out + self.ff(x) + + if self.residual: + out = out + x + + return out diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py new file mode 100644 index 000000000..944e70d2c --- /dev/null +++ b/src/open_clip/coca_model.py @@ -0,0 +1,227 @@ +import torch +import torch.nn.functional as F +from torch import nn, einsum + +from einops import rearrange, repeat +from dataclasses import dataclass + +from .transformer import LayerNorm +from .coca_layers import ParallelTransformerBlock, CrossAttention +from .model import CLIPTextCfg, _build_vision_tower + + +@dataclass +class CoCaCfg: + model_name: str = "CoCa_base" + coca_dim: int = 768 + coca_image_dim: int = 768 + coca_ff_mult: int = 4 + coca_unimodal_depth: int = 12 + coca_multimodal_depth: int = 12 + coca_dim_head: int = 64 + coca_heads: int = 12 + coca_contrastive_loss_weight: float = 1.0 + coca_caption_loss_weight: float = 2.0 + + # vit_image_size: int = 288 + # vit_patch_size: int = 18 + # # vit_num_classes: int = 1000 + # vit_dim: int = 768 + # vit_depth: int = 12 + # vit_heads: int = 12 + # vit_mlp_dim: int = 3072 + + +class CoCa(nn.Module): + def __init__(self, coca_cfg: CoCaCfg, vit_cfg: CLIPTextCfg, tokenizer): + super().__init__() + + unimodal_depth = coca_cfg.coca_unimodal_depth + multimodal_depth = coca_cfg.coca_multimodal_depth + image_dim = coca_cfg.coca_image_dim + num_img_queries = 256 + dim_head = coca_cfg.coca_dim_head + heads = coca_cfg.coca_heads + ff_mult = coca_cfg.coca_ff_mult + + self.dim = coca_cfg.coca_dim + self.caption_loss_weight = coca_cfg.coca_caption_loss_weight + self.contrastive_loss_weight = coca_cfg.coca_contrastive_loss_weight + self.pad_id = coca_cfg.pad_id + + self.tokenizer = tokenizer + num_tokens = len(self.tokenizer) + self.img_encoder = _build_vision_tower(vit_cfg) + self.token_emb = nn.Embedding(num_tokens, self.dim) + self.text_cls_token = nn.Parameter(torch.randn(self.dim)) + + # num image queries for multimodal, but 1 extra CLS for contrastive learning + self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, self.dim)) + self.img_attn_pool = CrossAttention( + dim=self.dim, + context_dim=image_dim, + dim_head=dim_head, + heads=heads, + norm_context=True, + ) + + self.img_attn_pool_norm = LayerNorm(self.dim) + self.text_cls_norm = LayerNorm(self.dim) + + # contrastive learning temperature + + self.temperature = nn.Parameter(torch.Tensor([1.0])) + + # unimodal layers + + self.unimodal_layers = nn.ModuleList([]) + for ind in range(unimodal_depth): + self.unimodal_layers.append( + ParallelTransformerBlock( + dim=self.dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult + ), + ) + + # multimodal layers + + self.multimodal_layers = nn.ModuleList([]) + for ind in range(multimodal_depth): + self.multimodal_layers.append( + nn.ModuleList( + [ + ParallelTransformerBlock( + dim=self.dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + ), + CrossAttention( + dim=self.dim, + dim_head=dim_head, + heads=heads, + residual=True, + parallel_ff=True, + ff_mult=ff_mult, + ), + ] + ) + ) + + # to logits + + self.to_logits = nn.Sequential( + LayerNorm(self.dim), nn.Linear(self.dim, num_tokens, bias=False) + ) + + # they used embedding weight tied projection out to logits, not common, but works + self.to_logits[-1].weight = self.token_emb.weight + nn.init.normal_(self.token_emb.weight, std=0.02) + + def embed_text(self, text): + batch, device = text.shape[0], text.device + + seq = text.shape[1] + + text_tokens = self.token_emb(text) + + # append text cls tokens + + text_cls_tokens = repeat(self.text_cls_token, "d -> b 1 d", b=batch) + text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2) + + # create specific mask for text cls token at the end + # to prevent it from attending to padding + + cls_mask = rearrange(text != self.pad_id, "b j -> b 1 j") + attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) + + # go through unimodal layers + + for attn_ff in self.unimodal_layers: + text_tokens = attn_ff(text_tokens, attn_mask=attn_mask) + + # get text cls token + + text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1] + text_embeds = self.text_cls_norm(text_cls_tokens) + return text_embeds, text_tokens + + def embed_image(self, images=None, image_tokens=None): + # encode images into embeddings + # with the img_encoder passed in at init + # it can also accept precomputed image tokens + + assert images is None or image_tokens is None + + if images is not None: + assert ( + self.img_encoder is not None + ), "img_encoder must be passed in for automatic image encoding" + image_tokens = self.img_encoder(images) + + # attention pool image tokens + + img_queries = repeat(self.img_queries, "n d -> b n d", b=image_tokens.shape[0]) + img_queries = self.img_attn_pool(img_queries, image_tokens) + img_queries = self.img_attn_pool_norm(img_queries) + + return img_queries[:, 0], img_queries[:, 1:] + + def forward( + self, + text, + images=None, + image_tokens=None, + labels=None, + return_loss=False, + return_embeddings=False, + ): + batch, device = text.shape[0], text.device + + if return_loss and labels is None: + text, labels = text[:, :-1], text[:, 1:] + + text_embeds, text_tokens = self.embed_text(text) + + image_embeds, image_tokens = self.embed_image( + images=images, image_tokens=image_tokens + ) + + # return embeddings if that is what the researcher wants + + if return_embeddings: + return text_embeds, image_embeds + + # go through multimodal layers + + for attn_ff, cross_attn in self.multimodal_layers: + text_tokens = attn_ff(text_tokens) + text_tokens = cross_attn(text_tokens, image_tokens) + + logits = self.to_logits(text_tokens) + + if not return_loss: + return logits + + # shorthand + + ce = F.cross_entropy + + # calculate caption loss (cross entropy loss) + + logits = rearrange(logits, "b n c -> b c n") + caption_loss = ce(logits, labels, ignore_index=self.pad_id) + caption_loss = caption_loss * self.caption_loss_weight + + # calculate contrastive loss + + sim = einsum("i d, j d -> i j", text_embeds, image_embeds) + sim = sim * self.temperature.exp() + contrastive_labels = torch.arange(batch, device=device) + + contrastive_loss = ( + ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels) + ) * 0.5 + contrastive_loss = contrastive_loss * self.contrastive_loss_weight + + return caption_loss + contrastive_loss diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index a36fa5f5d..066301d47 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -274,7 +274,7 @@ def __init__( def lock(self, unlocked_groups=0, freeze_bn_stats=False): for param in self.parameters(): param.requires_grad = False - + if unlocked_groups != 0: groups = [ [ From 91d01fa002fba1d85272d9f8c3c68527cf1983f9 Mon Sep 17 00:00:00 2001 From: gpucce Date: Sun, 27 Nov 2022 08:20:14 +0100 Subject: [PATCH 002/113] add coca loss --- src/open_clip/loss.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index de31426df..d8b917478 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -119,3 +119,23 @@ def forward(self, image_features, text_features, logit_scale): F.cross_entropy(logits_per_text, labels) ) / 2 return total_loss + +class CoCaLoss(nn.Module): + + def __init__(self, captionloss_weight, cliploss_weight, pad_id): + super().__init__() + self.clip_loss = ClipLoss() + self.clip_loss_weight = cliploss_weight + self.caption_loss = nn.CrossEntropyLoss() + self.caption_loss_weight = captionloss_weight + self.pad_id = pad_id + + def forward(self, image_features, text_features, logits, logit_scale, labels): + clip_loss = self.clip_loss(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + logits = logits.permute(0, 2, 1) + caption_loss = self.caption_loss(logits, labels, ignore_index=self.pad_id) + caption_loss = caption_loss * self.caption_loss_weight + + return clip_loss + caption_loss From efb6540b4fa6b23b74d12469570723110e3e49bc Mon Sep 17 00:00:00 2001 From: gpucce Date: Sun, 27 Nov 2022 08:22:42 +0100 Subject: [PATCH 003/113] remove loss from the model --- src/open_clip/coca_model.py | 33 +-------------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 944e70d2c..efd6f8c10 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -187,41 +187,10 @@ def forward( images=images, image_tokens=image_tokens ) - # return embeddings if that is what the researcher wants - - if return_embeddings: - return text_embeds, image_embeds - - # go through multimodal layers - for attn_ff, cross_attn in self.multimodal_layers: text_tokens = attn_ff(text_tokens) text_tokens = cross_attn(text_tokens, image_tokens) logits = self.to_logits(text_tokens) - if not return_loss: - return logits - - # shorthand - - ce = F.cross_entropy - - # calculate caption loss (cross entropy loss) - - logits = rearrange(logits, "b n c -> b c n") - caption_loss = ce(logits, labels, ignore_index=self.pad_id) - caption_loss = caption_loss * self.caption_loss_weight - - # calculate contrastive loss - - sim = einsum("i d, j d -> i j", text_embeds, image_embeds) - sim = sim * self.temperature.exp() - contrastive_labels = torch.arange(batch, device=device) - - contrastive_loss = ( - ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels) - ) * 0.5 - contrastive_loss = contrastive_loss * self.contrastive_loss_weight - - return caption_loss + contrastive_loss + return text_embeds, image_embeds, logits, labels From 669a3a0ff4ee2a4e83b066bc4ef7410dd96cd2df Mon Sep 17 00:00:00 2001 From: gpucce Date: Sun, 27 Nov 2022 08:53:45 +0100 Subject: [PATCH 004/113] fix loss --- src/open_clip/loss.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index d8b917478..c2dcbee4c 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -120,11 +120,29 @@ def forward(self, image_features, text_features, logit_scale): ) / 2 return total_loss -class CoCaLoss(nn.Module): - def __init__(self, captionloss_weight, cliploss_weight, pad_id): +class CoCaLoss(nn.Module): + def __init__( + self, + captionloss_weight, + cliploss_weight, + pad_id, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): super().__init__() - self.clip_loss = ClipLoss() + self.clip_loss = ClipLoss( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) self.clip_loss_weight = cliploss_weight self.caption_loss = nn.CrossEntropyLoss() self.caption_loss_weight = captionloss_weight From f081dc4769b844bd2e1ed369f0c5ccf2b3337047 Mon Sep 17 00:00:00 2001 From: gpucce Date: Sun, 27 Nov 2022 08:56:01 +0100 Subject: [PATCH 005/113] add underscores --- src/open_clip/loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index c2dcbee4c..2e71b30a4 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -124,8 +124,8 @@ def forward(self, image_features, text_features, logit_scale): class CoCaLoss(nn.Module): def __init__( self, - captionloss_weight, - cliploss_weight, + caption_loss_weight, + clip_loss_weight, pad_id, local_loss=False, gather_with_grad=False, @@ -143,9 +143,9 @@ def __init__( world_size=world_size, use_horovod=use_horovod ) - self.clip_loss_weight = cliploss_weight + self.clip_loss_weight = clip_loss_weight self.caption_loss = nn.CrossEntropyLoss() - self.caption_loss_weight = captionloss_weight + self.caption_loss_weight = caption_loss_weight self.pad_id = pad_id def forward(self, image_features, text_features, logits, logit_scale, labels): From 0b1c895417d09f4d376e543d4688403d1bb4990d Mon Sep 17 00:00:00 2001 From: gpucce Date: Sun, 27 Nov 2022 08:56:40 +0100 Subject: [PATCH 006/113] name changes --- src/open_clip/coca_model.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index efd6f8c10..d2162ab6c 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -13,15 +13,15 @@ @dataclass class CoCaCfg: model_name: str = "CoCa_base" - coca_dim: int = 768 - coca_image_dim: int = 768 - coca_ff_mult: int = 4 - coca_unimodal_depth: int = 12 - coca_multimodal_depth: int = 12 - coca_dim_head: int = 64 - coca_heads: int = 12 - coca_contrastive_loss_weight: float = 1.0 - coca_caption_loss_weight: float = 2.0 + dim: int = 768 + image_dim: int = 768 + ff_mult: int = 4 + unimodal_depth: int = 12 + multimodal_depth: int = 12 + dim_head: int = 64 + heads: int = 12 + contrastive_loss_weight: float = 1.0 + caption_loss_weight: float = 2.0 # vit_image_size: int = 288 # vit_patch_size: int = 18 @@ -36,17 +36,17 @@ class CoCa(nn.Module): def __init__(self, coca_cfg: CoCaCfg, vit_cfg: CLIPTextCfg, tokenizer): super().__init__() - unimodal_depth = coca_cfg.coca_unimodal_depth - multimodal_depth = coca_cfg.coca_multimodal_depth - image_dim = coca_cfg.coca_image_dim + unimodal_depth = coca_cfg.unimodal_depth + multimodal_depth = coca_cfg.multimodal_depth + image_dim = coca_cfg.image_dim num_img_queries = 256 - dim_head = coca_cfg.coca_dim_head - heads = coca_cfg.coca_heads - ff_mult = coca_cfg.coca_ff_mult + dim_head = coca_cfg.dim_head + heads = coca_cfg.heads + ff_mult = coca_cfg.ff_mult - self.dim = coca_cfg.coca_dim - self.caption_loss_weight = coca_cfg.coca_caption_loss_weight - self.contrastive_loss_weight = coca_cfg.coca_contrastive_loss_weight + self.dim = coca_cfg.dim + self.caption_loss_weight = coca_cfg.caption_loss_weight + self.contrastive_loss_weight = coca_cfg.contrastive_loss_weight self.pad_id = coca_cfg.pad_id self.tokenizer = tokenizer From d518dd0b4f21e09d2ad155b8d5a08bc8c5fc0f11 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 29 Nov 2022 10:18:23 +0100 Subject: [PATCH 007/113] add cross attention to Residual and CustomResidual --- src/open_clip/transformer.py | 81 +++++++++++++++++++++++++++++------- 1 file changed, 66 insertions(+), 15 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 066301d47..61b969464 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -67,11 +67,14 @@ def __init__( self.logit_scale_max = logit_scale_max # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original - self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) - if qkv_bias: - self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) - else: - self.in_proj_bias = None + self.in_proj_weight = nn.ModuleDict() + self.in_proj_bias = nn.ModuleDict() + for k in ['q', 'k', 'v']: + self.in_proj_weight[k + "_weight"] = nn.Parameter(torch.randn((dim, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias[k + "_bias"] = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None if self.scaled_cosine: self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) @@ -85,9 +88,17 @@ def __init__( self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): - L, N, C = x.shape - q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + def forward(self, + q_x, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + L, N, C = q_x.shape + q = F.linear(q_x, self.in_proj_weight['q_weight'], self.in_proj_bias['q_bias']) + k = F.linear(k_x if k_x else q_x, self.in_proj_weight['k_weight'], self.in_proj_bias['k_bias']) + v = F.linear(v_x if v_x else q_x, self.in_proj_weight['v_weight'], self.in_proj_bias['v_bias']) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) @@ -130,12 +141,15 @@ def __init__( ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) @@ -146,12 +160,33 @@ def __init__( ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() - def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, + k_x if k_x else q_x, + v_x if v_x else q_x, + need_weights=False, + attn_mask=attn_mask + )[0] + + def forward(self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + k_x = self.ln_1_kv(k_x) if k_x else self.ln_1(q_x) + v_x = self.ln_1_kv(v_x) if v_x else self.ln_1(q_x) + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x @@ -169,10 +204,14 @@ def __init__( scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, + is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + self.attn = Attention( d_model, n_head, scaled_cosine=scale_cosine_attn, @@ -191,8 +230,20 @@ def __init__( ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + + k_x = self.ln_1_kv(k_x) if k_x else self.ln_1(q_x) + v_x = self.ln_1_kv(v_x) if v_x else self.ln_1(q_x) + + x = q_x + self.ls_1( + self.ln_attn(self.attn(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + ) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x From 11bf57c9f7ba4342b0b6a8781591214561f709cc Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 29 Nov 2022 10:24:07 +0100 Subject: [PATCH 008/113] fix if --- src/open_clip/transformer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 61b969464..b6ca3061a 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -96,8 +96,8 @@ def forward(self, ): L, N, C = q_x.shape q = F.linear(q_x, self.in_proj_weight['q_weight'], self.in_proj_bias['q_bias']) - k = F.linear(k_x if k_x else q_x, self.in_proj_weight['k_weight'], self.in_proj_bias['k_bias']) - v = F.linear(v_x if v_x else q_x, self.in_proj_weight['v_weight'], self.in_proj_bias['v_bias']) + k = F.linear(k_x if k_x is not None else q_x, self.in_proj_weight['k_weight'], self.in_proj_bias['k_bias']) + v = F.linear(v_x if v_x is not None else q_x, self.in_proj_weight['v_weight'], self.in_proj_bias['v_bias']) q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) @@ -170,8 +170,8 @@ def attention( attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None return self.attn( q_x, - k_x if k_x else q_x, - v_x if v_x else q_x, + k_x if k_x is not None else q_x, + v_x if v_x is not None else q_x, need_weights=False, attn_mask=attn_mask )[0] @@ -183,8 +183,8 @@ def forward(self, attn_mask: Optional[torch.Tensor] = None ): - k_x = self.ln_1_kv(k_x) if k_x else self.ln_1(q_x) - v_x = self.ln_1_kv(v_x) if v_x else self.ln_1(q_x) + k_x = self.ln_1_kv(k_x) if k_x is not None else self.ln_1(q_x) + v_x = self.ln_1_kv(v_x) if v_x is not None else self.ln_1(q_x) x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) @@ -238,8 +238,8 @@ def forward( attn_mask: Optional[torch.Tensor] = None ): - k_x = self.ln_1_kv(k_x) if k_x else self.ln_1(q_x) - v_x = self.ln_1_kv(v_x) if v_x else self.ln_1(q_x) + k_x = self.ln_1_kv(k_x) if k_x is not None else self.ln_1(q_x) + v_x = self.ln_1_kv(v_x) if v_x is not None else self.ln_1(q_x) x = q_x + self.ls_1( self.ln_attn(self.attn(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) From f3dedf6217073214e101598f03d57443ad6efddd Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 29 Nov 2022 15:29:55 +0100 Subject: [PATCH 009/113] =?UTF-8?q?=C3=A4dd=20transformer=20'decoder'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/open_clip/transformer.py | 58 ++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index b6ca3061a..8e1f99559 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -281,6 +281,50 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = r(x, attn_mask=attn_mask) return x +class TransformerDecoder(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList( + zip( + [ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ], + [Attention(width, heads) for _ in range(layers)] + ) + ) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor], + v_x: Optional[torch.Tensor], + attn_mask: Optional[torch.Tensor] = None + ): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + q_x = checkpoint(r, q_x, k_x, v_x, attn_mask) + else: + q_x = r(q_x=q_x, k_x=k_x, v_x=v_x, attn_mask=attn_mask) + return q_x class VisionTransformer(nn.Module): def __init__( @@ -482,3 +526,17 @@ def forward(self, text): x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x + +class CoCaMultiModalTransformer(nn.Module): + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): \ No newline at end of file From 50c472636451f5f25141f8683ea9032c7e17d598 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 29 Nov 2022 15:35:03 +0100 Subject: [PATCH 010/113] minor fix --- src/open_clip/transformer.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 8e1f99559..921a1c989 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -281,7 +281,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = r(x, attn_mask=attn_mask) return x -class TransformerDecoder(nn.Module): +class CoCaMultiModalTransformer(nn.Module): def __init__( self, width: int, @@ -298,16 +298,16 @@ def __init__( self.layers = layers self.grad_checkpointing = False - self.resblocks = nn.ModuleList( - zip( - [ - ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) - for _ in range(layers) - ], - [Attention(width, heads) for _ in range(layers)] + all_layers = [] + for _ in range(layers): + all_layers.append( + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer + ) ) - ) + all_layers.append(Attention(width, heads)) + + self.resblocks = nn.ModuleList(all_layers) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype @@ -539,4 +539,5 @@ def __init__( output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - ): \ No newline at end of file + ): + super().__init__() \ No newline at end of file From 1e41d837d77c4b3a08ff774642ce8cd54d4c71b3 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 29 Nov 2022 15:38:10 +0100 Subject: [PATCH 011/113] looks better --- src/open_clip/transformer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 921a1c989..8cd007a75 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -94,10 +94,14 @@ def forward(self, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None ): + L, N, C = q_x.shape + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + q = F.linear(q_x, self.in_proj_weight['q_weight'], self.in_proj_bias['q_bias']) - k = F.linear(k_x if k_x is not None else q_x, self.in_proj_weight['k_weight'], self.in_proj_bias['k_bias']) - v = F.linear(v_x if v_x is not None else q_x, self.in_proj_weight['v_weight'], self.in_proj_bias['v_bias']) + k = F.linear(k_x, self.in_proj_weight['k_weight'], self.in_proj_bias['k_bias']) + v = F.linear(v_x, self.in_proj_weight['v_weight'], self.in_proj_bias['v_bias']) q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) @@ -167,13 +171,13 @@ def attention( v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None ): + + k_x = k_x if k_x is not None else q_x, + v_x = v_x if v_x is not None else q_x, + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None return self.attn( - q_x, - k_x if k_x is not None else q_x, - v_x if v_x is not None else q_x, - need_weights=False, - attn_mask=attn_mask + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask )[0] def forward(self, From 0d91609d9b0e8714346be900cf578afb655f55b5 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 29 Nov 2022 16:05:09 +0100 Subject: [PATCH 012/113] initlize coca model structure --- src/open_clip/coca_model.py | 123 +++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 57 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index d2162ab6c..1a5514dfb 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -1,3 +1,5 @@ +from typing import Callable, Optional + import torch import torch.nn.functional as F from torch import nn, einsum @@ -5,10 +7,9 @@ from einops import rearrange, repeat from dataclasses import dataclass -from .transformer import LayerNorm +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, CoCaMultiModalTransformer from .coca_layers import ParallelTransformerBlock, CrossAttention -from .model import CLIPTextCfg, _build_vision_tower - +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @dataclass class CoCaCfg: @@ -25,43 +26,86 @@ class CoCaCfg: # vit_image_size: int = 288 # vit_patch_size: int = 18 - # # vit_num_classes: int = 1000 # vit_dim: int = 768 # vit_depth: int = 12 # vit_heads: int = 12 # vit_mlp_dim: int = 3072 +def _build_coca_multimodal_tower( + embed_dim: int, + coca_cfg: CoCaCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(coca_cfg, dict): + coca_cfg = CoCaCfg(**coca_cfg) + + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = CoCaMultiModalTransformer( + context_length=coca_cfg.context_length, + width=coca_cfg.width, + heads=coca_cfg.heads, + layers=coca_cfg.layers, + ls_init_value=coca_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return text + + class CoCa(nn.Module): - def __init__(self, coca_cfg: CoCaCfg, vit_cfg: CLIPTextCfg, tokenizer): + def __init__( + self, + embed_dim, + coca_cfg: CoCaCfg, + text_cfg: CLIPTextCfg, + vit_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): super().__init__() - unimodal_depth = coca_cfg.unimodal_depth - multimodal_depth = coca_cfg.multimodal_depth - image_dim = coca_cfg.image_dim - num_img_queries = 256 - dim_head = coca_cfg.dim_head - heads = coca_cfg.heads - ff_mult = coca_cfg.ff_mult + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.img_encoder = _build_vision_tower(vit_cfg) + + self.multimodal = _build_coca_multimodal_tower(embed_dim, coca_cfg, quick_gelu, cast_dtype) - self.dim = coca_cfg.dim - self.caption_loss_weight = coca_cfg.caption_loss_weight - self.contrastive_loss_weight = coca_cfg.contrastive_loss_weight - self.pad_id = coca_cfg.pad_id + # multimodal_depth = coca_cfg.multimodal_depth + # image_dim = coca_cfg.image_dim + # num_img_queries = 256 + # dim_head = coca_cfg.dim_head + # heads = coca_cfg.heads + # ff_mult = coca_cfg.ff_mult + # self.dim = coca_cfg.dim + # self.caption_loss_weight = coca_cfg.caption_loss_weight + # self.contrastive_loss_weight = coca_cfg.contrastive_loss_weight + # self.pad_id = coca_cfg.pad_id - self.tokenizer = tokenizer num_tokens = len(self.tokenizer) - self.img_encoder = _build_vision_tower(vit_cfg) + self.token_emb = nn.Embedding(num_tokens, self.dim) self.text_cls_token = nn.Parameter(torch.randn(self.dim)) # num image queries for multimodal, but 1 extra CLS for contrastive learning self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, self.dim)) self.img_attn_pool = CrossAttention( - dim=self.dim, - context_dim=image_dim, - dim_head=dim_head, - heads=heads, + dim=coca_cfg.dim, + context_dim=coca_cfg.image_dim, + dim_head=coca_cfg.dim_head, + heads=coca_cfg.heads, norm_context=True, ) @@ -72,41 +116,6 @@ def __init__(self, coca_cfg: CoCaCfg, vit_cfg: CLIPTextCfg, tokenizer): self.temperature = nn.Parameter(torch.Tensor([1.0])) - # unimodal layers - - self.unimodal_layers = nn.ModuleList([]) - for ind in range(unimodal_depth): - self.unimodal_layers.append( - ParallelTransformerBlock( - dim=self.dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult - ), - ) - - # multimodal layers - - self.multimodal_layers = nn.ModuleList([]) - for ind in range(multimodal_depth): - self.multimodal_layers.append( - nn.ModuleList( - [ - ParallelTransformerBlock( - dim=self.dim, - dim_head=dim_head, - heads=heads, - ff_mult=ff_mult, - ), - CrossAttention( - dim=self.dim, - dim_head=dim_head, - heads=heads, - residual=True, - parallel_ff=True, - ff_mult=ff_mult, - ), - ] - ) - ) - # to logits self.to_logits = nn.Sequential( From 50e0cbe6d9fdfbe6b66b2b08d485f66bd7dd7303 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 29 Nov 2022 16:05:50 +0100 Subject: [PATCH 013/113] clean --- src/open_clip/transformer.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 8cd007a75..ccecb0219 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -285,6 +285,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = r(x, attn_mask=attn_mask) return x + class CoCaMultiModalTransformer(nn.Module): def __init__( self, @@ -530,18 +531,3 @@ def forward(self, text): x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x - -class CoCaMultiModalTransformer(nn.Module): - def __init__( - self, - context_length: int = 77, - vocab_size: int = 49408, - width: int = 512, - heads: int = 8, - layers: int = 12, - ls_init_value: float = None, - output_dim: int = 512, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - ): - super().__init__() \ No newline at end of file From 93b4236086d0c0519f5f6e2f29a32e676878eda1 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 29 Nov 2022 16:53:35 +0100 Subject: [PATCH 014/113] typo and format --- src/open_clip/coca_model.py | 13 +++++++++---- src/open_clip/transformer.py | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 1a5514dfb..52844c5e9 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -11,6 +11,7 @@ from .coca_layers import ParallelTransformerBlock, CrossAttention from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + @dataclass class CoCaCfg: model_name: str = "CoCa_base" @@ -42,7 +43,9 @@ def _build_coca_multimodal_tower( coca_cfg = CoCaCfg(**coca_cfg) act_layer = QuickGELU if quick_gelu else nn.GELU - norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) text = CoCaMultiModalTransformer( context_length=coca_cfg.context_length, @@ -77,15 +80,17 @@ def __init__( self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection - self.register_buffer('attn_mask', text.attn_mask, persistent=False) + self.register_buffer("attn_mask", text.attn_mask, persistent=False) self.img_encoder = _build_vision_tower(vit_cfg) - self.multimodal = _build_coca_multimodal_tower(embed_dim, coca_cfg, quick_gelu, cast_dtype) + self.multimodal = _build_coca_multimodal_tower( + embed_dim, coca_cfg, quick_gelu, cast_dtype + ) + num_img_queries = 256 # multimodal_depth = coca_cfg.multimodal_depth # image_dim = coca_cfg.image_dim - # num_img_queries = 256 # dim_head = coca_cfg.dim_head # heads = coca_cfg.heads # ff_mult = coca_cfg.ff_mult diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index ccecb0219..3eb83cdc2 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -172,8 +172,8 @@ def attention( attn_mask: Optional[torch.Tensor] = None ): - k_x = k_x if k_x is not None else q_x, - v_x = v_x if v_x is not None else q_x, + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None return self.attn( From 97e3c0f46e76dc042574a239afdba9c31591e1f1 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 29 Nov 2022 16:57:10 +0100 Subject: [PATCH 015/113] checkpoint signature --- src/open_clip/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 3eb83cdc2..e3fba3055 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -280,7 +280,7 @@ def get_cast_dtype(self) -> torch.dtype: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) + x = checkpoint(r, x, attn_mask=attn_mask) else: x = r(x, attn_mask=attn_mask) return x From 6ae6f8c2b90a7e5913958e8b7ef4f3201eabc179 Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 30 Nov 2022 16:37:37 +0100 Subject: [PATCH 016/113] adjust multimodal decoder and add CoCaTransformer --- src/open_clip/transformer.py | 105 ++++++++++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 13 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index e3fba3055..87cc90f25 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -286,7 +286,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): return x -class CoCaMultiModalTransformer(nn.Module): +class MultimodalTransformerDecoder(nn.Module): def __init__( self, width: int, @@ -303,16 +303,17 @@ def __init__( self.layers = layers self.grad_checkpointing = False - all_layers = [] - for _ in range(layers): - all_layers.append( - ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer - ) - ) - all_layers.append(Attention(width, heads)) + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) - self.resblocks = nn.ModuleList(all_layers) + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype @@ -324,11 +325,13 @@ def forward( v_x: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None ): - for r in self.resblocks: + for r, ca in zip(self.resblocks, self.cross_attn): if self.grad_checkpointing and not torch.jit.is_scripting(): - q_x = checkpoint(r, q_x, k_x, v_x, attn_mask) + q_x = checkpoint(r, q_x, attn_mask=attn_mask) + q_x = checkpoint(ca, q_x, k_x=k_x, v_x=v_x) else: - q_x = r(q_x=q_x, k_x=k_x, v_x=v_x, attn_mask=attn_mask) + q_x = r(q_x, attn_mask=attn_mask) + q_x = ca(q_x, k_x=k_x, v_x=v_x) return q_x class VisionTransformer(nn.Module): @@ -531,3 +534,79 @@ def forward(self, text): x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x + +class CoCaMultimodalTransformer(nn.Module): + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + + self.transformer = MultimodalTransformerDecoder( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + # this will be shared with the textual decoder (in CoCa) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text_embs, image_embs, text_eot_mask): + + text_embs = text_embs.permute(1, 0, 2) # NLD -> LND + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + x = self.transformer(text_embs, image_embs, image_embs, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text_eot_mask] @ self.text_projection + + return x From 0975dfe9f80423884f9deea1613041fcc87eefdb Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 1 Dec 2022 14:32:17 +0100 Subject: [PATCH 017/113] keep older logic --- src/open_clip/transformer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 87cc90f25..662fae433 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -67,14 +67,11 @@ def __init__( self.logit_scale_max = logit_scale_max # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original - self.in_proj_weight = nn.ModuleDict() - self.in_proj_bias = nn.ModuleDict() - for k in ['q', 'k', 'v']: - self.in_proj_weight[k + "_weight"] = nn.Parameter(torch.randn((dim, dim)) * self.scale) - if qkv_bias: - self.in_proj_bias[k + "_bias"] = nn.Parameter(torch.zeros(dim * 3)) - else: - self.in_proj_bias = None + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None if self.scaled_cosine: self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) @@ -99,6 +96,12 @@ def forward(self, k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x + w_q, w_k, w_v = self.in_proj_weight.chunk(3) + + q = F.linear(q_x, w_q, self.in_proj_bias) + k = F.linear(k_x, w_k, self.in_proj_bias) + v = F.linear(v_x, w_q, self.in_proj_bias).chunk(3, dim=-1) + q = F.linear(q_x, self.in_proj_weight['q_weight'], self.in_proj_bias['q_bias']) k = F.linear(k_x, self.in_proj_weight['k_weight'], self.in_proj_bias['k_bias']) v = F.linear(v_x, self.in_proj_weight['v_weight'], self.in_proj_bias['v_bias']) From f2265ec584aea0b47bbfab4d38cf12ccac162a12 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 1 Dec 2022 14:32:48 +0100 Subject: [PATCH 018/113] remove chunk --- src/open_clip/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 662fae433..4b26f9a0c 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -100,7 +100,7 @@ def forward(self, q = F.linear(q_x, w_q, self.in_proj_bias) k = F.linear(k_x, w_k, self.in_proj_bias) - v = F.linear(v_x, w_q, self.in_proj_bias).chunk(3, dim=-1) + v = F.linear(v_x, w_q, self.in_proj_bias) q = F.linear(q_x, self.in_proj_weight['q_weight'], self.in_proj_bias['q_bias']) k = F.linear(k_x, self.in_proj_weight['k_weight'], self.in_proj_bias['k_bias']) From 9d47f0e98f06b3145643114463255807d29a2f9e Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 1 Dec 2022 14:33:13 +0100 Subject: [PATCH 019/113] typo --- src/open_clip/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 4b26f9a0c..cbcacabf4 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -100,7 +100,7 @@ def forward(self, q = F.linear(q_x, w_q, self.in_proj_bias) k = F.linear(k_x, w_k, self.in_proj_bias) - v = F.linear(v_x, w_q, self.in_proj_bias) + v = F.linear(v_x, w_v, self.in_proj_bias) q = F.linear(q_x, self.in_proj_weight['q_weight'], self.in_proj_bias['q_bias']) k = F.linear(k_x, self.in_proj_weight['k_weight'], self.in_proj_bias['k_bias']) From 6a101ecac5848194695c2a50900dd3757c6be01b Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 1 Dec 2022 14:33:40 +0100 Subject: [PATCH 020/113] fix --- src/open_clip/transformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index cbcacabf4..42d32bd95 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -102,10 +102,6 @@ def forward(self, k = F.linear(k_x, w_k, self.in_proj_bias) v = F.linear(v_x, w_v, self.in_proj_bias) - q = F.linear(q_x, self.in_proj_weight['q_weight'], self.in_proj_bias['q_bias']) - k = F.linear(k_x, self.in_proj_weight['k_weight'], self.in_proj_bias['k_bias']) - v = F.linear(v_x, self.in_proj_weight['v_weight'], self.in_proj_bias['v_bias']) - q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) From e25985102abb782ae39d447071561c4786b709af Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 1 Dec 2022 14:34:16 +0100 Subject: [PATCH 021/113] make chunk dim explicit --- src/open_clip/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 42d32bd95..5feb77514 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -96,7 +96,7 @@ def forward(self, k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x - w_q, w_k, w_v = self.in_proj_weight.chunk(3) + w_q, w_k, w_v = self.in_proj_weight.chunk(3, dim=0) q = F.linear(q_x, w_q, self.in_proj_bias) k = F.linear(k_x, w_k, self.in_proj_bias) From 7fff61d275270f3d6f247fa9f2696c16061dd7ea Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 1 Dec 2022 16:35:48 +0100 Subject: [PATCH 022/113] adjust cfg names --- src/open_clip/coca_model.py | 77 +++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 52844c5e9..1048a4218 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -7,7 +7,7 @@ from einops import rearrange, repeat from dataclasses import dataclass -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, CoCaMultiModalTransformer +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, CoCaMultimodalTransformer, ResidualAttentionBlock from .coca_layers import ParallelTransformerBlock, CrossAttention from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @@ -15,13 +15,15 @@ @dataclass class CoCaCfg: model_name: str = "CoCa_base" - dim: int = 768 + context_length = 77 + width: int = 768 image_dim: int = 768 - ff_mult: int = 4 - unimodal_depth: int = 12 - multimodal_depth: int = 12 + mlp_ratio: int = 4 + ls_init_value: Optional[float] = None + layers: int = 12 dim_head: int = 64 heads: int = 12 + num_image_queries: int = 256 contrastive_loss_weight: float = 1.0 caption_loss_weight: float = 2.0 @@ -47,7 +49,7 @@ def _build_coca_multimodal_tower( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) - text = CoCaMultiModalTransformer( + text = CoCaMultimodalTransformer( context_length=coca_cfg.context_length, width=coca_cfg.width, heads=coca_cfg.heads, @@ -67,12 +69,15 @@ def __init__( embed_dim, coca_cfg: CoCaCfg, text_cfg: CLIPTextCfg, - vit_cfg: CLIPVisionCfg, + vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): super().__init__() + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + act_layer = QuickGELU if quick_gelu else nn.GELU + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer self.vocab_size = text.vocab_size @@ -82,40 +87,32 @@ def __init__( self.text_projection = text.text_projection self.register_buffer("attn_mask", text.attn_mask, persistent=False) - self.img_encoder = _build_vision_tower(vit_cfg) + self.img_encoder = _build_vision_tower( + embed_dim, vision_cfg, quick_gelu, cast_dtype + ) - self.multimodal = _build_coca_multimodal_tower( + self.multimodal_decoder = _build_coca_multimodal_tower( embed_dim, coca_cfg, quick_gelu, cast_dtype ) - num_img_queries = 256 - - # multimodal_depth = coca_cfg.multimodal_depth - # image_dim = coca_cfg.image_dim - # dim_head = coca_cfg.dim_head - # heads = coca_cfg.heads - # ff_mult = coca_cfg.ff_mult - # self.dim = coca_cfg.dim - # self.caption_loss_weight = coca_cfg.caption_loss_weight - # self.contrastive_loss_weight = coca_cfg.contrastive_loss_weight - # self.pad_id = coca_cfg.pad_id - - num_tokens = len(self.tokenizer) - - self.token_emb = nn.Embedding(num_tokens, self.dim) - self.text_cls_token = nn.Parameter(torch.randn(self.dim)) + num_img_queries = coca_cfg.num_image_queries + self.width = coca_cfg.width + num_tokens = text_cfg.vocab_size + self.text_cls_token = nn.Parameter(torch.randn(self.width)) # num image queries for multimodal, but 1 extra CLS for contrastive learning - self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, self.dim)) - self.img_attn_pool = CrossAttention( - dim=coca_cfg.dim, - context_dim=coca_cfg.image_dim, - dim_head=coca_cfg.dim_head, - heads=coca_cfg.heads, - norm_context=True, + self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, self.width)) + self.img_attn_pool = ResidualAttentionBlock( + d_model=coca_cfg.width, + n_head=coca_cfg.heads, + mlp_ratio=coca_cfg.mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + is_pooler=True, ) - self.img_attn_pool_norm = LayerNorm(self.dim) - self.text_cls_norm = LayerNorm(self.dim) + self.img_attn_pool_norm = norm_layer(self.width) + self.text_cls_norm = norm_layer(self.width) # contrastive learning temperature @@ -124,12 +121,18 @@ def __init__( # to logits self.to_logits = nn.Sequential( - LayerNorm(self.dim), nn.Linear(self.dim, num_tokens, bias=False) + norm_layer(self.width), nn.Linear(self.width, num_tokens, bias=False) ) + # get the token embeddings whether the encoder is HF or custom + for mod in self.transformer.state_dict(): + if any((emb_name in mod) and ("weight" in mod) for emb_name in ["word_embeddings", "token_embeddings"]): + token_embeddings = self.transformer.get_parameter(mod) + break + # they used embedding weight tied projection out to logits, not common, but works - self.to_logits[-1].weight = self.token_emb.weight - nn.init.normal_(self.token_emb.weight, std=0.02) + self.to_logits[-1].weight = token_embeddings + nn.init.normal_(token_embeddings, std=0.02) def embed_text(self, text): batch, device = text.shape[0], text.device From abd132d3ed45ae66d9427c834d2251844c33b79a Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 1 Dec 2022 16:55:06 +0100 Subject: [PATCH 023/113] add attentionalpooling --- src/open_clip/hf_model.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index 9829c96a2..9ee14f2fe 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -18,6 +18,7 @@ class BaseModelOutput: pass class PretrainedConfig: pass from .hf_configs import arch_dict +from .transformer import Attention # utils def _camel2snake(s): @@ -64,6 +65,24 @@ def forward(self, x:BaseModelOutput, attention_mask:TensorType): return x.last_hidden_state[:, self.cls_token_position, :] +@register_pooler +class AttentionalPooler(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.ln_q = nn.LayerNorm() + self.ln_kv = nn.LayerNorm() + self.attn = Attention( + dim=dim, num_heads=num_heads, causal=False, attn_dropout=0.0, ff_dropout=0.0 + ) + self.to_out = nn.Linear(dim, dim, bias=False) + + def forward(self, q_x, k_x: TensorType, v_x: TensorType): + q_x = self.ln_q(q_x) + k_x = self.ln_kv(k_x) + v_x = self.ln_kv(v_x) + x = self.attn(q_x, k_x, v_x) + return self.to_out(x) + class HFTextEncoder(nn.Module): """HuggingFace model adapter""" def __init__( From 452d7d2b671ef0c4bf28be3fe0444a014f698fb6 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 1 Dec 2022 16:55:53 +0100 Subject: [PATCH 024/113] add attentional pooling to coca --- src/open_clip/coca_model.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 1048a4218..339675e48 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -7,9 +7,9 @@ from einops import rearrange, repeat from dataclasses import dataclass -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, CoCaMultimodalTransformer, ResidualAttentionBlock -from .coca_layers import ParallelTransformerBlock, CrossAttention +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, CoCaMultimodalTransformer from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower +from .hf_model import AttentionalPooler @dataclass @@ -101,15 +101,7 @@ def __init__( # num image queries for multimodal, but 1 extra CLS for contrastive learning self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, self.width)) - self.img_attn_pool = ResidualAttentionBlock( - d_model=coca_cfg.width, - n_head=coca_cfg.heads, - mlp_ratio=coca_cfg.mlp_ratio, - act_layer=act_layer, - norm_layer=norm_layer, - is_cross_attention=True, - is_pooler=True, - ) + self.img_attn_pool = AttentionalPooler(coca_cfg.width, coca_cfg.heads) self.img_attn_pool_norm = norm_layer(self.width) self.text_cls_norm = norm_layer(self.width) From 43ce18fe6b954eda086a58a06df93071308e7da8 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 1 Dec 2022 16:57:10 +0100 Subject: [PATCH 025/113] small change --- src/open_clip/transformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 5feb77514..46aa55872 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -96,7 +96,7 @@ def forward(self, k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x - w_q, w_k, w_v = self.in_proj_weight.chunk(3, dim=0) + w_q, w_k, w_v = self.in_proj_weight.split(3, dim=0) q = F.linear(q_x, w_q, self.in_proj_bias) k = F.linear(k_x, w_k, self.in_proj_bias) @@ -597,7 +597,6 @@ def build_attention_mask(self): return mask def forward(self, text_embs, image_embs, text_eot_mask): - text_embs = text_embs.permute(1, 0, 2) # NLD -> LND image_embs = image_embs.permute(1, 0, 2) # NLD -> LND x = self.transformer(text_embs, image_embs, image_embs, attn_mask=self.attn_mask) From 3f0f012467b7c419bfc1aedc9ce356a154056f53 Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 2 Dec 2022 10:47:53 +0100 Subject: [PATCH 026/113] add cocatransformer variants and AttentionPooling --- src/open_clip/transformer.py | 64 ++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 46aa55872..7a4ee3fca 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -134,6 +134,27 @@ def forward(self, x = self.out_drop(x) return x +class AttentionPooler(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_1_kv = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + + def attention(self, q_x: torch.Tensor, kv_x: torch.Tensor): + return self.attn(q_x, kv_x, kv_x, need_weights=False)[0] + + def forward(self, q_x: torch.Tensor, kv_x: torch.Tensor): + return self.ls_1( + self.attention(q_x=self.ln_1(q_x), k_x=self.ln_1_kv(kv_x),) + ) + class ResidualAttentionBlock(nn.Module): def __init__( @@ -451,6 +472,25 @@ def forward(self, x: torch.Tensor): return x +class CoCaVisionTransformer(VisionTransformer): + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_post(x) + + img_embs, img_tokens = x[:, 0, :], x[:, 1:, :] + + return img_embs, img_tokens class TextTransformer(nn.Module): @@ -534,6 +574,30 @@ def forward(self, text): return x +class COCaTextTransformer(TextTransformer): + + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + eot_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + # looking at the tokenizer this seems ok + token_emb = x[torch.arange(x.shape[0]), :-1, :] @ self.text_projection + + return eot_emb, token_emb + + class CoCaMultimodalTransformer(nn.Module): def __init__( self, From 3e745ec2b572f17985c3979acd5d6702032fd934 Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 2 Dec 2022 10:48:35 +0100 Subject: [PATCH 027/113] remoive older attention pooler --- src/open_clip/hf_model.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index 9ee14f2fe..afd087a76 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -65,24 +65,6 @@ def forward(self, x:BaseModelOutput, attention_mask:TensorType): return x.last_hidden_state[:, self.cls_token_position, :] -@register_pooler -class AttentionalPooler(nn.Module): - def __init__(self, dim, num_heads): - super().__init__() - self.ln_q = nn.LayerNorm() - self.ln_kv = nn.LayerNorm() - self.attn = Attention( - dim=dim, num_heads=num_heads, causal=False, attn_dropout=0.0, ff_dropout=0.0 - ) - self.to_out = nn.Linear(dim, dim, bias=False) - - def forward(self, q_x, k_x: TensorType, v_x: TensorType): - q_x = self.ln_q(q_x) - k_x = self.ln_kv(k_x) - v_x = self.ln_kv(v_x) - x = self.attn(q_x, k_x, v_x) - return self.to_out(x) - class HFTextEncoder(nn.Module): """HuggingFace model adapter""" def __init__( From 4f4d3b7f373988ac5008dbe46d7172855233b24c Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 2 Dec 2022 10:53:52 +0100 Subject: [PATCH 028/113] adapt embed text to coca text transformer --- src/open_clip/coca_model.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 339675e48..60135e47c 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -7,9 +7,8 @@ from einops import rearrange, repeat from dataclasses import dataclass -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, CoCaMultimodalTransformer +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, CoCaMultimodalTransformer, ResidualAttentionBlock, AttentionPooler from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower -from .hf_model import AttentionalPooler @dataclass @@ -101,7 +100,7 @@ def __init__( # num image queries for multimodal, but 1 extra CLS for contrastive learning self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, self.width)) - self.img_attn_pool = AttentionalPooler(coca_cfg.width, coca_cfg.heads) + self.img_attn_pool = AttentionPooler(coca_cfg.width, coca_cfg.heads, norm_layer=norm_layer) self.img_attn_pool_norm = norm_layer(self.width) self.text_cls_norm = norm_layer(self.width) @@ -111,20 +110,12 @@ def __init__( self.temperature = nn.Parameter(torch.Tensor([1.0])) # to logits - self.to_logits = nn.Sequential( norm_layer(self.width), nn.Linear(self.width, num_tokens, bias=False) ) - # get the token embeddings whether the encoder is HF or custom - for mod in self.transformer.state_dict(): - if any((emb_name in mod) and ("weight" in mod) for emb_name in ["word_embeddings", "token_embeddings"]): - token_embeddings = self.transformer.get_parameter(mod) - break - # they used embedding weight tied projection out to logits, not common, but works - self.to_logits[-1].weight = token_embeddings - nn.init.normal_(token_embeddings, std=0.02) + self.to_logits[-1].weight = self.token_embedding.weight def embed_text(self, text): batch, device = text.shape[0], text.device @@ -144,14 +135,7 @@ def embed_text(self, text): cls_mask = rearrange(text != self.pad_id, "b j -> b 1 j") attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) - # go through unimodal layers - - for attn_ff in self.unimodal_layers: - text_tokens = attn_ff(text_tokens, attn_mask=attn_mask) - - # get text cls token - - text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1] + text_tokens, text_cls_tokens = self.transformer(text_tokens, attn_mask=attn_mask) text_embeds = self.text_cls_norm(text_cls_tokens) return text_embeds, text_tokens From 42539aa9008385d1218d44bdde84d450bd6ceb09 Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 2 Dec 2022 10:59:04 +0100 Subject: [PATCH 029/113] rm coca layers --- src/open_clip/coca_layers.py | 228 ----------------------------------- 1 file changed, 228 deletions(-) delete mode 100644 src/open_clip/coca_layers.py diff --git a/src/open_clip/coca_layers.py b/src/open_clip/coca_layers.py deleted file mode 100644 index 873737fb8..000000000 --- a/src/open_clip/coca_layers.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn, einsum - -from einops import rearrange, repeat - -from .transformer import LayerNorm - - -class SwiGLU(nn.Module): - def forward(self, x): - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - - def forward(self, max_seq_len, *, device): - seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = einsum("i , j -> i j", seq, self.inv_freq) - return torch.cat((freqs, freqs), dim=-1) - - -def rotate_half(x): - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(pos, t): - return (t * pos.cos()) + (rotate_half(t) * pos.sin()) - - -class ParallelTransformerBlock(nn.Module): - def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): - super().__init__() - self.norm = LayerNorm(dim) - - attn_inner_dim = dim_head * heads - ff_inner_dim = dim * ff_mult - self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) - - self.heads = heads - self.scale = dim_head**-0.5 - self.rotary_emb = RotaryEmbedding(dim_head) - - self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) - self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) - - self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)) - - # for caching causal mask and rotary embeddings - - self.register_buffer("mask", None, persistent=False) - self.register_buffer("pos_emb", None, persistent=False) - - def get_mask(self, n, device): - if self.mask is not None and self.mask.shape[-1] >= n: - return self.mask[:n, :n] - - mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) - self.register_buffer("mask", mask, persistent=False) - return mask - - def get_rotary_embedding(self, n, device): - if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: - return self.pos_emb[:n] - - pos_emb = self.rotary_emb(n, device=device) - self.register_buffer("pos_emb", pos_emb, persistent=False) - return pos_emb - - def forward(self, x, attn_mask=None): - """ - einstein notation - b - batch - h - heads - n, i, j - sequence length (base sequence length, source, target) - d - feature dimension - """ - - n, device, h = x.shape[1], x.device, self.heads - - # pre layernorm - x = self.norm(x) - - # attention queries, keys, values, and feedforward inner - q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) - - # split heads - # they use multi-query single-key-value attention, yet another Noam Shazeer paper - # they found no performance loss past a certain scale, and more efficient decoding obviously - # https://arxiv.org/abs/1911.02150 - - q = rearrange(q, "b n (h d) -> b h n d", h=h) - - # rotary embeddings - positions = self.get_rotary_embedding(n, device) - q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) - - # scale - q = q * self.scale - - # similarity - sim = einsum("b h i d, b j d -> b h i j", q, k) - - # causal mask - causal_mask = self.get_mask(n, device) - sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) - - # extra attention mask - for masking out attention from text CLS token to padding - - if attn_mask is not None: - attn_mask = rearrange(attn_mask, "b i j -> b 1 i j") - sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) - - # attention - - sim = sim - sim.amax(dim=-1, keepdim=True).detach() - attn = sim.softmax(dim=-1) - - # aggregate values - out = einsum("b h i j, b j d -> b h i d", attn, v) - - # merge heads - out = rearrange(out, "b h n d -> b n (h d)") - return (self.attn_out(out) + self.ff_out(ff)) + x - - -class CrossAttention(nn.Module): - def __init__( - self, - dim, - *, - context_dim=None, - dim_head=64, - heads=8, - parallel_ff=False, - ff_mult=4, - norm_context=False, - residual=False - ): - super().__init__() - self.heads = heads - self.scale = dim_head**-0.5 - self.residual = residual - inner_dim = heads * dim_head - context_dim = context_dim if context_dim is not None else dim - - self.norm = LayerNorm(dim) - self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - # whether to have parallel feedforward - - ff_inner_dim = ff_mult * dim - - self.ff = ( - nn.Sequential( - nn.Linear(dim, ff_inner_dim * 2, bias=False), - SwiGLU(), - nn.Linear(ff_inner_dim, dim, bias=False), - ) - if parallel_ff - else None - ) - - def forward(self, x, context): - """ - einstein notation - b - batch - h - heads - n, i, j - sequence length (base sequence length, source, target) - d - feature dimension - """ - - # pre-layernorm, for queries and context - - x = self.norm(x) - context = self.context_norm(context) - - # get queries - - q = self.to_q(x) - q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) - - # scale - - q = q * self.scale - - # get key / values - - k, v = self.to_kv(context).chunk(2, dim=-1) - - # query / key similarity - - sim = einsum("b h i d, b j d -> b h i j", q, k) - - # attention - - sim = sim - sim.amax(dim=-1, keepdim=True) - attn = sim.softmax(dim=-1) - - # aggregate - - out = einsum("b h i j, b j d -> b h i d", attn, v) - - # merge and combine heads - - out = rearrange(out, "b h n d -> b n (h d)") - out = self.to_out(out) - - # add parallel feedforward (for multimodal layers) - - if self.ff is not None: - out = out + self.ff(x) - - if self.residual: - out = out + x - - return out From 914a5708d7dc41f3452910eda22e210168c256fa Mon Sep 17 00:00:00 2001 From: gpucce Date: Sat, 3 Dec 2022 11:36:00 +0100 Subject: [PATCH 030/113] rename and remove useless CoCa models --- src/open_clip/coca_model.py | 40 +++++++++++++++++++++-- src/open_clip/transformer.py | 61 ++++-------------------------------- 2 files changed, 44 insertions(+), 57 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 60135e47c..9736d8273 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -7,7 +7,7 @@ from einops import rearrange, repeat from dataclasses import dataclass -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, CoCaMultimodalTransformer, ResidualAttentionBlock, AttentionPooler +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, MultimodalTransformer, AttentionPooler from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @@ -48,7 +48,7 @@ def _build_coca_multimodal_tower( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) - text = CoCaMultimodalTransformer( + text = MultimodalTransformer( context_length=coca_cfg.context_length, width=coca_cfg.width, heads=coca_cfg.heads, @@ -118,6 +118,26 @@ def __init__( self.to_logits[-1].weight = self.token_embedding.weight def embed_text(self, text): + # def forward(self, text): + # cast_dtype = self.transformer.get_cast_dtype() + + # x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + # x = x + self.positional_embedding.to(cast_dtype) + # x = x.permute(1, 0, 2) # NLD -> LND + # x = self.transformer(x, attn_mask=self.attn_mask) + # x = x.permute(1, 0, 2) # LND -> NLD + # x = self.ln_final(x) + + # # x.shape = [batch_size, n_ctx, transformer.width] + # # take features from the eot embedding (eot_token is the highest number in each sequence) + # eot_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + # # looking at the tokenizer this seems ok + # token_emb = x[torch.arange(x.shape[0]), :-1, :] @ self.text_projection + + # return eot_emb, token_emb + batch, device = text.shape[0], text.device seq = text.shape[1] @@ -144,6 +164,22 @@ def embed_image(self, images=None, image_tokens=None): # with the img_encoder passed in at init # it can also accept precomputed image tokens + # def forward(self, x: torch.Tensor): + # x = self.conv1(x) # shape = [*, width, grid, grid] + # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + # x = torch.cat( + # [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + # x], dim=1) # shape = [*, grid ** 2 + 1, width] + # x = x + self.positional_embedding.to(x.dtype) + # x = self.ln_pre(x) + # x = x.permute(1, 0, 2) # NLD -> LND + # x = self.transformer(x) + # x = x.permute(1, 0, 2) # LND -> NLD + # x = self.ln_post(x) + # img_embs, img_tokens = x[:, 0, :], x[:, 1:, :] + # return img_embs, img_tokens + assert images is None or image_tokens is None if images is not None: diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 7a4ee3fca..dcaffc464 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -138,22 +138,16 @@ class AttentionPooler(nn.Module): def __init__( self, d_model: int, - n_head: int, + n_head: int = 1, + n_queries: int = 256, norm_layer: Callable = LayerNorm, ): super().__init__() - - self.ln_1 = norm_layer(d_model) - self.ln_1_kv = norm_layer(d_model) + self.query = nn.Parameter(torch.randn(n_queries, d_model)) self.attn = nn.MultiheadAttention(d_model, n_head) - def attention(self, q_x: torch.Tensor, kv_x: torch.Tensor): - return self.attn(q_x, kv_x, kv_x, need_weights=False)[0] - def forward(self, q_x: torch.Tensor, kv_x: torch.Tensor): - return self.ls_1( - self.attention(q_x=self.ln_1(q_x), k_x=self.ln_1_kv(kv_x),) - ) + return self.attn(q_x, self.query, self.query, need_weights=False)[0] class ResidualAttentionBlock(nn.Module): @@ -306,7 +300,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): return x -class MultimodalTransformerDecoder(nn.Module): +class TransformerDecoder(nn.Module): def __init__( self, width: int, @@ -472,26 +466,6 @@ def forward(self, x: torch.Tensor): return x -class CoCaVisionTransformer(VisionTransformer): - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), - x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_post(x) - - img_embs, img_tokens = x[:, 0, :], x[:, 1:, :] - - return img_embs, img_tokens - class TextTransformer(nn.Module): def __init__( @@ -574,31 +548,8 @@ def forward(self, text): return x -class COCaTextTransformer(TextTransformer): - - - def forward(self, text): - cast_dtype = self.transformer.get_cast_dtype() - - x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.to(cast_dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - eot_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - # looking at the tokenizer this seems ok - token_emb = x[torch.arange(x.shape[0]), :-1, :] @ self.text_projection - - return eot_emb, token_emb - -class CoCaMultimodalTransformer(nn.Module): +class MultimodalTransformer(nn.Module): def __init__( self, context_length: int = 77, From 6215d4a1abba6bccc3be91a6fc196ca7d35af7b8 Mon Sep 17 00:00:00 2001 From: gpucce Date: Sat, 3 Dec 2022 11:45:49 +0100 Subject: [PATCH 031/113] make attentionpooler pooler only --- src/open_clip/transformer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index dcaffc464..b2bd30f73 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -146,8 +146,14 @@ def __init__( self.query = nn.Parameter(torch.randn(n_queries, d_model)) self.attn = nn.MultiheadAttention(d_model, n_head) - def forward(self, q_x: torch.Tensor, kv_x: torch.Tensor): - return self.attn(q_x, self.query, self.query, need_weights=False)[0] + def forward(self, kv: torch.Tensor): + kv = kv.reshape(1, 0 ,2) + N = kv.shape[1] + return self.attn(self._repeat(self.query, N), kv, kv, need_weights=False)[0] + + def _repeat(self, query, N): + L, D = query.shape + return query.unsqueeze(0).repeat(L, N, D) class ResidualAttentionBlock(nn.Module): From b97db74d397b251c6415aae3d268a057bca7704e Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 07:04:25 +0100 Subject: [PATCH 032/113] refactor for one transformer only --- src/open_clip/coca_model.py | 138 +++++++++++---------------------- src/open_clip/transformer.py | 146 ++++++++++++++--------------------- 2 files changed, 103 insertions(+), 181 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 9736d8273..d288adc77 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -7,7 +7,7 @@ from einops import rearrange, repeat from dataclasses import dataclass -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, MultimodalTransformer, AttentionPooler +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, TransformerDecoder, AttentionPooler from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @@ -34,7 +34,7 @@ class CoCaCfg: # vit_mlp_dim: int = 3072 -def _build_coca_multimodal_tower( +def _build_text_decoder_tower( embed_dim: int, coca_cfg: CoCaCfg, quick_gelu: bool = False, @@ -48,7 +48,7 @@ def _build_coca_multimodal_tower( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) - text = MultimodalTransformer( + text = TransformerDecoder( context_length=coca_cfg.context_length, width=coca_cfg.width, heads=coca_cfg.heads, @@ -75,7 +75,6 @@ def __init__( super().__init__() norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm - act_layer = QuickGELU if quick_gelu else nn.GELU text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer @@ -90,17 +89,17 @@ def __init__( embed_dim, vision_cfg, quick_gelu, cast_dtype ) - self.multimodal_decoder = _build_coca_multimodal_tower( + self.multimodal_decoder = _build_text_decoder_tower( embed_dim, coca_cfg, quick_gelu, cast_dtype ) + num_img_queries = coca_cfg.num_image_queries self.width = coca_cfg.width num_tokens = text_cfg.vocab_size self.text_cls_token = nn.Parameter(torch.randn(self.width)) # num image queries for multimodal, but 1 extra CLS for contrastive learning - self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, self.width)) - self.img_attn_pool = AttentionPooler(coca_cfg.width, coca_cfg.heads, norm_layer=norm_layer) + self.img_attn_pool = AttentionPooler(coca_cfg.width, coca_cfg.heads) self.img_attn_pool_norm = norm_layer(self.width) self.text_cls_norm = norm_layer(self.width) @@ -118,83 +117,45 @@ def __init__( self.to_logits[-1].weight = self.token_embedding.weight def embed_text(self, text): - # def forward(self, text): - # cast_dtype = self.transformer.get_cast_dtype() - - # x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - - # x = x + self.positional_embedding.to(cast_dtype) - # x = x.permute(1, 0, 2) # NLD -> LND - # x = self.transformer(x, attn_mask=self.attn_mask) - # x = x.permute(1, 0, 2) # LND -> NLD - # x = self.ln_final(x) - - # # x.shape = [batch_size, n_ctx, transformer.width] - # # take features from the eot embedding (eot_token is the highest number in each sequence) - # eot_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - # # looking at the tokenizer this seems ok - # token_emb = x[torch.arange(x.shape[0]), :-1, :] @ self.text_projection - - # return eot_emb, token_emb - - batch, device = text.shape[0], text.device - - seq = text.shape[1] - - text_tokens = self.token_emb(text) - - # append text cls tokens - - text_cls_tokens = repeat(self.text_cls_token, "d -> b 1 d", b=batch) - text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2) - - # create specific mask for text cls token at the end - # to prevent it from attending to padding - - cls_mask = rearrange(text != self.pad_id, "b j -> b 1 j") - attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) - - text_tokens, text_cls_tokens = self.transformer(text_tokens, attn_mask=attn_mask) - text_embeds = self.text_cls_norm(text_cls_tokens) - return text_embeds, text_tokens - - def embed_image(self, images=None, image_tokens=None): - # encode images into embeddings - # with the img_encoder passed in at init - # it can also accept precomputed image tokens - - # def forward(self, x: torch.Tensor): - # x = self.conv1(x) # shape = [*, width, grid, grid] - # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - # x = torch.cat( - # [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), - # x], dim=1) # shape = [*, grid ** 2 + 1, width] - # x = x + self.positional_embedding.to(x.dtype) - # x = self.ln_pre(x) - # x = x.permute(1, 0, 2) # NLD -> LND - # x = self.transformer(x) - # x = x.permute(1, 0, 2) # LND -> NLD - # x = self.ln_post(x) - # img_embs, img_tokens = x[:, 0, :], x[:, 1:, :] - # return img_embs, img_tokens - - assert images is None or image_tokens is None - - if images is not None: - assert ( - self.img_encoder is not None - ), "img_encoder must be passed in for automatic image encoding" - image_tokens = self.img_encoder(images) + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + cls_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + # looking at the tokenizer this seems ok + token_emb = x[torch.arange(x.shape[0]), :-1] @ self.text_projection + cls_emb = self.text_cls_norm(cls_emb) + return cls_emb, token_emb + + def embed_image(self, images=None): + x = self.img_encoder.conv1(images) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [self.img_encoder.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.img_encoder.positional_embedding.to(x.dtype) + x = self.img_encoder.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.img_encoder.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.img_encoder.ln_post(x) # attention pool image tokens + print("################", x.shape) + x = self.img_attn_pool(x) + x = self.img_attn_pool_norm(x) - img_queries = repeat(self.img_queries, "n d -> b n d", b=image_tokens.shape[0]) - img_queries = self.img_attn_pool(img_queries, image_tokens) - img_queries = self.img_attn_pool_norm(img_queries) - - return img_queries[:, 0], img_queries[:, 1:] + return x[:, 0], x[:, 1:] def forward( self, @@ -202,24 +163,15 @@ def forward( images=None, image_tokens=None, labels=None, - return_loss=False, - return_embeddings=False, ): - batch, device = text.shape[0], text.device - if return_loss and labels is None: + if labels is None: text, labels = text[:, :-1], text[:, 1:] text_embeds, text_tokens = self.embed_text(text) + image_embeds, image_tokens = self.embed_image(images) - image_embeds, image_tokens = self.embed_image( - images=images, image_tokens=image_tokens - ) - - for attn_ff, cross_attn in self.multimodal_layers: - text_tokens = attn_ff(text_tokens) - text_tokens = cross_attn(text_tokens, image_tokens) - + text_tokens = self.multimodal_decoder(image_tokens, text_tokens, eot_token_mask=text.argmax(dim=-1)) logits = self.to_logits(text_tokens) return text_embeds, image_embeds, logits, labels diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index b2bd30f73..96bd61330 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -140,20 +140,19 @@ def __init__( d_model: int, n_head: int = 1, n_queries: int = 256, - norm_layer: Callable = LayerNorm, ): super().__init__() self.query = nn.Parameter(torch.randn(n_queries, d_model)) self.attn = nn.MultiheadAttention(d_model, n_head) def forward(self, kv: torch.Tensor): - kv = kv.reshape(1, 0 ,2) + kv = kv.permute(1, 0 ,2) # NLD -> LND N = kv.shape[1] - return self.attn(self._repeat(self.query, N), kv, kv, need_weights=False)[0] + kv = self.attn(self._repeat(self.query, N), kv, kv, need_weights=False)[0] + return kv.permute(1, 0, 2) # LND -> NLD def _repeat(self, query, N): - L, D = query.shape - return query.unsqueeze(0).repeat(L, N, D) + return query.unsqueeze(1).repeat(1, N, 1) class ResidualAttentionBlock(nn.Module): @@ -196,6 +195,13 @@ def attention( v_x = v_x if v_x is not None else q_x attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + print("####################", q_x.shape) + print("#", k_x.shape) + print("##", v_x.shape) + try: + print("###", attn_mask.shape) + except: + pass return self.attn( q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask )[0] @@ -305,55 +311,6 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = r(x, attn_mask=attn_mask) return x - -class TransformerDecoder(nn.Module): - def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - ): - - super().__init__() - self.width = width - self.layers = layers - self.grad_checkpointing = False - - self.resblocks = nn.ModuleList([ - ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) - for _ in range(layers) - ]) - - self.cross_attn = nn.ModuleList([ - ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) - for _ in range(layers) - ]) - - def get_cast_dtype(self) -> torch.dtype: - return self.resblocks[0].mlp.c_fc.weight.dtype - - def forward( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor], - v_x: Optional[torch.Tensor], - attn_mask: Optional[torch.Tensor] = None - ): - for r, ca in zip(self.resblocks, self.cross_attn): - if self.grad_checkpointing and not torch.jit.is_scripting(): - q_x = checkpoint(r, q_x, attn_mask=attn_mask) - q_x = checkpoint(ca, q_x, k_x=k_x, v_x=v_x) - else: - q_x = r(q_x, attn_mask=attn_mask) - q_x = ca(q_x, k_x=k_x, v_x=v_x) - return q_x - class VisionTransformer(nn.Module): def __init__( self, @@ -428,24 +385,24 @@ def _unlock(x): _unlock(groups[-unlocked_groups:]) def init_parameters(self): - # FIXME OpenAI CLIP did not define an init for the VisualTransformer - # TODO experiment if default PyTorch init, below, or alternate init is best. - - # nn.init.normal_(self.class_embedding, std=self.scale) - # nn.init.normal_(self.positional_embedding, std=self.scale) - # - # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - # attn_std = self.transformer.width ** -0.5 - # fc_std = (2 * self.transformer.width) ** -0.5 - # for block in self.transformer.resblocks: - # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - # - # if self.text_projection is not None: - # nn.init.normal_(self.text_projection, std=self.scale) - pass + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass @torch.jit.ignore def set_grad_checkpointing(self, enable=True): @@ -555,39 +512,44 @@ def forward(self, text): return x -class MultimodalTransformer(nn.Module): +class TransformerDecoder(Transformer): def __init__( self, + width: int, + layers: int, + heads: int, context_length: int = 77, - vocab_size: int = 49408, - width: int = 512, - heads: int = 8, - layers: int = 12, + mlp_ratio: float = 4.0, ls_init_value: float = None, - output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + output_dim: int = 512, ): - super().__init__() - self.context_length = context_length - self.vocab_size = vocab_size - self.width = width - self.output_dim = output_dim - self.transformer = MultimodalTransformerDecoder( + super().__init__( width=width, layers=layers, heads=heads, + mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + self.ln_final = norm_layer(width) + # this will be shared with the textual decoder (in CoCa) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) def init_parameters(self): - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 @@ -617,15 +579,23 @@ def build_attention_mask(self): mask.triu_(1) # zero out the lower diagonal return mask - def forward(self, text_embs, image_embs, text_eot_mask): + def forward(self, text_embs, image_embs, eot_token_mask): text_embs = text_embs.permute(1, 0, 2) # NLD -> LND image_embs = image_embs.permute(1, 0, 2) # NLD -> LND - x = self.transformer(text_embs, image_embs, image_embs, attn_mask=self.attn_mask) + + for r, ca in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + text_embs = checkpoint(r, text_embs, attn_mask=self.attn_mask) + text_embs = checkpoint(ca, text_embs, k_x=image_embs, v_x=image_embs) + else: + text_embs = r(text_embs, attn_mask=self.attn_mask) + text_embs = ca(text_embs, k_x=image_embs, v_x=image_embs) + x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text_eot_mask] @ self.text_projection + x = x[torch.arange(x.shape[0]), eot_token_mask] @ self.text_projection return x From d89f018dc028d10e7bd068d3348f694f6468bff9 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 17:08:10 +0900 Subject: [PATCH 033/113] coca forward works --- src/open_clip/coca_model.py | 52 ++++++++++++++++++++++++------------ src/open_clip/transformer.py | 17 +++++------- 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index d288adc77..adc8a8111 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -7,7 +7,13 @@ from einops import rearrange, repeat from dataclasses import dataclass -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, TransformerDecoder, AttentionPooler +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + TransformerDecoder, + AttentionPooler, +) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @@ -22,9 +28,9 @@ class CoCaCfg: layers: int = 12 dim_head: int = 64 heads: int = 12 - num_image_queries: int = 256 contrastive_loss_weight: float = 1.0 caption_loss_weight: float = 2.0 + vocab_size = 49408 # vit_image_size: int = 288 # vit_patch_size: int = 18 @@ -74,7 +80,11 @@ def __init__( ): super().__init__() - norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + norm_layer = ( + LayerNormFp32 + if cast_dtype in (torch.float16, torch.bfloat16) + else LayerNorm + ) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer @@ -93,13 +103,15 @@ def __init__( embed_dim, coca_cfg, quick_gelu, cast_dtype ) - num_img_queries = coca_cfg.num_image_queries self.width = coca_cfg.width - num_tokens = text_cfg.vocab_size - self.text_cls_token = nn.Parameter(torch.randn(self.width)) + # self.text_cls_token = nn.Parameter(torch.randn(self.width)) + # self.text_cls_pos_emb = nn.Parameter(torch.empty(self.width)) + # nn.init.normal_(self.text_cls_pos_emb, std=0.01) # num image queries for multimodal, but 1 extra CLS for contrastive learning - self.img_attn_pool = AttentionPooler(coca_cfg.width, coca_cfg.heads) + self.img_attn_pool = AttentionPooler( + coca_cfg.width, coca_cfg.heads, n_queries=coca_cfg.context_length + 1 + ) self.img_attn_pool_norm = norm_layer(self.width) self.text_cls_norm = norm_layer(self.width) @@ -110,7 +122,7 @@ def __init__( # to logits self.to_logits = nn.Sequential( - norm_layer(self.width), nn.Linear(self.width, num_tokens, bias=False) + norm_layer(self.width), nn.Linear(self.width, text_cfg.vocab_size, bias=False) ) # they used embedding weight tied projection out to logits, not common, but works @@ -120,7 +132,6 @@ def embed_text(self, text): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=self.attn_mask) @@ -129,11 +140,11 @@ def embed_text(self, text): # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) - cls_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + x = x[torch.arange(x.shape[0]), :] @ self.text_projection # looking at the tokenizer this seems ok - token_emb = x[torch.arange(x.shape[0]), :-1] @ self.text_projection - cls_emb = self.text_cls_norm(cls_emb) + cls_emb = self.text_cls_norm(x[torch.arange(x.shape[0]), -1]) + token_emb = x[torch.arange(x.shape[0]), :-1] return cls_emb, token_emb def embed_image(self, images=None): @@ -141,8 +152,15 @@ def embed_image(self, images=None): x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat( - [self.img_encoder.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), - x], dim=1) # shape = [*, grid ** 2 + 1, width] + [ + self.img_encoder.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] x = x + self.img_encoder.positional_embedding.to(x.dtype) x = self.img_encoder.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND @@ -150,8 +168,6 @@ def embed_image(self, images=None): x = x.permute(1, 0, 2) # LND -> NLD x = self.img_encoder.ln_post(x) - # attention pool image tokens - print("################", x.shape) x = self.img_attn_pool(x) x = self.img_attn_pool_norm(x) @@ -171,7 +187,9 @@ def forward( text_embeds, text_tokens = self.embed_text(text) image_embeds, image_tokens = self.embed_image(images) - text_tokens = self.multimodal_decoder(image_tokens, text_tokens, eot_token_mask=text.argmax(dim=-1)) + text_tokens = self.multimodal_decoder( + image_tokens, text_tokens, eot_token_mask=text.argmax(dim=-1) + ) logits = self.to_logits(text_tokens) return text_embeds, image_embeds, logits, labels diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 96bd61330..866174bfe 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -149,7 +149,9 @@ def forward(self, kv: torch.Tensor): kv = kv.permute(1, 0 ,2) # NLD -> LND N = kv.shape[1] kv = self.attn(self._repeat(self.query, N), kv, kv, need_weights=False)[0] - return kv.permute(1, 0, 2) # LND -> NLD + out = kv.permute(1, 0, 2) # LND -> NLD + print("################### attn_pool", out.shape) + return out def _repeat(self, query, N): return query.unsqueeze(1).repeat(1, N, 1) @@ -195,13 +197,6 @@ def attention( v_x = v_x if v_x is not None else q_x attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None - print("####################", q_x.shape) - print("#", k_x.shape) - print("##", v_x.shape) - try: - print("###", attn_mask.shape) - except: - pass return self.attn( q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask )[0] @@ -538,7 +533,7 @@ def __init__( self.context_length = context_length self.cross_attn = nn.ModuleList([ ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, is_cross_attention=True) for _ in range(layers) ]) @@ -591,11 +586,11 @@ def forward(self, text_embs, image_embs, eot_token_mask): text_embs = r(text_embs, attn_mask=self.attn_mask) text_embs = ca(text_embs, k_x=image_embs, v_x=image_embs) - x = x.permute(1, 0, 2) # LND -> NLD + x = text_embs.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), eot_token_mask] @ self.text_projection + # x = x[torch.arange(x.shape[0]), eot_token_mask] @ self.text_projection return x From 9a8c15dff46c056a14a8c5ee91781425c0146a62 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 18:24:14 +0900 Subject: [PATCH 034/113] separatae context and n_queries --- src/open_clip/coca_model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index adc8a8111..2cd1af5e3 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -30,7 +30,7 @@ class CoCaCfg: heads: int = 12 contrastive_loss_weight: float = 1.0 caption_loss_weight: float = 2.0 - vocab_size = 49408 + n_queries: int = 256 # vit_image_size: int = 288 # vit_patch_size: int = 18 @@ -104,13 +104,9 @@ def __init__( ) self.width = coca_cfg.width - # self.text_cls_token = nn.Parameter(torch.randn(self.width)) - # self.text_cls_pos_emb = nn.Parameter(torch.empty(self.width)) - # nn.init.normal_(self.text_cls_pos_emb, std=0.01) - # num image queries for multimodal, but 1 extra CLS for contrastive learning self.img_attn_pool = AttentionPooler( - coca_cfg.width, coca_cfg.heads, n_queries=coca_cfg.context_length + 1 + coca_cfg.width, coca_cfg.heads, n_queries=coca_cfg.n_queries + 1 ) self.img_attn_pool_norm = norm_layer(self.width) From c8b9236366dadf23bae6eb9118434be8e36f26d7 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 18:24:36 +0900 Subject: [PATCH 035/113] add inital coca_base config --- src/open_clip/model_configs/coca_base.json | 32 ++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 src/open_clip/model_configs/coca_base.json diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json new file mode 100644 index 000000000..3bb938c08 --- /dev/null +++ b/src/open_clip/model_configs/coca_base.json @@ -0,0 +1,32 @@ +{ + "embed_dim": 768, + "coca_cfg": { + "width":768, + "model_name": "coca_base", + "context_length": 77, + "image_dim": 768, + "mlp_ratio": 4, + "ls_init_value": "None", + "layers": 12, + "dim_head": 64, + "heads": 12, + "contrastive_loss_weight": 1.0, + "caption_loss_weight": 2.0, + "n_queries": 256 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18 + }, + "text_cfg": { + "vocab_size": 64000, + "layers": 12, + "dim_head": 64, + "heads":12, + "mlp_ration": 4, + "context_length": 512, + "ls_init_value": "None" + } +} From d0f995ae4b14b3eb8a4bed2cb149e98a53df666c Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 18:32:33 +0900 Subject: [PATCH 036/113] remove config --- src/open_clip/model_configs/coca_base.json | 32 ---------------------- 1 file changed, 32 deletions(-) delete mode 100644 src/open_clip/model_configs/coca_base.json diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json deleted file mode 100644 index 3bb938c08..000000000 --- a/src/open_clip/model_configs/coca_base.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "embed_dim": 768, - "coca_cfg": { - "width":768, - "model_name": "coca_base", - "context_length": 77, - "image_dim": 768, - "mlp_ratio": 4, - "ls_init_value": "None", - "layers": 12, - "dim_head": 64, - "heads": 12, - "contrastive_loss_weight": 1.0, - "caption_loss_weight": 2.0, - "n_queries": 256 - }, - "vision_cfg": { - "image_size": 288, - "layers": 12, - "width": 768, - "patch_size": 18 - }, - "text_cfg": { - "vocab_size": 64000, - "layers": 12, - "dim_head": 64, - "heads":12, - "mlp_ration": 4, - "context_length": 512, - "ls_init_value": "None" - } -} From 52607747f7e48ca7bb9cdb6a137edee023d0fbad Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 19:00:45 +0900 Subject: [PATCH 037/113] small loss change --- src/open_clip/__init__.py | 2 +- src/open_clip/coca_model.py | 21 ++++++++------------- src/open_clip/loss.py | 2 +- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index b76dd51b9..d3bc068a7 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -1,7 +1,7 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer from .factory import list_models, add_model_config, get_model_config, load_checkpoint -from .loss import ClipLoss +from .loss import ClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype from .openai import load_openai_model, list_openai_models diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 2cd1af5e3..50ff1142c 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -1,10 +1,8 @@ -from typing import Callable, Optional +from typing import Optional import torch -import torch.nn.functional as F -from torch import nn, einsum - -from einops import rearrange, repeat +from torch import nn +import numpy as np from dataclasses import dataclass from .transformer import ( @@ -32,13 +30,6 @@ class CoCaCfg: caption_loss_weight: float = 2.0 n_queries: int = 256 - # vit_image_size: int = 288 - # vit_patch_size: int = 18 - # vit_dim: int = 768 - # vit_depth: int = 12 - # vit_heads: int = 12 - # vit_mlp_dim: int = 3072 - def _build_text_decoder_tower( embed_dim: int, @@ -99,6 +90,8 @@ def __init__( embed_dim, vision_cfg, quick_gelu, cast_dtype ) + + self.multimodal_decoder = _build_text_decoder_tower( embed_dim, coca_cfg, quick_gelu, cast_dtype ) @@ -124,6 +117,8 @@ def __init__( # they used embedding weight tied projection out to logits, not common, but works self.to_logits[-1].weight = self.token_embedding.weight + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + def embed_text(self, text): cast_dtype = self.transformer.get_cast_dtype() @@ -188,4 +183,4 @@ def forward( ) logits = self.to_logits(text_tokens) - return text_embeds, image_embeds, logits, labels + return text_embeds, image_embeds, logits, labels, self.logits_scale diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 2e71b30a4..6101c9d3d 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -148,7 +148,7 @@ def __init__( self.caption_loss_weight = caption_loss_weight self.pad_id = pad_id - def forward(self, image_features, text_features, logits, logit_scale, labels): + def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = self.clip_loss(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss From 7a2b84ef7408cd83609850b5e217ae0209a69685 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 19:00:59 +0900 Subject: [PATCH 038/113] init training file --- src/training/coca_train.py | 263 +++++++++++++++++++++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 src/training/coca_train.py diff --git a/src/training/coca_train.py b/src/training/coca_train.py new file mode 100644 index 000000000..28e2084af --- /dev/null +++ b/src/training/coca_train.py @@ -0,0 +1,263 @@ +import json +import logging +import math +import os +import time + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import CoCaLoss, get_cast_dtype +from .distributed import is_master +from .zero_shot import zero_shot_eval +from .precision import get_autocast + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, 'module'): + return model.module + else: + return model + + +def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): + device = torch.device(args.device) + autocast = get_autocast(args.precision) + cast_dtype = get_cast_dtype(args.precision) + + model.train() + loss = CoCaLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + clip_loss_weight=1.0, + caption_loss_weight=2.0 + ) + + data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch + dataloader = data['train'].dataloader + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + for i, batch in enumerate(dataloader): + step = num_batches_per_epoch * epoch + i + + if not args.skip_scheduler: + scheduler(step) + + images, texts = batch + images = images.to(device=device, dtype=cast_dtype, non_blocking=True) + texts = texts.to(device=device, non_blocking=True) + + data_time_m.update(time.time() - end) + optimizer.zero_grad() + + with autocast(): + image_features, text_features, logits, labels, logit_scale = model(images, texts) + total_loss = loss(image_features, text_features, logits, labels, logit_scale) + + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + if args.grad_clip_norm is not None: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + if is_master(args) and (i % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): + batch_size = len(images) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + logit_scale_scalar = logit_scale.item() + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size*args.world_size / batch_time_m.val:#g}/s " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale: {logit_scale_scalar:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "samples_per_scond": args.batch_size*args.world_size / batch_time_m.val, + "scale": logit_scale_scalar, + "lr": optimizer.param_groups[0]["lr"] + } + for name, val in log_data.items(): + name = "train/" + name + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, 'Please install wandb.' + wandb.log({name: val, 'step': step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None): + metrics = {} + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + metrics.update(zero_shot_metrics) + + autocast = get_autocast(args.precision) + cast_dtype = get_cast_dtype(args.precision) + + if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): + dataloader = data['val'].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + # FIXME this does not scale past small eval datasets + # all_image_features @ all_text_features will blow up memory and compute very quickly + cumulative_loss = 0.0 + all_image_features, all_text_features = [], [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + images, texts = batch + images = images.to(device=device, dtype=cast_dtype, non_blocking=True) + texts = texts.to(device=device, non_blocking=True) + + with autocast(): + image_features, text_features, logit_scale = model(images, texts) + # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly + # however, system RAM is easily exceeded and compute time becomes problematic + all_image_features.append(image_features.cpu()) + all_text_features.append(text_features.cpu()) + logit_scale = logit_scale.mean() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + batch_size = images.shape[0] + labels = torch.arange(batch_size, device=device).long() + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + cumulative_loss += total_loss * batch_size + num_samples += batch_size + if is_master(args) and (i % 100) == 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" + f"Loss: {cumulative_loss / num_samples:.6f}\t") + + val_metrics = get_metrics( + image_features=torch.cat(all_image_features), + text_features=torch.cat(all_text_features), + logit_scale=logit_scale.cpu(), + ) + loss = cumulative_loss / num_samples + metrics.update( + {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} + ) + + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, 'Please install wandb.' + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, 'epoch': epoch}) + + return metrics + + +def get_metrics(image_features, text_features, logit_scale): + metrics = {} + logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() + logits_per_text = logits_per_image.t().detach().cpu() + + logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + for name, logit in logits.items(): + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[1] + preds = preds.detach().cpu().numpy() + metrics[f"{name}_mean_rank"] = preds.mean() + 1 + metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = np.mean(preds < k) + + return metrics From 3ef1d17ebb44b39ecf840494a486acaa57d9dbe9 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 15:09:05 +0100 Subject: [PATCH 039/113] make variable order right --- src/open_clip/coca_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 50ff1142c..0a883a69b 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -179,8 +179,8 @@ def forward( image_embeds, image_tokens = self.embed_image(images) text_tokens = self.multimodal_decoder( - image_tokens, text_tokens, eot_token_mask=text.argmax(dim=-1) + text_tokens, image_tokens, eot_token_mask=text.argmax(dim=-1) ) logits = self.to_logits(text_tokens) - return text_embeds, image_embeds, logits, labels, self.logits_scale + return text_embeds, image_embeds, logits, labels, self.logit_scale From 86f47bb89375fa41bfe2f88b3abd887e15484100 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 15:09:28 +0100 Subject: [PATCH 040/113] remove print --- src/open_clip/transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 866174bfe..0ced0cafd 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -150,7 +150,6 @@ def forward(self, kv: torch.Tensor): N = kv.shape[1] kv = self.attn(self._repeat(self.query, N), kv, kv, need_weights=False)[0] out = kv.permute(1, 0, 2) # LND -> NLD - print("################### attn_pool", out.shape) return out def _repeat(self, query, N): From c6834b50859f79b5d34ed667169b7eee95d1b792 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 15:45:36 +0100 Subject: [PATCH 041/113] uniform names --- src/open_clip/coca_model.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 0a883a69b..f0f8eb50f 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -18,14 +18,14 @@ @dataclass class CoCaCfg: model_name: str = "CoCa_base" - context_length = 77 - width: int = 768 - image_dim: int = 768 + context_length:int = 76 + width: int = 512 + image_dim: int = 512 mlp_ratio: int = 4 ls_init_value: Optional[float] = None layers: int = 12 dim_head: int = 64 - heads: int = 12 + heads: int = 8 contrastive_loss_weight: float = 1.0 caption_loss_weight: float = 2.0 n_queries: int = 256 @@ -119,7 +119,7 @@ def __init__( self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - def embed_text(self, text): + def encode_text(self, text): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] @@ -138,7 +138,7 @@ def embed_text(self, text): token_emb = x[torch.arange(x.shape[0]), :-1] return cls_emb, token_emb - def embed_image(self, images=None): + def encode_image(self, images=None): x = self.img_encoder.conv1(images) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -175,8 +175,8 @@ def forward( if labels is None: text, labels = text[:, :-1], text[:, 1:] - text_embeds, text_tokens = self.embed_text(text) - image_embeds, image_tokens = self.embed_image(images) + text_embeds, text_tokens = self.encode_text(text) + image_embeds, image_tokens = self.encode_image(images) text_tokens = self.multimodal_decoder( text_tokens, image_tokens, eot_token_mask=text.argmax(dim=-1) From 7489c6897bda88dbb35562ad7b9ed2b54df65ae1 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 16:09:20 +0100 Subject: [PATCH 042/113] renaming --- src/open_clip/coca_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index f0f8eb50f..2c8a270fd 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -86,7 +86,7 @@ def __init__( self.text_projection = text.text_projection self.register_buffer("attn_mask", text.attn_mask, persistent=False) - self.img_encoder = _build_vision_tower( + self.visual = _build_vision_tower( embed_dim, vision_cfg, quick_gelu, cast_dtype ) @@ -139,12 +139,12 @@ def encode_text(self, text): return cls_emb, token_emb def encode_image(self, images=None): - x = self.img_encoder.conv1(images) # shape = [*, width, grid, grid] + x = self.visual.conv1(images) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat( [ - self.img_encoder.class_embedding.to(x.dtype) + self.visual.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device ), @@ -152,12 +152,12 @@ def encode_image(self, images=None): ], dim=1, ) # shape = [*, grid ** 2 + 1, width] - x = x + self.img_encoder.positional_embedding.to(x.dtype) - x = self.img_encoder.ln_pre(x) + x = x + self.visual.positional_embedding.to(x.dtype) + x = self.visual.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND - x = self.img_encoder.transformer(x) + x = self.visual.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD - x = self.img_encoder.ln_post(x) + x = self.visual.ln_post(x) x = self.img_attn_pool(x) x = self.img_attn_pool_norm(x) From 59503df4c8a295a90657621ea3025d33db788be5 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 16:29:27 +0100 Subject: [PATCH 043/113] add coca funcs to init --- src/open_clip/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index d3bc068a7..f06c5d62c 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -4,6 +4,7 @@ from .loss import ClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .coca_model import CoCa, CoCaCfg from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained From 504febde258ac875a545484fec98287bb3070293 Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 16:31:39 +0100 Subject: [PATCH 044/113] add coca config and exclude from testing --- src/open_clip/model_configs/coca_base.json | 32 ++++++++++++++++++++++ tests/test_inference.py | 1 + 2 files changed, 33 insertions(+) create mode 100644 src/open_clip/model_configs/coca_base.json diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json new file mode 100644 index 000000000..3bb938c08 --- /dev/null +++ b/src/open_clip/model_configs/coca_base.json @@ -0,0 +1,32 @@ +{ + "embed_dim": 768, + "coca_cfg": { + "width":768, + "model_name": "coca_base", + "context_length": 77, + "image_dim": 768, + "mlp_ratio": 4, + "ls_init_value": "None", + "layers": 12, + "dim_head": 64, + "heads": 12, + "contrastive_loss_weight": 1.0, + "caption_loss_weight": 2.0, + "n_queries": 256 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18 + }, + "text_cfg": { + "vocab_size": 64000, + "layers": 12, + "dim_head": 64, + "heads":12, + "mlp_ration": 4, + "context_length": 512, + "ls_init_value": "None" + } +} diff --git a/tests/test_inference.py b/tests/test_inference.py index de7ec2b48..f232e02ac 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -20,6 +20,7 @@ 'ViT-G-14', 'ViT-e-14', 'mt5-xl-ViT-H-14', + 'coca_base' }) @pytest.mark.parametrize('model_name', models_to_test) From 72a7e96ae75c2a732ac16143095b301d53e16ace Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 5 Dec 2022 16:32:20 +0100 Subject: [PATCH 045/113] add and comment simple test (no trained model) --- tests/test_inference_simple.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_inference_simple.py b/tests/test_inference_simple.py index fb6bb4958..038a139fc 100644 --- a/tests/test_inference_simple.py +++ b/tests/test_inference_simple.py @@ -3,6 +3,7 @@ from PIL import Image from open_clip.factory import get_tokenizer import pytest +import numpy as np import open_clip import os os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -24,3 +25,22 @@ def test_inference_simple(model_type, pretrained): text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] + + +# def test_inference_simple_coca(): +# model = open_clip.CoCa(512, open_clip.CoCaCfg(), open_clip.CLIPTextCfg(), open_clip.CLIPVisionCfg(width=512)) +# preprocess = open_clip.image_transform(model.visual.image_size, is_train=False, mean=0, std=0) + +# tokenizer = open_clip.tokenize +# current_dir = os.path.dirname(os.path.realpath(__file__)) + +# image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) +# text = tokenizer(["a diagram", "a dog", "a cat"]) + +# with torch.no_grad(): +# image_features = model.encode_image(image)[0] +# text_features = model.encode_text(text)[0] + +# text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) + +# assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] \ No newline at end of file From d8a94be37d5197af91923adc9d312f04e31a8a0a Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 12:21:23 +0900 Subject: [PATCH 046/113] add L2 norm --- src/open_clip/coca_model.py | 15 +++++++++++---- src/open_clip/transformer.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 2c8a270fd..f4ad7e8fb 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -2,6 +2,7 @@ import torch from torch import nn +from torch.nn import functional as F import numpy as np from dataclasses import dataclass @@ -29,6 +30,7 @@ class CoCaCfg: contrastive_loss_weight: float = 1.0 caption_loss_weight: float = 2.0 n_queries: int = 256 + dim_latents: int = None def _build_text_decoder_tower( @@ -91,7 +93,6 @@ def __init__( ) - self.multimodal_decoder = _build_text_decoder_tower( embed_dim, coca_cfg, quick_gelu, cast_dtype ) @@ -109,6 +110,10 @@ def __init__( self.temperature = nn.Parameter(torch.Tensor([1.0])) + self.dim_latents = coca_cfg.dim_latents if coca_cfg.dim_latents else coca_cfg.width + self.to_text_latents = nn.Linear(self.width, self.dim_latents, bias=False) + self.to_image_latents = nn.Linear(self.width, self.dim_latents, bias=False) + # to logits self.to_logits = nn.Sequential( norm_layer(self.width), nn.Linear(self.width, text_cfg.vocab_size, bias=False) @@ -127,16 +132,15 @@ def encode_text(self, text): x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=self.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), :] @ self.text_projection # looking at the tokenizer this seems ok - cls_emb = self.text_cls_norm(x[torch.arange(x.shape[0]), -1]) + cls_emb = x[torch.arange(x.shape[0]), -1] token_emb = x[torch.arange(x.shape[0]), :-1] - return cls_emb, token_emb + return self.text_cls_norm(cls_emb), token_emb def encode_image(self, images=None): x = self.visual.conv1(images) # shape = [*, width, grid, grid] @@ -178,6 +182,9 @@ def forward( text_embeds, text_tokens = self.encode_text(text) image_embeds, image_tokens = self.encode_image(images) + text_embeds = F.normalize(self.to_text_latents(text_embeds), dim=-1) + image_embeds = F.normalize(self.to_image_latents(image_embeds), dim=-1) + text_tokens = self.multimodal_decoder( text_tokens, image_tokens, eot_token_mask=text.argmax(dim=-1) ) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 0ced0cafd..6cae57d0c 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -138,7 +138,7 @@ class AttentionPooler(nn.Module): def __init__( self, d_model: int, - n_head: int = 1, + n_head: int = 8, n_queries: int = 256, ): super().__init__() From d250eace0132b5c9bbffc9efb5c66f4efe1a7075 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 12:41:50 +0900 Subject: [PATCH 047/113] make L2 same as in clip --- src/open_clip/coca_model.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index f4ad7e8fb..3ee58cecc 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -111,8 +111,8 @@ def __init__( self.temperature = nn.Parameter(torch.Tensor([1.0])) self.dim_latents = coca_cfg.dim_latents if coca_cfg.dim_latents else coca_cfg.width - self.to_text_latents = nn.Linear(self.width, self.dim_latents, bias=False) - self.to_image_latents = nn.Linear(self.width, self.dim_latents, bias=False) + self.to_text_latent = nn.Linear(self.width, self.dim_latents, bias=False) + self.to_image_latent = nn.Linear(self.width, self.dim_latents, bias=False) # to logits self.to_logits = nn.Sequential( @@ -124,7 +124,7 @@ def __init__( self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - def encode_text(self, text): + def encode_text(self, text, normalize=True): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] @@ -140,9 +140,14 @@ def encode_text(self, text): # looking at the tokenizer this seems ok cls_emb = x[torch.arange(x.shape[0]), -1] token_emb = x[torch.arange(x.shape[0]), :-1] - return self.text_cls_norm(cls_emb), token_emb - def encode_image(self, images=None): + cls_emb = self.text_cls_norm(cls_emb) + text_latent = self.to_text_latent(cls_emb) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latents + + return text_latent, token_emb + + def encode_image(self, images=None, normalize=True): x = self.visual.conv1(images) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -166,7 +171,10 @@ def encode_image(self, images=None): x = self.img_attn_pool(x) x = self.img_attn_pool_norm(x) - return x[:, 0], x[:, 1:] + image_latent = self.to_image_latent(x[:, 0]) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + + return image_latent, x[:, 1:] def forward( self, @@ -179,15 +187,12 @@ def forward( if labels is None: text, labels = text[:, :-1], text[:, 1:] - text_embeds, text_tokens = self.encode_text(text) - image_embeds, image_tokens = self.encode_image(images) - - text_embeds = F.normalize(self.to_text_latents(text_embeds), dim=-1) - image_embeds = F.normalize(self.to_image_latents(image_embeds), dim=-1) + text_latents, text_tokens = self.encode_text(text) + image_latents, image_tokens = self.encode_image(images) text_tokens = self.multimodal_decoder( text_tokens, image_tokens, eot_token_mask=text.argmax(dim=-1) ) logits = self.to_logits(text_tokens) - return text_embeds, image_embeds, logits, labels, self.logit_scale + return text_latents, image_latents, logits, labels, self.logit_scale From 8d9dfa63af1df6def9032021539e86797580cd57 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 12:43:43 +0900 Subject: [PATCH 048/113] remove unused temperature --- src/open_clip/coca_model.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 3ee58cecc..e91e99d63 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -92,7 +92,6 @@ def __init__( embed_dim, vision_cfg, quick_gelu, cast_dtype ) - self.multimodal_decoder = _build_text_decoder_tower( embed_dim, coca_cfg, quick_gelu, cast_dtype ) @@ -106,10 +105,6 @@ def __init__( self.img_attn_pool_norm = norm_layer(self.width) self.text_cls_norm = norm_layer(self.width) - # contrastive learning temperature - - self.temperature = nn.Parameter(torch.Tensor([1.0])) - self.dim_latents = coca_cfg.dim_latents if coca_cfg.dim_latents else coca_cfg.width self.to_text_latent = nn.Linear(self.width, self.dim_latents, bias=False) self.to_image_latent = nn.Linear(self.width, self.dim_latents, bias=False) From 1f2578c8a9df0899daa43c3e169b467bc1412155 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 12:44:11 +0900 Subject: [PATCH 049/113] type --- src/open_clip/coca_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index e91e99d63..5ef3d5198 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -138,7 +138,7 @@ def encode_text(self, text, normalize=True): cls_emb = self.text_cls_norm(cls_emb) text_latent = self.to_text_latent(cls_emb) - text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latents + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb From d8ff1bd25425a1c0066785c6052e87d074ef841b Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 12:57:37 +0900 Subject: [PATCH 050/113] clean --- src/open_clip/transformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 6cae57d0c..231118873 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -588,8 +588,4 @@ def forward(self, text_embs, image_embs, eot_token_mask): x = text_embs.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - # x = x[torch.arange(x.shape[0]), eot_token_mask] @ self.text_projection - return x From fa24047239e02c88cf7e433df4d802e89131569a Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 14:38:22 +0900 Subject: [PATCH 051/113] fix config --- src/open_clip/model_configs/coca_base.json | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index 3bb938c08..297ca75b0 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -1,12 +1,10 @@ { "embed_dim": 768, - "coca_cfg": { - "width":768, - "model_name": "coca_base", + "decoder_cfg": { + "width": 768, "context_length": 77, "image_dim": 768, "mlp_ratio": 4, - "ls_init_value": "None", "layers": 12, "dim_head": 64, "heads": 12, @@ -23,10 +21,8 @@ "text_cfg": { "vocab_size": 64000, "layers": 12, - "dim_head": 64, - "heads":12, - "mlp_ration": 4, - "context_length": 512, - "ls_init_value": "None" + "heads": 12, + "width": 768, + "context_length": 512 } -} +} \ No newline at end of file From f61f9d5f77767d28ffd0cd306b0f8046d20441f8 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 14:38:51 +0900 Subject: [PATCH 052/113] make rename and move cfg --- src/open_clip/coca_model.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 5ef3d5198..df6755030 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -17,8 +17,7 @@ @dataclass -class CoCaCfg: - model_name: str = "CoCa_base" +class TextDecoderCfg: context_length:int = 76 width: int = 512 image_dim: int = 512 @@ -35,12 +34,12 @@ class CoCaCfg: def _build_text_decoder_tower( embed_dim: int, - coca_cfg: CoCaCfg, + decoder_cfg: TextDecoderCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): - if isinstance(coca_cfg, dict): - coca_cfg = CoCaCfg(**coca_cfg) + if isinstance(decoder_cfg, dict): + decoder_cfg = TextDecoderCfg(**decoder_cfg) act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( @@ -48,24 +47,24 @@ def _build_text_decoder_tower( ) text = TransformerDecoder( - context_length=coca_cfg.context_length, - width=coca_cfg.width, - heads=coca_cfg.heads, - layers=coca_cfg.layers, - ls_init_value=coca_cfg.ls_init_value, + context_length=decoder_cfg.context_length, + width=decoder_cfg.width, + heads=decoder_cfg.heads, + layers=decoder_cfg.layers, + ls_init_value=decoder_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) - return text + return text, decoder_cfg class CoCa(nn.Module): def __init__( self, embed_dim, - coca_cfg: CoCaCfg, + decoder_cfg: TextDecoderCfg, text_cfg: CLIPTextCfg, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, @@ -73,6 +72,7 @@ def __init__( ): super().__init__() + norm_layer = ( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) @@ -92,26 +92,26 @@ def __init__( embed_dim, vision_cfg, quick_gelu, cast_dtype ) - self.multimodal_decoder = _build_text_decoder_tower( - embed_dim, coca_cfg, quick_gelu, cast_dtype + self.multimodal_decoder, decoder_cfg = _build_text_decoder_tower( + embed_dim, decoder_cfg, quick_gelu, cast_dtype ) - self.width = coca_cfg.width + self.width = decoder_cfg.width self.img_attn_pool = AttentionPooler( - coca_cfg.width, coca_cfg.heads, n_queries=coca_cfg.n_queries + 1 + decoder_cfg.width, decoder_cfg.heads, n_queries=decoder_cfg.n_queries + 1 ) self.img_attn_pool_norm = norm_layer(self.width) self.text_cls_norm = norm_layer(self.width) - self.dim_latents = coca_cfg.dim_latents if coca_cfg.dim_latents else coca_cfg.width + self.dim_latents = decoder_cfg.dim_latents if decoder_cfg.dim_latents else decoder_cfg.width self.to_text_latent = nn.Linear(self.width, self.dim_latents, bias=False) self.to_image_latent = nn.Linear(self.width, self.dim_latents, bias=False) # to logits self.to_logits = nn.Sequential( - norm_layer(self.width), nn.Linear(self.width, text_cfg.vocab_size, bias=False) + norm_layer(self.width), nn.Linear(self.width, self.vocab_size, bias=False) ) # they used embedding weight tied projection out to logits, not common, but works From 4b76187bbe10170da23df04adc97f625c13d6de8 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 14:39:06 +0900 Subject: [PATCH 053/113] rename --- src/open_clip/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index f06c5d62c..0e8f4ae33 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -4,7 +4,7 @@ from .loss import ClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype -from .coca_model import CoCa, CoCaCfg +from .coca_model import CoCa from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained From b8777feadd06d5a4229a19486851b0fb63d62b8b Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 14:39:30 +0900 Subject: [PATCH 054/113] temptative add coca to factory --- src/open_clip/factory.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 73dbdabe0..9d84b7885 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -12,6 +12,8 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model from .transform import image_transform @@ -149,7 +151,10 @@ def create_model( model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) pretrained_cfg = {} if pretrained: @@ -183,6 +188,26 @@ def create_model( return model +def create_loss(args): + + if "coca" not in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.caption_loss_weight, + clip_loss_weight=args.clip_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod) + def create_model_and_transforms( model_name: str, From 42aa408646b99a71c0e9f95d3712350e93cf4418 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 15:57:53 +0900 Subject: [PATCH 055/113] fix config --- src/open_clip/coca_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index df6755030..f40fa4aec 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -26,7 +26,7 @@ class TextDecoderCfg: layers: int = 12 dim_head: int = 64 heads: int = 8 - contrastive_loss_weight: float = 1.0 + clip_loss_weight: float = 1.0 caption_loss_weight: float = 2.0 n_queries: int = 256 dim_latents: int = None @@ -96,6 +96,11 @@ def __init__( embed_dim, decoder_cfg, quick_gelu, cast_dtype ) + self.loss_parameters = { + "caption_loss_weight": decoder_cfg.caption_loss_weight, + "clip_loss_weight": decoder_cfg.clip_loss_weight + } + self.width = decoder_cfg.width self.img_attn_pool = AttentionPooler( From 1044f363457aab637d2acd59bc52e32c69981f47 Mon Sep 17 00:00:00 2001 From: gpucce Date: Tue, 6 Dec 2022 15:59:10 +0900 Subject: [PATCH 056/113] update config --- src/open_clip/model_configs/coca_base.json | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index 297ca75b0..fe2cf84ea 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -8,7 +8,7 @@ "layers": 12, "dim_head": 64, "heads": 12, - "contrastive_loss_weight": 1.0, + "clip_loss_weight": 1.0, "caption_loss_weight": 2.0, "n_queries": 256 }, @@ -24,5 +24,6 @@ "heads": 12, "width": 768, "context_length": 512 - } + }, + "custom_text": "True" } \ No newline at end of file From dab7d7d12c5bbfa6d43ee0f2b454f181ad53d24e Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 14:12:12 +0900 Subject: [PATCH 057/113] embed contrastive cls token in model --- src/open_clip/coca_model.py | 43 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index f40fa4aec..78a839e46 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -18,7 +18,7 @@ @dataclass class TextDecoderCfg: - context_length:int = 76 + context_length: int = 77 width: int = 512 image_dim: int = 512 mlp_ratio: int = 4 @@ -67,6 +67,7 @@ def __init__( decoder_cfg: TextDecoderCfg, text_cfg: CLIPTextCfg, vision_cfg: CLIPVisionCfg, + n_queries: int = 256, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): @@ -88,6 +89,7 @@ def __init__( self.text_projection = text.text_projection self.register_buffer("attn_mask", text.attn_mask, persistent=False) + self.cls_token = nn.Parameter(torch.randn(embed_dim)) self.visual = _build_vision_tower( embed_dim, vision_cfg, quick_gelu, cast_dtype ) @@ -101,33 +103,39 @@ def __init__( "clip_loss_weight": decoder_cfg.clip_loss_weight } - self.width = decoder_cfg.width self.img_attn_pool = AttentionPooler( - decoder_cfg.width, decoder_cfg.heads, n_queries=decoder_cfg.n_queries + 1 + decoder_cfg.width, decoder_cfg.heads, n_queries=n_queries + 1 ) - self.img_attn_pool_norm = norm_layer(self.width) - self.text_cls_norm = norm_layer(self.width) + self.img_attn_pool_norm = norm_layer(embed_dim) + self.text_cls_norm = norm_layer(embed_dim) self.dim_latents = decoder_cfg.dim_latents if decoder_cfg.dim_latents else decoder_cfg.width - self.to_text_latent = nn.Linear(self.width, self.dim_latents, bias=False) - self.to_image_latent = nn.Linear(self.width, self.dim_latents, bias=False) + self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) + self.to_image_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) # to logits self.to_logits = nn.Sequential( - norm_layer(self.width), nn.Linear(self.width, self.vocab_size, bias=False) + norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False) ) - # they used embedding weight tied projection out to logits, not common, but works + # tie embedding weights and projection self.to_logits[-1].weight = self.token_embedding.weight self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + def _repeat(self, t, N): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + def encode_text(self, text, normalize=True): cast_dtype = self.transformer.get_cast_dtype() + # cls_mask = (text!=self.pad_id).unsqueeze(1) + # attn_mask = F.pad(cls_mask, (0, 1, text.shape[1], 0), value=True) + # attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0) x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=self.attn_mask) @@ -176,23 +184,14 @@ def encode_image(self, images=None, normalize=True): return image_latent, x[:, 1:] - def forward( - self, - text, - images=None, - image_tokens=None, - labels=None, - ): + def forward(self, image, text,): - if labels is None: - text, labels = text[:, :-1], text[:, 1:] + text, labels = text[:, :-1], text[:, 1:] text_latents, text_tokens = self.encode_text(text) - image_latents, image_tokens = self.encode_image(images) + image_latents, image_tokens = self.encode_image(image) - text_tokens = self.multimodal_decoder( - text_tokens, image_tokens, eot_token_mask=text.argmax(dim=-1) - ) + text_tokens = self.multimodal_decoder(text_tokens, image_tokens) logits = self.to_logits(text_tokens) return text_latents, image_latents, logits, labels, self.logit_scale From d0ae6832634367c4a4cdd32f7706257865a2f45a Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 14:13:06 +0900 Subject: [PATCH 058/113] remove unused arg --- src/open_clip/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 231118873..d223e5f33 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -573,7 +573,7 @@ def build_attention_mask(self): mask.triu_(1) # zero out the lower diagonal return mask - def forward(self, text_embs, image_embs, eot_token_mask): + def forward(self, text_embs, image_embs): text_embs = text_embs.permute(1, 0, 2) # NLD -> LND image_embs = image_embs.permute(1, 0, 2) # NLD -> LND From 5a40804e1fc92371124c092477d88bbd1402cc08 Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:25:10 +0900 Subject: [PATCH 059/113] import create_loss --- src/open_clip/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index 0e8f4ae33..a4e11b4f7 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -1,5 +1,5 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss from .factory import list_models, add_model_config, get_model_config, load_checkpoint from .loss import ClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\ From 67894380fe2ab038135fcb523f0f6b8a3a0d740c Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:25:46 +0900 Subject: [PATCH 060/113] make factory accept coca --- src/open_clip/factory.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 9d84b7885..d2293652e 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -74,8 +74,7 @@ def get_model_config(model_name): def get_tokenizer(model_name): config = get_model_config(model_name) - tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize - return tokenizer + return HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize def load_state_dict(checkpoint_path: str, map_location='cpu'): @@ -149,12 +148,12 @@ def create_model( if custom_text: if 'hf_model_name' in model_cfg.get('text_cfg', {}): model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) - else: if "coca" in model_name: model = CoCa(**model_cfg, cast_dtype=cast_dtype) else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) pretrained_cfg = {} if pretrained: @@ -189,17 +188,16 @@ def create_model( return model def create_loss(args): - - if "coca" not in args.model.lower(): + if "coca" in args.model.lower(): return CoCaLoss( - caption_loss_weight=args.caption_loss_weight, - clip_loss_weight=args.clip_loss_weight, - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod) + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod) return ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, From 60865ef6836414e1e9163bdb2c81c9fe14da580c Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:26:08 +0900 Subject: [PATCH 061/113] make caption loss distributed --- src/open_clip/loss.py | 70 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 6101c9d3d..5462bc619 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -61,6 +61,51 @@ def gather_features( return all_image_features, all_text_features +# def gather_logits_and_labels( +# logits, +# labels, +# local_loss=False, +# gather_with_grad=False, +# rank=0, +# world_size=1, +# use_horovod=False +# ): +# assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' +# if use_horovod: +# assert hvd is not None, 'Please install horovod' +# if gather_with_grad: +# all_logits = hvd.allgather(logits) +# all_labels = hvd.allgather(labels) +# else: +# with torch.no_grad(): +# all_logits = hvd.allgather(logits) +# all_labels = hvd.allgather(labels) +# if not local_loss: +# # ensure grads for local rank when all_* features don't have a gradient +# gathered_logits = list(logits.chunk(world_size, dim=0)) +# gathered_labels = list(labels.chunk(world_size, dim=0)) +# gathered_logits[rank] = logits +# gathered_labels[rank] = labels +# all_image_features = torch.cat(gathered_logits, dim=0) +# all_text_features = torch.cat(gathered_labels, dim=0) +# else: +# # We gather tensors from all gpus +# if gather_with_grad: +# all_logits = torch.cat(torch.distributed.nn.all_gather(logits), dim=0) +# all_labels = torch.cat(torch.distributed.nn.all_gather(labels), dim=0) +# else: +# gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] +# gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] +# dist.all_gather(gathered_image_features, image_features) +# dist.all_gather(gathered_text_features, text_features) +# if not local_loss: +# # ensure grads for local rank when all_* features don't have a gradient +# gathered_image_features[rank] = image_features +# gathered_text_features[rank] = text_features +# all_image_features = torch.cat(gathered_image_features, dim=0) +# all_text_features = torch.cat(gathered_text_features, dim=0) + +# return all_image_features, all_text_features class ClipLoss(nn.Module): @@ -121,12 +166,12 @@ def forward(self, image_features, text_features, logit_scale): return total_loss -class CoCaLoss(nn.Module): +class CoCaLoss(ClipLoss): def __init__( self, caption_loss_weight, clip_loss_weight, - pad_id, + pad_id=-100, local_loss=False, gather_with_grad=False, cache_labels=False, @@ -134,8 +179,7 @@ def __init__( world_size=1, use_horovod=False, ): - super().__init__() - self.clip_loss = ClipLoss( + super().__init__( local_loss=local_loss, gather_with_grad=gather_with_grad, cache_labels=cache_labels, @@ -143,17 +187,25 @@ def __init__( world_size=world_size, use_horovod=use_horovod ) + self.clip_loss_weight = clip_loss_weight - self.caption_loss = nn.CrossEntropyLoss() self.caption_loss_weight = caption_loss_weight - self.pad_id = pad_id + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) def forward(self, image_features, text_features, logits, labels, logit_scale): - clip_loss = self.clip_loss(image_features, text_features, logit_scale) + clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss - logits = logits.permute(0, 2, 1) - caption_loss = self.caption_loss(logits, labels, ignore_index=self.pad_id) + if self.world_size > 1: + all_logits, all_labels = gather_features( + logits.contiguous(), labels.contiguous(), + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + else: + all_logits = logits + all_labels = labels + + all_logits = all_logits.permute(0, 2, 1) + caption_loss = self.caption_loss(all_logits, all_labels) caption_loss = caption_loss * self.caption_loss_weight return clip_loss + caption_loss From ac617bf56bbce694d7d345dc537767bbf58a1254 Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:30:54 +0900 Subject: [PATCH 062/113] make loss customizable --- src/training/main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index b0e56de27..a1c20209b 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -24,7 +24,7 @@ except ImportError: hvd = None -from open_clip import create_model_and_transforms, trace_model, get_tokenizer +from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from training.data import get_data from training.distributed import is_master, init_distributed_device, world_info_from_env from training.logger import setup_logging @@ -267,7 +267,8 @@ def main(args): if is_master(args): logging.info(f'Start epoch {epoch}') - train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + loss = create_loss(args) + train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=writer) completed_epoch = epoch + 1 if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): From b9c2b259605c4cabc7e222b02419542c898b152f Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:31:22 +0900 Subject: [PATCH 063/113] pass loss trhough training_epoch --- src/training/train.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index c3b953a4a..812f2890e 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -13,7 +13,7 @@ except ImportError: wandb = None -from open_clip import ClipLoss, get_cast_dtype +from open_clip import get_cast_dtype from .distributed import is_master from .zero_shot import zero_shot_eval from .precision import get_autocast @@ -44,19 +44,12 @@ def unwrap_model(model): return model -def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): +def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=None): device = torch.device(args.device) autocast = get_autocast(args.precision) cast_dtype = get_cast_dtype(args.precision) model.train() - loss = ClipLoss( - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod) data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data['train'].dataloader @@ -69,7 +62,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w end = time.time() for i, batch in enumerate(dataloader): step = num_batches_per_epoch * epoch + i - + if not args.skip_scheduler: scheduler(step) @@ -81,8 +74,9 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w optimizer.zero_grad() with autocast(): - image_features, text_features, logit_scale = model(images, texts) - total_loss = loss(image_features, text_features, logit_scale) + loss_args = model(images, texts) + logit_scale = loss_args[-1] + total_loss = loss(*loss_args) if scaler is not None: scaler.scale(total_loss).backward() @@ -182,7 +176,10 @@ def evaluate(model, data, epoch, args, tb_writer=None): texts = texts.to(device=device, non_blocking=True) with autocast(): - image_features, text_features, logit_scale = model(images, texts) + model_out = model(images, texts) + image_features = model_out[0] + text_features = model_out[1] + logit_scale = model_out[-1] # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) From ccfd1e47f9afbcd20f95f71bec3e999b59c9e929 Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:31:48 +0900 Subject: [PATCH 064/113] add coca specific params to params --- src/training/params.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/training/params.py b/src/training/params.py index d6d7a8231..e4ba3d5b1 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -317,6 +317,18 @@ def parse_args(args): default=100, help="Log every n steps to tensorboard/console/wandb.", ) + parser.add_argument( + "--coca-caption-loss-weight", + type=float, + default=2.0, + help="Weight assigned to caption loss in CoCa." + ) + parser.add_argument( + "--coca-contrastive-loss-weight", + type=float, + default=1.0, + help="Weight assigned to contrastive loss when training CoCa." + ) args = parser.parse_args(args) From c1556d4f0ff902cb9be5c0ffbf6aa9d9a8d9a482 Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:32:23 +0900 Subject: [PATCH 065/113] removed decoder unused parameters --- src/open_clip/transformer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index d223e5f33..84bbc9765 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -540,9 +540,6 @@ def __init__( self.ln_final = norm_layer(width) - # this will be shared with the textual decoder (in CoCa) - self.text_projection = nn.Parameter(torch.empty(width, output_dim)) - def init_parameters(self): proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 From 68d608a4f0988c4569d734e227dcf4daf7b273cb Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:32:50 +0900 Subject: [PATCH 066/113] remove unused attributes --- src/open_clip/coca_model.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 78a839e46..68d0699d0 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -98,18 +98,11 @@ def __init__( embed_dim, decoder_cfg, quick_gelu, cast_dtype ) - self.loss_parameters = { - "caption_loss_weight": decoder_cfg.caption_loss_weight, - "clip_loss_weight": decoder_cfg.clip_loss_weight - } - - self.img_attn_pool = AttentionPooler( decoder_cfg.width, decoder_cfg.heads, n_queries=n_queries + 1 ) self.img_attn_pool_norm = norm_layer(embed_dim) - self.text_cls_norm = norm_layer(embed_dim) self.dim_latents = decoder_cfg.dim_latents if decoder_cfg.dim_latents else decoder_cfg.width self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) @@ -149,7 +142,7 @@ def encode_text(self, text, normalize=True): cls_emb = x[torch.arange(x.shape[0]), -1] token_emb = x[torch.arange(x.shape[0]), :-1] - cls_emb = self.text_cls_norm(cls_emb) + cls_emb = self.ln_final(cls_emb) text_latent = self.to_text_latent(cls_emb) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent @@ -179,7 +172,9 @@ def encode_image(self, images=None, normalize=True): x = self.img_attn_pool(x) x = self.img_attn_pool_norm(x) - image_latent = self.to_image_latent(x[:, 0]) + image_latent = x[:, 0] + if self.visual.proj is not None: + image_latent = image_latent @ self.visual.proj image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent return image_latent, x[:, 1:] From 59d4db4924725d7fce9bdfc15763ccdc306a6e51 Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:33:26 +0900 Subject: [PATCH 067/113] adjust coca_config --- src/open_clip/model_configs/coca_base.json | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index fe2cf84ea..99995de8c 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -2,7 +2,7 @@ "embed_dim": 768, "decoder_cfg": { "width": 768, - "context_length": 77, + "context_length": 76, "image_dim": 768, "mlp_ratio": 4, "layers": 12, @@ -16,14 +16,16 @@ "image_size": 288, "layers": 12, "width": 768, - "patch_size": 18 + "patch_size": 18, + "final_proj": false }, "text_cfg": { + "context_length": 77, "vocab_size": 64000, "layers": 12, "heads": 12, "width": 768, - "context_length": 512 + "hf_tokenizer_name": "roberta-base" }, "custom_text": "True" } \ No newline at end of file From 732f15f1dec37a99d8ccb117545a654a44e1fea0 Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 16:41:41 +0900 Subject: [PATCH 068/113] fix config and remove unused parameters --- src/open_clip/coca_model.py | 1 - src/open_clip/model_configs/coca_base.json | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 68d0699d0..ea1c3f712 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -106,7 +106,6 @@ def __init__( self.dim_latents = decoder_cfg.dim_latents if decoder_cfg.dim_latents else decoder_cfg.width self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) - self.to_image_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) # to logits self.to_logits = nn.Sequential( diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index 99995de8c..e40bdc1e6 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -16,8 +16,7 @@ "image_size": 288, "layers": 12, "width": 768, - "patch_size": 18, - "final_proj": false + "patch_size": 18 }, "text_cfg": { "context_length": 77, From 17072c6d5aff8c91fb04318c91298d959d1bdd29 Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 17:27:12 +0900 Subject: [PATCH 069/113] remove comment --- src/open_clip/loss.py | 45 ------------------------------------------- 1 file changed, 45 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 5462bc619..e1a062d59 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -61,51 +61,6 @@ def gather_features( return all_image_features, all_text_features -# def gather_logits_and_labels( -# logits, -# labels, -# local_loss=False, -# gather_with_grad=False, -# rank=0, -# world_size=1, -# use_horovod=False -# ): -# assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' -# if use_horovod: -# assert hvd is not None, 'Please install horovod' -# if gather_with_grad: -# all_logits = hvd.allgather(logits) -# all_labels = hvd.allgather(labels) -# else: -# with torch.no_grad(): -# all_logits = hvd.allgather(logits) -# all_labels = hvd.allgather(labels) -# if not local_loss: -# # ensure grads for local rank when all_* features don't have a gradient -# gathered_logits = list(logits.chunk(world_size, dim=0)) -# gathered_labels = list(labels.chunk(world_size, dim=0)) -# gathered_logits[rank] = logits -# gathered_labels[rank] = labels -# all_image_features = torch.cat(gathered_logits, dim=0) -# all_text_features = torch.cat(gathered_labels, dim=0) -# else: -# # We gather tensors from all gpus -# if gather_with_grad: -# all_logits = torch.cat(torch.distributed.nn.all_gather(logits), dim=0) -# all_labels = torch.cat(torch.distributed.nn.all_gather(labels), dim=0) -# else: -# gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] -# gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] -# dist.all_gather(gathered_image_features, image_features) -# dist.all_gather(gathered_text_features, text_features) -# if not local_loss: -# # ensure grads for local rank when all_* features don't have a gradient -# gathered_image_features[rank] = image_features -# gathered_text_features[rank] = text_features -# all_image_features = torch.cat(gathered_image_features, dim=0) -# all_text_features = torch.cat(gathered_text_features, dim=0) - -# return all_image_features, all_text_features class ClipLoss(nn.Module): From 74d5e377f3fede3d4630a0d032761d29a2ac4d6c Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 7 Dec 2022 14:09:16 +0100 Subject: [PATCH 070/113] remove more comments --- src/open_clip/coca_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index ea1c3f712..bdce002b8 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -107,7 +107,6 @@ def __init__( self.dim_latents = decoder_cfg.dim_latents if decoder_cfg.dim_latents else decoder_cfg.width self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) - # to logits self.to_logits = nn.Sequential( norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False) ) @@ -126,6 +125,7 @@ def encode_text(self, text, normalize=True): # cls_mask = (text!=self.pad_id).unsqueeze(1) # attn_mask = F.pad(cls_mask, (0, 1, text.shape[1], 0), value=True) # attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0) + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) x = x + self.positional_embedding.to(cast_dtype) @@ -137,7 +137,6 @@ def encode_text(self, text, normalize=True): # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), :] @ self.text_projection - # looking at the tokenizer this seems ok cls_emb = x[torch.arange(x.shape[0]), -1] token_emb = x[torch.arange(x.shape[0]), :-1] From 578aadf56304293cefd9b0d38c06999ae31fe5ff Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 13:43:16 +0900 Subject: [PATCH 071/113] rename attention pooler --- src/open_clip/coca_model.py | 4 ++-- src/open_clip/transformer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index bdce002b8..78bfde16e 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -11,7 +11,7 @@ LayerNorm, QuickGELU, TransformerDecoder, - AttentionPooler, + AttentionalPooler, ) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @@ -98,7 +98,7 @@ def __init__( embed_dim, decoder_cfg, quick_gelu, cast_dtype ) - self.img_attn_pool = AttentionPooler( + self.img_attn_pool = AttentionalPooler( decoder_cfg.width, decoder_cfg.heads, n_queries=n_queries + 1 ) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 84bbc9765..cca1b6682 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -134,7 +134,7 @@ def forward(self, x = self.out_drop(x) return x -class AttentionPooler(nn.Module): +class AttentionalPooler(nn.Module): def __init__( self, d_model: int, From 08f43a3ec4b315aee325686b3d68bd99d57f5681 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 13:59:03 +0900 Subject: [PATCH 072/113] rename TransformerDecoder --- src/open_clip/coca_model.py | 4 ++-- src/open_clip/transformer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 78bfde16e..ff48fb248 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -10,7 +10,7 @@ LayerNormFp32, LayerNorm, QuickGELU, - TransformerDecoder, + MultimodalTransformer, AttentionalPooler, ) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @@ -46,7 +46,7 @@ def _build_text_decoder_tower( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) - text = TransformerDecoder( + text = MultimodalTransformer( context_length=decoder_cfg.context_length, width=decoder_cfg.width, heads=decoder_cfg.heads, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index cca1b6682..7031b4e62 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -506,7 +506,7 @@ def forward(self, text): return x -class TransformerDecoder(Transformer): +class MultimodalTransformer(Transformer): def __init__( self, width: int, From 812a8bbbeb0717f864d8a31aa1a35f35da5961d7 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 16:12:46 +0900 Subject: [PATCH 073/113] make AttentionalPooler clearer --- src/open_clip/transformer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 7031b4e62..49517a971 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -145,12 +145,11 @@ def __init__( self.query = nn.Parameter(torch.randn(n_queries, d_model)) self.attn = nn.MultiheadAttention(d_model, n_head) - def forward(self, kv: torch.Tensor): - kv = kv.permute(1, 0 ,2) # NLD -> LND - N = kv.shape[1] - kv = self.attn(self._repeat(self.query, N), kv, kv, need_weights=False)[0] - out = kv.permute(1, 0, 2) # LND -> NLD - return out + def forward(self, k: torch.Tensor, v: torch.Tensor): + k, v = k.permute(1, 0, 2), v.permute(1, 0 ,2) # NLD -> LND + N = k.shape[1] + out = self.attn(self._repeat(self.query, N), k, v, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD def _repeat(self, query, N): return query.unsqueeze(1).repeat(1, N, 1) From f69f4e064b0ac0dbbeae0784d6aa378b293de739 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 16:58:18 +0900 Subject: [PATCH 074/113] add local loss logic to cocaloss --- src/open_clip/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index e1a062d59..87c643fa1 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -151,7 +151,7 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss - if self.world_size > 1: + if self.world_size > 1 and not self.local_loss: all_logits, all_labels = gather_features( logits.contiguous(), labels.contiguous(), self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) From 3c02aa5f64a951f9b8819f724961e0965884c33a Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 17:30:15 +0900 Subject: [PATCH 075/113] only create loss if train in data --- src/training/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/training/main.py b/src/training/main.py index a1c20209b..d17912435 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -263,11 +263,12 @@ def main(args): evaluate(model, data, start_epoch, args, writer) return + loss = create_loss(args) + for epoch in range(start_epoch, args.epochs): if is_master(args): logging.info(f'Start epoch {epoch}') - loss = create_loss(args) train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=writer) completed_epoch = epoch + 1 From 979cef4350a1325198f31f0adce3562329aa85b2 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 17:31:59 +0900 Subject: [PATCH 076/113] remove wrong file --- src/training/coca_train.py | 263 ------------------------------------- 1 file changed, 263 deletions(-) delete mode 100644 src/training/coca_train.py diff --git a/src/training/coca_train.py b/src/training/coca_train.py deleted file mode 100644 index 28e2084af..000000000 --- a/src/training/coca_train.py +++ /dev/null @@ -1,263 +0,0 @@ -import json -import logging -import math -import os -import time - -import numpy as np -import torch -import torch.nn.functional as F - -try: - import wandb -except ImportError: - wandb = None - -from open_clip import CoCaLoss, get_cast_dtype -from .distributed import is_master -from .zero_shot import zero_shot_eval -from .precision import get_autocast - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def unwrap_model(model): - if hasattr(model, 'module'): - return model.module - else: - return model - - -def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): - device = torch.device(args.device) - autocast = get_autocast(args.precision) - cast_dtype = get_cast_dtype(args.precision) - - model.train() - loss = CoCaLoss( - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod, - clip_loss_weight=1.0, - caption_loss_weight=2.0 - ) - - data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch - dataloader = data['train'].dataloader - num_batches_per_epoch = dataloader.num_batches - sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) - - loss_m = AverageMeter() - batch_time_m = AverageMeter() - data_time_m = AverageMeter() - end = time.time() - for i, batch in enumerate(dataloader): - step = num_batches_per_epoch * epoch + i - - if not args.skip_scheduler: - scheduler(step) - - images, texts = batch - images = images.to(device=device, dtype=cast_dtype, non_blocking=True) - texts = texts.to(device=device, non_blocking=True) - - data_time_m.update(time.time() - end) - optimizer.zero_grad() - - with autocast(): - image_features, text_features, logits, labels, logit_scale = model(images, texts) - total_loss = loss(image_features, text_features, logits, labels, logit_scale) - - if scaler is not None: - scaler.scale(total_loss).backward() - if args.horovod: - optimizer.synchronize() - scaler.unscale_(optimizer) - if args.grad_clip_norm is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) - with optimizer.skip_synchronize(): - scaler.step(optimizer) - else: - if args.grad_clip_norm is not None: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) - scaler.step(optimizer) - scaler.update() - else: - total_loss.backward() - if args.grad_clip_norm is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) - optimizer.step() - - # Note: we clamp to 4.6052 = ln(100), as in the original paper. - with torch.no_grad(): - unwrap_model(model).logit_scale.clamp_(0, math.log(100)) - - batch_time_m.update(time.time() - end) - end = time.time() - batch_count = i + 1 - if is_master(args) and (i % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): - batch_size = len(images) - num_samples = batch_count * batch_size * args.world_size - samples_per_epoch = dataloader.num_samples - percent_complete = 100.0 * batch_count / num_batches_per_epoch - - # NOTE loss is coarsely sampled, just master node and per log update - loss_m.update(total_loss.item(), batch_size) - logit_scale_scalar = logit_scale.item() - logging.info( - f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " - f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " - f"Data (t): {data_time_m.avg:.3f} " - f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size*args.world_size / batch_time_m.val:#g}/s " - f"LR: {optimizer.param_groups[0]['lr']:5f} " - f"Logit Scale: {logit_scale_scalar:.3f}" - ) - - # Save train loss / etc. Using non avg meter values as loggers have their own smoothing - log_data = { - "loss": loss_m.val, - "data_time": data_time_m.val, - "batch_time": batch_time_m.val, - "samples_per_scond": args.batch_size*args.world_size / batch_time_m.val, - "scale": logit_scale_scalar, - "lr": optimizer.param_groups[0]["lr"] - } - for name, val in log_data.items(): - name = "train/" + name - if tb_writer is not None: - tb_writer.add_scalar(name, val, step) - if args.wandb: - assert wandb is not None, 'Please install wandb.' - wandb.log({name: val, 'step': step}) - - # resetting batch / data time meters per log window - batch_time_m.reset() - data_time_m.reset() - # end for - - -def evaluate(model, data, epoch, args, tb_writer=None): - metrics = {} - if not is_master(args): - return metrics - device = torch.device(args.device) - model.eval() - - zero_shot_metrics = zero_shot_eval(model, data, epoch, args) - metrics.update(zero_shot_metrics) - - autocast = get_autocast(args.precision) - cast_dtype = get_cast_dtype(args.precision) - - if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): - dataloader = data['val'].dataloader - num_samples = 0 - samples_per_val = dataloader.num_samples - - # FIXME this does not scale past small eval datasets - # all_image_features @ all_text_features will blow up memory and compute very quickly - cumulative_loss = 0.0 - all_image_features, all_text_features = [], [] - with torch.no_grad(): - for i, batch in enumerate(dataloader): - images, texts = batch - images = images.to(device=device, dtype=cast_dtype, non_blocking=True) - texts = texts.to(device=device, non_blocking=True) - - with autocast(): - image_features, text_features, logit_scale = model(images, texts) - # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly - # however, system RAM is easily exceeded and compute time becomes problematic - all_image_features.append(image_features.cpu()) - all_text_features.append(text_features.cpu()) - logit_scale = logit_scale.mean() - logits_per_image = logit_scale * image_features @ text_features.t() - logits_per_text = logits_per_image.t() - - batch_size = images.shape[0] - labels = torch.arange(batch_size, device=device).long() - total_loss = ( - F.cross_entropy(logits_per_image, labels) + - F.cross_entropy(logits_per_text, labels) - ) / 2 - - cumulative_loss += total_loss * batch_size - num_samples += batch_size - if is_master(args) and (i % 100) == 0: - logging.info( - f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" - f"Loss: {cumulative_loss / num_samples:.6f}\t") - - val_metrics = get_metrics( - image_features=torch.cat(all_image_features), - text_features=torch.cat(all_text_features), - logit_scale=logit_scale.cpu(), - ) - loss = cumulative_loss / num_samples - metrics.update( - {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} - ) - - if not metrics: - return metrics - - logging.info( - f"Eval Epoch: {epoch} " - + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) - ) - - if args.save_logs: - for name, val in metrics.items(): - if tb_writer is not None: - tb_writer.add_scalar(f"val/{name}", val, epoch) - - with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: - f.write(json.dumps(metrics)) - f.write("\n") - - if args.wandb: - assert wandb is not None, 'Please install wandb.' - for name, val in metrics.items(): - wandb.log({f"val/{name}": val, 'epoch': epoch}) - - return metrics - - -def get_metrics(image_features, text_features, logit_scale): - metrics = {} - logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() - logits_per_text = logits_per_image.t().detach().cpu() - - logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} - ground_truth = torch.arange(len(text_features)).view(-1, 1) - - for name, logit in logits.items(): - ranking = torch.argsort(logit, descending=True) - preds = torch.where(ranking == ground_truth)[1] - preds = preds.detach().cpu().numpy() - metrics[f"{name}_mean_rank"] = preds.mean() + 1 - metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 - for k in [1, 5, 10]: - metrics[f"{name}_R@{k}"] = np.mean(preds < k) - - return metrics From 2ec204b5a4277c38432257ce3f96928e1755c6df Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 17:36:18 +0900 Subject: [PATCH 077/113] fix attentional pooler call --- src/open_clip/coca_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index ff48fb248..2807ac5b6 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -167,7 +167,7 @@ def encode_image(self, images=None, normalize=True): x = x.permute(1, 0, 2) # LND -> NLD x = self.visual.ln_post(x) - x = self.img_attn_pool(x) + x = self.img_attn_pool(x, x) x = self.img_attn_pool_norm(x) image_latent = x[:, 0] From 29c7dfae508f10612b27596c44e4c83a8e114e45 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 17:40:53 +0900 Subject: [PATCH 078/113] not ready for testing --- tests/test_inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 9350ab6d1..f1409421b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -20,7 +20,6 @@ 'ViT-bigG-14', 'ViT-e-14', 'mt5-xl-ViT-H-14', - 'coca_base' }) @pytest.mark.parametrize('model_name', models_to_test) From 5a4126b248023e9ee5b6d53a354845377364829b Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 17:44:19 +0900 Subject: [PATCH 079/113] really not ready for testing --- tests/test_inference_simple.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/tests/test_inference_simple.py b/tests/test_inference_simple.py index 038a139fc..99dce5a63 100644 --- a/tests/test_inference_simple.py +++ b/tests/test_inference_simple.py @@ -3,7 +3,6 @@ from PIL import Image from open_clip.factory import get_tokenizer import pytest -import numpy as np import open_clip import os os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -24,23 +23,4 @@ def test_inference_simple(model_type, pretrained): text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) - assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] - - -# def test_inference_simple_coca(): -# model = open_clip.CoCa(512, open_clip.CoCaCfg(), open_clip.CLIPTextCfg(), open_clip.CLIPVisionCfg(width=512)) -# preprocess = open_clip.image_transform(model.visual.image_size, is_train=False, mean=0, std=0) - -# tokenizer = open_clip.tokenize -# current_dir = os.path.dirname(os.path.realpath(__file__)) - -# image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) -# text = tokenizer(["a diagram", "a dog", "a cat"]) - -# with torch.no_grad(): -# image_features = model.encode_image(image)[0] -# text_features = model.encode_text(text)[0] - -# text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) - -# assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] \ No newline at end of file + assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] \ No newline at end of file From 6e4947403ba8cbc65b1da70085db45cd8bce89c3 Mon Sep 17 00:00:00 2001 From: gpucce Date: Thu, 8 Dec 2022 17:45:10 +0900 Subject: [PATCH 080/113] eof lien --- tests/test_inference_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inference_simple.py b/tests/test_inference_simple.py index 99dce5a63..fb6bb4958 100644 --- a/tests/test_inference_simple.py +++ b/tests/test_inference_simple.py @@ -23,4 +23,4 @@ def test_inference_simple(model_type, pretrained): text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) - assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] \ No newline at end of file + assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] From 599d448ed1936bbc085df1bf06ae113d56ec08ff Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 11:13:20 +0900 Subject: [PATCH 081/113] uniform names --- src/training/train.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index 3b30f8e79..1cfed09e2 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -83,9 +83,9 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args if args.accum_freq == 1: with autocast(): - loss_args = model(images, texts) - logit_scale = loss_args[-1] - total_loss = loss(*loss_args) + model_out = model(images, texts) + logit_scale = model_out[-1] + total_loss = loss(*model_out) backward(total_loss, scaler) else: @@ -242,16 +242,16 @@ def evaluate(model, data, epoch, args, tb_writer=None): if is_master(args) and (i % 100) == 0: logging.info( f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" - f"Loss: {cumulative_loss / num_samples:.6f}\t") + f"Clip Loss: {cumulative_loss / num_samples:.6f}\t") - val_metrics = get_metrics( + val_metrics = get_clip_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), logit_scale=logit_scale.cpu(), ) loss = cumulative_loss / num_samples metrics.update( - {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} + {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} ) if not metrics: @@ -279,7 +279,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): return metrics -def get_metrics(image_features, text_features, logit_scale): +def get_clip_metrics(image_features, text_features, logit_scale): metrics = {} logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_image.t().detach().cpu() @@ -297,3 +297,7 @@ def get_metrics(image_features, text_features, logit_scale): metrics[f"{name}_R@{k}"] = np.mean(preds < k) return metrics + + +def get_generative_metrics(logits, labels): + loss = F.cross_entropy(logits.reshape(0, 2, 1), labels) From d7953dacb8190b106c801b1db0a072246a7650d6 Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 11:48:28 +0900 Subject: [PATCH 082/113] add possible generative loss to evaluate --- src/training/train.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index 1cfed09e2..3c8389944 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -210,6 +210,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): # FIXME this does not scale past small eval datasets # all_image_features @ all_text_features will blow up memory and compute very quickly cumulative_loss = 0.0 + cumulative_gen_loss = 0.0 all_image_features, all_text_features = [], [] with torch.no_grad(): for i, batch in enumerate(dataloader): @@ -237,6 +238,8 @@ def evaluate(model, data, epoch, args, tb_writer=None): F.cross_entropy(logits_per_text, labels) ) / 2 + gen_loss = maybe_compute_generative_loss(model_out) + cumulative_loss += total_loss * batch_size num_samples += batch_size if is_master(args) and (i % 100) == 0: @@ -244,6 +247,12 @@ def evaluate(model, data, epoch, args, tb_writer=None): f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" f"Clip Loss: {cumulative_loss / num_samples:.6f}\t") + if gen_loss is not None: + cumulative_gen_loss += gen_loss * batch_size + logging.info( + f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t") + + val_metrics = get_clip_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), @@ -253,6 +262,9 @@ def evaluate(model, data, epoch, args, tb_writer=None): metrics.update( {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} ) + if gen_loss is not None: + gen_loss = cumulative_gen_loss / num_samples + metrics.update({"val_generative_loss": gen_loss.item()}) if not metrics: return metrics @@ -299,5 +311,8 @@ def get_clip_metrics(image_features, text_features, logit_scale): return metrics -def get_generative_metrics(logits, labels): - loss = F.cross_entropy(logits.reshape(0, 2, 1), labels) +def maybe_compute_generative_loss(model_out): + if len(model_out) > 3: + token_logits = model_out[2] + token_labels = model_out[3] + return F.cross_entropy(token_logits.reshape(0, 2, 1), token_labels) From e2042d4d7a63772657d23df40446045bea4a23d1 Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 12:01:19 +0900 Subject: [PATCH 083/113] change _build function names --- src/open_clip/coca_model.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 2807ac5b6..1f917338f 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -17,7 +17,7 @@ @dataclass -class TextDecoderCfg: +class MultimodalCfg: context_length: int = 77 width: int = 512 image_dim: int = 512 @@ -32,14 +32,24 @@ class TextDecoderCfg: dim_latents: int = None -def _build_text_decoder_tower( +def _build_input_dependent_text_tower( embed_dim: int, - decoder_cfg: TextDecoderCfg, + multimodal_cfg: MultimodalCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, + multimodal:bool = True ): + + if not multimodal: + return _build_text_tower( + embed_dim=embed_dim, + text_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype + ) + if isinstance(decoder_cfg, dict): - decoder_cfg = TextDecoderCfg(**decoder_cfg) + decoder_cfg = MultimodalCfg(**decoder_cfg) act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( @@ -64,7 +74,7 @@ class CoCa(nn.Module): def __init__( self, embed_dim, - decoder_cfg: TextDecoderCfg, + multimodal_cfg: MultimodalCfg, text_cfg: CLIPTextCfg, vision_cfg: CLIPVisionCfg, n_queries: int = 256, @@ -80,7 +90,7 @@ def __init__( else LayerNorm ) - text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding @@ -94,7 +104,7 @@ def __init__( embed_dim, vision_cfg, quick_gelu, cast_dtype ) - self.multimodal_decoder, decoder_cfg = _build_text_decoder_tower( + self.multimodal_decoder, decoder_cfg = _build_input_dependent_text_tower( embed_dim, decoder_cfg, quick_gelu, cast_dtype ) From 15c69f8b8ff41bf906e99a29e9eebe26f779a769 Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 12:02:51 +0900 Subject: [PATCH 084/113] remove wrong import --- src/open_clip/hf_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index afd087a76..9829c96a2 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -18,7 +18,6 @@ class BaseModelOutput: pass class PretrainedConfig: pass from .hf_configs import arch_dict -from .transformer import Attention # utils def _camel2snake(s): From c219381a0f16e21122ef601659a983873942dcc2 Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 12:17:29 +0900 Subject: [PATCH 085/113] remove local_loss from captioning loss --- src/open_clip/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 87c643fa1..e1a062d59 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -151,7 +151,7 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss - if self.world_size > 1 and not self.local_loss: + if self.world_size > 1: all_logits, all_labels = gather_features( logits.contiguous(), labels.contiguous(), self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) From 5c77e4d14a298b7a236a3d835ae6402fdbfc0ca5 Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 13:43:35 +0900 Subject: [PATCH 086/113] indexing error --- src/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index 3c8389944..62787f60b 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -313,6 +313,6 @@ def get_clip_metrics(image_features, text_features, logit_scale): def maybe_compute_generative_loss(model_out): if len(model_out) > 3: - token_logits = model_out[2] - token_labels = model_out[3] + token_logits = model_out[3] + token_labels = model_out[4] return F.cross_entropy(token_logits.reshape(0, 2, 1), token_labels) From 3f095a6da6d6c5550680eb2759b5c44b0f4c144b Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 13:43:53 +0900 Subject: [PATCH 087/113] finish renaming --- src/open_clip/coca_model.py | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 1f917338f..db1c28ab4 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -17,17 +17,11 @@ @dataclass -class MultimodalCfg: - context_length: int = 77 - width: int = 512 +class MultimodalCfg(CLIPTextCfg): image_dim: int = 512 mlp_ratio: int = 4 - ls_init_value: Optional[float] = None - layers: int = 12 dim_head: int = 64 heads: int = 8 - clip_loss_weight: float = 1.0 - caption_loss_weight: float = 2.0 n_queries: int = 256 dim_latents: int = None @@ -48,8 +42,8 @@ def _build_input_dependent_text_tower( cast_dtype=cast_dtype ) - if isinstance(decoder_cfg, dict): - decoder_cfg = MultimodalCfg(**decoder_cfg) + if isinstance(multimodal_cfg, dict): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( @@ -57,17 +51,17 @@ def _build_input_dependent_text_tower( ) text = MultimodalTransformer( - context_length=decoder_cfg.context_length, - width=decoder_cfg.width, - heads=decoder_cfg.heads, - layers=decoder_cfg.layers, - ls_init_value=decoder_cfg.ls_init_value, + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) - return text, decoder_cfg + return text, multimodal_cfg class CoCa(nn.Module): @@ -90,7 +84,7 @@ def __init__( else LayerNorm ) - text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False) self.transformer = text.transformer self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding @@ -104,17 +98,17 @@ def __init__( embed_dim, vision_cfg, quick_gelu, cast_dtype ) - self.multimodal_decoder, decoder_cfg = _build_input_dependent_text_tower( - embed_dim, decoder_cfg, quick_gelu, cast_dtype + self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower( + embed_dim, multimodal_cfg, quick_gelu, cast_dtype ) self.img_attn_pool = AttentionalPooler( - decoder_cfg.width, decoder_cfg.heads, n_queries=n_queries + 1 + multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1 ) self.img_attn_pool_norm = norm_layer(embed_dim) - self.dim_latents = decoder_cfg.dim_latents if decoder_cfg.dim_latents else decoder_cfg.width + self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) self.to_logits = nn.Sequential( From 60f35f31eb24cf2aae770217cc3a2d5654f98aaf Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 13:45:20 +0900 Subject: [PATCH 088/113] adjust configs --- src/open_clip/model_configs/coca_base.json | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index e40bdc1e6..502333493 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -1,6 +1,6 @@ { "embed_dim": 768, - "decoder_cfg": { + "multimodal_cfg": { "width": 768, "context_length": 76, "image_dim": 768, @@ -8,8 +8,6 @@ "layers": 12, "dim_head": 64, "heads": 12, - "clip_loss_weight": 1.0, - "caption_loss_weight": 2.0, "n_queries": 256 }, "vision_cfg": { From a53f4773819264a9f0babc4c93652fb886f439c1 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Fri, 9 Dec 2022 05:50:59 +0100 Subject: [PATCH 089/113] add training test for coca --- tests/test_training_simple.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/test_training_simple.py b/tests/test_training_simple.py index fe55b3328..9db18cb24 100644 --- a/tests/test_training_simple.py +++ b/tests/test_training_simple.py @@ -24,6 +24,22 @@ def test_training(): '--model', 'RN50' ]) +@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") +def test_training_coca(): + main([ + '--save-frequency', '1', + '--zeroshot-frequency', '1', + '--dataset-type', "synthetic", + '--train-num-samples', '16', + '--warmup', '1', + '--batch-size', '4', + '--lr', '1e-3', + '--wd', '0.1', + '--epochs', '1', + '--workers', '2', + '--model', 'coca_base' + ]) + @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") def test_training_mt5(): main([ @@ -60,4 +76,4 @@ def test_training_unfreezing_vit(): '--model', 'ViT-B-32', '--lock-image', '--lock-image-unlocked-groups', '5' - ]) \ No newline at end of file + ]) From b3f3d68aa36a74226ff65c22b7fcdf15d703799f Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 15:06:38 +0900 Subject: [PATCH 090/113] simplify captioning loss --- src/open_clip/loss.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index e1a062d59..87824781f 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -151,13 +151,13 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss - if self.world_size > 1: - all_logits, all_labels = gather_features( - logits.contiguous(), labels.contiguous(), - self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) - else: - all_logits = logits - all_labels = labels + # if self.world_size > 1 and not self.local_loss: + # all_logits, all_labels = gather_features( + # logits.contiguous(), labels.contiguous(), + # self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + # else: + all_logits = logits + all_labels = labels all_logits = all_logits.permute(0, 2, 1) caption_loss = self.caption_loss(all_logits, all_labels) From 8eb4772b660f5560739fe96aae7608530895f7ab Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 15:18:51 +0900 Subject: [PATCH 091/113] remove hf --- src/open_clip/model_configs/coca_base.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index 502333493..6840ecb38 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -21,8 +21,7 @@ "vocab_size": 64000, "layers": 12, "heads": 12, - "width": 768, - "hf_tokenizer_name": "roberta-base" + "width": 768 }, "custom_text": "True" } \ No newline at end of file From cf0f85718b8b84d7c992e92b8a07e43c8ffd1fcd Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 16:32:35 +0900 Subject: [PATCH 092/113] fix evaluate and loss --- src/open_clip/loss.py | 15 ++++++++------- src/training/train.py | 8 +++++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 87824781f..1b2d0763d 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -151,13 +151,14 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss - # if self.world_size > 1 and not self.local_loss: - # all_logits, all_labels = gather_features( - # logits.contiguous(), labels.contiguous(), - # self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) - # else: - all_logits = logits - all_labels = labels + if self.world_size > 1: + all_logits, all_labels = gather_features( + logits.contiguous(), labels.contiguous(), + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + else: + all_logits = logits + all_labels = labels all_logits = all_logits.permute(0, 2, 1) caption_loss = self.caption_loss(all_logits, all_labels) diff --git a/src/training/train.py b/src/training/train.py index 62787f60b..10523cc64 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -313,6 +313,8 @@ def get_clip_metrics(image_features, text_features, logit_scale): def maybe_compute_generative_loss(model_out): if len(model_out) > 3: - token_logits = model_out[3] - token_labels = model_out[4] - return F.cross_entropy(token_logits.reshape(0, 2, 1), token_labels) + token_logits = model_out[2] + token_labels = model_out[3] + print("TOKEN_LOGITS", token_logits.shape) + print("TOKEN_LABELS", token_labels.shape) + return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) From d547017d0459884ac60d5dae4f2efbce4c5858bd Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 16:40:32 +0900 Subject: [PATCH 093/113] remove print --- src/training/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index 10523cc64..a380f898f 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -315,6 +315,4 @@ def maybe_compute_generative_loss(model_out): if len(model_out) > 3: token_logits = model_out[2] token_labels = model_out[3] - print("TOKEN_LOGITS", token_logits.shape) - print("TOKEN_LABELS", token_labels.shape) return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) From 75be611af9148b2e09f5897ca36d6dca2d1aad79 Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 17:50:42 +0900 Subject: [PATCH 094/113] move projection --- src/open_clip/coca_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index db1c28ab4..34df12fbc 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -18,7 +18,6 @@ @dataclass class MultimodalCfg(CLIPTextCfg): - image_dim: int = 512 mlp_ratio: int = 4 dim_head: int = 64 heads: int = 8 @@ -171,12 +170,13 @@ def encode_image(self, images=None, normalize=True): x = x.permute(1, 0, 2) # LND -> NLD x = self.visual.ln_post(x) + if self.visual.proj is not None: + x = x @ self.visual.proj + x = self.img_attn_pool(x, x) x = self.img_attn_pool_norm(x) image_latent = x[:, 0] - if self.visual.proj is not None: - image_latent = image_latent @ self.visual.proj image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent return image_latent, x[:, 1:] From 356fb7d58bac7bf4f2c4f6ae95a82cbedee0715a Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 17:51:14 +0900 Subject: [PATCH 095/113] add coca vit 32 config --- .../model_configs/coca_ViT-B-32.json | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 src/open_clip/model_configs/coca_ViT-B-32.json diff --git a/src/open_clip/model_configs/coca_ViT-B-32.json b/src/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 000000000..3efdc24d8 --- /dev/null +++ b/src/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file From 8008f2545ae91c3f1b3101d02accb875046c1bbf Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 17:52:27 +0900 Subject: [PATCH 096/113] test on new config --- tests/test_training_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_training_simple.py b/tests/test_training_simple.py index 9db18cb24..5e1f649e7 100644 --- a/tests/test_training_simple.py +++ b/tests/test_training_simple.py @@ -37,7 +37,7 @@ def test_training_coca(): '--wd', '0.1', '--epochs', '1', '--workers', '2', - '--model', 'coca_base' + '--model', 'coca_ViT-B-32' ]) @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") From 5b54a4b09abf0ae8bb1a41e5b4e514b4d66839ff Mon Sep 17 00:00:00 2001 From: gpucce Date: Fri, 9 Dec 2022 18:09:45 +0900 Subject: [PATCH 097/113] adjust coca_base config --- src/open_clip/model_configs/coca_base.json | 1 - 1 file changed, 1 deletion(-) diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index 6840ecb38..525203d4d 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -3,7 +3,6 @@ "multimodal_cfg": { "width": 768, "context_length": 76, - "image_dim": 768, "mlp_ratio": 4, "layers": 12, "dim_head": 64, From 720dabfddfea068e956d9ca33250382da74bec44 Mon Sep 17 00:00:00 2001 From: gpucce Date: Sat, 10 Dec 2022 11:48:29 +0900 Subject: [PATCH 098/113] remove coca from test_inference --- tests/test_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_inference.py b/tests/test_inference.py index 0b3fdc1e0..a97b53aeb 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -19,6 +19,8 @@ 'ViT-bigG-14', 'ViT-e-14', 'mt5-xl-ViT-H-14', + 'coca_base', + 'coca_ViT-B-32' }) if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ: From bcb82c4f42ce5149062c0a4ace333cf79f738449 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Sat, 10 Dec 2022 16:08:58 +0100 Subject: [PATCH 099/113] maybe fix regression test --- src/open_clip/transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 605294e99..3c1b7affc 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -240,8 +240,8 @@ def forward(self, attn_mask: Optional[torch.Tensor] = None ): - k_x = self.ln_1_kv(k_x) if k_x is not None else self.ln_1(q_x) - v_x = self.ln_1_kv(v_x) if v_x is not None else self.ln_1(q_x) + k_x = self.ln_1_kv(k_x) if k_x is not None else None + v_x = self.ln_1_kv(v_x) if v_x is not None else None x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) @@ -295,8 +295,8 @@ def forward( attn_mask: Optional[torch.Tensor] = None ): - k_x = self.ln_1_kv(k_x) if k_x is not None else self.ln_1(q_x) - v_x = self.ln_1_kv(v_x) if v_x is not None else self.ln_1(q_x) + k_x = self.ln_1_kv(k_x) if k_x is not None else None + v_x = self.ln_1_kv(v_x) if v_x is not None else None x = q_x + self.ls_1( self.ln_attn(self.attn(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) From d0f49471f3d7a2fd0e69efb3c1f30373ca5eb420 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Sat, 10 Dec 2022 16:22:22 +0100 Subject: [PATCH 100/113] make logits and labels contiguous --- src/open_clip/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 1b2d0763d..69842918f 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -157,8 +157,8 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) else: - all_logits = logits - all_labels = labels + all_logits = logits.contiguous() + all_labels = labels.contiguous() all_logits = all_logits.permute(0, 2, 1) caption_loss = self.caption_loss(all_logits, all_labels) From 39f20e6684a852e88c04df1c353bc9f5df5c008b Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Sat, 10 Dec 2022 16:25:00 +0100 Subject: [PATCH 101/113] simpler logic --- src/open_clip/loss.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 69842918f..2a99cdbbe 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -151,17 +151,15 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss + logits = logits.contiguous() + labels = labels.contiguous() if self.world_size > 1: - all_logits, all_labels = gather_features( - logits.contiguous(), labels.contiguous(), + logits, labels = gather_features( + logits, labels, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) - else: - all_logits = logits.contiguous() - all_labels = labels.contiguous() - - all_logits = all_logits.permute(0, 2, 1) - caption_loss = self.caption_loss(all_logits, all_labels) + logits = logits.permute(0, 2, 1) + caption_loss = self.caption_loss(logits, labels) caption_loss = caption_loss * self.caption_loss_weight return clip_loss + caption_loss From 2dde78d3c0b1acf97c9b38334c83ddabb554db1c Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Sat, 10 Dec 2022 17:35:57 +0100 Subject: [PATCH 102/113] make contiguous after transpose --- src/open_clip/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 2a99cdbbe..bf3305503 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -158,7 +158,7 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): logits, labels, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) - logits = logits.permute(0, 2, 1) + logits = logits.permute(0, 2, 1).contiguous() caption_loss = self.caption_loss(logits, labels) caption_loss = caption_loss * self.caption_loss_weight From de4c06315104b48b5a9a46d1488e903334f16afb Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Sat, 10 Dec 2022 17:44:25 +0100 Subject: [PATCH 103/113] last test --- src/open_clip/loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index bf3305503..f6d23ad58 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -151,13 +151,13 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss - logits = logits.contiguous() - labels = labels.contiguous() + if self.world_size > 1: logits, labels = gather_features( logits, labels, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) - + + labels = labels.contiguous() logits = logits.permute(0, 2, 1).contiguous() caption_loss = self.caption_loss(logits, labels) caption_loss = caption_loss * self.caption_loss_weight From 00aa464c20f60168d8220ea25da4adbbe5129eba Mon Sep 17 00:00:00 2001 From: gpucce Date: Mon, 12 Dec 2022 13:54:27 +0900 Subject: [PATCH 104/113] try fix loss --- src/open_clip/loss.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index f6d23ad58..4a4622446 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -153,13 +153,17 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): if self.world_size > 1: - logits, labels = gather_features( - logits, labels, + all_logits, all_labels = gather_features( + logits.contiguous(), labels.contiguous(), self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) - - labels = labels.contiguous() - logits = logits.permute(0, 2, 1).contiguous() - caption_loss = self.caption_loss(logits, labels) + else: + all_logits = logits + all_labels = labels + + caption_loss = self.caption_loss( + all_logits.permute(0, 2, 1).contiguous(), + all_labels + ) caption_loss = caption_loss * self.caption_loss_weight return clip_loss + caption_loss From 27bfc7d00b3922982bf39610006c9b2bc116da2a Mon Sep 17 00:00:00 2001 From: iejmac Date: Sat, 17 Dec 2022 19:58:55 +0000 Subject: [PATCH 105/113] CoCa PR: loss fix + rename file --- src/open_clip/{model.py => clip_model.py} | 0 src/open_clip/coca_model.py | 5 ++--- src/open_clip/loss.py | 20 +++++++++----------- src/open_clip/transformer.py | 6 +++--- 4 files changed, 14 insertions(+), 17 deletions(-) rename src/open_clip/{model.py => clip_model.py} (100%) diff --git a/src/open_clip/model.py b/src/open_clip/clip_model.py similarity index 100% rename from src/open_clip/model.py rename to src/open_clip/clip_model.py diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 34df12fbc..692070ee1 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -181,8 +181,7 @@ def encode_image(self, images=None, normalize=True): return image_latent, x[:, 1:] - def forward(self, image, text,): - + def forward(self, image, text): text, labels = text[:, :-1], text[:, 1:] text_latents, text_tokens = self.encode_text(text) @@ -191,4 +190,4 @@ def forward(self, image, text,): text_tokens = self.multimodal_decoder(text_tokens, image_tokens) logits = self.to_logits(text_tokens) - return text_latents, image_latents, logits, labels, self.logit_scale + return image_latents, text_latents, logits, labels, self.logit_scale.exp() diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 4a4622446..c4f2d6f3c 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -44,6 +44,9 @@ def gather_features( all_text_features = torch.cat(gathered_text_features, dim=0) else: # We gather tensors from all gpus + print("in gather_features") + print(image_features.requires_grad) + print(text_features.requires_grad) if gather_with_grad: all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) @@ -59,6 +62,10 @@ def gather_features( all_image_features = torch.cat(gathered_image_features, dim=0) all_text_features = torch.cat(gathered_text_features, dim=0) + print('after gather') + print(all_image_features.requires_grad) + print(all_text_features.requires_grad) + return all_image_features, all_text_features @@ -151,18 +158,9 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss - - if self.world_size > 1: - all_logits, all_labels = gather_features( - logits.contiguous(), labels.contiguous(), - self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) - else: - all_logits = logits - all_labels = labels - caption_loss = self.caption_loss( - all_logits.permute(0, 2, 1).contiguous(), - all_labels + logits.permute(0, 2, 1), + labels, ) caption_loss = caption_loss * self.caption_loss_weight diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index c47479e7e..a83a0915a 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -137,9 +137,9 @@ def forward(self, k = F.linear(k_x, w_k, self.in_proj_bias) v = F.linear(v_x, w_v, self.in_proj_bias) - q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + q = q.view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.view(L, N * self.num_heads, -1).transpose(0, 1) if self.logit_scale is not None: attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) From e694999c1cd29f8c786bdda1d4671dfa2c444e27 Mon Sep 17 00:00:00 2001 From: iejmac Date: Sat, 17 Dec 2022 20:08:35 +0000 Subject: [PATCH 106/113] wait for feedback on this --- src/open_clip/{clip_model.py => model.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/open_clip/{clip_model.py => model.py} (100%) diff --git a/src/open_clip/clip_model.py b/src/open_clip/model.py similarity index 100% rename from src/open_clip/clip_model.py rename to src/open_clip/model.py From 5427b0ade49566338b1d8270a508e702c14454d3 Mon Sep 17 00:00:00 2001 From: iejmac Date: Sat, 17 Dec 2022 20:09:21 +0000 Subject: [PATCH 107/113] cleanup --- src/open_clip/loss.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index c4f2d6f3c..9f26dd4f8 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -44,9 +44,6 @@ def gather_features( all_text_features = torch.cat(gathered_text_features, dim=0) else: # We gather tensors from all gpus - print("in gather_features") - print(image_features.requires_grad) - print(text_features.requires_grad) if gather_with_grad: all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) @@ -62,10 +59,6 @@ def gather_features( all_image_features = torch.cat(gathered_image_features, dim=0) all_text_features = torch.cat(gathered_text_features, dim=0) - print('after gather') - print(all_image_features.requires_grad) - print(all_text_features.requires_grad) - return all_image_features, all_text_features From abd7849bd533c3148985e04e08bdc99b4696845a Mon Sep 17 00:00:00 2001 From: iejmac Date: Sat, 17 Dec 2022 23:40:22 +0000 Subject: [PATCH 108/113] CoCa PR: add set_grad_checkpointing + fix checkpoint API --- src/open_clip/coca_model.py | 67 ++++++++++++++++++++---------------- src/open_clip/transformer.py | 10 +++--- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 692070ee1..85d5bd532 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -119,35 +119,11 @@ def __init__( self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - def _repeat(self, t, N): - return t.reshape(1, 1, -1).repeat(N, 1, 1) - - def encode_text(self, text, normalize=True): - cast_dtype = self.transformer.get_cast_dtype() - - # cls_mask = (text!=self.pad_id).unsqueeze(1) - # attn_mask = F.pad(cls_mask, (0, 1, text.shape[1], 0), value=True) - # attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0) - - x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) - x = x + self.positional_embedding.to(cast_dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), :] @ self.text_projection - - cls_emb = x[torch.arange(x.shape[0]), -1] - token_emb = x[torch.arange(x.shape[0]), :-1] - - cls_emb = self.ln_final(cls_emb) - text_latent = self.to_text_latent(cls_emb) - text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent - - return text_latent, token_emb + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + self.multimodal_decoder.grad_checkpointing = enable def encode_image(self, images=None, normalize=True): x = self.visual.conv1(images) # shape = [*, width, grid, grid] @@ -181,13 +157,44 @@ def encode_image(self, images=None, normalize=True): return image_latent, x[:, 1:] + def _repeat(self, t, N): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def encode_text(self, text, normalize=True): + cast_dtype = self.transformer.get_cast_dtype() + + # cls_mask = (text!=self.pad_id).unsqueeze(1) + # attn_mask = F.pad(cls_mask, (0, 1, text.shape[1], 0), value=True) + # attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0) + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), :] @ self.text_projection + + cls_emb = x[torch.arange(x.shape[0]), -1] + token_emb = x[torch.arange(x.shape[0]), :-1] + + cls_emb = self.ln_final(cls_emb) + text_latent = self.to_text_latent(cls_emb) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + + return text_latent, token_emb + + def forward(self, image, text): text, labels = text[:, :-1], text[:, 1:] text_latents, text_tokens = self.encode_text(text) image_latents, image_tokens = self.encode_image(image) - text_tokens = self.multimodal_decoder(text_tokens, image_tokens) + text_tokens = self.multimodal_decoder(image_tokens, text_tokens) logits = self.to_logits(text_tokens) return image_latents, text_latents, logits, labels, self.logit_scale.exp() diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index a83a0915a..7a346bf6b 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -334,7 +334,8 @@ def get_cast_dtype(self) -> torch.dtype: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask=attn_mask) + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) return x @@ -619,14 +620,15 @@ def build_attention_mask(self): mask.triu_(1) # zero out the lower diagonal return mask - def forward(self, text_embs, image_embs): + def forward(self, image_embs, text_embs): text_embs = text_embs.permute(1, 0, 2) # NLD -> LND image_embs = image_embs.permute(1, 0, 2) # NLD -> LND for r, ca in zip(self.resblocks, self.cross_attn): if self.grad_checkpointing and not torch.jit.is_scripting(): - text_embs = checkpoint(r, text_embs, attn_mask=self.attn_mask) - text_embs = checkpoint(ca, text_embs, k_x=image_embs, v_x=image_embs) + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(r, text_embs, None, None, self.attn_mask) + text_embs = checkpoint(ca, text_embs, image_embs, image_embs, None) else: text_embs = r(text_embs, attn_mask=self.attn_mask) text_embs = ca(text_embs, k_x=image_embs, v_x=image_embs) From 919f5a09a7b027ef1c8db1a4b479c1f368e180ce Mon Sep 17 00:00:00 2001 From: iejmac Date: Sun, 18 Dec 2022 03:29:19 +0000 Subject: [PATCH 109/113] CoCa PR: fix eval (which uses encode_x instead of forward) --- src/open_clip/coca_model.py | 12 ++++++------ src/training/zero_shot.py | 5 +++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 85d5bd532..6bae3a11e 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -125,7 +125,7 @@ def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable self.multimodal_decoder.grad_checkpointing = enable - def encode_image(self, images=None, normalize=True): + def encode_image(self, images, normalize=True, return_tokens=False): x = self.visual.conv1(images) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -155,12 +155,12 @@ def encode_image(self, images=None, normalize=True): image_latent = x[:, 0] image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent - return image_latent, x[:, 1:] + return (image_latent, x[:, 1:]) if return_tokens image_latent def _repeat(self, t, N): return t.reshape(1, 1, -1).repeat(N, 1, 1) - def encode_text(self, text, normalize=True): + def encode_text(self, text, normalize=True, return_tokens=False): cast_dtype = self.transformer.get_cast_dtype() # cls_mask = (text!=self.pad_id).unsqueeze(1) @@ -185,14 +185,14 @@ def encode_text(self, text, normalize=True): text_latent = self.to_text_latent(cls_emb) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent - return text_latent, token_emb + return (text_latent, token_emb) if return_tokens else text_latent def forward(self, image, text): text, labels = text[:, :-1], text[:, 1:] - text_latents, text_tokens = self.encode_text(text) - image_latents, image_tokens = self.encode_image(image) + text_latents, text_tokens = self.encode_text(text, return_tokens=True) + image_latents, image_tokens = self.encode_image(image, return_tokens=True) text_tokens = self.multimodal_decoder(image_tokens, text_tokens) logits = self.to_logits(text_tokens) diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index e5768b4a3..7fca52b72 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -19,7 +19,8 @@ def zero_shot_classifier(model, classnames, templates, args): if args.distributed and not args.horovod: class_embeddings = model.module.encode_text(texts) else: - class_embeddings = model.encode_text(texts) + texts = texts[:, :-1] # TODO: temp do this bretter + class_embeddings, _ = model.encode_text(texts) class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) @@ -49,7 +50,7 @@ def run(model, classifier, dataloader, args): if args.distributed and not args.horovod: image_features = model.module.encode_image(images) else: - image_features = model.encode_image(images) + image_features, _ = model.encode_image(images) image_features = F.normalize(image_features, dim=-1) logits = 100. * image_features @ classifier From 5b29ec064cb220fa00e9097d37855ead0fc6073c Mon Sep 17 00:00:00 2001 From: iejmac Date: Sun, 18 Dec 2022 03:33:51 +0000 Subject: [PATCH 110/113] move making space for CLS token into encode_text --- src/open_clip/coca_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 6bae3a11e..ddec3f5cb 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -161,6 +161,7 @@ def _repeat(self, t, N): return t.reshape(1, 1, -1).repeat(N, 1, 1) def encode_text(self, text, normalize=True, return_tokens=False): + text = text[:, :-1] # make space for CLS token cast_dtype = self.transformer.get_cast_dtype() # cls_mask = (text!=self.pad_id).unsqueeze(1) @@ -187,9 +188,8 @@ def encode_text(self, text, normalize=True, return_tokens=False): return (text_latent, token_emb) if return_tokens else text_latent - def forward(self, image, text): - text, labels = text[:, :-1], text[:, 1:] + labels = text[:, 1:] text_latents, text_tokens = self.encode_text(text, return_tokens=True) image_latents, image_tokens = self.encode_image(image, return_tokens=True) From 752de0ad122ea2bf719af7b44e661dcff9922a18 Mon Sep 17 00:00:00 2001 From: iejmac Date: Sun, 18 Dec 2022 04:18:54 +0000 Subject: [PATCH 111/113] rever zs changes + fix --- src/open_clip/coca_model.py | 2 +- src/training/zero_shot.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index ddec3f5cb..5d965889b 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -155,7 +155,7 @@ def encode_image(self, images, normalize=True, return_tokens=False): image_latent = x[:, 0] image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent - return (image_latent, x[:, 1:]) if return_tokens image_latent + return (image_latent, x[:, 1:]) if return_tokens else image_latent def _repeat(self, t, N): return t.reshape(1, 1, -1).repeat(N, 1, 1) diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 7fca52b72..e5768b4a3 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -19,8 +19,7 @@ def zero_shot_classifier(model, classnames, templates, args): if args.distributed and not args.horovod: class_embeddings = model.module.encode_text(texts) else: - texts = texts[:, :-1] # TODO: temp do this bretter - class_embeddings, _ = model.encode_text(texts) + class_embeddings = model.encode_text(texts) class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) @@ -50,7 +49,7 @@ def run(model, classifier, dataloader, args): if args.distributed and not args.horovod: image_features = model.module.encode_image(images) else: - image_features, _ = model.encode_image(images) + image_features = model.encode_image(images) image_features = F.normalize(image_features, dim=-1) logits = 100. * image_features @ classifier From 64c33d81822ae5cf57ee4a141ee7bba0837f0b11 Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 21 Dec 2022 00:18:48 +0900 Subject: [PATCH 112/113] add cls mask for pad ids --- src/open_clip/coca_model.py | 49 +++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 5d965889b..8952bbbdc 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -30,7 +30,7 @@ def _build_input_dependent_text_tower( multimodal_cfg: MultimodalCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, - multimodal:bool = True + multimodal: bool = True, ): if not multimodal: @@ -38,16 +38,14 @@ def _build_input_dependent_text_tower( embed_dim=embed_dim, text_cfg=multimodal_cfg, quick_gelu=quick_gelu, - cast_dtype=cast_dtype + cast_dtype=cast_dtype, ) if isinstance(multimodal_cfg, dict): multimodal_cfg = MultimodalCfg(**multimodal_cfg) act_layer = QuickGELU if quick_gelu else nn.GELU - norm_layer = ( - LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm - ) + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm text = MultimodalTransformer( context_length=multimodal_cfg.context_length, @@ -76,14 +74,11 @@ def __init__( ): super().__init__() + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm - norm_layer = ( - LayerNormFp32 - if cast_dtype in (torch.float16, torch.bfloat16) - else LayerNorm + text = _build_input_dependent_text_tower( + embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False ) - - text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False) self.transformer = text.transformer self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding @@ -92,10 +87,9 @@ def __init__( self.text_projection = text.text_projection self.register_buffer("attn_mask", text.attn_mask, persistent=False) + self.heads = text_cfg["heads"] self.cls_token = nn.Parameter(torch.randn(embed_dim)) - self.visual = _build_vision_tower( - embed_dim, vision_cfg, quick_gelu, cast_dtype - ) + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower( embed_dim, multimodal_cfg, quick_gelu, cast_dtype @@ -107,7 +101,9 @@ def __init__( self.img_attn_pool_norm = norm_layer(embed_dim) - self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width + self.dim_latents = ( + multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width + ) self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) self.to_logits = nn.Sequential( @@ -118,6 +114,7 @@ def __init__( self.to_logits[-1].weight = self.token_embedding.weight self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = 0 @torch.jit.ignore def set_grad_checkpointing(self, enable=True): @@ -132,9 +129,7 @@ def encode_image(self, images, normalize=True, return_tokens=False): x = torch.cat( [ self.visual.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x, ], dim=1, @@ -160,19 +155,31 @@ def encode_image(self, images, normalize=True, return_tokens=False): def _repeat(self, t, N): return t.reshape(1, 1, -1).repeat(N, 1, 1) + def _build_cls_mask(self, text, cast_dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) + additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + def encode_text(self, text, normalize=True, return_tokens=False): - text = text[:, :-1] # make space for CLS token + text = text[:, :-1] # make space for CLS token cast_dtype = self.transformer.get_cast_dtype() # cls_mask = (text!=self.pad_id).unsqueeze(1) - # attn_mask = F.pad(cls_mask, (0, 1, text.shape[1], 0), value=True) + attn_mask = self.attn_mask[None, :].expand( + text.shape[0] * self.heads, *self.attn_mask.shape + ) + cls_mask = self._build_cls_mask(text, cast_dtype) # attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0) x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) + x = self.transformer(x, attn_mask=attn_mask + cls_mask) x = x.permute(1, 0, 2) # LND -> NLD # x.shape = [batch_size, n_ctx, transformer.width] From 17813eb8bfff02b952a2a30de7e7dd961550ad9a Mon Sep 17 00:00:00 2001 From: gpucce Date: Wed, 21 Dec 2022 02:19:39 +0900 Subject: [PATCH 113/113] simplify encode image --- src/open_clip/coca_model.py | 17 +---------------- src/open_clip/transformer.py | 5 ++++- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 8952bbbdc..3b13cc782 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -123,22 +123,7 @@ def set_grad_checkpointing(self, enable=True): self.multimodal_decoder.grad_checkpointing = enable def encode_image(self, images, normalize=True, return_tokens=False): - x = self.visual.conv1(images) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [ - self.visual.class_embedding.to(x.dtype) - + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), - x, - ], - dim=1, - ) # shape = [*, grid ** 2 + 1, width] - x = x + self.visual.positional_embedding.to(x.dtype) - x = self.visual.ln_pre(x) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.visual.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD + x = self.visual(images, output_tokens=True) x = self.visual.ln_post(x) if self.visual.proj is not None: diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 8a6de4846..0d2f2e7ae 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -448,7 +448,7 @@ def init_parameters(self): def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, output_tokens = False): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -465,6 +465,9 @@ def forward(self, x: torch.Tensor): x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD + if output_tokens: + return x + if self.global_average_pool: x = x.mean(dim=1) else: