Skip to content

Commit

Permalink
use TextEncoder in coca encode_image (#321)
Browse files Browse the repository at this point in the history
* use self.text in encode image

* unused var

* rever aAtention and CustoResidualAttentionBlock

* remove whiteline

* add dict output

* bintegrate self.text attributes

* HF compatibility

* better config and minor fixes

* clean

* remove eembed_cls option from HF

* use cls_token_position

* fix cls masking

* resize labels

* text -> self.text

* split loss logging

* add total loss

* minor logs formatting

* fix generate

* simpler logic

* disentangle proj for HF too

* adjust config

* only norm cls

* move attn_pool to VisionTransformer

* adjust coca_base config

* fix grad checkpointing in MultimodalTransformer

Co-authored-by: gpucce <[email protected]>
Co-authored-by: iejMac <[email protected]>
  • Loading branch information
3 people authored Jan 6, 2023
1 parent dee1ea5 commit 30a73d4
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 246 deletions.
192 changes: 69 additions & 123 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
LayerNorm,
QuickGELU,
MultimodalTransformer,
AttentionalPooler,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
from .generation_utils import top_a, top_k, top_p
Expand All @@ -22,34 +21,50 @@ class MultimodalCfg(CLIPTextCfg):
dim_head: int = 64
heads: int = 8
n_queries: int = 256
dim_latents: int = None
attn_pooler_heads: int = 8
latent_dim: int = 512

class CoCaEncoderDecoder(nn.Module):
def __init__(self, encoder, decoder) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.encoder.set_grad_checkpointing(enable)
self.decoder.set_grad_checkpointing(enable)

def _build_input_dependent_text_tower(
embed_dim: int,
multimodal_cfg: MultimodalCfg,
def _build_encoder_decoder_tower(
embed_dim,
multimodal_cfg,
text_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
multimodal:bool = True
):

if not multimodal:
return _build_text_tower(
embed_dim=embed_dim,
text_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype
)

if isinstance(multimodal_cfg, dict):
multimodal_cfg = MultimodalCfg(**multimodal_cfg)
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg

encoder = _build_text_tower(
multimodal_cfg.latent_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype
)

vocab_size = (
encoder.config.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else multimodal_cfg.vocab_size
)

act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)

text = MultimodalTransformer(
decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
Expand All @@ -59,154 +74,88 @@ def _build_input_dependent_text_tower(
act_layer=act_layer,
norm_layer=norm_layer,
)

return text, multimodal_cfg



return CoCaEncoderDecoder(encoder, decoder), multimodal_cfg, vocab_size

class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
n_queries: int = 256,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0
):
super().__init__()

multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg

norm_layer = (
LayerNormFp32
if cast_dtype in (torch.float16, torch.bfloat16)
else LayerNorm
)

text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.context_length = self.positional_embedding.shape[0] - 1

self.cls_token = nn.Parameter(torch.randn(embed_dim))
self.visual = _build_vision_tower(
embed_dim, vision_cfg, quick_gelu, cast_dtype
)
self.heads = text_cfg["heads"]

self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower(
embed_dim, multimodal_cfg, quick_gelu, cast_dtype
self.text, multimodal_cfg, vocab_size = _build_encoder_decoder_tower(
embed_dim, multimodal_cfg, text_cfg, quick_gelu, cast_dtype
)

self.img_attn_pool = AttentionalPooler(
multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1
self.visual = _build_vision_tower(
multimodal_cfg.latent_dim, vision_cfg, quick_gelu, cast_dtype
)

self.img_attn_pool_norm = norm_layer(embed_dim)

self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width
self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False)

self.to_logits = nn.Sequential(
norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False)
norm_layer(multimodal_cfg.width), nn.Linear(multimodal_cfg.width, vocab_size, bias=False)
)

# tie embedding weights and projection
self.to_logits[-1].weight = self.token_embedding.weight

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = 0
self.pad_id = pad_id

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
self.multimodal_decoder.grad_checkpointing = enable
self.text.set_grad_checkpointing(enable)

def encode_image(self, images, normalize=True, return_tokens=False):
x = self.visual(images, output_tokens=True)

if hasattr(self.visual, "ln_post"):
x = self.visual.ln_post(x)

if hasattr(self.visual, "proj") and self.visual.proj is not None:
x = x @ self.visual.proj

x = self.img_attn_pool(x, x)
x = self.img_attn_pool_norm(x)

image_latent = x[:, 0]
image_latent, tokens_embs = self.visual(images, output_tokens=True)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent

return (image_latent, x[:, 1:]) if return_tokens else image_latent

def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def _build_cls_mask(self, text, cast_dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask
return (image_latent, tokens_embs) if return_tokens else image_latent

def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat(
[
x + self.positional_embedding[:seq_len, :].to(cast_dtype),
self._repeat(self.cls_token + self.positional_embedding[-1, :], x.shape[0])
],
dim=1
)
seq_len += 1 # seq is 1 longer as we added CLS
attn_mask = self.attn_mask[None, :seq_len, :seq_len].expand(
text.shape[0] * self.heads, seq_len, seq_len
)
cls_mask = self._build_cls_mask(text, cast_dtype)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask + cls_mask)
x = x.permute(1, 0, 2) # LND -> NLD

x = x[torch.arange(x.shape[0]), :] @ self.text_projection

cls_emb = x[torch.arange(x.shape[0]), -1]
token_emb = x[torch.arange(x.shape[0]), :-1]

cls_emb = self.ln_final(cls_emb)
text_latent = self.to_text_latent(cls_emb)
text_latent, token_emb = self.text.encoder(text, output_tokens=True)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent

return (text_latent, token_emb) if return_tokens else text_latent

def forward(self, image, text):
labels = text[:, 1:]

text_latents, text_tokens = self.encode_text(text, return_tokens=True)
image_latents, image_tokens = self.encode_image(image, return_tokens=True)
def forward(self, image, text, output_dict=False):

text_tokens = self.multimodal_decoder(image_tokens, text_tokens)
logits = self.to_logits(text_tokens)

return image_latents, text_latents, logits, labels, self.logit_scale.exp()
text_latent, token_embs = self.encode_text(text, return_tokens=True)
image_latent, image_embs = self.encode_image(image, return_tokens=True)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]

token_embs = self.text.decoder(image_embs, token_embs)
logits = self.to_logits(token_embs)
if output_dict:
return {
"image_features":image_latent,
"text_features":text_latent,
"logits":logits,
"labels":labels,
"logit_scale":self.logit_scale.exp()
}

return image_latent, text_latent, logits, labels, self.logit_scale.exp()

def generate(
self,
image,
text,
seq_len,
max_seq_len=None,
max_seq_len=77,
mask_prob = 0.0,
temperature = 1.,
filter_logits_fn = top_k,
Expand All @@ -217,9 +166,6 @@ def generate(

assert mask_prob < 1, "mask_prob must be smaller than 1."

if max_seq_len is None:
max_seq_len = self.context_length

was_training = self.training
num_dims = len(text.shape)

Expand Down
27 changes: 18 additions & 9 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType):

return x.last_hidden_state[:, self.cls_token_position, :]


class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""

Expand All @@ -90,7 +89,8 @@ def __init__(
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True):
pretrained: bool = True
):
super().__init__()

self.output_dim = output_dim
Expand All @@ -113,11 +113,10 @@ def __init__(
else:
self.config = config
self.transformer = AutoModel.from_config(config)

if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
pooler_type = (arch_dict[self.config.model_type]["pooler"])

self.pooler = _POOLERS[pooler_type]()

d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
Expand All @@ -132,12 +131,22 @@ def __init__(
nn.Linear(hidden_size, output_dim, bias=False),
)

def forward(self, x: TensorType) -> TensorType:
def forward(self, x: TensorType, output_tokens=False) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)

return self.proj(pooled_out)
projected = self.proj(pooled_out)

if output_tokens:
seq_len = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)
return projected, tokens

return projected

def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers: # full freezing
Expand Down
12 changes: 8 additions & 4 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.prev_num_logits = 0
self.labels = {}

def forward(self, image_features, text_features, logit_scale):
def forward(self, image_features, text_features, logit_scale, output_dict=False):
device = image_features.device
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
Expand Down Expand Up @@ -118,7 +118,8 @@ def forward(self, image_features, text_features, logit_scale):
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return total_loss

return {"contrastive_loss": total_loss} if output_dict else total_loss


class CoCaLoss(ClipLoss):
Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(
self.caption_loss_weight = caption_loss_weight
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)

def forward(self, image_features, text_features, logits, labels, logit_scale):
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss

Expand All @@ -157,4 +158,7 @@ def forward(self, image_features, text_features, logits, labels, logit_scale):
)
caption_loss = caption_loss * self.caption_loss_weight

return clip_loss + caption_loss
if output_dict:
return {"contrastive_loss":clip_loss, "caption_loss":caption_loss}

return clip_loss, caption_loss
Loading

0 comments on commit 30a73d4

Please sign in to comment.