diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index b76dd51b9..a4e11b4f7 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -1,9 +1,10 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss from .factory import list_models, add_model_config, get_model_config, load_checkpoint -from .loss import ClipLoss +from .loss import ClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .coca_model import CoCa from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py new file mode 100644 index 000000000..3b13cc782 --- /dev/null +++ b/src/open_clip/coca_model.py @@ -0,0 +1,192 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, + AttentionalPooler, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + dim_latents: int = None + + +def _build_input_dependent_text_tower( + embed_dim: int, + multimodal_cfg: MultimodalCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + multimodal: bool = True, +): + + if not multimodal: + return _build_text_tower( + embed_dim=embed_dim, + text_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + if isinstance(multimodal_cfg, dict): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) + + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return text, multimodal_cfg + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + n_queries: int = 256, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = _build_input_dependent_text_tower( + embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False + ) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer("attn_mask", text.attn_mask, persistent=False) + + self.heads = text_cfg["heads"] + self.cls_token = nn.Parameter(torch.randn(embed_dim)) + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower( + embed_dim, multimodal_cfg, quick_gelu, cast_dtype + ) + + self.img_attn_pool = AttentionalPooler( + multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1 + ) + + self.img_attn_pool_norm = norm_layer(embed_dim) + + self.dim_latents = ( + multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width + ) + self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) + + self.to_logits = nn.Sequential( + norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False) + ) + + # tie embedding weights and projection + self.to_logits[-1].weight = self.token_embedding.weight + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = 0 + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + self.multimodal_decoder.grad_checkpointing = enable + + def encode_image(self, images, normalize=True, return_tokens=False): + x = self.visual(images, output_tokens=True) + x = self.visual.ln_post(x) + + if self.visual.proj is not None: + x = x @ self.visual.proj + + x = self.img_attn_pool(x, x) + x = self.img_attn_pool_norm(x) + + image_latent = x[:, 0] + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + + return (image_latent, x[:, 1:]) if return_tokens else image_latent + + def _repeat(self, t, N): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def _build_cls_mask(self, text, cast_dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) + additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def encode_text(self, text, normalize=True, return_tokens=False): + text = text[:, :-1] # make space for CLS token + cast_dtype = self.transformer.get_cast_dtype() + + # cls_mask = (text!=self.pad_id).unsqueeze(1) + attn_mask = self.attn_mask[None, :].expand( + text.shape[0] * self.heads, *self.attn_mask.shape + ) + cls_mask = self._build_cls_mask(text, cast_dtype) + # attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0) + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask + cls_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), :] @ self.text_projection + + cls_emb = x[torch.arange(x.shape[0]), -1] + token_emb = x[torch.arange(x.shape[0]), :-1] + + cls_emb = self.ln_final(cls_emb) + text_latent = self.to_text_latent(cls_emb) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + + return (text_latent, token_emb) if return_tokens else text_latent + + def forward(self, image, text): + labels = text[:, 1:] + + text_latents, text_tokens = self.encode_text(text, return_tokens=True) + image_latents, image_tokens = self.encode_image(image, return_tokens=True) + + text_tokens = self.multimodal_decoder(image_tokens, text_tokens) + logits = self.to_logits(text_tokens) + + return image_latents, text_latents, logits, labels, self.logit_scale.exp() diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index bf07009bc..c2297cf50 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -12,6 +12,8 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model from .transform import image_transform @@ -72,8 +74,7 @@ def get_model_config(model_name): def get_tokenizer(model_name): config = get_model_config(model_name) - tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize - return tokenizer + return HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize def load_state_dict(checkpoint_path: str, map_location='cpu'): @@ -152,7 +153,10 @@ def create_model( if custom_text: if 'hf_model_name' in model_cfg.get('text_cfg', {}): model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) @@ -188,6 +192,25 @@ def create_model( return model +def create_loss(args): + if "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod) + def create_model_and_transforms( model_name: str, diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index de31426df..9f26dd4f8 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -119,3 +119,42 @@ def forward(self, image_features, text_features, logit_scale): F.cross_entropy(logits_per_text, labels) ) / 2 return total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=-100, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale): + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + return clip_loss + caption_loss diff --git a/src/open_clip/model_configs/coca_ViT-B-32.json b/src/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 000000000..3efdc24d8 --- /dev/null +++ b/src/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/coca_base.json b/src/open_clip/model_configs/coca_base.json new file mode 100644 index 000000000..525203d4d --- /dev/null +++ b/src/open_clip/model_configs/coca_base.json @@ -0,0 +1,26 @@ +{ + "embed_dim": 768, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768 + }, + "custom_text": "True" +} \ No newline at end of file diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 3f337f9b1..0d2f2e7ae 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -124,12 +124,26 @@ def __init__( self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): - L, N, C = x.shape - q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) - q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + def forward(self, + q_x, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + + L, N, C = q_x.shape + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + w_q, w_k, w_v = self.in_proj_weight.split(3, dim=0) + + q = F.linear(q_x, w_q, self.in_proj_bias) + k = F.linear(k_x, w_k, self.in_proj_bias) + v = F.linear(v_x, w_v, self.in_proj_bias) + + q = q.view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.view(L, N * self.num_heads, -1).transpose(0, 1) if self.logit_scale is not None: attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) @@ -159,6 +173,26 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): x = self.out_drop(x) return x +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + n_head: int = 8, + n_queries: int = 256, + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head) + + def forward(self, k: torch.Tensor, v: torch.Tensor): + k, v = k.permute(1, 0, 2), v.permute(1, 0 ,2) # NLD -> LND + N = k.shape[1] + out = self.attn(self._repeat(self.query, N), k, v, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N): + return query.unsqueeze(1).repeat(1, N, 1) + class ResidualAttentionBlock(nn.Module): def __init__( @@ -169,12 +203,15 @@ def __init__( ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) @@ -185,12 +222,33 @@ def __init__( ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() - def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward(self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + + k_x = self.ln_1_kv(k_x) if k_x is not None else None + v_x = self.ln_1_kv(v_x) if v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x @@ -208,10 +266,14 @@ def __init__( scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, + is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + self.attn = Attention( d_model, n_head, scaled_cosine=scale_cosine_attn, @@ -230,8 +292,20 @@ def __init__( ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None + ): + + k_x = self.ln_1_kv(k_x) if k_x is not None else None + v_x = self.ln_1_kv(v_x) if v_x is not None else None + + x = q_x + self.ls_1( + self.ln_attn(self.attn(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + ) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x @@ -264,12 +338,12 @@ def get_cast_dtype(self) -> torch.dtype: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) return x - class VisionTransformer(nn.Module): def __init__( self, @@ -374,7 +448,7 @@ def init_parameters(self): def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, output_tokens = False): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] @@ -391,6 +465,9 @@ def forward(self, x: torch.Tensor): x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD + if output_tokens: + return x + if self.global_average_pool: x = x.mean(dim=1) else: @@ -403,7 +480,6 @@ def forward(self, x: torch.Tensor): return x - class TextTransformer(nn.Module): def __init__( @@ -485,3 +561,86 @@ def forward(self, text): x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, is_cross_attention=True) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LND + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + + for r, ca in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(r, text_embs, None, None, self.attn_mask) + text_embs = checkpoint(ca, text_embs, image_embs, image_embs, None) + else: + text_embs = r(text_embs, attn_mask=self.attn_mask) + text_embs = ca(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + return x diff --git a/src/training/main.py b/src/training/main.py index 26d4cc529..a2f9f5b42 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -24,7 +24,7 @@ except ImportError: hvd = None -from open_clip import create_model_and_transforms, trace_model, get_tokenizer +from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from training.data import get_data from training.distributed import is_master, init_distributed_device, world_info_from_env from training.logger import setup_logging @@ -264,11 +264,13 @@ def main(args): evaluate(model, data, start_epoch, args, writer) return + loss = create_loss(args) + for epoch in range(start_epoch, args.epochs): if is_master(args): logging.info(f'Start epoch {epoch}') - train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=writer) completed_epoch = epoch + 1 if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): diff --git a/src/training/params.py b/src/training/params.py index abc07dd50..a35ce2732 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -332,6 +332,18 @@ def parse_args(args): default=100, help="Log every n steps to tensorboard/console/wandb.", ) + parser.add_argument( + "--coca-caption-loss-weight", + type=float, + default=2.0, + help="Weight assigned to caption loss in CoCa." + ) + parser.add_argument( + "--coca-contrastive-loss-weight", + type=float, + default=1.0, + help="Weight assigned to contrastive loss when training CoCa." + ) args = parser.parse_args(args) diff --git a/src/training/train.py b/src/training/train.py index 83f4f6fa7..3899590c3 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -13,7 +13,7 @@ except ImportError: wandb = None -from open_clip import ClipLoss, get_cast_dtype +from open_clip import get_cast_dtype from .distributed import is_master from .zero_shot import zero_shot_eval from .precision import get_autocast @@ -51,20 +51,12 @@ def backward(total_loss, scaler): else: total_loss.backward() - -def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): +def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=None): device = torch.device(args.device) autocast = get_autocast(args.precision) cast_dtype = get_cast_dtype(args.precision) model.train() - loss = ClipLoss( - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod) data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data['train'].dataloader @@ -94,8 +86,9 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w if args.accum_freq == 1: with autocast(): - image_features, text_features, logit_scale = model(images, texts) - total_loss = loss(image_features, text_features, logit_scale) + model_out = model(images, texts) + logit_scale = model_out[-1] + total_loss = loss(*model_out) backward(total_loss, scaler) else: @@ -222,6 +215,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): # FIXME this does not scale past small eval datasets # all_image_features @ all_text_features will blow up memory and compute very quickly cumulative_loss = 0.0 + cumulative_gen_loss = 0.0 all_image_features, all_text_features = [], [] with torch.no_grad(): for i, batch in enumerate(dataloader): @@ -230,7 +224,10 @@ def evaluate(model, data, epoch, args, tb_writer=None): texts = texts.to(device=device, non_blocking=True) with autocast(): - image_features, text_features, logit_scale = model(images, texts) + model_out = model(images, texts) + image_features = model_out[0] + text_features = model_out[1] + logit_scale = model_out[-1] # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) @@ -246,22 +243,33 @@ def evaluate(model, data, epoch, args, tb_writer=None): F.cross_entropy(logits_per_text, labels) ) / 2 + gen_loss = maybe_compute_generative_loss(model_out) + cumulative_loss += total_loss * batch_size num_samples += batch_size if is_master(args) and (i % 100) == 0: logging.info( f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" - f"Loss: {cumulative_loss / num_samples:.6f}\t") + f"Clip Loss: {cumulative_loss / num_samples:.6f}\t") + + if gen_loss is not None: + cumulative_gen_loss += gen_loss * batch_size + logging.info( + f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t") - val_metrics = get_metrics( + + val_metrics = get_clip_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), logit_scale=logit_scale.cpu(), ) loss = cumulative_loss / num_samples metrics.update( - {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} + {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} ) + if gen_loss is not None: + gen_loss = cumulative_gen_loss / num_samples + metrics.update({"val_generative_loss": gen_loss.item()}) if not metrics: return metrics @@ -288,7 +296,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): return metrics -def get_metrics(image_features, text_features, logit_scale): +def get_clip_metrics(image_features, text_features, logit_scale): metrics = {} logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_image.t().detach().cpu() @@ -306,3 +314,10 @@ def get_metrics(image_features, text_features, logit_scale): metrics[f"{name}_R@{k}"] = np.mean(preds < k) return metrics + + +def maybe_compute_generative_loss(model_out): + if len(model_out) > 3: + token_logits = model_out[2] + token_labels = model_out[3] + return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) diff --git a/tests/test_inference.py b/tests/test_inference.py index 0b3fdc1e0..a97b53aeb 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -19,6 +19,8 @@ 'ViT-bigG-14', 'ViT-e-14', 'mt5-xl-ViT-H-14', + 'coca_base', + 'coca_ViT-B-32' }) if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ: diff --git a/tests/test_training_simple.py b/tests/test_training_simple.py index fe55b3328..5e1f649e7 100644 --- a/tests/test_training_simple.py +++ b/tests/test_training_simple.py @@ -24,6 +24,22 @@ def test_training(): '--model', 'RN50' ]) +@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") +def test_training_coca(): + main([ + '--save-frequency', '1', + '--zeroshot-frequency', '1', + '--dataset-type', "synthetic", + '--train-num-samples', '16', + '--warmup', '1', + '--batch-size', '4', + '--lr', '1e-3', + '--wd', '0.1', + '--epochs', '1', + '--workers', '2', + '--model', 'coca_ViT-B-32' + ]) + @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") def test_training_mt5(): main([ @@ -60,4 +76,4 @@ def test_training_unfreezing_vit(): '--model', 'ViT-B-32', '--lock-image', '--lock-image-unlocked-groups', '5' - ]) \ No newline at end of file + ])