Skip to content

Commit

Permalink
Add coca trained (#307) (#308)
Browse files Browse the repository at this point in the history
* Add coca trained (#307)

* initial setup

* add coca loss

* remove loss from the model

* fix loss

* add underscores

* name changes

* add cross attention to Residual and CustomResidual

* fix if

* ädd transformer 'decoder'

* minor fix

* looks better

* initlize coca model structure

* clean

* typo and format

* checkpoint signature

* adjust multimodal decoder and add CoCaTransformer

* keep older logic

* remove chunk

* typo

* fix

* make chunk dim explicit

* adjust cfg names

* add attentionalpooling

* add attentional pooling to coca

* small change

* add cocatransformer variants and AttentionPooling

* remoive older attention pooler

* adapt embed text to coca text transformer

* rm coca layers

* rename and remove useless CoCa models

* make attentionpooler pooler only

* refactor for one transformer only

* coca forward works

* separatae context and n_queries

* add inital coca_base config

* remove config

* small loss change

* init training file

* make variable order right

* remove print

* uniform names

* renaming

* add coca funcs to init

* add coca config and exclude from testing

* add and comment simple test (no trained model)

* add L2 norm

* make L2 same as in clip

* remove unused temperature

* type

* clean

* fix config

* make rename and move cfg

* rename

* temptative add coca to factory

* fix config

* update config

* embed contrastive cls token in model

* remove unused arg

* import create_loss

* make factory accept coca

* make caption loss distributed

* make loss customizable

* pass loss trhough training_epoch

* add coca specific params to params

* removed decoder unused parameters

* remove unused attributes

* adjust coca_config

* fix config and remove unused parameters

* remove comment

* remove more comments

* rename attention pooler

* rename TransformerDecoder

* make AttentionalPooler clearer

* add local loss logic to cocaloss

* only create loss if train in data

* remove wrong file

* fix attentional pooler call

* not ready for testing

* really not ready for testing

* eof lien

* uniform names

* add possible generative loss to evaluate

* change _build function names

* remove wrong import

* remove local_loss from captioning loss

* indexing error

* finish renaming

* adjust configs

* add training test for coca

* simplify captioning loss

* remove hf

* fix evaluate and loss

* remove print

* move projection

* add coca vit 32 config

* test on new config

* adjust coca_base config

* remove coca from test_inference

* maybe fix regression test

* make logits and labels contiguous

* simpler logic

* make contiguous after transpose

* last test

* try fix loss

* CoCa PR: loss fix + rename file

* wait for feedback on this

* cleanup

* CoCa PR: add set_grad_checkpointing + fix checkpoint API

* CoCa PR: fix eval (which uses encode_x instead of forward)

* move making space for CLS token into encode_text

* rever zs changes + fix

Co-authored-by: gpucce <[email protected]>
Co-authored-by: gpucce <[email protected]>
Co-authored-by: iejmac <[email protected]>

* Add coca to CI

* Add coca to CI pr

* simplify encode_iamge (#313)

Co-authored-by: Romain Beaumont <[email protected]>

* Add cls mask (#312)

* buil_cls_mask

* add cls_mask to encode_text

* add model properties

Co-authored-by: Romain Beaumont <[email protected]>
Co-authored-by: gpucce <[email protected]>

* Ignore pad tokens in captioning loss (#316)

* add ignore_index

* just need to pick right index

Co-authored-by: gpucce <[email protected]>

* add `generate` to coca model (#314)

* add initial generative support

* make generation context_length independend

* remove kwargs

* last positional embeddings for CLS

* typo

* fix mask len

* add comment

* remove unused args

* simpler logic for input shorter than context length

Co-authored-by: gpucce <[email protected]>

* use `TextEncoder` in coca `encode_image` (#321)

* use self.text in encode image

* unused var

* rever aAtention and CustoResidualAttentionBlock

* remove whiteline

* add dict output

* bintegrate self.text attributes

* HF compatibility

* better config and minor fixes

* clean

* remove eembed_cls option from HF

* use cls_token_position

* fix cls masking

* resize labels

* text -> self.text

* split loss logging

* add total loss

* minor logs formatting

* fix generate

* simpler logic

* disentangle proj for HF too

* adjust config

* only norm cls

* move attn_pool to VisionTransformer

* adjust coca_base config

* fix grad checkpointing in MultimodalTransformer

Co-authored-by: gpucce <[email protected]>
Co-authored-by: iejMac <[email protected]>

* Get some basic PEP changes out of the way

* Add tests bis (#355)

* make jit compilable

* redundant annotation

* less tests

* less annotations

* even less annotations

* fix name check in ci

* some annotations back

* make it simpler

* make hf simpler too

* better jit support with tests

* remove extra line

* add customtextclip

* more jit tests

* missing assert

* add eval

* typo

* rever forward changes

* clean coca model

* more cleaning

* last cleaning

* train.py: fix is_clip when doing distributed (#364)

* add README (#365)

* add README

* multimodal_cfg info

* multimodal

* remove output_dict argument (#368)

* remove output_dict argument

* cleaner

* do same thing for _encode_image (#366)

* do same thing for _encode_image

* encoder

* try this

* adjust inference tests

* fix syntax

* True not None

* dumb

* CoCa/forward: remove unused output_dict param

* Revert "do same thing for _encode_image (#366)"

This reverts commit de343fb.

* refactor

* white space

* remove extra layer norm

* move to_logits into decoder

* leave for later

* better torchscript

* annotate hf too

* Add CoCa-ViT-L/14 config (#379)

* Remove dead LN code, refactor attn_pool conditional for more clarity, minor formatting tweaks

* latent_dim to embed_dim

* remove extra cfg

* A bit more cleanup, keep context_length as context len, 'num_pos' to incl extra tokens. None type check for embed_cls instead of getattr

* CoCa: add B/32 pretrained (#389)

* add B/32 pretrained

* fix

* no capital

* slash

* remove coca from ci.yml

---------

Co-authored-by: gpucce <[email protected]>
Co-authored-by: gpucce <[email protected]>
Co-authored-by: iejmac <[email protected]>
Co-authored-by: iejMac <[email protected]>
Co-authored-by: Ross Wightman <[email protected]>
  • Loading branch information
6 people authored Jan 29, 2023
1 parent aefb471 commit 76c8f85
Show file tree
Hide file tree
Showing 22 changed files with 976 additions and 93 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Continuous integration
on:
push:
branches:
- main
- main
paths-ignore:
- '**.md'
- 'CITATION.cff'
Expand All @@ -12,7 +12,7 @@ on:
- 'docs/**'
pull_request:
branches:
- main
- main
paths-ignore:
- '**.md'
- 'CITATION.cff'
Expand Down Expand Up @@ -81,7 +81,7 @@ jobs:
--group ${{ matrix.job }} \
-m regression_test \
tests \
| head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=])' \
| head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=-False]|-True])' \
> models_gh_runner.txt
if [ -n "${{ inputs.manual_revision_reference }}" ]; then
REVISION_REFERENCE=${{ inputs.manual_revision_reference }}
Expand Down
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,20 @@ python -m training.main \
--resume /path/to/checkpoints/epoch_K.pt
```
### Training CoCa:
Training [CoCa](https://arxiv.org/abs/2205.01917) models is enabled through specifying a CoCa config using the ```--model``` parameter of the training script. Currently available configs are "coca_base", "coca_ViT-B-32", and "coca_roberta-ViT-B-32" (which uses RoBERTa as the text encoder). CoCa configs are different from CLIP configs because they have an additional "multimodal_cfg" component which specifies parameters for the multimodal text decoder. Here's an example from the coca_ViT-B-32 config:
```json
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"latent_dim": 512,
"attn_pooler_heads": 8
}
```

### Training with pre-trained language models as text encoder:

If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen:
Expand Down Expand Up @@ -485,7 +499,8 @@ Future trained models will use nn.GELU.
('ViT-bigG-14', 'laion2b_s39b_b160k'),
('roberta-ViT-B-32', 'laion2b_s12b_b32k'),
('xlm-roberta-base-ViT-B-32', 'laion5b_s13b_b90k'),
('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'),]
('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'),
('coca_ViT-B-32', 'laion2B-s13B-b90k'),]
>>> model, train_transform, eval_transform = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
```
Expand Down
5 changes: 3 additions & 2 deletions src/open_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .loss import ClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .coca_model import CoCa
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
Expand Down
193 changes: 193 additions & 0 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import Optional

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass

from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
from .generation_utils import top_a, top_k, top_p


@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8


def _build_text_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)

decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)

return decoder


class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg

self.text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)

vocab_size = (
text_cfg.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
)

self.visual = _build_vision_tower(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)

self.text_decoder = _build_text_decoder_tower(
vocab_size,
multimodal_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = pad_id

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)

def _encode_image(self, images, normalize=True):
image_latent, tokens_embs = self.visual(images)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs

def _encode_text(self, text, normalize=True):
text = text[:, :-1] # make space for CLS token
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb

def encode_image(self, images, normalize=True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent

def encode_text(self, text, normalize=True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent

def forward(self, image, text):
text_latent, token_embs = self._encode_text(text)
image_latent, image_embs = self._encode_image(image)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]

logits = self.text_decoder(image_embs, token_embs)
return {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"labels": labels,
"logit_scale": self.logit_scale.exp()
}

def generate(
self,
image,
text,
seq_len,
max_seq_len=77,
mask_prob=0.0,
temperature=1.,
filter_logits_fn=top_k,
filter_thres=0.9,
min_p_pow=2.0,
min_p_ratio=0.02,
):

assert mask_prob < 1, "mask_prob must be smaller than 1."

was_training = self.training
num_dims = len(text.shape)

if num_dims == 1:
text = text[None, :]

_, t = text.shape
self.eval()
out = text

for _ in range(seq_len):
x = out[:, -max_seq_len:]

# TODO: adjust for dict output
logits = self(image, x)["logits"][:, -1]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

elif filter_logits_fn is top_a:
filtered_logits = filter_logits_fn(
logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio
)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)

out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

self.train(was_training)
return out
29 changes: 28 additions & 1 deletion src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, CoCaLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
Expand Down Expand Up @@ -177,7 +179,10 @@ def create_model(
if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

Expand Down Expand Up @@ -216,6 +221,28 @@ def create_model(
return model


def create_loss(args):
if "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)


def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
Expand Down
37 changes: 37 additions & 0 deletions src/open_clip/generation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F


def exists(val):
return val is not None


def top_p(logits, thres=0.9):
# nucleus
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0

sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)


def top_k(logits, thres=0.9):
k = ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs


def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = float('-inf')
logits[probs >= limit] = 1
return logits
Loading

0 comments on commit 76c8f85

Please sign in to comment.