diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b5698eca0..c7314f628 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,7 +3,7 @@ name: Continuous integration on: push: branches: - - main + - main paths-ignore: - '**.md' - 'CITATION.cff' @@ -12,7 +12,7 @@ on: - 'docs/**' pull_request: branches: - - main + - main paths-ignore: - '**.md' - 'CITATION.cff' @@ -81,7 +81,7 @@ jobs: --group ${{ matrix.job }} \ -m regression_test \ tests \ - | head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=])' \ + | head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=-False]|-True])' \ > models_gh_runner.txt if [ -n "${{ inputs.manual_revision_reference }}" ]; then REVISION_REFERENCE=${{ inputs.manual_revision_reference }} diff --git a/README.md b/README.md index c69d288d5..49178c754 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,20 @@ python -m training.main \ --resume /path/to/checkpoints/epoch_K.pt ``` +### Training CoCa: +Training [CoCa](https://arxiv.org/abs/2205.01917) models is enabled through specifying a CoCa config using the ```--model``` parameter of the training script. Currently available configs are "coca_base", "coca_ViT-B-32", and "coca_roberta-ViT-B-32" (which uses RoBERTa as the text encoder). CoCa configs are different from CLIP configs because they have an additional "multimodal_cfg" component which specifies parameters for the multimodal text decoder. Here's an example from the coca_ViT-B-32 config: +```json +"multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "latent_dim": 512, + "attn_pooler_heads": 8 +} +``` + ### Training with pre-trained language models as text encoder: If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen: @@ -485,7 +499,8 @@ Future trained models will use nn.GELU. ('ViT-bigG-14', 'laion2b_s39b_b160k'), ('roberta-ViT-B-32', 'laion2b_s12b_b32k'), ('xlm-roberta-base-ViT-B-32', 'laion5b_s13b_b90k'), - ('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'),] + ('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'), + ('coca_ViT-B-32', 'laion2B-s13B-b90k'),] >>> model, train_transform, eval_transform = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') ``` diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index 3cf72e928..c513c8aac 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -1,9 +1,10 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss from .factory import list_models, add_model_config, get_model_config, load_checkpoint -from .loss import ClipLoss +from .loss import ClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .coca_model import CoCa from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py new file mode 100644 index 000000000..bd6a0ba00 --- /dev/null +++ b/src/open_clip/coca_model.py @@ -0,0 +1,193 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower +from .generation_utils import top_a, top_k, top_p + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = pad_id + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize=True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize=True): + text = text[:, :-1] # make space for CLS token + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True): + text_latent, _ = self._encode_text(text, normalize=normalize) + return text_latent + + def forward(self, image, text): + text_latent, token_embs = self._encode_text(text) + image_latent, image_embs = self._encode_image(image) + + # TODO: add assertion to avoid bugs? + labels = text[:, -token_embs.shape[1]:] + + logits = self.text_decoder(image_embs, token_embs) + return { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() + } + + def generate( + self, + image, + text, + seq_len, + max_seq_len=77, + mask_prob=0.0, + temperature=1., + filter_logits_fn=top_k, + filter_thres=0.9, + min_p_pow=2.0, + min_p_ratio=0.02, + ): + + assert mask_prob < 1, "mask_prob must be smaller than 1." + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + _, t = text.shape + self.eval() + out = text + + for _ in range(seq_len): + x = out[:, -max_seq_len:] + + # TODO: adjust for dict output + logits = self(image, x)["logits"][:, -1] + + if filter_logits_fn in {top_k, top_p}: + filtered_logits = filter_logits_fn(logits, thres=filter_thres) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + elif filter_logits_fn is top_a: + filtered_logits = filter_logits_fn( + logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio + ) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + sample = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + out = out[:, t:] + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 6ac41f877..d8476f29c 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -12,6 +12,8 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform, AugmentationCfg @@ -177,7 +179,10 @@ def create_model( if custom_text: if is_hf_model: model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) @@ -216,6 +221,28 @@ def create_model( return model +def create_loss(args): + if "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + + def create_model_and_transforms( model_name: str, pretrained: Optional[str] = None, diff --git a/src/open_clip/generation_utils.py b/src/open_clip/generation_utils.py new file mode 100644 index 000000000..fade1f0ae --- /dev/null +++ b/src/open_clip/generation_utils.py @@ -0,0 +1,37 @@ +from math import ceil +import torch +from torch import nn +import torch.nn.functional as F + + +def exists(val): + return val is not None + + +def top_p(logits, thres=0.9): + # nucleus + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cum_probs > (1 - thres) + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + + sorted_logits[sorted_indices_to_remove] = float('-inf') + return sorted_logits.scatter(1, sorted_indices, sorted_logits) + + +def top_k(logits, thres=0.9): + k = ceil((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(1, ind, val) + return probs + + +def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02): + probs = F.softmax(logits, dim=-1) + limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio + logits[probs < limit] = float('-inf') + logits[probs >= limit] = 1 + return logits diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index b9f1103d6..fbccc8127 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -82,6 +82,7 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType): class HFTextEncoder(nn.Module): """HuggingFace model adapter""" + output_tokens: torch.jit.Final[bool] def __init__( self, @@ -90,9 +91,11 @@ def __init__( config: PretrainedConfig = None, pooler_type: str = None, proj: str = None, - pretrained: bool = True): + pretrained: bool = True, + output_tokens: bool = False, + ): super().__init__() - + self.output_tokens = output_tokens self.output_dim = output_dim # TODO: find better way to get this information @@ -113,11 +116,10 @@ def __init__( else: self.config = config self.transformer = AutoModel.from_config(config) - if pooler_type is None: # get default arch pooler - self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]() - else: - self.pooler = _POOLERS[pooler_type]() + pooler_type = (arch_dict[self.config.model_type]["pooler"]) + + self.pooler = _POOLERS[pooler_type]() d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"]) if (d_model == output_dim) and (proj is None): # do we always need a proj? @@ -132,12 +134,22 @@ def __init__( nn.Linear(hidden_size, output_dim, bias=False), ) - def forward(self, x: TensorType) -> TensorType: + def forward(self, x: TensorType): attn_mask = (x != self.config.pad_token_id).long() out = self.transformer(input_ids=x, attention_mask=attn_mask) pooled_out = self.pooler(out, attn_mask) - - return self.proj(pooled_out) + projected = self.proj(pooled_out) + + seq_len = out.last_hidden_state.shape[1] + tokens = ( + out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] + if type(self.pooler) == ClsPooler + else out.last_hidden_state + ) + + if self.output_tokens: + return projected, tokens + return projected def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): if not unlocked_layers: # full freezing diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index de31426df..5a112125d 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -5,6 +5,7 @@ try: import torch.distributed.nn from torch import distributed as dist + has_distributed = True except ImportError: has_distributed = False @@ -85,7 +86,7 @@ def __init__( self.prev_num_logits = 0 self.labels = {} - def forward(self, image_features, text_features, logit_scale): + def forward(self, image_features, text_features, logit_scale, output_dict=False): device = image_features.device if self.world_size > 1: all_image_features, all_text_features = gather_features( @@ -115,7 +116,50 @@ def forward(self, image_features, text_features, logit_scale): labels = self.labels[device] total_loss = ( - F.cross_entropy(logits_per_image, labels) + - F.cross_entropy(logits_per_text, labels) - ) / 2 - return total_loss + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss diff --git a/src/open_clip/model.py b/src/open_clip/model.py index a3a504f6e..a0f4b8501 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -31,6 +31,9 @@ class CLIPVisionCfg: ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling timm_model_name: str = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') @@ -38,6 +41,7 @@ class CLIPVisionCfg: timm_proj_bias: bool = False # enable bias final projection timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth + output_tokens: bool = False @dataclass @@ -53,6 +57,9 @@ class CLIPTextCfg: hf_model_pretrained: bool = True proj: str = 'mlp' pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False def get_cast_dtype(precision: str): @@ -88,7 +95,7 @@ def _build_vision_tower( drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, embed_dim=embed_dim, - image_size=vision_cfg.image_size + image_size=vision_cfg.image_size, ) act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models elif isinstance(vision_cfg.layers, (tuple, list)): @@ -98,7 +105,7 @@ def _build_vision_tower( output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, - width=vision_cfg.width + width=vision_cfg.width, ) else: vision_heads = vision_cfg.width // vision_cfg.head_width @@ -113,6 +120,10 @@ def _build_vision_tower( ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, @@ -136,7 +147,8 @@ def _build_text_tower( output_dim=embed_dim, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, - pretrained=text_cfg.hf_model_pretrained + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, ) else: act_layer = QuickGELU if quick_gelu else nn.GELU @@ -150,6 +162,9 @@ def _build_text_tower( layers=text_cfg.layers, ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, act_layer=act_layer, norm_layer=norm_layer, ) @@ -157,6 +172,8 @@ def _build_text_tower( class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + def __init__( self, embed_dim: int, @@ -164,8 +181,10 @@ def __init__( text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, ): super().__init__() + self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) @@ -209,10 +228,18 @@ def encode_text(self, text, normalize: bool = False): def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) + if self.output_dict: + return { + "image_features":image_features, + "text_features":text_features, + "logit_scale":self.logit_scale.exp() + } return image_features, text_features, self.logit_scale.exp() class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + def __init__( self, embed_dim: int, @@ -220,8 +247,10 @@ def __init__( text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, ): super().__init__() + self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) @@ -249,6 +278,12 @@ def encode_text(self, text, normalize: bool = False): def forward(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) + if self.output_dict: + return { + "image_features":image_features, + "text_features":text_features, + "logit_scale":self.logit_scale.exp() + } return image_features, text_features, self.logit_scale.exp() @@ -340,7 +375,7 @@ def build_model_from_openai_state_dict( vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, - layers=transformer_layers + layers=transformer_layers, ) model = CLIP( embed_dim, diff --git a/src/open_clip/model_configs/coca_ViT-B-32.json b/src/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 000000000..7e7eb520a --- /dev/null +++ b/src/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/coca_ViT-L-14.json b/src/open_clip/model_configs/coca_ViT-L-14.json new file mode 100644 index 000000000..3d5ca4ca2 --- /dev/null +++ b/src/open_clip/model_configs/coca_ViT-L-14.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "attn_pooler_heads": 12 + }, + "custom_text": true +} diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json new file mode 100644 index 000000000..cf8c6cecb --- /dev/null +++ b/src/open_clip/model_configs/coca_base.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 512, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "vocab_size": 64000, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256, + "attn_pooler_heads": 8 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768, + "embed_cls": true, + "output_tokens": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/coca_roberta-ViT-B-32.json b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 000000000..fb46354b9 --- /dev/null +++ b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index 73643f95d..7b2359e56 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -174,6 +174,10 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), ) +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + _PRETRAINED = { "RN50": _RN50, @@ -198,6 +202,7 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "convnext_base": _convnext_base, "convnext_base_w": _convnext_base_w, "convnext_base_w_320": _convnext_base_w_320, + "coca_ViT-B-32": _coca_VITB32, } diff --git a/src/open_clip/tokenizer.py b/src/open_clip/tokenizer.py index 01e9f9d25..109d2aef7 100644 --- a/src/open_clip/tokenizer.py +++ b/src/open_clip/tokenizer.py @@ -186,16 +186,23 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo class HFTokenizer: - "HuggingFace tokenizer wrapper" - def __init__(self, tokenizer_name:str): + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] texts = [whitespace_clean(basic_clean(text)) for text in texts] - input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids return input_ids diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index d73d7050c..65085642a 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -1,6 +1,6 @@ from collections import OrderedDict import math -from typing import Callable, Optional, Sequence +from typing import Callable, Optional, Sequence, Tuple import torch from torch import nn @@ -160,6 +160,32 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): return x +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + class ResidualAttentionBlock(nn.Module): def __init__( self, @@ -169,12 +195,15 @@ def __init__( ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) @@ -185,12 +214,32 @@ def __init__( ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() - def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x @@ -264,13 +313,16 @@ def get_cast_dtype(self) -> torch.dtype: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) return x class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + def __init__( self, image_size: int, @@ -281,12 +333,17 @@ def __init__( mlp_ratio: float, ls_init_value: float = None, global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, output_dim: int = 512, patch_dropout: float = 0., act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + output_tokens: bool = False ): super().__init__() + self.output_tokens = output_tokens self.image_size = to_2tuple(image_size) self.patch_size = to_2tuple(patch_size) self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) @@ -312,8 +369,14 @@ def __init__( ) self.global_average_pool = global_average_pool - self.ln_post = norm_layer(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) self.init_parameters() @@ -374,6 +437,12 @@ def init_parameters(self): def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] @@ -391,20 +460,25 @@ def forward(self, x: torch.Tensor): x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD - if self.global_average_pool: - x = x.mean(dim=1) + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) else: - x = x[:, 0] - - x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) if self.proj is not None: - x = x @ self.proj + pooled = pooled @ self.proj - return x + if self.output_tokens: + return pooled, tokens + + return pooled class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] def __init__( self, @@ -417,15 +491,30 @@ def __init__( output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, ): super().__init__() - self.context_length = context_length + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length self.vocab_size = vocab_size self.width = width self.output_dim = output_dim + self.embed_cls = embed_cls + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if self.embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None self.token_embedding = nn.Embedding(vocab_size, width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) self.transformer = Transformer( width=width, layers=layers, @@ -435,7 +524,6 @@ def __init__( norm_layer=norm_layer, ) self.ln_final = norm_layer(width) - self.text_projection = nn.Parameter(torch.empty(width, output_dim)) self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) @@ -444,6 +532,8 @@ def __init__( def init_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) + if hasattr(self, "embed_cls") and self.embed_cls: + nn.init.normal_(self.cls_emb, std=0.01) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 @@ -462,26 +552,150 @@ def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens + # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) + mask = torch.empty(self.num_pos, self.num_pos) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + def forward(self, text): cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.to(cast_dtype) + attn_mask = self.attn_mask + if self.embed_cls is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) + x = self.transformer(x, attn_mask=attn_mask) x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + if self.embed_cls is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/src/training/main.py b/src/training/main.py index e648099c0..aad36de37 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -27,7 +27,7 @@ except ImportError: hvd = None -from open_clip import create_model_and_transforms, trace_model, get_tokenizer +from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from training.data import get_data from training.distributed import is_master, init_distributed_device, broadcast_object from training.logger import setup_logging @@ -367,11 +367,13 @@ def main(args): evaluate(model, data, start_epoch, args, writer) return + loss = create_loss(args) + for epoch in range(start_epoch, args.epochs): if is_master(args): logging.info(f'Start epoch {epoch}') - train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=writer) completed_epoch = epoch + 1 if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): diff --git a/src/training/params.py b/src/training/params.py index 44db413a5..3d2352eb0 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -368,6 +368,18 @@ def parse_args(args): default=100, help="Log every n steps to tensorboard/console/wandb.", ) + parser.add_argument( + "--coca-caption-loss-weight", + type=float, + default=2.0, + help="Weight assigned to caption loss in CoCa." + ) + parser.add_argument( + "--coca-contrastive-loss-weight", + type=float, + default=1.0, + help="Weight assigned to contrastive loss when training CoCa." + ) parser.add_argument( "--remote-sync", type=str, diff --git a/src/training/train.py b/src/training/train.py index bf42f1475..88e4c5a34 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -7,13 +7,14 @@ import numpy as np import torch import torch.nn.functional as F +from torch.nn.parallel.distributed import DistributedDataParallel try: import wandb except ImportError: wandb = None -from open_clip import ClipLoss, get_cast_dtype +from open_clip import get_cast_dtype, CLIP, CustomTextCLIP from .distributed import is_master from .zero_shot import zero_shot_eval from .precision import get_autocast @@ -37,6 +38,15 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count +def is_clip(model): + return type(model) in [CLIP, CustomTextCLIP] + +def postprocess_clip_output(model_out): + return { + "image_features": model_out[0], + "text_features": model_out[1], + "logit_scale": model_out[2] + } def unwrap_model(model): if hasattr(model, 'module'): @@ -52,19 +62,12 @@ def backward(total_loss, scaler): total_loss.backward() -def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): +def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=None): device = torch.device(args.device) autocast = get_autocast(args.precision) cast_dtype = get_cast_dtype(args.precision) model.train() - loss = ClipLoss( - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod) data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data['train'].dataloader @@ -72,9 +75,9 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) if args.accum_freq > 1: - accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], [] + accum_images, accum_texts, accum_features = [], [], {} - loss_m = AverageMeter() + losses_m = {} batch_time_m = AverageMeter() data_time_m = AverageMeter() end = time.time() @@ -94,17 +97,33 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w if args.accum_freq == 1: with autocast(): - image_features, text_features, logit_scale = model(images, texts) - total_loss = loss(image_features, text_features, logit_scale) + model_out = model(images, texts) + # for clip if it does not output_dict + module = model.module if type(model) == DistributedDataParallel else model + if is_clip(module) and not module.output_dict: + model_out = postprocess_clip_output(model_out) + logit_scale = model_out["logit_scale"] + losses = loss(**model_out, output_dict=True) + + total_loss = sum(losses.values()) + losses["loss"] = total_loss backward(total_loss, scaler) else: # First, cache the features without any gradient tracking. with torch.no_grad(): with autocast(): - chunk_image_features, chunk_text_features, _ = model(images, texts) - accum_image_features.append(chunk_image_features) - accum_text_features.append(chunk_text_features) + model_out = model(images, texts) + # for clip if it does not output_dict + module = model.module if type(model) == DistributedDataParallel else model + if is_clip(module) and not module.output_dict: + model_out = postprocess_clip_output(model_out) + model_out.pop("logit_scale") + for key, val in model_out.items(): + if key in accum_features: + accum_features[key].append(val) + else: + accum_features[key] = [val] accum_images.append(images) accum_texts.append(texts) @@ -122,12 +141,18 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w images = accum_images[j] texts = accum_texts[j] with autocast(): - chunk_image_features, chunk_text_features, logit_scale = model(images, texts) - image_features = torch.cat( - accum_image_features[:j] + [chunk_image_features] + accum_image_features[j + 1:]) - text_features = torch.cat( - accum_text_features[:j] + [chunk_text_features] + accum_text_features[j + 1:]) - total_loss = loss(image_features, text_features, logit_scale) + model_out = model(images, texts, output_dict=True) + # for clip if it does not output_dict + module = model.module if type(model) == DistributedDataParallel else model + if is_clip(module) and not model.output_dict: + model_out = postprocess_clip_output(model_out) + logit_scale = model_out.pop("logit_scale") + for key, val in accum_features: + accumulated = accum_features[key] + accumulated = accumulated[:j] + [model_out[key]] + accumulated[j + 1:] + losses = loss(**accumulated, logit_scale=logit_scale, output_dict=True) + total_loss = sum(losses.values()) + losses["loss"] = total_loss backward(total_loss, scaler) if scaler is not None: @@ -151,7 +176,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w # reset gradient accum, if enabled if args.accum_freq > 1: - accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], [] + accum_images, accum_texts, accum_features = [], [], {} # Note: we clamp to 4.6052 = ln(100), as in the original paper. with torch.no_grad(): @@ -167,26 +192,36 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w percent_complete = 100.0 * batch_count / num_batches_per_epoch # NOTE loss is coarsely sampled, just master node and per log update - loss_m.update(total_loss.item(), batch_size) + for key, val in losses.items(): + if key not in losses_m: + losses_m[key] = AverageMeter() + losses_m[key].update(val.item(), batch_size) + logit_scale_scalar = logit_scale.item() + loss_log = " ".join( + [ + f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" + for loss_name, loss_m in losses_m.items() + ] + ) logging.info( f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " - f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " f"Data (t): {data_time_m.avg:.3f} " f"Batch (t): {batch_time_m.avg:.3f}, {args.accum_freq * args.batch_size * args.world_size / batch_time_m.val:#g}/s " f"LR: {optimizer.param_groups[0]['lr']:5f} " - f"Logit Scale: {logit_scale_scalar:.3f}" + f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log ) # Save train loss / etc. Using non avg meter values as loggers have their own smoothing log_data = { - "loss": loss_m.val, "data_time": data_time_m.val, "batch_time": batch_time_m.val, "samples_per_second": args.accum_freq * args.batch_size * args.world_size / batch_time_m.val, "scale": logit_scale_scalar, "lr": optimizer.param_groups[0]["lr"] - } + } + log_data.update({name:val.val for name,val in losses_m.items()}) + for name, val in log_data.items(): name = "train/" + name if tb_writer is not None: @@ -222,6 +257,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): # FIXME this does not scale past small eval datasets # all_image_features @ all_text_features will blow up memory and compute very quickly cumulative_loss = 0.0 + cumulative_gen_loss = 0.0 all_image_features, all_text_features = [], [] with torch.no_grad(): for i, batch in enumerate(dataloader): @@ -230,7 +266,14 @@ def evaluate(model, data, epoch, args, tb_writer=None): texts = texts.to(device=device, non_blocking=True) with autocast(): - image_features, text_features, logit_scale = model(images, texts) + model_out = model(images, texts, output_dict=True) + # for clip if it does not output_dict + module = model.module if type(model) == DistributedDataParallel else model + if is_clip(module) and not module.output_dict: + model_out = postprocess_clip_output(model_out) + image_features = model_out["image_features"] + text_features = model_out["text_features"] + logit_scale = model_out["logit_scale"] # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) @@ -246,22 +289,32 @@ def evaluate(model, data, epoch, args, tb_writer=None): F.cross_entropy(logits_per_text, labels) ) / 2 + gen_loss = maybe_compute_generative_loss(model_out) + cumulative_loss += total_loss * batch_size num_samples += batch_size if is_master(args) and (i % 100) == 0: logging.info( f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" - f"Loss: {cumulative_loss / num_samples:.6f}\t") + f"Clip Loss: {cumulative_loss / num_samples:.6f}\t") + + if gen_loss is not None: + cumulative_gen_loss += gen_loss * batch_size + logging.info( + f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t") - val_metrics = get_metrics( + val_metrics = get_clip_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), logit_scale=logit_scale.cpu(), ) loss = cumulative_loss / num_samples metrics.update( - {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} + {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} ) + if gen_loss is not None: + gen_loss = cumulative_gen_loss / num_samples + metrics.update({"val_generative_loss": gen_loss.item()}) if not metrics: return metrics @@ -288,7 +341,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): return metrics -def get_metrics(image_features, text_features, logit_scale): +def get_clip_metrics(image_features, text_features, logit_scale): metrics = {} logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_image.t().detach().cpu() @@ -306,3 +359,10 @@ def get_metrics(image_features, text_features, logit_scale): metrics[f"{name}_R@{k}"] = np.mean(preds < k) return metrics + + +def maybe_compute_generative_loss(model_out): + if "logits" in model_out and "labels" in model_out: + token_logits = model_out["logits"] + token_labels = model_out["labels"] + return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) diff --git a/tests/test_inference.py b/tests/test_inference.py index ecd46d072..dca8dc44c 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -7,6 +7,12 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '' +if hasattr(torch._C, '_jit_set_profiling_executor'): + # legacy executor is too slow to compile large models for unit tests + # no need for the fusion performance here + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(False) + models_to_test = set(open_clip.list_models()) # testing excemptions @@ -21,6 +27,9 @@ 'ViT-bigG-14', 'ViT-e-14', 'mt5-xl-ViT-H-14', + 'coca_base', + 'coca_ViT-B-32', + 'coca_roberta-ViT-B-32' }) if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ: @@ -29,17 +38,25 @@ models_to_test = set(f.read().splitlines()).intersection(models_to_test) print(f"Selected models from {external_model_list}: {models_to_test}") +# TODO: add "coca_ViT-B-32" onece https://github.com/pytorch/pytorch/issues/92073 gets fixed models_to_test = list(models_to_test) models_to_test.sort() +models_to_test = [(model_name, False) for model_name in models_to_test] + +models_to_jit_test = {"ViT-B-32"} +models_to_jit_test = list(models_to_jit_test) +models_to_jit_test = [(model_name, True) for model_name in models_to_jit_test] +models_to_test_fully = models_to_test + models_to_jit_test + @pytest.mark.regression_test -@pytest.mark.parametrize('model_name', models_to_test) +@pytest.mark.parametrize("model_name,jit", models_to_test_fully) def test_inference_with_data( model_name, + jit, pretrained = None, pretrained_hf = False, precision = 'fp32', - jit = False, force_quick_gelu = False, ): util_test.seed_all() @@ -78,5 +95,39 @@ def test_inference_with_data( gt_image = torch.load(gt_image_path) y_image = util_test.inference_image(model, preprocess_val, input_image) assert (y_image == gt_image).all(), f"image output differs @ {input_image_path}" + + if not jit: + model.eval() + model_out = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) + if type(model) not in [open_clip.CLIP, open_clip.CustomTextCLIP]: + assert type(model_out) == dict + else: + model.output_dict = True + model_out_dict = util_test.forward_model(model, model_name, preprocess_val, input_image, input_text) + assert (model_out_dict["image_features"] == model_out[0]).all() + assert (model_out_dict["text_features"] == model_out[1]).all() + assert (model_out_dict["logit_scale"] == model_out[2]).all() + model.output_dict = None + else: + model, _, preprocess_val = open_clip.create_model_and_transforms( + model_name, + pretrained = pretrained, + precision = precision, + jit = False, + force_quick_gelu = force_quick_gelu, + pretrained_hf = pretrained_hf + ) + + test_model = util_test.TestWrapper(model, model_name, output_dict=False) + test_model = torch.jit.script(test_model) + model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) + assert model_out["test_output"].shape[-1] == 2 + + test_model = util_test.TestWrapper(model, model_name, output_dict=True) + test_model = torch.jit.script(test_model) + model_out = util_test.forward_model(test_model, model_name, preprocess_val, input_image, input_text) + assert model_out["test_output"].shape[-1] == 2 + + diff --git a/tests/test_training_simple.py b/tests/test_training_simple.py index fe55b3328..5e1f649e7 100644 --- a/tests/test_training_simple.py +++ b/tests/test_training_simple.py @@ -24,6 +24,22 @@ def test_training(): '--model', 'RN50' ]) +@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") +def test_training_coca(): + main([ + '--save-frequency', '1', + '--zeroshot-frequency', '1', + '--dataset-type', "synthetic", + '--train-num-samples', '16', + '--warmup', '1', + '--batch-size', '4', + '--lr', '1e-3', + '--wd', '0.1', + '--epochs', '1', + '--workers', '2', + '--model', 'coca_ViT-B-32' + ]) + @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") def test_training_mt5(): main([ @@ -60,4 +76,4 @@ def test_training_unfreezing_vit(): '--model', 'ViT-B-32', '--lock-image', '--lock-image-unlocked-groups', '5' - ]) \ No newline at end of file + ]) diff --git a/tests/util_test.py b/tests/util_test.py index b2a2c9c3d..d09b09a2d 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -33,6 +33,24 @@ def inference_image(model, preprocess_val, batches): x = torch.stack([preprocess_val(img) for img in x]) y.append(model.encode_image(x)) return torch.stack(y) + +def forward_model(model, model_name, preprocess_val, image_batch, text_batch): + y = [] + tokenizer = open_clip.get_tokenizer(model_name) + with torch.no_grad(): + for x_im, x_txt in zip(image_batch, text_batch): + x_im = torch.stack([preprocess_val(im) for im in x_im]) + x_txt = tokenizer(x_txt) + y.append(model(x_im, x_txt)) + if type(y[0]) == dict: + out = {} + for key in y[0].keys(): + out[key] = torch.stack([batch_out[key] for batch_out in y]) + else: + out = [] + for i in range(len(y[0])): + out.append(torch.stack([batch_out[i] for batch_out in y])) + return out def random_image_batch(batch_size, size): h, w = size @@ -178,6 +196,25 @@ def create_test_data( def _sytem_assert(string): assert os.system(string) == 0 +class TestWrapper(torch.nn.Module): + output_dict: torch.jit.Final[bool] + def __init__(self, model, model_name, output_dict=True) -> None: + super().__init__() + self.model = model + self.output_dict = output_dict + if type(model) in [open_clip.CLIP, open_clip.CustomTextCLIP]: + self.model.output_dict = self.output_dict + config = open_clip.get_model_config(model_name) + self.head = torch.nn.Linear(config["embed_dim"], 2) + + def forward(self, image, text): + x = self.model(image, text) + if self.output_dict: + out = self.head(x["image_features"]) + else: + out = self.head(x[0]) + return {"test_output": out} + def main(args): global open_clip import importlib