From 1b8660101544343c035c6434d29bf974a87f71d3 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Tue, 20 Dec 2022 23:17:52 +0100 Subject: [PATCH 01/30] Add coca trained (#307) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial setup * add coca loss * remove loss from the model * fix loss * add underscores * name changes * add cross attention to Residual and CustomResidual * fix if * ädd transformer 'decoder' * minor fix * looks better * initlize coca model structure * clean * typo and format * checkpoint signature * adjust multimodal decoder and add CoCaTransformer * keep older logic * remove chunk * typo * fix * make chunk dim explicit * adjust cfg names * add attentionalpooling * add attentional pooling to coca * small change * add cocatransformer variants and AttentionPooling * remoive older attention pooler * adapt embed text to coca text transformer * rm coca layers * rename and remove useless CoCa models * make attentionpooler pooler only * refactor for one transformer only * coca forward works * separatae context and n_queries * add inital coca_base config * remove config * small loss change * init training file * make variable order right * remove print * uniform names * renaming * add coca funcs to init * add coca config and exclude from testing * add and comment simple test (no trained model) * add L2 norm * make L2 same as in clip * remove unused temperature * type * clean * fix config * make rename and move cfg * rename * temptative add coca to factory * fix config * update config * embed contrastive cls token in model * remove unused arg * import create_loss * make factory accept coca * make caption loss distributed * make loss customizable * pass loss trhough training_epoch * add coca specific params to params * removed decoder unused parameters * remove unused attributes * adjust coca_config * fix config and remove unused parameters * remove comment * remove more comments * rename attention pooler * rename TransformerDecoder * make AttentionalPooler clearer * add local loss logic to cocaloss * only create loss if train in data * remove wrong file * fix attentional pooler call * not ready for testing * really not ready for testing * eof lien * uniform names * add possible generative loss to evaluate * change _build function names * remove wrong import * remove local_loss from captioning loss * indexing error * finish renaming * adjust configs * add training test for coca * simplify captioning loss * remove hf * fix evaluate and loss * remove print * move projection * add coca vit 32 config * test on new config * adjust coca_base config * remove coca from test_inference * maybe fix regression test * make logits and labels contiguous * simpler logic * make contiguous after transpose * last test * try fix loss * CoCa PR: loss fix + rename file * wait for feedback on this * cleanup * CoCa PR: add set_grad_checkpointing + fix checkpoint API * CoCa PR: fix eval (which uses encode_x instead of forward) * move making space for CLS token into encode_text * rever zs changes + fix Co-authored-by: gpucce Co-authored-by: gpucce Co-authored-by: iejmac --- src/open_clip/__init__.py | 5 +- src/open_clip/coca_model.py | 200 ++++++++++++++++++ src/open_clip/factory.py | 29 ++- src/open_clip/loss.py | 39 ++++ .../model_configs/coca_ViT-B-32.json | 24 +++ src/open_clip/model_configs/coca_base.json | 26 +++ src/open_clip/transformer.py | 188 ++++++++++++++-- src/training/main.py | 6 +- src/training/params.py | 12 ++ src/training/train.py | 49 +++-- tests/test_inference.py | 2 + tests/test_training_simple.py | 18 +- 12 files changed, 557 insertions(+), 41 deletions(-) create mode 100644 src/open_clip/coca_model.py create mode 100644 src/open_clip/model_configs/coca_ViT-B-32.json create mode 100644 src/open_clip/model_configs/coca_base.json diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index b76dd51b9..a4e11b4f7 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -1,9 +1,10 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss from .factory import list_models, add_model_config, get_model_config, load_checkpoint -from .loss import ClipLoss +from .loss import ClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .coca_model import CoCa from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py new file mode 100644 index 000000000..5d965889b --- /dev/null +++ b/src/open_clip/coca_model.py @@ -0,0 +1,200 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, + AttentionalPooler, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + dim_latents: int = None + + +def _build_input_dependent_text_tower( + embed_dim: int, + multimodal_cfg: MultimodalCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + multimodal:bool = True +): + + if not multimodal: + return _build_text_tower( + embed_dim=embed_dim, + text_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype + ) + + if isinstance(multimodal_cfg, dict): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) + + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + text = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return text, multimodal_cfg + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + n_queries: int = 256, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + + + norm_layer = ( + LayerNormFp32 + if cast_dtype in (torch.float16, torch.bfloat16) + else LayerNorm + ) + + text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer("attn_mask", text.attn_mask, persistent=False) + + self.cls_token = nn.Parameter(torch.randn(embed_dim)) + self.visual = _build_vision_tower( + embed_dim, vision_cfg, quick_gelu, cast_dtype + ) + + self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower( + embed_dim, multimodal_cfg, quick_gelu, cast_dtype + ) + + self.img_attn_pool = AttentionalPooler( + multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1 + ) + + self.img_attn_pool_norm = norm_layer(embed_dim) + + self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width + self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) + + self.to_logits = nn.Sequential( + norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False) + ) + + # tie embedding weights and projection + self.to_logits[-1].weight = self.token_embedding.weight + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + self.multimodal_decoder.grad_checkpointing = enable + + def encode_image(self, images, normalize=True, return_tokens=False): + x = self.visual.conv1(images) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.visual.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.visual.positional_embedding.to(x.dtype) + x = self.visual.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.visual.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.visual.ln_post(x) + + if self.visual.proj is not None: + x = x @ self.visual.proj + + x = self.img_attn_pool(x, x) + x = self.img_attn_pool_norm(x) + + image_latent = x[:, 0] + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + + return (image_latent, x[:, 1:]) if return_tokens else image_latent + + def _repeat(self, t, N): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def encode_text(self, text, normalize=True, return_tokens=False): + text = text[:, :-1] # make space for CLS token + cast_dtype = self.transformer.get_cast_dtype() + + # cls_mask = (text!=self.pad_id).unsqueeze(1) + # attn_mask = F.pad(cls_mask, (0, 1, text.shape[1], 0), value=True) + # attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0) + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), :] @ self.text_projection + + cls_emb = x[torch.arange(x.shape[0]), -1] + token_emb = x[torch.arange(x.shape[0]), :-1] + + cls_emb = self.ln_final(cls_emb) + text_latent = self.to_text_latent(cls_emb) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + + return (text_latent, token_emb) if return_tokens else text_latent + + def forward(self, image, text): + labels = text[:, 1:] + + text_latents, text_tokens = self.encode_text(text, return_tokens=True) + image_latents, image_tokens = self.encode_image(image, return_tokens=True) + + text_tokens = self.multimodal_decoder(image_tokens, text_tokens) + logits = self.to_logits(text_tokens) + + return image_latents, text_latents, logits, labels, self.logit_scale.exp() diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index bf07009bc..c2297cf50 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -12,6 +12,8 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model from .transform import image_transform @@ -72,8 +74,7 @@ def get_model_config(model_name): def get_tokenizer(model_name): config = get_model_config(model_name) - tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize - return tokenizer + return HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize def load_state_dict(checkpoint_path: str, map_location='cpu'): @@ -152,7 +153,10 @@ def create_model( if custom_text: if 'hf_model_name' in model_cfg.get('text_cfg', {}): model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) @@ -188,6 +192,25 @@ def create_model( return model +def create_loss(args): + if "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod) + def create_model_and_transforms( model_name: str, diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index de31426df..9f26dd4f8 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -119,3 +119,42 @@ def forward(self, image_features, text_features, logit_scale): F.cross_entropy(logits_per_text, labels) ) / 2 return total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=-100, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale): + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + return clip_loss + caption_loss diff --git a/src/open_clip/model_configs/coca_ViT-B-32.json b/src/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 000000000..3efdc24d8 --- /dev/null +++ b/src/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json new file mode 100644 index 000000000..525203d4d --- /dev/null +++ b/src/open_clip/model_configs/coca_base.json @@ -0,0 +1,26 @@ +{ + "embed_dim": 768, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768 + }, + "custom_text": "True" +} \ No newline at end of file diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 3f337f9b1..8a6de4846 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -124,12 +124,26 @@ def __init__( self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): - L, N, C = x.shape - q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) - q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + def forward(self, + q_x, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + + L, N, C = q_x.shape + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + w_q, w_k, w_v = self.in_proj_weight.split(3, dim=0) + + q = F.linear(q_x, w_q, self.in_proj_bias) + k = F.linear(k_x, w_k, self.in_proj_bias) + v = F.linear(v_x, w_v, self.in_proj_bias) + + q = q.view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.view(L, N * self.num_heads, -1).transpose(0, 1) if self.logit_scale is not None: attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) @@ -159,6 +173,26 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): x = self.out_drop(x) return x +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + n_head: int = 8, + n_queries: int = 256, + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head) + + def forward(self, k: torch.Tensor, v: torch.Tensor): + k, v = k.permute(1, 0, 2), v.permute(1, 0 ,2) # NLD -> LND + N = k.shape[1] + out = self.attn(self._repeat(self.query, N), k, v, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N): + return query.unsqueeze(1).repeat(1, N, 1) + class ResidualAttentionBlock(nn.Module): def __init__( @@ -169,12 +203,15 @@ def __init__( ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) @@ -185,12 +222,33 @@ def __init__( ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() - def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward(self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + + k_x = self.ln_1_kv(k_x) if k_x is not None else None + v_x = self.ln_1_kv(v_x) if v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x @@ -208,10 +266,14 @@ def __init__( scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, + is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + self.attn = Attention( d_model, n_head, scaled_cosine=scale_cosine_attn, @@ -230,8 +292,20 @@ def __init__( ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + + k_x = self.ln_1_kv(k_x) if k_x is not None else None + v_x = self.ln_1_kv(v_x) if v_x is not None else None + + x = q_x + self.ls_1( + self.ln_attn(self.attn(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + ) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x @@ -264,12 +338,12 @@ def get_cast_dtype(self) -> torch.dtype: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) return x - class VisionTransformer(nn.Module): def __init__( self, @@ -403,7 +477,6 @@ def forward(self, x: torch.Tensor): return x - class TextTransformer(nn.Module): def __init__( @@ -485,3 +558,86 @@ def forward(self, text): x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, is_cross_attention=True) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LND + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + + for r, ca in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(r, text_embs, None, None, self.attn_mask) + text_embs = checkpoint(ca, text_embs, image_embs, image_embs, None) + else: + text_embs = r(text_embs, attn_mask=self.attn_mask) + text_embs = ca(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + return x diff --git a/src/training/main.py b/src/training/main.py index 26d4cc529..a2f9f5b42 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -24,7 +24,7 @@ except ImportError: hvd = None -from open_clip import create_model_and_transforms, trace_model, get_tokenizer +from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from training.data import get_data from training.distributed import is_master, init_distributed_device, world_info_from_env from training.logger import setup_logging @@ -264,11 +264,13 @@ def main(args): evaluate(model, data, start_epoch, args, writer) return + loss = create_loss(args) + for epoch in range(start_epoch, args.epochs): if is_master(args): logging.info(f'Start epoch {epoch}') - train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=writer) completed_epoch = epoch + 1 if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): diff --git a/src/training/params.py b/src/training/params.py index abc07dd50..a35ce2732 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -332,6 +332,18 @@ def parse_args(args): default=100, help="Log every n steps to tensorboard/console/wandb.", ) + parser.add_argument( + "--coca-caption-loss-weight", + type=float, + default=2.0, + help="Weight assigned to caption loss in CoCa." + ) + parser.add_argument( + "--coca-contrastive-loss-weight", + type=float, + default=1.0, + help="Weight assigned to contrastive loss when training CoCa." + ) args = parser.parse_args(args) diff --git a/src/training/train.py b/src/training/train.py index 83f4f6fa7..3899590c3 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -13,7 +13,7 @@ except ImportError: wandb = None -from open_clip import ClipLoss, get_cast_dtype +from open_clip import get_cast_dtype from .distributed import is_master from .zero_shot import zero_shot_eval from .precision import get_autocast @@ -51,20 +51,12 @@ def backward(total_loss, scaler): else: total_loss.backward() - -def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): +def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=None): device = torch.device(args.device) autocast = get_autocast(args.precision) cast_dtype = get_cast_dtype(args.precision) model.train() - loss = ClipLoss( - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod) data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data['train'].dataloader @@ -94,8 +86,9 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w if args.accum_freq == 1: with autocast(): - image_features, text_features, logit_scale = model(images, texts) - total_loss = loss(image_features, text_features, logit_scale) + model_out = model(images, texts) + logit_scale = model_out[-1] + total_loss = loss(*model_out) backward(total_loss, scaler) else: @@ -222,6 +215,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): # FIXME this does not scale past small eval datasets # all_image_features @ all_text_features will blow up memory and compute very quickly cumulative_loss = 0.0 + cumulative_gen_loss = 0.0 all_image_features, all_text_features = [], [] with torch.no_grad(): for i, batch in enumerate(dataloader): @@ -230,7 +224,10 @@ def evaluate(model, data, epoch, args, tb_writer=None): texts = texts.to(device=device, non_blocking=True) with autocast(): - image_features, text_features, logit_scale = model(images, texts) + model_out = model(images, texts) + image_features = model_out[0] + text_features = model_out[1] + logit_scale = model_out[-1] # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) @@ -246,22 +243,33 @@ def evaluate(model, data, epoch, args, tb_writer=None): F.cross_entropy(logits_per_text, labels) ) / 2 + gen_loss = maybe_compute_generative_loss(model_out) + cumulative_loss += total_loss * batch_size num_samples += batch_size if is_master(args) and (i % 100) == 0: logging.info( f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" - f"Loss: {cumulative_loss / num_samples:.6f}\t") + f"Clip Loss: {cumulative_loss / num_samples:.6f}\t") + + if gen_loss is not None: + cumulative_gen_loss += gen_loss * batch_size + logging.info( + f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t") - val_metrics = get_metrics( + + val_metrics = get_clip_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), logit_scale=logit_scale.cpu(), ) loss = cumulative_loss / num_samples metrics.update( - {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} + {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} ) + if gen_loss is not None: + gen_loss = cumulative_gen_loss / num_samples + metrics.update({"val_generative_loss": gen_loss.item()}) if not metrics: return metrics @@ -288,7 +296,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): return metrics -def get_metrics(image_features, text_features, logit_scale): +def get_clip_metrics(image_features, text_features, logit_scale): metrics = {} logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_image.t().detach().cpu() @@ -306,3 +314,10 @@ def get_metrics(image_features, text_features, logit_scale): metrics[f"{name}_R@{k}"] = np.mean(preds < k) return metrics + + +def maybe_compute_generative_loss(model_out): + if len(model_out) > 3: + token_logits = model_out[2] + token_labels = model_out[3] + return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) diff --git a/tests/test_inference.py b/tests/test_inference.py index 0b3fdc1e0..a97b53aeb 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -19,6 +19,8 @@ 'ViT-bigG-14', 'ViT-e-14', 'mt5-xl-ViT-H-14', + 'coca_base', + 'coca_ViT-B-32' }) if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ: diff --git a/tests/test_training_simple.py b/tests/test_training_simple.py index fe55b3328..5e1f649e7 100644 --- a/tests/test_training_simple.py +++ b/tests/test_training_simple.py @@ -24,6 +24,22 @@ def test_training(): '--model', 'RN50' ]) +@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") +def test_training_coca(): + main([ + '--save-frequency', '1', + '--zeroshot-frequency', '1', + '--dataset-type', "synthetic", + '--train-num-samples', '16', + '--warmup', '1', + '--batch-size', '4', + '--lr', '1e-3', + '--wd', '0.1', + '--epochs', '1', + '--workers', '2', + '--model', 'coca_ViT-B-32' + ]) + @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") def test_training_mt5(): main([ @@ -60,4 +76,4 @@ def test_training_unfreezing_vit(): '--model', 'ViT-B-32', '--lock-image', '--lock-image-unlocked-groups', '5' - ]) \ No newline at end of file + ]) From 29fa3322ab7c4da0eacb6f46401fc520f94b9627 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Wed, 21 Dec 2022 18:30:05 +0100 Subject: [PATCH 02/30] Add coca to CI --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 80deb83e4..1b7210b93 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,7 +3,8 @@ name: Continuous integration on: push: branches: - - main + - main + - coca paths-ignore: - '**.md' - 'CITATION.cff' From 911c737831c095005dcc492dec5668b12563cd09 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Wed, 21 Dec 2022 18:31:50 +0100 Subject: [PATCH 03/30] Add coca to CI pr --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1b7210b93..7021e4a20 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,8 @@ on: - 'docs/**' pull_request: branches: - - main + - main + - coca paths-ignore: - '**.md' - 'CITATION.cff' From b4881bcc13d6e773a68c4476d0be44ac71afd7f1 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Wed, 21 Dec 2022 22:58:35 +0100 Subject: [PATCH 04/30] simplify encode_iamge (#313) Co-authored-by: Romain Beaumont --- src/open_clip/coca_model.py | 25 +++++-------------------- src/open_clip/transformer.py | 18 ++++++++++-------- 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 5d965889b..fe750c019 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -126,27 +126,12 @@ def set_grad_checkpointing(self, enable=True): self.multimodal_decoder.grad_checkpointing = enable def encode_image(self, images, normalize=True, return_tokens=False): - x = self.visual.conv1(images) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [ - self.visual.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), - x, - ], - dim=1, - ) # shape = [*, grid ** 2 + 1, width] - x = x + self.visual.positional_embedding.to(x.dtype) - x = self.visual.ln_pre(x) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.visual.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.visual.ln_post(x) + x = self.visual(images, output_tokens=True) + + if hasattr(self.visual, "ln_post"): + x = self.visual.ln_post(x) - if self.visual.proj is not None: + if hasattr(self.visual, "proj") and self.visual.proj is not None: x = x @ self.visual.proj x = self.img_attn_pool(x, x) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 8a6de4846..e2d5d50bf 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -448,7 +448,7 @@ def init_parameters(self): def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, output_tokens: bool = False): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -465,15 +465,17 @@ def forward(self, x: torch.Tensor): x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD - if self.global_average_pool: - x = x.mean(dim=1) - else: - x = x[:, 0] - x = self.ln_post(x) + if not output_tokens: + if self.global_average_pool: + x = x.mean(dim=1) + else: + x = x[:, 0] + + x = self.ln_post(x) - if self.proj is not None: - x = x @ self.proj + if self.proj is not None: + x = x @ self.proj return x From 50bc5991a82f601e14c155173e7c1ca536cd6af8 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Wed, 21 Dec 2022 23:00:43 +0100 Subject: [PATCH 05/30] Add cls mask (#312) * buil_cls_mask * add cls_mask to encode_text * add model properties Co-authored-by: Romain Beaumont Co-authored-by: gpucce --- src/open_clip/coca_model.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index fe750c019..48723d601 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -96,6 +96,7 @@ def __init__( self.visual = _build_vision_tower( embed_dim, vision_cfg, quick_gelu, cast_dtype ) + self.heads = text_cfg["heads"] self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower( embed_dim, multimodal_cfg, quick_gelu, cast_dtype @@ -118,6 +119,7 @@ def __init__( self.to_logits[-1].weight = self.token_embedding.weight self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = 0 @torch.jit.ignore def set_grad_checkpointing(self, enable=True): @@ -145,19 +147,29 @@ def encode_image(self, images, normalize=True, return_tokens=False): def _repeat(self, t, N): return t.reshape(1, 1, -1).repeat(N, 1, 1) + def _build_cls_mask(self, text, cast_dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) + additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + def encode_text(self, text, normalize=True, return_tokens=False): text = text[:, :-1] # make space for CLS token cast_dtype = self.transformer.get_cast_dtype() - # cls_mask = (text!=self.pad_id).unsqueeze(1) - # attn_mask = F.pad(cls_mask, (0, 1, text.shape[1], 0), value=True) - # attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0) + attn_mask = self.attn_mask[None, :].expand( + text.shape[0] * self.heads, *self.attn_mask.shape + ) + cls_mask = self._build_cls_mask(text, cast_dtype) x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) + x = self.transformer(x, attn_mask=attn_mask + cls_mask) x = x.permute(1, 0, 2) # LND -> NLD # x.shape = [batch_size, n_ctx, transformer.width] From 279e08813a82d5e13ac63cfd222730f3ba93335a Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Thu, 22 Dec 2022 14:37:25 +0100 Subject: [PATCH 06/30] Ignore pad tokens in captioning loss (#316) * add ignore_index * just need to pick right index Co-authored-by: gpucce --- src/open_clip/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 9f26dd4f8..555cf545d 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -126,7 +126,7 @@ def __init__( self, caption_loss_weight, clip_loss_weight, - pad_id=-100, + pad_id=0, # pad_token for open_clip custom tokenizer local_loss=False, gather_with_grad=False, cache_labels=False, From dee1ea50352919136f759980cefd18fa8bd5e32f Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Thu, 22 Dec 2022 16:45:50 +0100 Subject: [PATCH 07/30] add `generate` to coca model (#314) * add initial generative support * make generation context_length independend * remove kwargs * last positional embeddings for CLS * typo * fix mask len * add comment * remove unused args * simpler logic for input shorter than context length Co-authored-by: gpucce --- src/open_clip/coca_model.py | 83 +++++++++++++++++++++++++++---- src/open_clip/generation_utils.py | 38 ++++++++++++++ src/open_clip/transformer.py | 13 ++--- 3 files changed, 118 insertions(+), 16 deletions(-) create mode 100644 src/open_clip/generation_utils.py diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 48723d601..711d89f55 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -14,7 +14,7 @@ AttentionalPooler, ) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower - +from .generation_utils import top_a, top_k, top_p @dataclass class MultimodalCfg(CLIPTextCfg): @@ -91,6 +91,7 @@ def __init__( self.ln_final = text.ln_final self.text_projection = text.text_projection self.register_buffer("attn_mask", text.attn_mask, persistent=False) + self.context_length = self.positional_embedding.shape[0] - 1 self.cls_token = nn.Parameter(torch.randn(embed_dim)) self.visual = _build_vision_tower( @@ -159,21 +160,25 @@ def _build_cls_mask(self, text, cast_dtype): def encode_text(self, text, normalize=True, return_tokens=False): text = text[:, :-1] # make space for CLS token cast_dtype = self.transformer.get_cast_dtype() - - attn_mask = self.attn_mask[None, :].expand( - text.shape[0] * self.heads, *self.attn_mask.shape + seq_len = text.shape[1] + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + x = torch.cat( + [ + x + self.positional_embedding[:seq_len, :].to(cast_dtype), + self._repeat(self.cls_token + self.positional_embedding[-1, :], x.shape[0]) + ], + dim=1 + ) + seq_len += 1 # seq is 1 longer as we added CLS + attn_mask = self.attn_mask[None, :seq_len, :seq_len].expand( + text.shape[0] * self.heads, seq_len, seq_len ) cls_mask = self._build_cls_mask(text, cast_dtype) - - x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) - x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=attn_mask + cls_mask) x = x.permute(1, 0, 2) # LND -> NLD - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), :] @ self.text_projection cls_emb = x[torch.arange(x.shape[0]), -1] @@ -195,3 +200,61 @@ def forward(self, image, text): logits = self.to_logits(text_tokens) return image_latents, text_latents, logits, labels, self.logit_scale.exp() + + def generate( + self, + image, + text, + seq_len, + max_seq_len=None, + mask_prob = 0.0, + temperature = 1., + filter_logits_fn = top_k, + filter_thres = 0.9, + min_p_pow = 2.0, + min_p_ratio = 0.02, + ): + + assert mask_prob < 1, "mask_prob must be smaller than 1." + + if max_seq_len is None: + max_seq_len = self.context_length + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + _, t = text.shape + self.eval() + out = text + + for _ in range(seq_len): + x = out[:, -max_seq_len:] + + # TODO: adjust for dict output + logits = self(image, x)[2][:, -1] + + if filter_logits_fn in {top_k, top_p}: + filtered_logits = filter_logits_fn(logits, thres=filter_thres) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + elif filter_logits_fn is top_a: + filtered_logits = filter_logits_fn( + logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio + ) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + sample = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + + out = out[:, t:] + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out diff --git a/src/open_clip/generation_utils.py b/src/open_clip/generation_utils.py new file mode 100644 index 000000000..82b7fdefb --- /dev/null +++ b/src/open_clip/generation_utils.py @@ -0,0 +1,38 @@ +from math import ceil +import torch +from torch import nn +import torch.nn.functional as F + +def exists(val): + return val is not None + +# nucleus + +def top_p(logits, thres = 0.9): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cum_probs > (1 - thres) + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + + sorted_logits[sorted_indices_to_remove] = float('-inf') + return sorted_logits.scatter(1, sorted_indices, sorted_logits) + +# topk + +def top_k(logits, thres = 0.9): + k = ceil((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(1, ind, val) + return probs + +# top_a + +def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02): + probs = F.softmax(logits, dim=-1) + limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio + logits[probs < limit] = float('-inf') + logits[probs >= limit] = 1 + return logits \ No newline at end of file diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index e2d5d50bf..5c3dd48b1 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -627,17 +627,18 @@ def build_attention_mask(self): return mask def forward(self, image_embs, text_embs): - text_embs = text_embs.permute(1, 0, 2) # NLD -> LND + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] - for r, ca in zip(self.resblocks, self.cross_attn): + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 - text_embs = checkpoint(r, text_embs, None, None, self.attn_mask) - text_embs = checkpoint(ca, text_embs, image_embs, image_embs, None) + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) else: - text_embs = r(text_embs, attn_mask=self.attn_mask) - text_embs = ca(text_embs, k_x=image_embs, v_x=image_embs) + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) x = text_embs.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) From 30a73d41812579607a426fae90ec4adf69152e3a Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Fri, 6 Jan 2023 01:16:05 +0100 Subject: [PATCH 08/30] use `TextEncoder` in coca `encode_image` (#321) * use self.text in encode image * unused var * rever aAtention and CustoResidualAttentionBlock * remove whiteline * add dict output * bintegrate self.text attributes * HF compatibility * better config and minor fixes * clean * remove eembed_cls option from HF * use cls_token_position * fix cls masking * resize labels * text -> self.text * split loss logging * add total loss * minor logs formatting * fix generate * simpler logic * disentangle proj for HF too * adjust config * only norm cls * move attn_pool to VisionTransformer * adjust coca_base config * fix grad checkpointing in MultimodalTransformer Co-authored-by: gpucce Co-authored-by: iejMac --- src/open_clip/coca_model.py | 192 +++++++----------- src/open_clip/hf_model.py | 27 ++- src/open_clip/loss.py | 12 +- src/open_clip/model.py | 26 ++- .../model_configs/coca_ViT-B-32.json | 13 +- src/open_clip/model_configs/coca_base.json | 12 +- .../model_configs/coca_roberta-ViT-B-32.json | 23 +++ src/open_clip/transformer.py | 172 +++++++++------- src/training/train.py | 72 ++++--- tests/test_inference.py | 3 +- 10 files changed, 306 insertions(+), 246 deletions(-) create mode 100644 src/open_clip/model_configs/coca_roberta-ViT-B-32.json diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 711d89f55..62bdbd549 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -11,7 +11,6 @@ LayerNorm, QuickGELU, MultimodalTransformer, - AttentionalPooler, ) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower from .generation_utils import top_a, top_k, top_p @@ -22,34 +21,50 @@ class MultimodalCfg(CLIPTextCfg): dim_head: int = 64 heads: int = 8 n_queries: int = 256 - dim_latents: int = None + attn_pooler_heads: int = 8 + latent_dim: int = 512 +class CoCaEncoderDecoder(nn.Module): + def __init__(self, encoder, decoder) -> None: + super().__init__() + self.encoder = encoder + self.decoder = decoder + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.encoder.set_grad_checkpointing(enable) + self.decoder.set_grad_checkpointing(enable) -def _build_input_dependent_text_tower( - embed_dim: int, - multimodal_cfg: MultimodalCfg, +def _build_encoder_decoder_tower( + embed_dim, + multimodal_cfg, + text_cfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, - multimodal:bool = True ): - if not multimodal: - return _build_text_tower( - embed_dim=embed_dim, - text_cfg=multimodal_cfg, - quick_gelu=quick_gelu, - cast_dtype=cast_dtype - ) - - if isinstance(multimodal_cfg, dict): - multimodal_cfg = MultimodalCfg(**multimodal_cfg) + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + + encoder = _build_text_tower( + multimodal_cfg.latent_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype + ) + + vocab_size = ( + encoder.config.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else multimodal_cfg.vocab_size + ) act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) - text = MultimodalTransformer( + decoder = MultimodalTransformer( context_length=multimodal_cfg.context_length, width=multimodal_cfg.width, heads=multimodal_cfg.heads, @@ -59,10 +74,9 @@ def _build_input_dependent_text_tower( act_layer=act_layer, norm_layer=norm_layer, ) - - return text, multimodal_cfg - - + + return CoCaEncoderDecoder(encoder, decoder), multimodal_cfg, vocab_size + class CoCa(nn.Module): def __init__( self, @@ -70,12 +84,14 @@ def __init__( multimodal_cfg: MultimodalCfg, text_cfg: CLIPTextCfg, vision_cfg: CLIPVisionCfg, - n_queries: int = 256, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0 ): super().__init__() - + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg norm_layer = ( LayerNormFp32 @@ -83,130 +99,63 @@ def __init__( else LayerNorm ) - text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False) - self.transformer = text.transformer - self.vocab_size = text.vocab_size - self.token_embedding = text.token_embedding - self.positional_embedding = text.positional_embedding - self.ln_final = text.ln_final - self.text_projection = text.text_projection - self.register_buffer("attn_mask", text.attn_mask, persistent=False) - self.context_length = self.positional_embedding.shape[0] - 1 - - self.cls_token = nn.Parameter(torch.randn(embed_dim)) - self.visual = _build_vision_tower( - embed_dim, vision_cfg, quick_gelu, cast_dtype - ) - self.heads = text_cfg["heads"] - - self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower( - embed_dim, multimodal_cfg, quick_gelu, cast_dtype + self.text, multimodal_cfg, vocab_size = _build_encoder_decoder_tower( + embed_dim, multimodal_cfg, text_cfg, quick_gelu, cast_dtype ) - - self.img_attn_pool = AttentionalPooler( - multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1 + self.visual = _build_vision_tower( + multimodal_cfg.latent_dim, vision_cfg, quick_gelu, cast_dtype ) - self.img_attn_pool_norm = norm_layer(embed_dim) - - self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width - self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) - self.to_logits = nn.Sequential( - norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False) + norm_layer(multimodal_cfg.width), nn.Linear(multimodal_cfg.width, vocab_size, bias=False) ) - # tie embedding weights and projection - self.to_logits[-1].weight = self.token_embedding.weight - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.pad_id = 0 + self.pad_id = pad_id @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) - self.transformer.grad_checkpointing = enable - self.multimodal_decoder.grad_checkpointing = enable + self.text.set_grad_checkpointing(enable) def encode_image(self, images, normalize=True, return_tokens=False): - x = self.visual(images, output_tokens=True) - - if hasattr(self.visual, "ln_post"): - x = self.visual.ln_post(x) - - if hasattr(self.visual, "proj") and self.visual.proj is not None: - x = x @ self.visual.proj - - x = self.img_attn_pool(x, x) - x = self.img_attn_pool_norm(x) - - image_latent = x[:, 0] + image_latent, tokens_embs = self.visual(images, output_tokens=True) image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent - - return (image_latent, x[:, 1:]) if return_tokens else image_latent - - def _repeat(self, t, N): - return t.reshape(1, 1, -1).repeat(N, 1, 1) - - def _build_cls_mask(self, text, cast_dtype): - cls_mask = (text != self.pad_id).unsqueeze(1) - cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) - additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) - additive_mask.fill_(0) - additive_mask.masked_fill_(~cls_mask, float("-inf")) - additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) - return additive_mask + return (image_latent, tokens_embs) if return_tokens else image_latent def encode_text(self, text, normalize=True, return_tokens=False): text = text[:, :-1] # make space for CLS token - cast_dtype = self.transformer.get_cast_dtype() - seq_len = text.shape[1] - x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - x = torch.cat( - [ - x + self.positional_embedding[:seq_len, :].to(cast_dtype), - self._repeat(self.cls_token + self.positional_embedding[-1, :], x.shape[0]) - ], - dim=1 - ) - seq_len += 1 # seq is 1 longer as we added CLS - attn_mask = self.attn_mask[None, :seq_len, :seq_len].expand( - text.shape[0] * self.heads, seq_len, seq_len - ) - cls_mask = self._build_cls_mask(text, cast_dtype) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=attn_mask + cls_mask) - x = x.permute(1, 0, 2) # LND -> NLD - - x = x[torch.arange(x.shape[0]), :] @ self.text_projection - - cls_emb = x[torch.arange(x.shape[0]), -1] - token_emb = x[torch.arange(x.shape[0]), :-1] - - cls_emb = self.ln_final(cls_emb) - text_latent = self.to_text_latent(cls_emb) + text_latent, token_emb = self.text.encoder(text, output_tokens=True) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent - return (text_latent, token_emb) if return_tokens else text_latent - def forward(self, image, text): - labels = text[:, 1:] - - text_latents, text_tokens = self.encode_text(text, return_tokens=True) - image_latents, image_tokens = self.encode_image(image, return_tokens=True) + def forward(self, image, text, output_dict=False): - text_tokens = self.multimodal_decoder(image_tokens, text_tokens) - logits = self.to_logits(text_tokens) - - return image_latents, text_latents, logits, labels, self.logit_scale.exp() + text_latent, token_embs = self.encode_text(text, return_tokens=True) + image_latent, image_embs = self.encode_image(image, return_tokens=True) + + # TODO: add assertion to avoid bugs? + labels = text[:, -token_embs.shape[1]:] + + token_embs = self.text.decoder(image_embs, token_embs) + logits = self.to_logits(token_embs) + if output_dict: + return { + "image_features":image_latent, + "text_features":text_latent, + "logits":logits, + "labels":labels, + "logit_scale":self.logit_scale.exp() + } + + return image_latent, text_latent, logits, labels, self.logit_scale.exp() def generate( self, image, text, seq_len, - max_seq_len=None, + max_seq_len=77, mask_prob = 0.0, temperature = 1., filter_logits_fn = top_k, @@ -217,9 +166,6 @@ def generate( assert mask_prob < 1, "mask_prob must be smaller than 1." - if max_seq_len is None: - max_seq_len = self.context_length - was_training = self.training num_dims = len(text.shape) diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index b9f1103d6..4ffd3b849 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -79,7 +79,6 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType): return x.last_hidden_state[:, self.cls_token_position, :] - class HFTextEncoder(nn.Module): """HuggingFace model adapter""" @@ -90,7 +89,8 @@ def __init__( config: PretrainedConfig = None, pooler_type: str = None, proj: str = None, - pretrained: bool = True): + pretrained: bool = True + ): super().__init__() self.output_dim = output_dim @@ -113,11 +113,10 @@ def __init__( else: self.config = config self.transformer = AutoModel.from_config(config) - if pooler_type is None: # get default arch pooler - self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]() - else: - self.pooler = _POOLERS[pooler_type]() + pooler_type = (arch_dict[self.config.model_type]["pooler"]) + + self.pooler = _POOLERS[pooler_type]() d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"]) if (d_model == output_dim) and (proj is None): # do we always need a proj? @@ -132,12 +131,22 @@ def __init__( nn.Linear(hidden_size, output_dim, bias=False), ) - def forward(self, x: TensorType) -> TensorType: + def forward(self, x: TensorType, output_tokens=False) -> TensorType: attn_mask = (x != self.config.pad_token_id).long() out = self.transformer(input_ids=x, attention_mask=attn_mask) pooled_out = self.pooler(out, attn_mask) - - return self.proj(pooled_out) + projected = self.proj(pooled_out) + + if output_tokens: + seq_len = out.last_hidden_state.shape[1] + tokens = ( + out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] + if type(self.pooler) == ClsPooler + else out.last_hidden_state + ) + return projected, tokens + + return projected def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): if not unlocked_layers: # full freezing diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 555cf545d..a838863fb 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -85,7 +85,7 @@ def __init__( self.prev_num_logits = 0 self.labels = {} - def forward(self, image_features, text_features, logit_scale): + def forward(self, image_features, text_features, logit_scale, output_dict=False): device = image_features.device if self.world_size > 1: all_image_features, all_text_features = gather_features( @@ -118,7 +118,8 @@ def forward(self, image_features, text_features, logit_scale): F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 - return total_loss + + return {"contrastive_loss": total_loss} if output_dict else total_loss class CoCaLoss(ClipLoss): @@ -147,7 +148,7 @@ def __init__( self.caption_loss_weight = caption_loss_weight self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) - def forward(self, image_features, text_features, logits, labels, logit_scale): + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss @@ -157,4 +158,7 @@ def forward(self, image_features, text_features, logits, labels, logit_scale): ) caption_loss = caption_loss * self.caption_loss_weight - return clip_loss + caption_loss + if output_dict: + return {"contrastive_loss":clip_loss, "caption_loss":caption_loss} + + return clip_loss, caption_loss diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 6b1abbf16..5fc3a16e5 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -31,6 +31,9 @@ class CLIPVisionCfg: ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling timm_model_name: str = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') @@ -51,6 +54,8 @@ class CLIPTextCfg: hf_model_pretrained: bool = True proj: str = 'mlp' pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 def get_cast_dtype(precision: str): @@ -109,6 +114,9 @@ def _build_vision_tower( ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, @@ -146,6 +154,8 @@ def _build_text_tower( layers=text_cfg.layers, ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + pad_id=text_cfg.pad_id, act_layer=act_layer, norm_layer=norm_layer, ) @@ -202,9 +212,15 @@ def encode_text(self, text, normalize: bool = False): x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return F.normalize(x, dim=-1) if normalize else x - def forward(self, image, text): + def forward(self, image, text, output_dict=False): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) + if output_dict: + return { + "image_features":image_features, + "text_features":text_features, + "logit_scale":self.logit_scale.exp() + } return image_features, text_features, self.logit_scale.exp() @@ -242,9 +258,15 @@ def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features - def forward(self, image, text): + def forward(self, image, text, output_dict=False): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) + if output_dict: + return { + "image_features":image_features, + "text_features":text_features, + "logit_scale":self.logit_scale.exp() + } return image_features, text_features, self.logit_scale.exp() diff --git a/src/open_clip/model_configs/coca_ViT-B-32.json b/src/open_clip/model_configs/coca_ViT-B-32.json index 3efdc24d8..8e13625c8 100644 --- a/src/open_clip/model_configs/coca_ViT-B-32.json +++ b/src/open_clip/model_configs/coca_ViT-B-32.json @@ -4,21 +4,26 @@ "image_size": 224, "layers": 12, "width": 768, - "patch_size": 32 + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8 }, "text_cfg": { - "context_length": 77, + "context_length": 76, "vocab_size": 49408, "width": 512, "heads": 8, - "layers": 12 + "layers": 12, + "embed_cls": true }, "multimodal_cfg": { "context_length": 76, "vocab_size": 49408, "width": 512, "heads": 8, - "layers": 12 + "layers": 12, + "latent_dim": 512, + "attn_pooler_heads": 8 }, "custom_text": true } \ No newline at end of file diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index 525203d4d..d26b61087 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -1,13 +1,16 @@ { - "embed_dim": 768, + "embed_dim": 512, "multimodal_cfg": { "width": 768, "context_length": 76, + "vocab_size": 64000, "mlp_ratio": 4, "layers": 12, "dim_head": 64, "heads": 12, - "n_queries": 256 + "n_queries": 256, + "latent_dim": 512, + "attn_pooler_heads": 8 }, "vision_cfg": { "image_size": 288, @@ -16,11 +19,12 @@ "patch_size": 18 }, "text_cfg": { - "context_length": 77, + "context_length": 76, "vocab_size": 64000, "layers": 12, "heads": 12, - "width": 768 + "width": 768, + "embed_cls": true }, "custom_text": "True" } \ No newline at end of file diff --git a/src/open_clip/model_configs/coca_roberta-ViT-B-32.json b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 000000000..6b706c861 --- /dev/null +++ b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "linear", + "width": 768 + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12, + "latent_dim": 512 + }, + "custom_text": true +} diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 5c3dd48b1..068f7875f 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -124,26 +124,12 @@ def __init__( self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) - def forward(self, - q_x, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None - ): - - L, N, C = q_x.shape - k_x = k_x if k_x is not None else q_x - v_x = v_x if v_x is not None else q_x - - w_q, w_k, w_v = self.in_proj_weight.split(3, dim=0) - - q = F.linear(q_x, w_q, self.in_proj_bias) - k = F.linear(k_x, w_k, self.in_proj_bias) - v = F.linear(v_x, w_v, self.in_proj_bias) - - q = q.view(L, N * self.num_heads, -1).transpose(0, 1) - k = k.view(L, N * self.num_heads, -1).transpose(0, 1) - v = v.view(L, N * self.num_heads, -1).transpose(0, 1) + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) if self.logit_scale is not None: attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) @@ -177,17 +163,22 @@ class AttentionalPooler(nn.Module): def __init__( self, d_model: int, + context_dim: int, n_head: int = 8, n_queries: int = 256, + norm_layer: Callable = LayerNorm ): super().__init__() self.query = nn.Parameter(torch.randn(n_queries, d_model)) - self.attn = nn.MultiheadAttention(d_model, n_head) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) - def forward(self, k: torch.Tensor, v: torch.Tensor): - k, v = k.permute(1, 0, 2), v.permute(1, 0 ,2) # NLD -> LND - N = k.shape[1] - out = self.attn(self._repeat(self.query, N), k, v, need_weights=False)[0] + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] return out.permute(1, 0, 2) # LND -> NLD def _repeat(self, query, N): @@ -266,21 +257,17 @@ def __init__( scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, - is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) - if is_cross_attention: - self.ln_1_kv = norm_layer(d_model) - self.attn = Attention( d_model, n_head, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, ) self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() - self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) @@ -290,22 +277,10 @@ def __init__( ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) - self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() - def forward( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None - ): - - k_x = self.ln_1_kv(k_x) if k_x is not None else None - v_x = self.ln_1_kv(v_x) if v_x is not None else None - - x = q_x + self.ls_1( - self.ln_attn(self.attn(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) - ) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x @@ -355,6 +330,9 @@ def __init__( mlp_ratio: float, ls_init_value: float = None, global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, output_dim: int = 512, patch_dropout: float = 0., act_layer: Callable = nn.GELU, @@ -386,8 +364,13 @@ def __init__( ) self.global_average_pool = global_average_pool - self.ln_post = norm_layer(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) self.init_parameters() @@ -466,18 +449,22 @@ def forward(self, x: torch.Tensor, output_tokens: bool = False): x = x.permute(1, 0, 2) # LND -> NLD - if not output_tokens: - if self.global_average_pool: - x = x.mean(dim=1) - else: - x = x[:, 0] - + if hasattr(self, "attn_pool"): + x = self.attn_pool(x) x = self.ln_post(x) - if self.proj is not None: - x = x @ self.proj + if self.global_average_pool: + pooled, tokens = x.mean(dim=1), x + else: + pooled, tokens = x[:, 0], x[:, 1:] - return x + if not hasattr(self, "attn_pool"): + pooled = self.ln_post(pooled) + + if self.proj is not None: + pooled = pooled @ self.proj + + return (pooled, tokens) if output_tokens else pooled class TextTransformer(nn.Module): @@ -492,13 +479,24 @@ def __init__( output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, ): super().__init__() self.context_length = context_length self.vocab_size = vocab_size self.width = width self.output_dim = output_dim - + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + if embed_cls: + self.embed_cls = embed_cls + self.cls_emb = nn.Parameter(torch.empty(width)) + self.heads = heads + self.pad_id = pad_id + self.context_length += 1 + + self.token_embedding = nn.Embedding(vocab_size, width) self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) self.transformer = Transformer( @@ -510,8 +508,7 @@ def __init__( norm_layer=norm_layer, ) self.ln_final = norm_layer(width) - self.text_projection = nn.Parameter(torch.empty(width, output_dim)) - + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) self.init_parameters() @@ -519,6 +516,8 @@ def __init__( def init_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) + if hasattr(self, "embed_cls") and self.embed_cls: + nn.init.normal_(self.cls_emb, std=0.01) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 @@ -543,23 +542,52 @@ def build_attention_mask(self): mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask - - def forward(self, text): + + def build_cls_mask(self, text, cast_dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) + additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text, output_tokens: bool = False): cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.to(cast_dtype) + attn_mask = self.attn_mask + if hasattr(self, "embed_cls") and self.embed_cls: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) + x = self.transformer(x, attn_mask=attn_mask) x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - return x + if hasattr(self, "embed_cls") and self.embed_cls: + pooled = x[:, -1] + tokens = x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] + tokens = x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + return (pooled, tokens) if output_tokens else pooled class MultimodalTransformer(Transformer): @@ -614,16 +642,12 @@ def init_parameters(self): if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.transformer.grad_checkpointing = enable - def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal + mask.triu_(1) # zero out the lower diagonal return mask def forward(self, image_embs, text_embs): @@ -644,3 +668,7 @@ def forward(self, image_embs, text_embs): x = self.ln_final(x) return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/src/training/train.py b/src/training/train.py index 3899590c3..405c0a633 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -64,9 +64,9 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) if args.accum_freq > 1: - accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], [] + accum_images, accum_texts, accum_features = [], [], {} - loss_m = AverageMeter() + losses_m = {} batch_time_m = AverageMeter() data_time_m = AverageMeter() end = time.time() @@ -86,18 +86,24 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args if args.accum_freq == 1: with autocast(): - model_out = model(images, texts) - logit_scale = model_out[-1] - total_loss = loss(*model_out) + model_out = model(images, texts, output_dict=True) + logit_scale = model_out["logit_scale"] + losses = loss(**model_out, output_dict=True) + total_loss = sum(losses.values()) + losses["loss"] = total_loss backward(total_loss, scaler) else: # First, cache the features without any gradient tracking. with torch.no_grad(): with autocast(): - chunk_image_features, chunk_text_features, _ = model(images, texts) - accum_image_features.append(chunk_image_features) - accum_text_features.append(chunk_text_features) + model_out = model(images, texts, output_dict=True) + model_out.pop("logit_scale") + for key, val in model_out.items(): + if key in accum_features: + accum_features[key].append(val) + else: + accum_features[key] = [val] accum_images.append(images) accum_texts.append(texts) @@ -115,12 +121,14 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args images = accum_images[j] texts = accum_texts[j] with autocast(): - chunk_image_features, chunk_text_features, logit_scale = model(images, texts) - image_features = torch.cat( - accum_image_features[:j] + [chunk_image_features] + accum_image_features[j + 1:]) - text_features = torch.cat( - accum_text_features[:j] + [chunk_text_features] + accum_text_features[j + 1:]) - total_loss = loss(image_features, text_features, logit_scale) + model_out = model(images, texts, output_dict=True) + logit_scale = model_out.pop("logit_scale") + for key, val in accum_features: + accumulated = accum_features[key] + accumulated = accumulated[:j] + [model_out[key]] + accumulated[j + 1:] + losses = loss(**accumulated, logit_scale=logit_scale, output_dict=True) + total_loss = sum(losses.values()) + losses["loss"] = total_loss backward(total_loss, scaler) if scaler is not None: @@ -144,7 +152,7 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args # reset gradient accum, if enabled if args.accum_freq > 1: - accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], [] + accum_images, accum_texts, accum_features = [], [], {} # Note: we clamp to 4.6052 = ln(100), as in the original paper. with torch.no_grad(): @@ -160,26 +168,36 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args percent_complete = 100.0 * batch_count / num_batches_per_epoch # NOTE loss is coarsely sampled, just master node and per log update - loss_m.update(total_loss.item(), batch_size) + for key, val in losses.items(): + if key not in losses_m: + losses_m[key] = AverageMeter() + losses_m[key].update(val.item(), batch_size) + logit_scale_scalar = logit_scale.item() + loss_log = " ".join( + [ + f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" + for loss_name, loss_m in losses_m.items() + ] + ) logging.info( f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " - f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " f"Data (t): {data_time_m.avg:.3f} " f"Batch (t): {batch_time_m.avg:.3f}, {args.accum_freq * args.batch_size * args.world_size / batch_time_m.val:#g}/s " f"LR: {optimizer.param_groups[0]['lr']:5f} " - f"Logit Scale: {logit_scale_scalar:.3f}" + f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log ) # Save train loss / etc. Using non avg meter values as loggers have their own smoothing log_data = { - "loss": loss_m.val, "data_time": data_time_m.val, "batch_time": batch_time_m.val, "samples_per_second": args.accum_freq * args.batch_size * args.world_size / batch_time_m.val, "scale": logit_scale_scalar, "lr": optimizer.param_groups[0]["lr"] - } + } + log_data.update({name:val.val for name,val in losses_m.items()}) + for name, val in log_data.items(): name = "train/" + name if tb_writer is not None: @@ -224,10 +242,10 @@ def evaluate(model, data, epoch, args, tb_writer=None): texts = texts.to(device=device, non_blocking=True) with autocast(): - model_out = model(images, texts) - image_features = model_out[0] - text_features = model_out[1] - logit_scale = model_out[-1] + model_out = model(images, texts, output_dict=True) + image_features = model_out["image_features"] + text_features = model_out["text_features"] + logit_scale = model_out["logit_scale"] # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) @@ -317,7 +335,7 @@ def get_clip_metrics(image_features, text_features, logit_scale): def maybe_compute_generative_loss(model_out): - if len(model_out) > 3: - token_logits = model_out[2] - token_labels = model_out[3] + if "logits" in model_out and "labels" in model_out: + token_logits = model_out["logits"] + token_labels = model_out["labels"] return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) diff --git a/tests/test_inference.py b/tests/test_inference.py index a97b53aeb..4df65b546 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -20,7 +20,8 @@ 'ViT-e-14', 'mt5-xl-ViT-H-14', 'coca_base', - 'coca_ViT-B-32' + 'coca_ViT-B-32', + 'coca_roberta-ViT-B-32' }) if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ: From 061482bf4fe6667c754bc0075e166d07afaa663b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 Jan 2023 09:33:34 -0800 Subject: [PATCH 09/30] Get some basic PEP changes out of the way --- src/open_clip/coca_model.py | 90 ++++++++++++++++--------------- src/open_clip/factory.py | 1 + src/open_clip/generation_utils.py | 11 ++-- src/open_clip/loss.py | 31 +++++------ src/open_clip/transformer.py | 52 +++++++++--------- src/training/train.py | 2 +- 6 files changed, 95 insertions(+), 92 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 62bdbd549..a7b1d869c 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -15,6 +15,7 @@ from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower from .generation_utils import top_a, top_k, top_p + @dataclass class MultimodalCfg(CLIPTextCfg): mlp_ratio: int = 4 @@ -24,37 +25,38 @@ class MultimodalCfg(CLIPTextCfg): attn_pooler_heads: int = 8 latent_dim: int = 512 + class CoCaEncoderDecoder(nn.Module): def __init__(self, encoder, decoder) -> None: super().__init__() self.encoder = encoder self.decoder = decoder - + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.encoder.set_grad_checkpointing(enable) self.decoder.set_grad_checkpointing(enable) + def _build_encoder_decoder_tower( - embed_dim, - multimodal_cfg, - text_cfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, + embed_dim, + multimodal_cfg, + text_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, ): - multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg - + encoder = _build_text_tower( - multimodal_cfg.latent_dim, - text_cfg=text_cfg, - quick_gelu=quick_gelu, + multimodal_cfg.latent_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, cast_dtype=cast_dtype ) - + vocab_size = ( - encoder.config.vocab_size # for hf models + encoder.config.vocab_size # for hf models if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None else multimodal_cfg.vocab_size ) @@ -74,19 +76,20 @@ def _build_encoder_decoder_tower( act_layer=act_layer, norm_layer=norm_layer, ) - + return CoCaEncoderDecoder(encoder, decoder), multimodal_cfg, vocab_size - + + class CoCa(nn.Module): def __init__( - self, - embed_dim, - multimodal_cfg: MultimodalCfg, - text_cfg: CLIPTextCfg, - vision_cfg: CLIPVisionCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, - pad_id: int = 0 + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0 ): super().__init__() multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg @@ -124,7 +127,7 @@ def encode_image(self, images, normalize=True, return_tokens=False): return (image_latent, tokens_embs) if return_tokens else image_latent def encode_text(self, text, normalize=True, return_tokens=False): - text = text[:, :-1] # make space for CLS token + text = text[:, :-1] # make space for CLS token text_latent, token_emb = self.text.encoder(text, output_tokens=True) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return (text_latent, token_emb) if return_tokens else text_latent @@ -133,36 +136,36 @@ def forward(self, image, text, output_dict=False): text_latent, token_embs = self.encode_text(text, return_tokens=True) image_latent, image_embs = self.encode_image(image, return_tokens=True) - + # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] - + token_embs = self.text.decoder(image_embs, token_embs) logits = self.to_logits(token_embs) if output_dict: return { - "image_features":image_latent, - "text_features":text_latent, - "logits":logits, - "labels":labels, - "logit_scale":self.logit_scale.exp() + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() } return image_latent, text_latent, logits, labels, self.logit_scale.exp() def generate( - self, - image, - text, - seq_len, - max_seq_len=77, - mask_prob = 0.0, - temperature = 1., - filter_logits_fn = top_k, - filter_thres = 0.9, - min_p_pow = 2.0, - min_p_ratio = 0.02, - ): + self, + image, + text, + seq_len, + max_seq_len=77, + mask_prob=0.0, + temperature=1., + filter_logits_fn=top_k, + filter_thres=0.9, + min_p_pow=2.0, + min_p_ratio=0.02, + ): assert mask_prob < 1, "mask_prob must be smaller than 1." @@ -196,7 +199,6 @@ def generate( out = torch.cat((out, sample), dim=-1) - out = out[:, t:] if num_dims == 1: diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index c2297cf50..912cdc0ac 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -192,6 +192,7 @@ def create_model( return model + def create_loss(args): if "coca" in args.model.lower(): return CoCaLoss( diff --git a/src/open_clip/generation_utils.py b/src/open_clip/generation_utils.py index 82b7fdefb..fade1f0ae 100644 --- a/src/open_clip/generation_utils.py +++ b/src/open_clip/generation_utils.py @@ -3,12 +3,13 @@ from torch import nn import torch.nn.functional as F + def exists(val): return val is not None -# nucleus -def top_p(logits, thres = 0.9): +def top_p(logits, thres=0.9): + # nucleus sorted_logits, sorted_indices = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) @@ -19,20 +20,18 @@ def top_p(logits, thres = 0.9): sorted_logits[sorted_indices_to_remove] = float('-inf') return sorted_logits.scatter(1, sorted_indices, sorted_logits) -# topk -def top_k(logits, thres = 0.9): +def top_k(logits, thres=0.9): k = ceil((1 - thres) * logits.shape[-1]) val, ind = torch.topk(logits, k) probs = torch.full_like(logits, float('-inf')) probs.scatter_(1, ind, val) return probs -# top_a def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02): probs = F.softmax(logits, dim=-1) limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio logits[probs < limit] = float('-inf') logits[probs >= limit] = 1 - return logits \ No newline at end of file + return logits diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index a838863fb..5a112125d 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -5,6 +5,7 @@ try: import torch.distributed.nn from torch import distributed as dist + has_distributed = True except ImportError: has_distributed = False @@ -115,25 +116,25 @@ def forward(self, image_features, text_features, logit_scale, output_dict=False) labels = self.labels[device] total_loss = ( - F.cross_entropy(logits_per_image, labels) + - F.cross_entropy(logits_per_text, labels) - ) / 2 - + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + return {"contrastive_loss": total_loss} if output_dict else total_loss class CoCaLoss(ClipLoss): def __init__( - self, - caption_loss_weight, - clip_loss_weight, - pad_id=0, # pad_token for open_clip custom tokenizer - local_loss=False, - gather_with_grad=False, - cache_labels=False, - rank=0, - world_size=1, - use_horovod=False, + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, ): super().__init__( local_loss=local_loss, @@ -159,6 +160,6 @@ def forward(self, image_features, text_features, logits, labels, logit_scale, ou caption_loss = caption_loss * self.caption_loss_weight if output_dict: - return {"contrastive_loss":clip_loss, "caption_loss":caption_loss} + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} return clip_loss, caption_loss diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 81527044e..e958f8fa2 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -159,6 +159,7 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): x = self.out_drop(x) return x + class AttentionalPooler(nn.Module): def __init__( self, @@ -175,11 +176,11 @@ def __init__( self.ln_k = norm_layer(context_dim) def forward(self, x: torch.Tensor): - x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND N = x.shape[1] q = self.ln_q(self.query) out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] - return out.permute(1, 0, 2) # LND -> NLD + return out.permute(1, 0, 2) # LND -> NLD def _repeat(self, query, N): return query.unsqueeze(1).repeat(1, N, 1) @@ -214,13 +215,12 @@ def __init__( self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def attention( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None ): - k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x @@ -230,12 +230,11 @@ def attention( )[0] def forward(self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None - ): - + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): k_x = self.ln_1_kv(k_x) if k_x is not None else None v_x = self.ln_1_kv(v_x) if v_x is not None else None @@ -319,6 +318,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = r(x, attn_mask=attn_mask) return x + class VisionTransformer(nn.Module): def __init__( self, @@ -448,7 +448,6 @@ def forward(self, x: torch.Tensor, output_tokens: bool = False): x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD - if hasattr(self, "attn_pool"): x = self.attn_pool(x) x = self.ln_post(x) @@ -466,6 +465,7 @@ def forward(self, x: torch.Tensor, output_tokens: bool = False): return (pooled, tokens) if output_tokens else pooled + class TextTransformer(nn.Module): def __init__( @@ -487,7 +487,7 @@ def __init__( self.vocab_size = vocab_size self.width = width self.output_dim = output_dim - + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) if embed_cls: self.embed_cls = embed_cls @@ -495,8 +495,7 @@ def __init__( self.heads = heads self.pad_id = pad_id self.context_length += 1 - - + self.token_embedding = nn.Embedding(vocab_size, width) self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) self.transformer = Transformer( @@ -508,7 +507,7 @@ def __init__( norm_layer=norm_layer, ) self.ln_final = norm_layer(width) - + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) self.init_parameters() @@ -542,7 +541,7 @@ def build_attention_mask(self): mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask - + def build_cls_mask(self, text, cast_dtype): cls_mask = (text != self.pad_id).unsqueeze(1) cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) @@ -567,7 +566,7 @@ def forward(self, text, output_tokens: bool = False): cls_mask = self.build_cls_mask(text, cast_dtype) attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] - x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=attn_mask) x = x.permute(1, 0, 2) # LND -> NLD @@ -583,10 +582,10 @@ def forward(self, text, output_tokens: bool = False): x = self.ln_final(x) pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] tokens = x - + if self.text_projection is not None: pooled = pooled @ self.text_projection - + return (pooled, tokens) if output_tokens else pooled @@ -616,7 +615,8 @@ def __init__( self.context_length = context_length self.cross_attn = nn.ModuleList([ ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, is_cross_attention=True) + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, + is_cross_attention=True) for _ in range(layers) ]) @@ -647,7 +647,7 @@ def build_attention_mask(self): # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal + mask.triu_(1) # zero out the lower diagonal return mask def forward(self, image_embs, text_embs): @@ -668,7 +668,7 @@ def forward(self, image_embs, text_embs): x = self.ln_final(x) return x - + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable diff --git a/src/training/train.py b/src/training/train.py index 5c0b72cb3..e8c28182a 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -51,6 +51,7 @@ def backward(total_loss, scaler): else: total_loss.backward() + def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=None): device = torch.device(args.device) autocast = get_autocast(args.precision) @@ -275,7 +276,6 @@ def evaluate(model, data, epoch, args, tb_writer=None): logging.info( f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t") - val_metrics = get_clip_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), From d0bd09ebac1ee8a78a07255528ce0ccab23a9da9 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Sat, 21 Jan 2023 21:52:30 +0100 Subject: [PATCH 10/30] Add tests bis (#355) * make jit compilable * redundant annotation * less tests * less annotations * even less annotations * fix name check in ci * some annotations back * make it simpler * make hf simpler too * better jit support with tests * remove extra line * add customtextclip * more jit tests * missing assert * add eval * typo * rever forward changes * clean coca model * more cleaning * last cleaning --- .github/workflows/ci.yml | 2 +- src/open_clip/coca_model.py | 46 +++++++----- src/open_clip/hf_model.py | 25 ++++--- src/open_clip/model.py | 23 +++++- .../model_configs/coca_ViT-B-32.json | 6 +- src/open_clip/model_configs/coca_base.json | 8 +- .../model_configs/coca_roberta-ViT-B-32.json | 6 +- src/open_clip/transformer.py | 46 ++++++++---- src/training/train.py | 28 ++++++- tests/test_inference.py | 73 ++++++++++++++++++- tests/util_test.py | 18 +++++ 11 files changed, 217 insertions(+), 64 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9400d6290..0bd877181 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -83,7 +83,7 @@ jobs: --group ${{ matrix.job }} \ -m regression_test \ tests \ - | head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=])' \ + | head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=-False]|-True])' \ > models_gh_runner.txt if [ -n "${{ inputs.manual_revision_reference }}" ]; then REVISION_REFERENCE=${{ inputs.manual_revision_reference }} diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index a7b1d869c..bc0b66545 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -89,7 +89,7 @@ def __init__( vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, - pad_id: int = 0 + pad_id: int = 0, ): super().__init__() multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg @@ -121,37 +121,43 @@ def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) - def encode_image(self, images, normalize=True, return_tokens=False): - image_latent, tokens_embs = self.visual(images, output_tokens=True) + def _encode_image(self, images, normalize=True): + image_latent, tokens_embs = self.visual(images) image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent - return (image_latent, tokens_embs) if return_tokens else image_latent + return image_latent, tokens_embs - def encode_text(self, text, normalize=True, return_tokens=False): - text = text[:, :-1] # make space for CLS token - text_latent, token_emb = self.text.encoder(text, output_tokens=True) + def _encode_text(self, text, normalize=True): + text = text[:, :-1] # make space for CLS token + text_latent, token_emb = self.text.encoder(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent - return (text_latent, token_emb) if return_tokens else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True): + text_latent, _ = self._encode_text(text, normalize=normalize) + return text_latent + def forward(self, image, text, output_dict=False): - text_latent, token_embs = self.encode_text(text, return_tokens=True) - image_latent, image_embs = self.encode_image(image, return_tokens=True) + text_latent, token_embs = self._encode_text(text) + image_latent, image_embs = self._encode_image(image) # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] token_embs = self.text.decoder(image_embs, token_embs) logits = self.to_logits(token_embs) - if output_dict: - return { - "image_features": image_latent, - "text_features": text_latent, - "logits": logits, - "labels": labels, - "logit_scale": self.logit_scale.exp() - } - - return image_latent, text_latent, logits, labels, self.logit_scale.exp() + return { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() + } def generate( self, diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index 4ffd3b849..e7df076e7 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -89,10 +89,13 @@ def __init__( config: PretrainedConfig = None, pooler_type: str = None, proj: str = None, - pretrained: bool = True + pretrained: bool = True, + output_tokens: bool = False ): super().__init__() - + self.output_tokens = None + if output_tokens: + self.output_tokens = output_tokens self.output_dim = output_dim # TODO: find better way to get this information @@ -131,21 +134,21 @@ def __init__( nn.Linear(hidden_size, output_dim, bias=False), ) - def forward(self, x: TensorType, output_tokens=False) -> TensorType: + def forward(self, x: TensorType): attn_mask = (x != self.config.pad_token_id).long() out = self.transformer(input_ids=x, attention_mask=attn_mask) pooled_out = self.pooler(out, attn_mask) projected = self.proj(pooled_out) - if output_tokens: - seq_len = out.last_hidden_state.shape[1] - tokens = ( - out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] - if type(self.pooler) == ClsPooler - else out.last_hidden_state - ) - return projected, tokens + seq_len = out.last_hidden_state.shape[1] + tokens = ( + out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] + if type(self.pooler) == ClsPooler + else out.last_hidden_state + ) + if self.output_tokens is not None: + return projected, tokens return projected def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 5fc3a16e5..e1064fbaf 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -39,6 +39,7 @@ class CLIPVisionCfg: timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection + output_tokens: bool = False @dataclass @@ -56,6 +57,7 @@ class CLIPTextCfg: pooler_type: str = 'mean_pooler' embed_cls: bool = False pad_id: int = 0 + output_tokens: bool = False def get_cast_dtype(precision: str): @@ -117,6 +119,7 @@ def _build_vision_tower( attentional_pool=vision_cfg.attentional_pool, n_queries=vision_cfg.n_queries, attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, @@ -140,7 +143,8 @@ def _build_text_tower( output_dim=embed_dim, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, - pretrained=text_cfg.hf_model_pretrained + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens ) else: act_layer = QuickGELU if quick_gelu else nn.GELU @@ -155,6 +159,7 @@ def _build_text_tower( ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, pad_id=text_cfg.pad_id, act_layer=act_layer, norm_layer=norm_layer, @@ -170,8 +175,12 @@ def __init__( text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False ): super().__init__() + self.output_dict = None + if output_dict: + self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) @@ -212,10 +221,11 @@ def encode_text(self, text, normalize: bool = False): x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return F.normalize(x, dim=-1) if normalize else x - def forward(self, image, text, output_dict=False): + def forward(self, image, text, output_dict:bool=False): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) - if output_dict: + output_dict = self.output_dict + if output_dict is not None: return { "image_features":image_features, "text_features":text_features, @@ -232,8 +242,12 @@ def __init__( text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False ): super().__init__() + self.output_dict = None + if output_dict: + self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) @@ -261,7 +275,8 @@ def encode_text(self, text, normalize: bool = False): def forward(self, image, text, output_dict=False): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) - if output_dict: + output_dict = self.output_dict + if output_dict is not None: return { "image_features":image_features, "text_features":text_features, diff --git a/src/open_clip/model_configs/coca_ViT-B-32.json b/src/open_clip/model_configs/coca_ViT-B-32.json index 8e13625c8..bcf207e9d 100644 --- a/src/open_clip/model_configs/coca_ViT-B-32.json +++ b/src/open_clip/model_configs/coca_ViT-B-32.json @@ -6,7 +6,8 @@ "width": 768, "patch_size": 32, "attentional_pool": true, - "attn_pooler_heads": 8 + "attn_pooler_heads": 8, + "output_tokens": true }, "text_cfg": { "context_length": 76, @@ -14,7 +15,8 @@ "width": 512, "heads": 8, "layers": 12, - "embed_cls": true + "embed_cls": true, + "output_tokens": true }, "multimodal_cfg": { "context_length": 76, diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index d26b61087..30e04456b 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -16,7 +16,8 @@ "image_size": 288, "layers": 12, "width": 768, - "patch_size": 18 + "patch_size": 18, + "output_tokens": true }, "text_cfg": { "context_length": 76, @@ -24,7 +25,8 @@ "layers": 12, "heads": 12, "width": 768, - "embed_cls": true + "embed_cls": true, + "output_tokens": true }, - "custom_text": "True" + "custom_text": true } \ No newline at end of file diff --git a/src/open_clip/model_configs/coca_roberta-ViT-B-32.json b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json index 6b706c861..a4e19ab5a 100644 --- a/src/open_clip/model_configs/coca_roberta-ViT-B-32.json +++ b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -4,13 +4,15 @@ "image_size": 224, "layers": 12, "width": 768, - "patch_size": 32 + "patch_size": 32, + "output_tokens": true }, "text_cfg": { "hf_model_name": "roberta-base", "hf_tokenizer_name": "roberta-base", "proj": "linear", - "width": 768 + "width": 768, + "output_tokens": true }, "multimodal_cfg": { "context_length": 76, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index e958f8fa2..11e9a883c 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -182,7 +182,7 @@ def forward(self, x: torch.Tensor): out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] return out.permute(1, 0, 2) # LND -> NLD - def _repeat(self, query, N): + def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) @@ -230,13 +230,13 @@ def attention( )[0] def forward(self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None - ): - k_x = self.ln_1_kv(k_x) if k_x is not None else None - v_x = self.ln_1_kv(v_x) if v_x is not None else None + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) @@ -337,8 +337,12 @@ def __init__( patch_dropout: float = 0., act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + output_tokens: bool = False ): super().__init__() + self.output_tokens = None + if output_tokens: + self.output_tokens = output_tokens self.image_size = to_2tuple(image_size) self.patch_size = to_2tuple(patch_size) self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) @@ -431,7 +435,7 @@ def init_parameters(self): def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable - def forward(self, x: torch.Tensor, output_tokens: bool = False): + def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -463,7 +467,10 @@ def forward(self, x: torch.Tensor, output_tokens: bool = False): if self.proj is not None: pooled = pooled @ self.proj - return (pooled, tokens) if output_tokens else pooled + if self.output_tokens is not None: + return pooled, tokens + + return pooled class TextTransformer(nn.Module): @@ -481,8 +488,12 @@ def __init__( norm_layer: Callable = LayerNorm, embed_cls: bool = False, pad_id: int = 0, + output_tokens: bool = False ): super().__init__() + self.output_tokens = None + if output_tokens: + self.output_tokens = output_tokens self.context_length = context_length self.vocab_size = vocab_size self.width = width @@ -542,19 +553,19 @@ def build_attention_mask(self): mask.triu_(1) # zero out the lower diagonal return mask - def build_cls_mask(self, text, cast_dtype): + def build_cls_mask(self, text, cast_dtype: torch.dtype): cls_mask = (text != self.pad_id).unsqueeze(1) - cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) - additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) additive_mask.fill_(0) additive_mask.masked_fill_(~cls_mask, float("-inf")) additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) return additive_mask - def _repeat(self, t, N): + def _repeat(self, t, N: int): return t.reshape(1, 1, -1).repeat(N, 1, 1) - def forward(self, text, output_tokens: bool = False): + def forward(self, text): cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] @@ -586,7 +597,10 @@ def forward(self, text, output_tokens: bool = False): if self.text_projection is not None: pooled = pooled @ self.text_projection - return (pooled, tokens) if output_tokens else pooled + if self.output_tokens is not None: + return pooled, tokens + + return pooled class MultimodalTransformer(Transformer): diff --git a/src/training/train.py b/src/training/train.py index e8c28182a..368d3163f 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -13,7 +13,7 @@ except ImportError: wandb = None -from open_clip import get_cast_dtype +from open_clip import get_cast_dtype, CLIP, CustomTextCLIP from .distributed import is_master from .zero_shot import zero_shot_eval from .precision import get_autocast @@ -37,6 +37,15 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count +def is_clip(model): + return type(model) in [CLIP, CustomTextCLIP] + +def postprocess_clip_output(model_out): + return { + "image_features": model_out[0], + "text_features": model_out[1], + "logit_scale": model_out[2] + } def unwrap_model(model): if hasattr(model, 'module'): @@ -87,9 +96,13 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args if args.accum_freq == 1: with autocast(): - model_out = model(images, texts, output_dict=True) + model_out = model(images, texts) + # for clip if it does not output_dict + if is_clip(model) and not model.output_dict: + model_out = postprocess_clip_output(model_out) logit_scale = model_out["logit_scale"] losses = loss(**model_out, output_dict=True) + total_loss = sum(losses.values()) losses["loss"] = total_loss @@ -98,7 +111,10 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args # First, cache the features without any gradient tracking. with torch.no_grad(): with autocast(): - model_out = model(images, texts, output_dict=True) + model_out = model(images, texts) + # for clip if it does not output_dict + if is_clip(model) and not model.output_dict: + model_out = postprocess_clip_output(model_out) model_out.pop("logit_scale") for key, val in model_out.items(): if key in accum_features: @@ -123,6 +139,9 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args texts = accum_texts[j] with autocast(): model_out = model(images, texts, output_dict=True) + # for clip if it does not output_dict + if is_clip(model) and not model_out.output_dict: + model_out = postprocess_clip_output(model_out) logit_scale = model_out.pop("logit_scale") for key, val in accum_features: accumulated = accum_features[key] @@ -244,6 +263,9 @@ def evaluate(model, data, epoch, args, tb_writer=None): with autocast(): model_out = model(images, texts, output_dict=True) + # for clip if it does not output_dict + if is_clip(model) and not model.output_dict: + model_out = postprocess_clip_output(model_out) image_features = model_out["image_features"] text_features = model_out["text_features"] logit_scale = model_out["logit_scale"] diff --git a/tests/test_inference.py b/tests/test_inference.py index 4ae1d6603..1f074e4b2 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -7,6 +7,12 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '' +if hasattr(torch._C, '_jit_set_profiling_executor'): + # legacy executor is too slow to compile large models for unit tests + # no need for the fusion performance here + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(False) + models_to_test = set(open_clip.list_models()) # testing excemptions @@ -32,17 +38,46 @@ models_to_test = set(f.read().splitlines()).intersection(models_to_test) print(f"Selected models from {external_model_list}: {models_to_test}") +# TODO: add "coca_ViT-B-32" onece https://github.com/pytorch/pytorch/issues/92073 gets fixed models_to_test = list(models_to_test) models_to_test.sort() +models_to_test = [(model_name, False) for model_name in models_to_test] + +models_to_jit_test = {"ViT-B-32"} +models_to_jit_test = list(models_to_jit_test) +models_to_jit_test = [(model_name, True) for model_name in models_to_jit_test] +models_to_test_fully = models_to_test + models_to_jit_test + + +class TestWrapper(torch.nn.Module): + def __init__(self, model, model_name, output_dict=True) -> None: + super().__init__() + self.model = model + self.output_dict = None + if output_dict: + self.output_dict = output_dict + if type(model) in [open_clip.CLIP, open_clip.CustomTextCLIP]: + self.model.output_dict = self.output_dict + config = open_clip.get_model_config(model_name) + self.head = torch.nn.Linear(config["embed_dim"], 2) + + def forward(self, image, text): + output_dict = self.output_dict + x = self.model(image, text) + if output_dict is not None: + out = self.head(x["image_features"]) + else: + out = self.head(x[0]) + return {"test_output": out} @pytest.mark.regression_test -@pytest.mark.parametrize('model_name', models_to_test) +@pytest.mark.parametrize("model_name,jit", models_to_test_fully) def test_inference_with_data( model_name, + jit, pretrained = None, pretrained_hf = False, precision = 'fp32', - jit = False, force_quick_gelu = False, ): util_test.seed_all() @@ -81,5 +116,39 @@ def test_inference_with_data( gt_image = torch.load(gt_image_path) y_image = util_test.inference_image(model, preprocess_val, input_image) assert (y_image == gt_image).all(), f"image output differs @ {input_image_path}" + + if not jit: + model.eval() + model_out = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) + if type(model) not in [open_clip.CLIP, open_clip.CustomTextCLIP]: + assert type(model_out) == dict + else: + model.output_dict = True + model_out_dict = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) + assert (model_out_dict["image_features"] == model_out[0]).all() + assert (model_out_dict["text_features"] == model_out[1]).all() + assert (model_out_dict["logit_scale"] == model_out[2]).all() + model.output_dict = None + else: + model, _, preprocess_val = open_clip.create_model_and_transforms( + model_name, + pretrained = pretrained, + precision = precision, + jit = False, + force_quick_gelu = force_quick_gelu, + pretrained_hf = pretrained_hf + ) + + test_model = TestWrapper(model, model_name, output_dict=False) + test_model = torch.jit.script(test_model) + model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) + assert model_out["test_output"].shape[-1] == 2 + + test_model = TestWrapper(model, model_name, output_dict=True) + test_model = torch.jit.script(test_model) + model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) + assert model_out["test_output"].shape[-1] == 2 + + diff --git a/tests/util_test.py b/tests/util_test.py index b2a2c9c3d..e0a386a8b 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -33,6 +33,24 @@ def inference_image(model, preprocess_val, batches): x = torch.stack([preprocess_val(img) for img in x]) y.append(model.encode_image(x)) return torch.stack(y) + +def forward_model(model, model_name, preprocess_val, image_batch, text_batch): + y = [] + tokenizer = open_clip.get_tokenizer(model_name) + with torch.no_grad(): + for x_im, x_txt in zip(image_batch, text_batch): + x_im = torch.stack([preprocess_val(im) for im in x_im]) + x_txt = tokenizer(x_txt) + y.append(model(x_im, x_txt)) + if type(y[0]) == dict: + out = {} + for key in y[0].keys(): + out[key] = torch.stack([batch_out[key] for batch_out in y]) + else: + out = [] + for i in range(len(y[0])): + out.append(torch.stack([batch_out[i] for batch_out in y])) + return out def random_image_batch(batch_size, size): h, w = size From 2ab47b754ce9283ee297ef078e80c1fcd8e51cd9 Mon Sep 17 00:00:00 2001 From: Maciej Kilian Date: Sat, 21 Jan 2023 14:49:28 -0800 Subject: [PATCH 11/30] train.py: fix is_clip when doing distributed (#364) --- src/training/train.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index 04dd3b3fd..88e4c5a34 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -7,6 +7,7 @@ import numpy as np import torch import torch.nn.functional as F +from torch.nn.parallel.distributed import DistributedDataParallel try: import wandb @@ -98,7 +99,8 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args with autocast(): model_out = model(images, texts) # for clip if it does not output_dict - if is_clip(model) and not model.output_dict: + module = model.module if type(model) == DistributedDataParallel else model + if is_clip(module) and not module.output_dict: model_out = postprocess_clip_output(model_out) logit_scale = model_out["logit_scale"] losses = loss(**model_out, output_dict=True) @@ -113,7 +115,8 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args with autocast(): model_out = model(images, texts) # for clip if it does not output_dict - if is_clip(model) and not model.output_dict: + module = model.module if type(model) == DistributedDataParallel else model + if is_clip(module) and not module.output_dict: model_out = postprocess_clip_output(model_out) model_out.pop("logit_scale") for key, val in model_out.items(): @@ -140,7 +143,8 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args with autocast(): model_out = model(images, texts, output_dict=True) # for clip if it does not output_dict - if is_clip(model) and not model_out.output_dict: + module = model.module if type(model) == DistributedDataParallel else model + if is_clip(module) and not model.output_dict: model_out = postprocess_clip_output(model_out) logit_scale = model_out.pop("logit_scale") for key, val in accum_features: @@ -264,7 +268,8 @@ def evaluate(model, data, epoch, args, tb_writer=None): with autocast(): model_out = model(images, texts, output_dict=True) # for clip if it does not output_dict - if is_clip(model) and not model.output_dict: + module = model.module if type(model) == DistributedDataParallel else model + if is_clip(module) and not module.output_dict: model_out = postprocess_clip_output(model_out) image_features = model_out["image_features"] text_features = model_out["text_features"] From c0e5950f6cd34d15552339faf6922fdebcb1bd0e Mon Sep 17 00:00:00 2001 From: Maciej Kilian Date: Sat, 21 Jan 2023 16:01:55 -0800 Subject: [PATCH 12/30] add README (#365) * add README * multimodal_cfg info * multimodal --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index 4bf3aaabd..b6a96372b 100644 --- a/README.md +++ b/README.md @@ -257,6 +257,20 @@ python -m training.main \ --resume /path/to/checkpoints/epoch_K.pt ``` +### Training CoCa: +Training [CoCa](https://arxiv.org/abs/2205.01917) models is enabled through specifying a CoCa config using the ```--model``` parameter of the training script. Currently available configs are "coca_base", "coca_ViT-B-32", and "coca_roberta-ViT-B-32" (which uses RoBERTa as the text encoder). CoCa configs are different from CLIP configs because they have an additional "multimodal_cfg" component which specifies parameters for the multimodal text decoder. Here's an example from the coca_ViT-B-32 config: +```json +"multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "latent_dim": 512, + "attn_pooler_heads": 8 +} +``` + ### Training with pre-trained language models as text encoder: If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen: From 3f5b0fb13973b6c5a8c495b84ebf92cc76078dac Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Sun, 22 Jan 2023 21:43:32 +0100 Subject: [PATCH 13/30] remove output_dict argument (#368) * remove output_dict argument * cleaner --- src/open_clip/model.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 1b8c1d09c..dbe62f9d2 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -225,11 +225,10 @@ def encode_text(self, text, normalize: bool = False): x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return F.normalize(x, dim=-1) if normalize else x - def forward(self, image, text, output_dict:bool=False): + def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) - output_dict = self.output_dict - if output_dict is not None: + if self.output_dict is not None: return { "image_features":image_features, "text_features":text_features, @@ -276,11 +275,10 @@ def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features - def forward(self, image, text, output_dict=False): + def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) - output_dict = self.output_dict - if output_dict is not None: + if self.output_dict is not None: return { "image_features":image_features, "text_features":text_features, From de343fb73e9512c63bcbf3d902359c652580aef0 Mon Sep 17 00:00:00 2001 From: Maciej Kilian Date: Sun, 22 Jan 2023 14:16:35 -0800 Subject: [PATCH 14/30] do same thing for _encode_image (#366) * do same thing for _encode_image * encoder * try this * adjust inference tests * fix syntax * True not None * dumb --- src/open_clip/coca_model.py | 27 +++++++++++---------------- tests/util_test.py | 6 ++++-- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index bc0b66545..4481f340d 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -121,30 +121,25 @@ def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) - def _encode_image(self, images, normalize=True): + def encode_image(self, images, normalize=True): image_latent, tokens_embs = self.visual(images) image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent - return image_latent, tokens_embs + if getattr(self.visual, "output_tokens", False): + return image_latent, tokens_embs + return image_latent - def _encode_text(self, text, normalize=True): + def encode_text(self, text, normalize=True): text = text[:, :-1] # make space for CLS token - text_latent, token_emb = self.text.encoder(text) + text_latent, token_embs = self.text.encoder(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent - return text_latent, token_emb + if getattr(self.text.encoder, "output_tokens", False): + return text_latent, token_embs + return text_latent - def encode_image(self, images, normalize=True): - image_latent, _ = self._encode_image(images, normalize=normalize) - return image_latent - - def encode_text(self, text, normalize=True): - text_latent, _ = self._encode_text(text, normalize=normalize) - return text_latent - - def forward(self, image, text, output_dict=False): - text_latent, token_embs = self._encode_text(text) - image_latent, image_embs = self._encode_image(image) + text_latent, token_embs = self.encode_text(text) + image_latent, image_embs = self.encode_image(image) # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] diff --git a/tests/util_test.py b/tests/util_test.py index e0a386a8b..fba404de3 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -23,7 +23,8 @@ def inference_text(model, model_name, batches): with torch.no_grad(): for x in batches: x = tokenizer(x) - y.append(model.encode_text(x)) + out = model.encode_text(x) + y.append(out[0] if isinstance(out, tuple) else out) return torch.stack(y) def inference_image(model, preprocess_val, batches): @@ -31,7 +32,8 @@ def inference_image(model, preprocess_val, batches): with torch.no_grad(): for x in batches: x = torch.stack([preprocess_val(img) for img in x]) - y.append(model.encode_image(x)) + out = model.encode_image(x) + y.append(out[0] if isinstance(out, tuple) else out) return torch.stack(y) def forward_model(model, model_name, preprocess_val, image_batch, text_batch): From 88aa6ce1749f8fe86ae50097535cf985b4c9b5f7 Mon Sep 17 00:00:00 2001 From: iejMac Date: Mon, 23 Jan 2023 03:07:10 +0000 Subject: [PATCH 15/30] CoCa/forward: remove unused output_dict param --- src/open_clip/coca_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 4481f340d..182786a2c 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -136,7 +136,7 @@ def encode_text(self, text, normalize=True): return text_latent, token_embs return text_latent - def forward(self, image, text, output_dict=False): + def forward(self, image, text): text_latent, token_embs = self.encode_text(text) image_latent, image_embs = self.encode_image(image) From 3b66f37926bc44f309d113283b9224ce75877f82 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Tue, 24 Jan 2023 09:58:25 +0100 Subject: [PATCH 16/30] Revert "do same thing for _encode_image (#366)" This reverts commit de343fb73e9512c63bcbf3d902359c652580aef0. --- src/open_clip/coca_model.py | 29 +++++++++++++++++------------ tests/util_test.py | 6 ++---- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 182786a2c..bc0b66545 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -121,25 +121,30 @@ def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) - def encode_image(self, images, normalize=True): + def _encode_image(self, images, normalize=True): image_latent, tokens_embs = self.visual(images) image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent - if getattr(self.visual, "output_tokens", False): - return image_latent, tokens_embs - return image_latent + return image_latent, tokens_embs - def encode_text(self, text, normalize=True): + def _encode_text(self, text, normalize=True): text = text[:, :-1] # make space for CLS token - text_latent, token_embs = self.text.encoder(text) + text_latent, token_emb = self.text.encoder(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent - if getattr(self.text.encoder, "output_tokens", False): - return text_latent, token_embs - return text_latent + return text_latent, token_emb - def forward(self, image, text): + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True): + text_latent, _ = self._encode_text(text, normalize=normalize) + return text_latent + + + def forward(self, image, text, output_dict=False): - text_latent, token_embs = self.encode_text(text) - image_latent, image_embs = self.encode_image(image) + text_latent, token_embs = self._encode_text(text) + image_latent, image_embs = self._encode_image(image) # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] diff --git a/tests/util_test.py b/tests/util_test.py index fba404de3..e0a386a8b 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -23,8 +23,7 @@ def inference_text(model, model_name, batches): with torch.no_grad(): for x in batches: x = tokenizer(x) - out = model.encode_text(x) - y.append(out[0] if isinstance(out, tuple) else out) + y.append(model.encode_text(x)) return torch.stack(y) def inference_image(model, preprocess_val, batches): @@ -32,8 +31,7 @@ def inference_image(model, preprocess_val, batches): with torch.no_grad(): for x in batches: x = torch.stack([preprocess_val(img) for img in x]) - out = model.encode_image(x) - y.append(out[0] if isinstance(out, tuple) else out) + y.append(model.encode_image(x)) return torch.stack(y) def forward_model(model, model_name, preprocess_val, image_batch, text_batch): From cdb91ddf3c421f7fa28f337ab094730ce52c0627 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Tue, 24 Jan 2023 11:36:15 +0100 Subject: [PATCH 17/30] refactor --- src/open_clip/coca_model.py | 76 +++++++++++++++---------------------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index bc0b66545..51aeea831 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -26,41 +26,13 @@ class MultimodalCfg(CLIPTextCfg): latent_dim: int = 512 -class CoCaEncoderDecoder(nn.Module): - def __init__(self, encoder, decoder) -> None: - super().__init__() - self.encoder = encoder - self.decoder = decoder - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.encoder.set_grad_checkpointing(enable) - self.decoder.set_grad_checkpointing(enable) - - -def _build_encoder_decoder_tower( +def _build_text_decoder_tower( embed_dim, multimodal_cfg, - text_cfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg - text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg - - encoder = _build_text_tower( - multimodal_cfg.latent_dim, - text_cfg=text_cfg, - quick_gelu=quick_gelu, - cast_dtype=cast_dtype - ) - - vocab_size = ( - encoder.config.vocab_size # for hf models - if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None - else multimodal_cfg.vocab_size - ) - act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm @@ -77,7 +49,7 @@ def _build_encoder_decoder_tower( norm_layer=norm_layer, ) - return CoCaEncoderDecoder(encoder, decoder), multimodal_cfg, vocab_size + return decoder class CoCa(nn.Module): @@ -102,17 +74,29 @@ def __init__( else LayerNorm ) - self.text, multimodal_cfg, vocab_size = _build_encoder_decoder_tower( - embed_dim, multimodal_cfg, text_cfg, quick_gelu, cast_dtype + self.text = _build_text_tower( + multimodal_cfg.latent_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size ) + self.visual = _build_vision_tower( multimodal_cfg.latent_dim, vision_cfg, quick_gelu, cast_dtype ) - - self.to_logits = nn.Sequential( - norm_layer(multimodal_cfg.width), nn.Linear(multimodal_cfg.width, vocab_size, bias=False) + + self.text_decoder = _build_text_decoder_tower( + embed_dim, multimodal_cfg, quick_gelu, cast_dtype ) + self.decoder_norm = norm_layer(multimodal_cfg.width) + self.decoder_logits = nn.Linear(multimodal_cfg.width, vocab_size, bias=False) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.pad_id = pad_id @@ -120,15 +104,16 @@ def __init__( def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) def _encode_image(self, images, normalize=True): image_latent, tokens_embs = self.visual(images) image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent return image_latent, tokens_embs - def _encode_text(self, text, normalize=True): - text = text[:, :-1] # make space for CLS token - text_latent, token_emb = self.text.encoder(text) + def _encode_text(self, text, normalize=True, embed_cls=False): + text = text[:, :-1] if embed_cls else text # make space for CLS token + text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb @@ -136,21 +121,22 @@ def encode_image(self, images, normalize=True): image_latent, _ = self._encode_image(images, normalize=normalize) return image_latent - def encode_text(self, text, normalize=True): - text_latent, _ = self._encode_text(text, normalize=normalize) + def encode_text(self, text, normalize=True, embed_cls=False): + text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) return text_latent - def forward(self, image, text, output_dict=False): + def forward(self, image, text, embed_cls=True): - text_latent, token_embs = self._encode_text(text) + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) image_latent, image_embs = self._encode_image(image) # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] - token_embs = self.text.decoder(image_embs, token_embs) - logits = self.to_logits(token_embs) + token_embs = self.text_decoder(image_embs, token_embs) + token_embs = self.decoder_norm(token_embs) + logits = self.decoder_logits(token_embs) return { "image_features": image_latent, "text_features": text_latent, @@ -189,7 +175,7 @@ def generate( x = out[:, -max_seq_len:] # TODO: adjust for dict output - logits = self(image, x)[2][:, -1] + logits = self(image, x, embed_cls=False)["logits"][:, -1] if filter_logits_fn in {top_k, top_p}: filtered_logits = filter_logits_fn(logits, thres=filter_thres) From 58eb5bd214517ad153390ddfb91c0d4516fec337 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Tue, 24 Jan 2023 18:52:38 +0100 Subject: [PATCH 18/30] white space --- src/open_clip/coca_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 51aeea831..47b2defa8 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -90,7 +90,7 @@ def __init__( self.visual = _build_vision_tower( multimodal_cfg.latent_dim, vision_cfg, quick_gelu, cast_dtype ) - + self.text_decoder = _build_text_decoder_tower( embed_dim, multimodal_cfg, quick_gelu, cast_dtype ) From cbd66ed6b702e13e0426d8a6a2ba9112971d3433 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Tue, 24 Jan 2023 19:24:28 +0100 Subject: [PATCH 19/30] remove extra layer norm --- src/open_clip/coca_model.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 47b2defa8..7cda97b71 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -95,7 +95,6 @@ def __init__( embed_dim, multimodal_cfg, quick_gelu, cast_dtype ) - self.decoder_norm = norm_layer(multimodal_cfg.width) self.decoder_logits = nn.Linear(multimodal_cfg.width, vocab_size, bias=False) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.pad_id = pad_id @@ -111,8 +110,8 @@ def _encode_image(self, images, normalize=True): image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent return image_latent, tokens_embs - def _encode_text(self, text, normalize=True, embed_cls=False): - text = text[:, :-1] if embed_cls else text # make space for CLS token + def _encode_text(self, text, normalize=True): + text = text[:, :-1] # make space for CLS token text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb @@ -121,21 +120,20 @@ def encode_image(self, images, normalize=True): image_latent, _ = self._encode_image(images, normalize=normalize) return image_latent - def encode_text(self, text, normalize=True, embed_cls=False): - text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + def encode_text(self, text, normalize=True): + text_latent, _ = self._encode_text(text, normalize=normalize) return text_latent - def forward(self, image, text, embed_cls=True): + def forward(self, image, text): - text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + text_latent, token_embs = self._encode_text(text) image_latent, image_embs = self._encode_image(image) # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] token_embs = self.text_decoder(image_embs, token_embs) - token_embs = self.decoder_norm(token_embs) logits = self.decoder_logits(token_embs) return { "image_features": image_latent, From bf6ef3e8587750e824905eb89bd14fabd039f99a Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Tue, 24 Jan 2023 19:58:35 +0100 Subject: [PATCH 20/30] move to_logits into decoder --- src/open_clip/coca_model.py | 6 ++---- src/open_clip/transformer.py | 4 ++++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 7cda97b71..73ecdb383 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -92,10 +92,9 @@ def __init__( ) self.text_decoder = _build_text_decoder_tower( - embed_dim, multimodal_cfg, quick_gelu, cast_dtype + vocab_size, multimodal_cfg, quick_gelu, cast_dtype ) - self.decoder_logits = nn.Linear(multimodal_cfg.width, vocab_size, bias=False) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.pad_id = pad_id @@ -133,8 +132,7 @@ def forward(self, image, text): # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] - token_embs = self.text_decoder(image_embs, token_embs) - logits = self.decoder_logits(token_embs) + logits = self.text_decoder(image_embs, token_embs) return { "image_features": image_latent, "text_features": text_latent, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 11e9a883c..840b3e039 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -637,6 +637,7 @@ def __init__( self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) def init_parameters(self): proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) @@ -681,6 +682,9 @@ def forward(self, image_embs, text_embs): x = text_embs.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) + if self.text_projection is not None: + x = x @ self.text_projection + return x @torch.jit.ignore From 03dfeab54ee1709cb4fc1daab644c22037a987a0 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Tue, 24 Jan 2023 23:29:19 +0100 Subject: [PATCH 21/30] leave for later --- src/open_clip/coca_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 73ecdb383..2eeeb3dd0 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -171,7 +171,7 @@ def generate( x = out[:, -max_seq_len:] # TODO: adjust for dict output - logits = self(image, x, embed_cls=False)["logits"][:, -1] + logits = self(image, x)["logits"][:, -1] if filter_logits_fn in {top_k, top_p}: filtered_logits = filter_logits_fn(logits, thres=filter_thres) From 15d6223f7bc3731453859409c37aa5d2d5bf3602 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Mon, 23 Jan 2023 10:52:47 +0100 Subject: [PATCH 22/30] better torchscript --- src/open_clip/model.py | 14 ++++++-------- src/open_clip/transformer.py | 15 ++++++--------- tests/test_inference.py | 25 ++----------------------- tests/util_test.py | 19 +++++++++++++++++++ 4 files changed, 33 insertions(+), 40 deletions(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index dbe62f9d2..a2ddaf1be 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -172,6 +172,7 @@ def _build_text_tower( class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, @@ -182,9 +183,7 @@ def __init__( output_dict: bool = False ): super().__init__() - self.output_dict = None - if output_dict: - self.output_dict = output_dict + self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) @@ -228,7 +227,7 @@ def encode_text(self, text, normalize: bool = False): def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) - if self.output_dict is not None: + if self.output_dict: return { "image_features":image_features, "text_features":text_features, @@ -238,6 +237,7 @@ def forward(self, image, text): class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, @@ -248,9 +248,7 @@ def __init__( output_dict: bool = False ): super().__init__() - self.output_dict = None - if output_dict: - self.output_dict = output_dict + self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) @@ -278,7 +276,7 @@ def encode_text(self, text, normalize: bool = False): def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) - if self.output_dict is not None: + if self.output_dict: return { "image_features":image_features, "text_features":text_features, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 840b3e039..ce5cf03f7 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -320,6 +320,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] def __init__( self, image_size: int, @@ -340,9 +341,7 @@ def __init__( output_tokens: bool = False ): super().__init__() - self.output_tokens = None - if output_tokens: - self.output_tokens = output_tokens + self.output_tokens = output_tokens self.image_size = to_2tuple(image_size) self.patch_size = to_2tuple(patch_size) self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) @@ -467,14 +466,14 @@ def forward(self, x: torch.Tensor): if self.proj is not None: pooled = pooled @ self.proj - if self.output_tokens is not None: + if self.output_tokens: return pooled, tokens return pooled class TextTransformer(nn.Module): - + output_tokens: torch.jit.Final[bool] def __init__( self, context_length: int = 77, @@ -491,9 +490,7 @@ def __init__( output_tokens: bool = False ): super().__init__() - self.output_tokens = None - if output_tokens: - self.output_tokens = output_tokens + self.output_tokens = output_tokens self.context_length = context_length self.vocab_size = vocab_size self.width = width @@ -597,7 +594,7 @@ def forward(self, text): if self.text_projection is not None: pooled = pooled @ self.text_projection - if self.output_tokens is not None: + if self.output_tokens: return pooled, tokens return pooled diff --git a/tests/test_inference.py b/tests/test_inference.py index 1f074e4b2..9be3378c8 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -49,27 +49,6 @@ models_to_test_fully = models_to_test + models_to_jit_test -class TestWrapper(torch.nn.Module): - def __init__(self, model, model_name, output_dict=True) -> None: - super().__init__() - self.model = model - self.output_dict = None - if output_dict: - self.output_dict = output_dict - if type(model) in [open_clip.CLIP, open_clip.CustomTextCLIP]: - self.model.output_dict = self.output_dict - config = open_clip.get_model_config(model_name) - self.head = torch.nn.Linear(config["embed_dim"], 2) - - def forward(self, image, text): - output_dict = self.output_dict - x = self.model(image, text) - if output_dict is not None: - out = self.head(x["image_features"]) - else: - out = self.head(x[0]) - return {"test_output": out} - @pytest.mark.regression_test @pytest.mark.parametrize("model_name,jit", models_to_test_fully) def test_inference_with_data( @@ -139,12 +118,12 @@ def test_inference_with_data( pretrained_hf = pretrained_hf ) - test_model = TestWrapper(model, model_name, output_dict=False) + test_model = util_test.TestWrapper(model, model_name, output_dict=False) test_model = torch.jit.script(test_model) model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) assert model_out["test_output"].shape[-1] == 2 - test_model = TestWrapper(model, model_name, output_dict=True) + test_model = util_test.TestWrapper(model, model_name, output_dict=True) test_model = torch.jit.script(test_model) model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) assert model_out["test_output"].shape[-1] == 2 diff --git a/tests/util_test.py b/tests/util_test.py index e0a386a8b..d09b09a2d 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -196,6 +196,25 @@ def create_test_data( def _sytem_assert(string): assert os.system(string) == 0 +class TestWrapper(torch.nn.Module): + output_dict: torch.jit.Final[bool] + def __init__(self, model, model_name, output_dict=True) -> None: + super().__init__() + self.model = model + self.output_dict = output_dict + if type(model) in [open_clip.CLIP, open_clip.CustomTextCLIP]: + self.model.output_dict = self.output_dict + config = open_clip.get_model_config(model_name) + self.head = torch.nn.Linear(config["embed_dim"], 2) + + def forward(self, image, text): + x = self.model(image, text) + if self.output_dict: + out = self.head(x["image_features"]) + else: + out = self.head(x[0]) + return {"test_output": out} + def main(args): global open_clip import importlib From 9beb0d46f274cfdeae6f68719c809586d62f6667 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Mon, 23 Jan 2023 11:15:04 +0100 Subject: [PATCH 23/30] annotate hf too --- src/open_clip/hf_model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index e7df076e7..47a982209 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -81,7 +81,7 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType): class HFTextEncoder(nn.Module): """HuggingFace model adapter""" - + output_tokens: torch.jit.Final[bool] def __init__( self, model_name_or_path: str, @@ -93,9 +93,7 @@ def __init__( output_tokens: bool = False ): super().__init__() - self.output_tokens = None - if output_tokens: - self.output_tokens = output_tokens + self.output_tokens = output_tokens self.output_dim = output_dim # TODO: find better way to get this information @@ -147,7 +145,7 @@ def forward(self, x: TensorType): else out.last_hidden_state ) - if self.output_tokens is not None: + if self.output_tokens: return projected, tokens return projected From fde2aee191578d522e6e9e4783c6f722eac07e07 Mon Sep 17 00:00:00 2001 From: Maciej Kilian Date: Thu, 26 Jan 2023 16:57:29 -0800 Subject: [PATCH 24/30] Add CoCa-ViT-L/14 config (#379) --- .../model_configs/coca_ViT-L-14.json | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 src/open_clip/model_configs/coca_ViT-L-14.json diff --git a/src/open_clip/model_configs/coca_ViT-L-14.json b/src/open_clip/model_configs/coca_ViT-L-14.json new file mode 100644 index 000000000..40e3c7343 --- /dev/null +++ b/src/open_clip/model_configs/coca_ViT-L-14.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "latent_dim": 768, + "attn_pooler_heads": 12 + }, + "custom_text": true +} From f7c566bf8c28c289ff2e29d9f21131e0b997c9e9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 27 Jan 2023 17:23:30 -0800 Subject: [PATCH 25/30] Remove dead LN code, refactor attn_pool conditional for more clarity, minor formatting tweaks --- src/open_clip/coca_model.py | 22 +++++++++----------- src/open_clip/model.py | 14 +++++++------ src/open_clip/transformer.py | 40 +++++++++++++++++++++++------------- 3 files changed, 44 insertions(+), 32 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 2eeeb3dd0..79a3fce04 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -68,17 +68,11 @@ def __init__( text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg - norm_layer = ( - LayerNormFp32 - if cast_dtype in (torch.float16, torch.bfloat16) - else LayerNorm - ) - self.text = _build_text_tower( multimodal_cfg.latent_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, - cast_dtype=cast_dtype + cast_dtype=cast_dtype, ) vocab_size = ( @@ -88,11 +82,17 @@ def __init__( ) self.visual = _build_vision_tower( - multimodal_cfg.latent_dim, vision_cfg, quick_gelu, cast_dtype + multimodal_cfg.latent_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, ) self.text_decoder = _build_text_decoder_tower( - vocab_size, multimodal_cfg, quick_gelu, cast_dtype + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, ) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) @@ -110,7 +110,7 @@ def _encode_image(self, images, normalize=True): return image_latent, tokens_embs def _encode_text(self, text, normalize=True): - text = text[:, :-1] # make space for CLS token + text = text[:, :-1] # make space for CLS token text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb @@ -123,9 +123,7 @@ def encode_text(self, text, normalize=True): text_latent, _ = self._encode_text(text, normalize=normalize) return text_latent - def forward(self, image, text): - text_latent, token_embs = self._encode_text(text) image_latent, image_embs = self._encode_image(image) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index a2ddaf1be..a0f4b8501 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -95,7 +95,7 @@ def _build_vision_tower( drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, embed_dim=embed_dim, - image_size=vision_cfg.image_size + image_size=vision_cfg.image_size, ) act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models elif isinstance(vision_cfg.layers, (tuple, list)): @@ -105,7 +105,7 @@ def _build_vision_tower( output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, - width=vision_cfg.width + width=vision_cfg.width, ) else: vision_heads = vision_cfg.width // vision_cfg.head_width @@ -148,7 +148,7 @@ def _build_text_tower( proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, pretrained=text_cfg.hf_model_pretrained, - output_tokens=text_cfg.output_tokens + output_tokens=text_cfg.output_tokens, ) else: act_layer = QuickGELU if quick_gelu else nn.GELU @@ -173,6 +173,7 @@ def _build_text_tower( class CLIP(nn.Module): output_dict: torch.jit.Final[bool] + def __init__( self, embed_dim: int, @@ -180,7 +181,7 @@ def __init__( text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, - output_dict: bool = False + output_dict: bool = False, ): super().__init__() self.output_dict = output_dict @@ -238,6 +239,7 @@ def forward(self, image, text): class CustomTextCLIP(nn.Module): output_dict: torch.jit.Final[bool] + def __init__( self, embed_dim: int, @@ -245,7 +247,7 @@ def __init__( text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, - output_dict: bool = False + output_dict: bool = False, ): super().__init__() self.output_dict = output_dict @@ -373,7 +375,7 @@ def build_model_from_openai_state_dict( vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, - layers=transformer_layers + layers=transformer_layers, ) model = CLIP( embed_dim, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index ce5cf03f7..977325888 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -1,6 +1,6 @@ from collections import OrderedDict import math -from typing import Callable, Optional, Sequence +from typing import Callable, Optional, Sequence, Tuple import torch from torch import nn @@ -219,7 +219,7 @@ def attention( q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None + attn_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x @@ -229,12 +229,13 @@ def attention( q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask )[0] - def forward(self, + def forward( + self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None - ): + attn_mask: Optional[torch.Tensor] = None, + ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None @@ -321,6 +322,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class VisionTransformer(nn.Module): output_tokens: torch.jit.Final[bool] + def __init__( self, image_size: int, @@ -372,6 +374,7 @@ def __init__( self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) else: + self.attn_pool = None self.ln_post = norm_layer(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) @@ -434,6 +437,12 @@ def init_parameters(self): def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] @@ -451,16 +460,12 @@ def forward(self, x: torch.Tensor): x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD - if hasattr(self, "attn_pool"): + if self.attn_pool is not None: x = self.attn_pool(x) x = self.ln_post(x) - - if self.global_average_pool: - pooled, tokens = x.mean(dim=1), x + pooled, tokens = self._global_pool(x) else: - pooled, tokens = x[:, 0], x[:, 1:] - - if not hasattr(self, "attn_pool"): + pooled, tokens = self._global_pool(x) pooled = self.ln_post(pooled) if self.proj is not None: @@ -474,6 +479,7 @@ def forward(self, x: torch.Tensor): class TextTransformer(nn.Module): output_tokens: torch.jit.Final[bool] + def __init__( self, context_length: int = 77, @@ -626,8 +632,14 @@ def __init__( self.context_length = context_length self.cross_attn = nn.ModuleList([ ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, - is_cross_attention=True) + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) for _ in range(layers) ]) From 953357505f9159fd39d00b406927ab8685d11dbd Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Sat, 28 Jan 2023 12:55:17 +0100 Subject: [PATCH 26/30] latent_dim to embed_dim --- src/open_clip/coca_model.py | 13 ++++++------- .../model_configs/ViT-B-32_output_dict.json | 17 +++++++++++++++++ src/open_clip/model_configs/coca_ViT-B-32.json | 1 - src/open_clip/model_configs/coca_ViT-L-14.json | 1 - src/open_clip/model_configs/coca_base.json | 1 - .../model_configs/coca_roberta-ViT-B-32.json | 3 +-- 6 files changed, 24 insertions(+), 12 deletions(-) create mode 100644 src/open_clip/model_configs/ViT-B-32_output_dict.json diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 79a3fce04..bd6a0ba00 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -23,7 +23,6 @@ class MultimodalCfg(CLIPTextCfg): heads: int = 8 n_queries: int = 256 attn_pooler_heads: int = 8 - latent_dim: int = 512 def _build_text_decoder_tower( @@ -69,12 +68,12 @@ def __init__( vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg self.text = _build_text_tower( - multimodal_cfg.latent_dim, + embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) - + vocab_size = ( text_cfg.vocab_size # for hf models if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None @@ -82,7 +81,7 @@ def __init__( ) self.visual = _build_vision_tower( - multimodal_cfg.latent_dim, + embed_dim=embed_dim, vision_cfg=vision_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, @@ -114,14 +113,14 @@ def _encode_text(self, text, normalize=True): text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb - + def encode_image(self, images, normalize=True): image_latent, _ = self._encode_image(images, normalize=normalize) return image_latent - + def encode_text(self, text, normalize=True): text_latent, _ = self._encode_text(text, normalize=normalize) - return text_latent + return text_latent def forward(self, image, text): text_latent, token_embs = self._encode_text(text) diff --git a/src/open_clip/model_configs/ViT-B-32_output_dict.json b/src/open_clip/model_configs/ViT-B-32_output_dict.json new file mode 100644 index 000000000..77da264d7 --- /dev/null +++ b/src/open_clip/model_configs/ViT-B-32_output_dict.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "output_dict": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/coca_ViT-B-32.json b/src/open_clip/model_configs/coca_ViT-B-32.json index bcf207e9d..7e7eb520a 100644 --- a/src/open_clip/model_configs/coca_ViT-B-32.json +++ b/src/open_clip/model_configs/coca_ViT-B-32.json @@ -24,7 +24,6 @@ "width": 512, "heads": 8, "layers": 12, - "latent_dim": 512, "attn_pooler_heads": 8 }, "custom_text": true diff --git a/src/open_clip/model_configs/coca_ViT-L-14.json b/src/open_clip/model_configs/coca_ViT-L-14.json index 40e3c7343..3d5ca4ca2 100644 --- a/src/open_clip/model_configs/coca_ViT-L-14.json +++ b/src/open_clip/model_configs/coca_ViT-L-14.json @@ -24,7 +24,6 @@ "width": 768, "heads": 12, "layers": 12, - "latent_dim": 768, "attn_pooler_heads": 12 }, "custom_text": true diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json index 30e04456b..cf8c6cecb 100644 --- a/src/open_clip/model_configs/coca_base.json +++ b/src/open_clip/model_configs/coca_base.json @@ -9,7 +9,6 @@ "dim_head": 64, "heads": 12, "n_queries": 256, - "latent_dim": 512, "attn_pooler_heads": 8 }, "vision_cfg": { diff --git a/src/open_clip/model_configs/coca_roberta-ViT-B-32.json b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json index a4e19ab5a..fb46354b9 100644 --- a/src/open_clip/model_configs/coca_roberta-ViT-B-32.json +++ b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -18,8 +18,7 @@ "context_length": 76, "width": 768, "heads": 8, - "layers": 12, - "latent_dim": 512 + "layers": 12 }, "custom_text": true } From f5e0c5a7dcb9c89e4f04ffa19c26dbb83df8e4e0 Mon Sep 17 00:00:00 2001 From: Giovanni Puccetti Date: Sat, 28 Jan 2023 13:05:54 +0100 Subject: [PATCH 27/30] remove extra cfg --- .../model_configs/ViT-B-32_output_dict.json | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 src/open_clip/model_configs/ViT-B-32_output_dict.json diff --git a/src/open_clip/model_configs/ViT-B-32_output_dict.json b/src/open_clip/model_configs/ViT-B-32_output_dict.json deleted file mode 100644 index 77da264d7..000000000 --- a/src/open_clip/model_configs/ViT-B-32_output_dict.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - }, - "output_dict": true -} \ No newline at end of file From 1ba2ab6828bb7a304615e76f2855c6e60cdafcd7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 28 Jan 2023 14:58:30 -0800 Subject: [PATCH 28/30] A bit more cleanup, keep context_length as context len, 'num_pos' to incl extra tokens. None type check for embed_cls instead of getattr --- src/open_clip/factory.py | 6 ++++-- src/open_clip/hf_model.py | 6 ++++-- src/open_clip/tokenizer.py | 15 ++++++++++---- src/open_clip/transformer.py | 38 ++++++++++++++++++------------------ 4 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 9c6240b43..d8476f29c 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -231,14 +231,16 @@ def create_loss(args): cache_labels=True, rank=args.rank, world_size=args.world_size, - use_horovod=args.horovod) + use_horovod=args.horovod, + ) return ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, cache_labels=True, rank=args.rank, world_size=args.world_size, - use_horovod=args.horovod) + use_horovod=args.horovod, + ) def create_model_and_transforms( diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index 47a982209..fbccc8127 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -79,9 +79,11 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType): return x.last_hidden_state[:, self.cls_token_position, :] + class HFTextEncoder(nn.Module): """HuggingFace model adapter""" output_tokens: torch.jit.Final[bool] + def __init__( self, model_name_or_path: str, @@ -90,8 +92,8 @@ def __init__( pooler_type: str = None, proj: str = None, pretrained: bool = True, - output_tokens: bool = False - ): + output_tokens: bool = False, + ): super().__init__() self.output_tokens = output_tokens self.output_dim = output_dim diff --git a/src/open_clip/tokenizer.py b/src/open_clip/tokenizer.py index 01e9f9d25..109d2aef7 100644 --- a/src/open_clip/tokenizer.py +++ b/src/open_clip/tokenizer.py @@ -186,16 +186,23 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo class HFTokenizer: - "HuggingFace tokenizer wrapper" - def __init__(self, tokenizer_name:str): + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] texts = [whitespace_clean(basic_clean(text)) for text in texts] - input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids return input_ids diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 977325888..65085642a 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -493,25 +493,28 @@ def __init__( norm_layer: Callable = LayerNorm, embed_cls: bool = False, pad_id: int = 0, - output_tokens: bool = False + output_tokens: bool = False, ): super().__init__() self.output_tokens = output_tokens - self.context_length = context_length + self.num_pos = self.context_length = context_length self.vocab_size = vocab_size self.width = width self.output_dim = output_dim + self.embed_cls = embed_cls + self.heads = heads + self.pad_id = pad_id self.text_projection = nn.Parameter(torch.empty(width, output_dim)) - if embed_cls: - self.embed_cls = embed_cls + + if self.embed_cls: self.cls_emb = nn.Parameter(torch.empty(width)) - self.heads = heads - self.pad_id = pad_id - self.context_length += 1 + self.num_pos += 1 + else: + self.cls_emb = None self.token_embedding = nn.Embedding(vocab_size, width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) self.transformer = Transformer( width=width, layers=layers, @@ -549,9 +552,9 @@ def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens + # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) + mask = torch.empty(self.num_pos, self.num_pos) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask @@ -574,7 +577,7 @@ def forward(self, text): x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] attn_mask = self.attn_mask - if hasattr(self, "embed_cls") and self.embed_cls: + if self.embed_cls is not None: seq_len += 1 x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) cls_mask = self.build_cls_mask(text, cast_dtype) @@ -587,22 +590,19 @@ def forward(self, text): # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) - - if hasattr(self, "embed_cls") and self.embed_cls: - pooled = x[:, -1] - tokens = x[:, :-1] + if self.embed_cls is not None: + pooled, tokens = x[:, -1], x[:, :-1] pooled = self.ln_final(pooled) else: x = self.ln_final(x) - pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] - tokens = x + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x if self.text_projection is not None: pooled = pooled @ self.text_projection if self.output_tokens: return pooled, tokens - + return pooled @@ -667,7 +667,7 @@ def init_parameters(self): nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens + # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) From f0847fa270c5ada3ca132cb1c7c846f5080afd56 Mon Sep 17 00:00:00 2001 From: Maciej Kilian Date: Sat, 28 Jan 2023 16:38:06 -0800 Subject: [PATCH 29/30] CoCa: add B/32 pretrained (#389) * add B/32 pretrained * fix * no capital * slash --- README.md | 3 ++- src/open_clip/pretrained.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9e1b8df75..49178c754 100644 --- a/README.md +++ b/README.md @@ -499,7 +499,8 @@ Future trained models will use nn.GELU. ('ViT-bigG-14', 'laion2b_s39b_b160k'), ('roberta-ViT-B-32', 'laion2b_s12b_b32k'), ('xlm-roberta-base-ViT-B-32', 'laion5b_s13b_b90k'), - ('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'),] + ('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'), + ('coca_ViT-B-32', 'laion2B-s13B-b90k'),] >>> model, train_transform, eval_transform = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') ``` diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index 73643f95d..7b2359e56 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -174,6 +174,10 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), ) +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + _PRETRAINED = { "RN50": _RN50, @@ -198,6 +202,7 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "convnext_base": _convnext_base, "convnext_base_w": _convnext_base_w, "convnext_base_w_320": _convnext_base_w_320, + "coca_ViT-B-32": _coca_VITB32, } From ba081d327f25805b8d354192733db1d68859b221 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Sun, 29 Jan 2023 01:41:23 +0100 Subject: [PATCH 30/30] remove coca from ci.yml --- .github/workflows/ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0bd877181..c7314f628 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,6 @@ on: push: branches: - main - - coca paths-ignore: - '**.md' - 'CITATION.cff' @@ -14,7 +13,6 @@ on: pull_request: branches: - main - - coca paths-ignore: - '**.md' - 'CITATION.cff'