Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add coca trained (#307) #308

Merged
merged 34 commits into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1b86601
Add coca trained (#307)
rom1504 Dec 20, 2022
29fa332
Add coca to CI
rom1504 Dec 21, 2022
911c737
Add coca to CI pr
rom1504 Dec 21, 2022
b4881bc
simplify encode_iamge (#313)
gpucce Dec 21, 2022
50bc599
Add cls mask (#312)
gpucce Dec 21, 2022
279e088
Ignore pad tokens in captioning loss (#316)
gpucce Dec 22, 2022
dee1ea5
add `generate` to coca model (#314)
gpucce Dec 22, 2022
30a73d4
use `TextEncoder` in coca `encode_image` (#321)
gpucce Jan 6, 2023
f616050
Merge branch 'main' into coca
rom1504 Jan 6, 2023
061482b
Get some basic PEP changes out of the way
rwightman Jan 9, 2023
d0bd09e
Add tests bis (#355)
gpucce Jan 21, 2023
ef80b7b
Merge branch 'main' into coca
rom1504 Jan 21, 2023
2ab47b7
train.py: fix is_clip when doing distributed (#364)
iejMac Jan 21, 2023
c0e5950
add README (#365)
iejMac Jan 22, 2023
9ab881e
Merge branch 'main' into coca
rom1504 Jan 22, 2023
3f5b0fb
remove output_dict argument (#368)
gpucce Jan 22, 2023
de343fb
do same thing for _encode_image (#366)
iejMac Jan 22, 2023
88aa6ce
CoCa/forward: remove unused output_dict param
iejMac Jan 23, 2023
3b66f37
Revert "do same thing for _encode_image (#366)"
gpucce Jan 24, 2023
cdb91dd
refactor
gpucce Jan 24, 2023
58eb5bd
white space
gpucce Jan 24, 2023
cbd66ed
remove extra layer norm
gpucce Jan 24, 2023
bf6ef3e
move to_logits into decoder
gpucce Jan 24, 2023
03dfeab
leave for later
gpucce Jan 24, 2023
15d6223
better torchscript
gpucce Jan 23, 2023
9beb0d4
annotate hf too
gpucce Jan 23, 2023
fde2aee
Add CoCa-ViT-L/14 config (#379)
iejMac Jan 27, 2023
24e454d
Merge branch 'main' into coca
rom1504 Jan 27, 2023
f7c566b
Remove dead LN code, refactor attn_pool conditional for more clarity,…
rwightman Jan 28, 2023
9533575
latent_dim to embed_dim
gpucce Jan 28, 2023
f5e0c5a
remove extra cfg
gpucce Jan 28, 2023
1ba2ab6
A bit more cleanup, keep context_length as context len, 'num_pos' to …
rwightman Jan 28, 2023
f0847fa
CoCa: add B/32 pretrained (#389)
iejMac Jan 29, 2023
ba081d3
remove coca from ci.yml
rom1504 Jan 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ name: Continuous integration
on:
push:
branches:
- main
- main
- coca
paths-ignore:
- '**.md'
- 'CITATION.cff'
Expand All @@ -12,7 +13,8 @@ on:
- 'docs/**'
pull_request:
branches:
- main
- main
- coca
paths-ignore:
- '**.md'
- 'CITATION.cff'
Expand Down
5 changes: 3 additions & 2 deletions src/open_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .loss import ClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .coca_model import CoCa
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
Expand Down
206 changes: 206 additions & 0 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from typing import Optional

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass

from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
from .generation_utils import top_a, top_k, top_p

@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8
latent_dim: int = 512

class CoCaEncoderDecoder(nn.Module):
def __init__(self, encoder, decoder) -> None:
Copy link
Collaborator

@rwightman rwightman Jan 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there are reason for this module to exist? vs a text_encoder and tex_decoder in CoCa model? having modules organized like this (two modules in a class with no forward that exercises them) is atypical and breaks some assumptions for profilers, etc

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, checkpoints can easily be remapped, so existing checkpoints is not a reason not to make changes right now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More thoughts, to keep it a bit closer to the Clip model on the encoder side

self.text = _build_text_encoder()
self.visual = _build_vision_tower()
self.decoder = _build_multimodal_decoder() # or self.multimodal_decoder or .text_decoder?
self.decoder_norm = 
self.decoder_logits =

super().__init__()
self.encoder = encoder
self.decoder = decoder

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.encoder.set_grad_checkpointing(enable)
self.decoder.set_grad_checkpointing(enable)

def _build_encoder_decoder_tower(
embed_dim,
multimodal_cfg,
text_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the Encoder/Decoder Module above, with those split, this can be split to have text(_encoder) + text_decoder


multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg

encoder = _build_text_tower(
multimodal_cfg.latent_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype
)

vocab_size = (
encoder.config.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else multimodal_cfg.vocab_size
)

act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)

decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)

return CoCaEncoderDecoder(encoder, decoder), multimodal_cfg, vocab_size

class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg

norm_layer = (
LayerNormFp32
if cast_dtype in (torch.float16, torch.bfloat16)
else LayerNorm
)

self.text, multimodal_cfg, vocab_size = _build_encoder_decoder_tower(
embed_dim, multimodal_cfg, text_cfg, quick_gelu, cast_dtype
)
self.visual = _build_vision_tower(
multimodal_cfg.latent_dim, vision_cfg, quick_gelu, cast_dtype
)

self.to_logits = nn.Sequential(
norm_layer(multimodal_cfg.width), nn.Linear(multimodal_cfg.width, vocab_size, bias=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not a fan of these being in a nn.Sequential with no names, should either be separate attributes like .decoder_norm, .decoder_logits, or add names like 'norm', 'fc' to the nn.Sequential, without that you'd break checkpoint compat if you say wanted to experiment and add a dropout btw norm and linear

)

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = pad_id

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)

def encode_image(self, images, normalize=True, return_tokens=False):
image_latent, tokens_embs = self.visual(images, output_tokens=True)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return (image_latent, tokens_embs) if return_tokens else image_latent

def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
text_latent, token_emb = self.text.encoder(text, output_tokens=True)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return (text_latent, token_emb) if return_tokens else text_latent

def forward(self, image, text, output_dict=False):

text_latent, token_embs = self.encode_text(text, return_tokens=True)
image_latent, image_embs = self.encode_image(image, return_tokens=True)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]

token_embs = self.text.decoder(image_embs, token_embs)
logits = self.to_logits(token_embs)
if output_dict:
return {
"image_features":image_latent,
"text_features":text_latent,
"logits":logits,
"labels":labels,
"logit_scale":self.logit_scale.exp()
}

return image_latent, text_latent, logits, labels, self.logit_scale.exp()

def generate(
self,
image,
text,
seq_len,
max_seq_len=77,
mask_prob = 0.0,
temperature = 1.,
filter_logits_fn = top_k,
filter_thres = 0.9,
min_p_pow = 2.0,
min_p_ratio = 0.02,
):

assert mask_prob < 1, "mask_prob must be smaller than 1."

was_training = self.training
num_dims = len(text.shape)

if num_dims == 1:
text = text[None, :]

_, t = text.shape
self.eval()
out = text

for _ in range(seq_len):
x = out[:, -max_seq_len:]

# TODO: adjust for dict output
logits = self(image, x)[2][:, -1]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

elif filter_logits_fn is top_a:
filtered_logits = filter_logits_fn(
logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio
)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)


out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

self.train(was_training)
return out
29 changes: 26 additions & 3 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, CoCaLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
from .transform import image_transform
Expand Down Expand Up @@ -72,8 +74,7 @@ def get_model_config(model_name):

def get_tokenizer(model_name):
config = get_model_config(model_name)
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
return tokenizer
return HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize


def load_state_dict(checkpoint_path: str, map_location='cpu'):
Expand Down Expand Up @@ -152,7 +153,10 @@ def create_model(
if custom_text:
if 'hf_model_name' in model_cfg.get('text_cfg', {}):
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

Expand Down Expand Up @@ -188,6 +192,25 @@ def create_model(

return model

def create_loss(args):
if "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod)


def create_model_and_transforms(
model_name: str,
Expand Down
38 changes: 38 additions & 0 deletions src/open_clip/generation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F

def exists(val):
return val is not None

# nucleus

def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0

sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# topk

def top_k(logits, thres = 0.9):
k = ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs

# top_a

def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = float('-inf')
logits[probs >= limit] = 1
return logits
27 changes: 18 additions & 9 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType):

return x.last_hidden_state[:, self.cls_token_position, :]


class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""

Expand All @@ -90,7 +89,8 @@ def __init__(
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True):
pretrained: bool = True
):
super().__init__()

self.output_dim = output_dim
Expand All @@ -113,11 +113,10 @@ def __init__(
else:
self.config = config
self.transformer = AutoModel.from_config(config)

if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
pooler_type = (arch_dict[self.config.model_type]["pooler"])

self.pooler = _POOLERS[pooler_type]()

d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
Expand All @@ -132,12 +131,22 @@ def __init__(
nn.Linear(hidden_size, output_dim, bias=False),
)

def forward(self, x: TensorType) -> TensorType:
def forward(self, x: TensorType, output_tokens=False) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)

return self.proj(pooled_out)
projected = self.proj(pooled_out)

if output_tokens:
seq_len = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)
return projected, tokens

return projected

def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers: # full freezing
Expand Down
Loading