-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add coca trained (#307) #308
Changes from 32 commits
1b86601
29fa332
911c737
b4881bc
50bc599
279e088
dee1ea5
30a73d4
f616050
061482b
d0bd09e
ef80b7b
2ab47b7
c0e5950
9ab881e
3f5b0fb
de343fb
88aa6ce
3b66f37
cdb91dd
58eb5bd
cbd66ed
bf6ef3e
03dfeab
15d6223
9beb0d4
fde2aee
24e454d
f7c566b
9533575
f5e0c5a
1ba2ab6
f0847fa
ba081d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
import numpy as np | ||
from dataclasses import dataclass | ||
|
||
from .transformer import ( | ||
LayerNormFp32, | ||
LayerNorm, | ||
QuickGELU, | ||
MultimodalTransformer, | ||
) | ||
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower | ||
from .generation_utils import top_a, top_k, top_p | ||
|
||
|
||
@dataclass | ||
class MultimodalCfg(CLIPTextCfg): | ||
mlp_ratio: int = 4 | ||
dim_head: int = 64 | ||
heads: int = 8 | ||
n_queries: int = 256 | ||
attn_pooler_heads: int = 8 | ||
|
||
|
||
def _build_text_decoder_tower( | ||
embed_dim, | ||
multimodal_cfg, | ||
quick_gelu: bool = False, | ||
cast_dtype: Optional[torch.dtype] = None, | ||
): | ||
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg | ||
act_layer = QuickGELU if quick_gelu else nn.GELU | ||
norm_layer = ( | ||
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm | ||
) | ||
|
||
decoder = MultimodalTransformer( | ||
context_length=multimodal_cfg.context_length, | ||
width=multimodal_cfg.width, | ||
heads=multimodal_cfg.heads, | ||
layers=multimodal_cfg.layers, | ||
ls_init_value=multimodal_cfg.ls_init_value, | ||
output_dim=embed_dim, | ||
act_layer=act_layer, | ||
norm_layer=norm_layer, | ||
) | ||
|
||
return decoder | ||
|
||
|
||
class CoCa(nn.Module): | ||
def __init__( | ||
self, | ||
embed_dim, | ||
multimodal_cfg: MultimodalCfg, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just noticed this, embed_dim isn't used for CoCa as it's taken from multimodal_cfg.latent_dim ... a little bit weird to have the values in cfg, and the arg, and then not use it.. hmmm There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't look like the MultimodalTransformer tower uses the latent_dim itself, so should that just be the determined by the cfg['embed_dim'] like the other models and remove multimodal_cfg['latent_dim'] ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it resolved ? if not can you create an issue for it ? |
||
text_cfg: CLIPTextCfg, | ||
vision_cfg: CLIPVisionCfg, | ||
quick_gelu: bool = False, | ||
cast_dtype: Optional[torch.dtype] = None, | ||
pad_id: int = 0, | ||
): | ||
super().__init__() | ||
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg | ||
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg | ||
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg | ||
|
||
self.text = _build_text_tower( | ||
embed_dim=embed_dim, | ||
text_cfg=text_cfg, | ||
quick_gelu=quick_gelu, | ||
cast_dtype=cast_dtype, | ||
) | ||
|
||
vocab_size = ( | ||
text_cfg.vocab_size # for hf models | ||
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None | ||
else text_cfg.vocab_size | ||
) | ||
|
||
self.visual = _build_vision_tower( | ||
embed_dim=embed_dim, | ||
vision_cfg=vision_cfg, | ||
quick_gelu=quick_gelu, | ||
cast_dtype=cast_dtype, | ||
) | ||
|
||
self.text_decoder = _build_text_decoder_tower( | ||
vocab_size, | ||
multimodal_cfg=multimodal_cfg, | ||
quick_gelu=quick_gelu, | ||
cast_dtype=cast_dtype, | ||
) | ||
|
||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | ||
self.pad_id = pad_id | ||
|
||
@torch.jit.ignore | ||
def set_grad_checkpointing(self, enable=True): | ||
self.visual.set_grad_checkpointing(enable) | ||
self.text.set_grad_checkpointing(enable) | ||
self.text_decoder.set_grad_checkpointing(enable) | ||
|
||
def _encode_image(self, images, normalize=True): | ||
image_latent, tokens_embs = self.visual(images) | ||
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent | ||
return image_latent, tokens_embs | ||
|
||
def _encode_text(self, text, normalize=True): | ||
text = text[:, :-1] # make space for CLS token | ||
text_latent, token_emb = self.text(text) | ||
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent | ||
return text_latent, token_emb | ||
|
||
def encode_image(self, images, normalize=True): | ||
image_latent, _ = self._encode_image(images, normalize=normalize) | ||
return image_latent | ||
|
||
def encode_text(self, text, normalize=True): | ||
text_latent, _ = self._encode_text(text, normalize=normalize) | ||
return text_latent | ||
|
||
def forward(self, image, text): | ||
text_latent, token_embs = self._encode_text(text) | ||
image_latent, image_embs = self._encode_image(image) | ||
|
||
# TODO: add assertion to avoid bugs? | ||
labels = text[:, -token_embs.shape[1]:] | ||
|
||
logits = self.text_decoder(image_embs, token_embs) | ||
return { | ||
"image_features": image_latent, | ||
"text_features": text_latent, | ||
"logits": logits, | ||
"labels": labels, | ||
"logit_scale": self.logit_scale.exp() | ||
} | ||
|
||
def generate( | ||
self, | ||
image, | ||
text, | ||
seq_len, | ||
max_seq_len=77, | ||
mask_prob=0.0, | ||
temperature=1., | ||
filter_logits_fn=top_k, | ||
filter_thres=0.9, | ||
min_p_pow=2.0, | ||
min_p_ratio=0.02, | ||
): | ||
|
||
assert mask_prob < 1, "mask_prob must be smaller than 1." | ||
|
||
was_training = self.training | ||
num_dims = len(text.shape) | ||
|
||
if num_dims == 1: | ||
text = text[None, :] | ||
|
||
_, t = text.shape | ||
self.eval() | ||
out = text | ||
|
||
for _ in range(seq_len): | ||
x = out[:, -max_seq_len:] | ||
|
||
# TODO: adjust for dict output | ||
logits = self(image, x)["logits"][:, -1] | ||
|
||
if filter_logits_fn in {top_k, top_p}: | ||
filtered_logits = filter_logits_fn(logits, thres=filter_thres) | ||
probs = F.softmax(filtered_logits / temperature, dim=-1) | ||
|
||
elif filter_logits_fn is top_a: | ||
filtered_logits = filter_logits_fn( | ||
logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio | ||
) | ||
probs = F.softmax(filtered_logits / temperature, dim=-1) | ||
|
||
sample = torch.multinomial(probs, 1) | ||
|
||
out = torch.cat((out, sample), dim=-1) | ||
|
||
out = out[:, t:] | ||
|
||
if num_dims == 1: | ||
out = out.squeeze(0) | ||
|
||
self.train(was_training) | ||
return out |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from math import ceil | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def exists(val): | ||
return val is not None | ||
|
||
|
||
def top_p(logits, thres=0.9): | ||
# nucleus | ||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||
|
||
sorted_indices_to_remove = cum_probs > (1 - thres) | ||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() | ||
sorted_indices_to_remove[:, 0] = 0 | ||
|
||
sorted_logits[sorted_indices_to_remove] = float('-inf') | ||
return sorted_logits.scatter(1, sorted_indices, sorted_logits) | ||
|
||
|
||
def top_k(logits, thres=0.9): | ||
k = ceil((1 - thres) * logits.shape[-1]) | ||
val, ind = torch.topk(logits, k) | ||
probs = torch.full_like(logits, float('-inf')) | ||
probs.scatter_(1, ind, val) | ||
return probs | ||
|
||
|
||
def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02): | ||
probs = F.softmax(logits, dim=-1) | ||
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio | ||
logits[probs < limit] = float('-inf') | ||
logits[probs >= limit] = 1 | ||
return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per the Encoder/Decoder Module above, with those split, this can be split to have text(_encoder) + text_decoder