Skip to content

Commit

Permalink
Add coca trained (#307)
Browse files Browse the repository at this point in the history
* initial setup

* add coca loss

* remove loss from the model

* fix loss

* add underscores

* name changes

* add cross attention to Residual and CustomResidual

* fix if

* ädd transformer 'decoder'

* minor fix

* looks better

* initlize coca model structure

* clean

* typo and format

* checkpoint signature

* adjust multimodal decoder and add CoCaTransformer

* keep older logic

* remove chunk

* typo

* fix

* make chunk dim explicit

* adjust cfg names

* add attentionalpooling

* add attentional pooling to coca

* small change

* add cocatransformer variants and AttentionPooling

* remoive older attention pooler

* adapt embed text to coca text transformer

* rm coca layers

* rename and remove useless CoCa models

* make attentionpooler pooler only

* refactor for one transformer only

* coca forward works

* separatae context and n_queries

* add inital coca_base config

* remove config

* small loss change

* init training file

* make variable order right

* remove print

* uniform names

* renaming

* add coca funcs to init

* add coca config and exclude from testing

* add and comment simple test (no trained model)

* add L2 norm

* make L2 same as in clip

* remove unused temperature

* type

* clean

* fix config

* make rename and move cfg

* rename

* temptative add coca to factory

* fix config

* update config

* embed contrastive cls token in model

* remove unused arg

* import create_loss

* make factory accept coca

* make caption loss distributed

* make loss customizable

* pass loss trhough training_epoch

* add coca specific params to params

* removed decoder unused parameters

* remove unused attributes

* adjust coca_config

* fix config and remove unused parameters

* remove comment

* remove more comments

* rename attention pooler

* rename TransformerDecoder

* make AttentionalPooler clearer

* add local loss logic to cocaloss

* only create loss if train in data

* remove wrong file

* fix attentional pooler call

* not ready for testing

* really not ready for testing

* eof lien

* uniform names

* add possible generative loss to evaluate

* change _build function names

* remove wrong import

* remove local_loss from captioning loss

* indexing error

* finish renaming

* adjust configs

* add training test for coca

* simplify captioning loss

* remove hf

* fix evaluate and loss

* remove print

* move projection

* add coca vit 32 config

* test on new config

* adjust coca_base config

* remove coca from test_inference

* maybe fix regression test

* make logits and labels contiguous

* simpler logic

* make contiguous after transpose

* last test

* try fix loss

* CoCa PR: loss fix + rename file

* wait for feedback on this

* cleanup

* CoCa PR: add set_grad_checkpointing + fix checkpoint API

* CoCa PR: fix eval (which uses encode_x instead of forward)

* move making space for CLS token into encode_text

* rever zs changes + fix

Co-authored-by: gpucce <[email protected]>
Co-authored-by: gpucce <[email protected]>
Co-authored-by: iejmac <[email protected]>
  • Loading branch information
4 people authored Dec 20, 2022
1 parent fa141ee commit 1b86601
Show file tree
Hide file tree
Showing 12 changed files with 557 additions and 41 deletions.
5 changes: 3 additions & 2 deletions src/open_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .loss import ClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .coca_model import CoCa
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
Expand Down
200 changes: 200 additions & 0 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Optional

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass

from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
AttentionalPooler,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower


@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
dim_latents: int = None


def _build_input_dependent_text_tower(
embed_dim: int,
multimodal_cfg: MultimodalCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
multimodal:bool = True
):

if not multimodal:
return _build_text_tower(
embed_dim=embed_dim,
text_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype
)

if isinstance(multimodal_cfg, dict):
multimodal_cfg = MultimodalCfg(**multimodal_cfg)

act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)

text = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)

return text, multimodal_cfg


class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
n_queries: int = 256,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
super().__init__()


norm_layer = (
LayerNormFp32
if cast_dtype in (torch.float16, torch.bfloat16)
else LayerNorm
)

text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)

self.cls_token = nn.Parameter(torch.randn(embed_dim))
self.visual = _build_vision_tower(
embed_dim, vision_cfg, quick_gelu, cast_dtype
)

self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower(
embed_dim, multimodal_cfg, quick_gelu, cast_dtype
)

self.img_attn_pool = AttentionalPooler(
multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1
)

self.img_attn_pool_norm = norm_layer(embed_dim)

self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width
self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False)

self.to_logits = nn.Sequential(
norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False)
)

# tie embedding weights and projection
self.to_logits[-1].weight = self.token_embedding.weight

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
self.multimodal_decoder.grad_checkpointing = enable

def encode_image(self, images, normalize=True, return_tokens=False):
x = self.visual.conv1(images) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.visual.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.visual.positional_embedding.to(x.dtype)
x = self.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.visual.ln_post(x)

if self.visual.proj is not None:
x = x @ self.visual.proj

x = self.img_attn_pool(x, x)
x = self.img_attn_pool_norm(x)

image_latent = x[:, 0]
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent

return (image_latent, x[:, 1:]) if return_tokens else image_latent

def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
cast_dtype = self.transformer.get_cast_dtype()

# cls_mask = (text!=self.pad_id).unsqueeze(1)
# attn_mask = F.pad(cls_mask, (0, 1, text.shape[1], 0), value=True)
# attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0)

x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1)
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), :] @ self.text_projection

cls_emb = x[torch.arange(x.shape[0]), -1]
token_emb = x[torch.arange(x.shape[0]), :-1]

cls_emb = self.ln_final(cls_emb)
text_latent = self.to_text_latent(cls_emb)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent

return (text_latent, token_emb) if return_tokens else text_latent

def forward(self, image, text):
labels = text[:, 1:]

text_latents, text_tokens = self.encode_text(text, return_tokens=True)
image_latents, image_tokens = self.encode_image(image, return_tokens=True)

text_tokens = self.multimodal_decoder(image_tokens, text_tokens)
logits = self.to_logits(text_tokens)

return image_latents, text_latents, logits, labels, self.logit_scale.exp()
29 changes: 26 additions & 3 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, CoCaLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
from .transform import image_transform
Expand Down Expand Up @@ -72,8 +74,7 @@ def get_model_config(model_name):

def get_tokenizer(model_name):
config = get_model_config(model_name)
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
return tokenizer
return HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize


def load_state_dict(checkpoint_path: str, map_location='cpu'):
Expand Down Expand Up @@ -152,7 +153,10 @@ def create_model(
if custom_text:
if 'hf_model_name' in model_cfg.get('text_cfg', {}):
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

Expand Down Expand Up @@ -188,6 +192,25 @@ def create_model(

return model

def create_loss(args):
if "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod)


def create_model_and_transforms(
model_name: str,
Expand Down
39 changes: 39 additions & 0 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,42 @@ def forward(self, image_features, text_features, logit_scale):
F.cross_entropy(logits_per_text, labels)
) / 2
return total_loss


class CoCaLoss(ClipLoss):
def __init__(
self,
caption_loss_weight,
clip_loss_weight,
pad_id=-100,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__(
local_loss=local_loss,
gather_with_grad=gather_with_grad,
cache_labels=cache_labels,
rank=rank,
world_size=world_size,
use_horovod=use_horovod
)

self.clip_loss_weight = clip_loss_weight
self.caption_loss_weight = caption_loss_weight
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)

def forward(self, image_features, text_features, logits, labels, logit_scale):
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss

caption_loss = self.caption_loss(
logits.permute(0, 2, 1),
labels,
)
caption_loss = caption_loss * self.caption_loss_weight

return clip_loss + caption_loss
24 changes: 24 additions & 0 deletions src/open_clip/model_configs/coca_ViT-B-32.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"custom_text": true
}
26 changes: 26 additions & 0 deletions src/open_clip/model_configs/coca_base.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"embed_dim": 768,
"multimodal_cfg": {
"width": 768,
"context_length": 76,
"mlp_ratio": 4,
"layers": 12,
"dim_head": 64,
"heads": 12,
"n_queries": 256
},
"vision_cfg": {
"image_size": 288,
"layers": 12,
"width": 768,
"patch_size": 18
},
"text_cfg": {
"context_length": 77,
"vocab_size": 64000,
"layers": 12,
"heads": 12,
"width": 768
},
"custom_text": "True"
}
Loading

0 comments on commit 1b86601

Please sign in to comment.