Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use TextEncoder in coca encode_image #321

Merged
merged 25 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 69 additions & 123 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
LayerNorm,
QuickGELU,
MultimodalTransformer,
AttentionalPooler,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
from .generation_utils import top_a, top_k, top_p
Expand All @@ -22,34 +21,50 @@ class MultimodalCfg(CLIPTextCfg):
dim_head: int = 64
heads: int = 8
n_queries: int = 256
dim_latents: int = None
attn_pooler_heads: int = 8
latent_dim: int = 512

class CoCaEncoderDecoder(nn.Module):
def __init__(self, encoder, decoder) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.encoder.set_grad_checkpointing(enable)
self.decoder.set_grad_checkpointing(enable)

def _build_input_dependent_text_tower(
embed_dim: int,
multimodal_cfg: MultimodalCfg,
def _build_encoder_decoder_tower(
embed_dim,
multimodal_cfg,
text_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
multimodal:bool = True
):

if not multimodal:
return _build_text_tower(
embed_dim=embed_dim,
text_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype
)

if isinstance(multimodal_cfg, dict):
multimodal_cfg = MultimodalCfg(**multimodal_cfg)
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg

encoder = _build_text_tower(
multimodal_cfg.latent_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype
)

vocab_size = (
encoder.config.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else multimodal_cfg.vocab_size
)

act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)

text = MultimodalTransformer(
decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
Expand All @@ -59,154 +74,88 @@ def _build_input_dependent_text_tower(
act_layer=act_layer,
norm_layer=norm_layer,
)

return text, multimodal_cfg



return CoCaEncoderDecoder(encoder, decoder), multimodal_cfg, vocab_size

class CoCa(nn.Module):
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
n_queries: int = 256,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0
):
super().__init__()

multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg

norm_layer = (
LayerNormFp32
if cast_dtype in (torch.float16, torch.bfloat16)
else LayerNorm
)

text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.context_length = self.positional_embedding.shape[0] - 1

self.cls_token = nn.Parameter(torch.randn(embed_dim))
self.visual = _build_vision_tower(
embed_dim, vision_cfg, quick_gelu, cast_dtype
)
self.heads = text_cfg["heads"]

self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower(
embed_dim, multimodal_cfg, quick_gelu, cast_dtype
self.text, multimodal_cfg, vocab_size = _build_encoder_decoder_tower(
embed_dim, multimodal_cfg, text_cfg, quick_gelu, cast_dtype
)

self.img_attn_pool = AttentionalPooler(
multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1
self.visual = _build_vision_tower(
multimodal_cfg.latent_dim, vision_cfg, quick_gelu, cast_dtype
)

self.img_attn_pool_norm = norm_layer(embed_dim)

self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width
self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False)

self.to_logits = nn.Sequential(
norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False)
norm_layer(multimodal_cfg.width), nn.Linear(multimodal_cfg.width, vocab_size, bias=False)
)

# tie embedding weights and projection
self.to_logits[-1].weight = self.token_embedding.weight

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = 0
self.pad_id = pad_id

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
self.multimodal_decoder.grad_checkpointing = enable
self.text.set_grad_checkpointing(enable)

def encode_image(self, images, normalize=True, return_tokens=False):
x = self.visual(images, output_tokens=True)

if hasattr(self.visual, "ln_post"):
x = self.visual.ln_post(x)

if hasattr(self.visual, "proj") and self.visual.proj is not None:
x = x @ self.visual.proj

x = self.img_attn_pool(x, x)
x = self.img_attn_pool_norm(x)

image_latent = x[:, 0]
image_latent, tokens_embs = self.visual(images, output_tokens=True)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent

return (image_latent, x[:, 1:]) if return_tokens else image_latent

def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def _build_cls_mask(self, text, cast_dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask
return (image_latent, tokens_embs) if return_tokens else image_latent

def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat(
[
x + self.positional_embedding[:seq_len, :].to(cast_dtype),
self._repeat(self.cls_token + self.positional_embedding[-1, :], x.shape[0])
],
dim=1
)
seq_len += 1 # seq is 1 longer as we added CLS
attn_mask = self.attn_mask[None, :seq_len, :seq_len].expand(
text.shape[0] * self.heads, seq_len, seq_len
)
cls_mask = self._build_cls_mask(text, cast_dtype)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask + cls_mask)
x = x.permute(1, 0, 2) # LND -> NLD

x = x[torch.arange(x.shape[0]), :] @ self.text_projection

cls_emb = x[torch.arange(x.shape[0]), -1]
token_emb = x[torch.arange(x.shape[0]), :-1]

cls_emb = self.ln_final(cls_emb)
text_latent = self.to_text_latent(cls_emb)
text_latent, token_emb = self.text.encoder(text, output_tokens=True)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent

return (text_latent, token_emb) if return_tokens else text_latent

def forward(self, image, text):
labels = text[:, 1:]

text_latents, text_tokens = self.encode_text(text, return_tokens=True)
image_latents, image_tokens = self.encode_image(image, return_tokens=True)
def forward(self, image, text, output_dict=False):

text_tokens = self.multimodal_decoder(image_tokens, text_tokens)
logits = self.to_logits(text_tokens)

return image_latents, text_latents, logits, labels, self.logit_scale.exp()
text_latent, token_embs = self.encode_text(text, return_tokens=True)
image_latent, image_embs = self.encode_image(image, return_tokens=True)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]

token_embs = self.text.decoder(image_embs, token_embs)
logits = self.to_logits(token_embs)
if output_dict:
return {
"image_features":image_latent,
"text_features":text_latent,
"logits":logits,
"labels":labels,
"logit_scale":self.logit_scale.exp()
}

return image_latent, text_latent, logits, labels, self.logit_scale.exp()

def generate(
self,
image,
text,
seq_len,
max_seq_len=None,
max_seq_len=77,
mask_prob = 0.0,
temperature = 1.,
filter_logits_fn = top_k,
Expand All @@ -217,9 +166,6 @@ def generate(

assert mask_prob < 1, "mask_prob must be smaller than 1."

if max_seq_len is None:
max_seq_len = self.context_length

was_training = self.training
num_dims = len(text.shape)

Expand Down
27 changes: 18 additions & 9 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType):

return x.last_hidden_state[:, self.cls_token_position, :]


class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""

Expand All @@ -90,7 +89,8 @@ def __init__(
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True):
pretrained: bool = True
):
super().__init__()

self.output_dim = output_dim
Expand All @@ -113,11 +113,10 @@ def __init__(
else:
self.config = config
self.transformer = AutoModel.from_config(config)

if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
pooler_type = (arch_dict[self.config.model_type]["pooler"])

self.pooler = _POOLERS[pooler_type]()

d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
Expand All @@ -132,12 +131,22 @@ def __init__(
nn.Linear(hidden_size, output_dim, bias=False),
)

def forward(self, x: TensorType) -> TensorType:
def forward(self, x: TensorType, output_tokens=False) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)

return self.proj(pooled_out)
projected = self.proj(pooled_out)

if output_tokens:
seq_len = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)
return projected, tokens
gpucce marked this conversation as resolved.
Show resolved Hide resolved

return projected

def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers: # full freezing
Expand Down
12 changes: 8 additions & 4 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.prev_num_logits = 0
self.labels = {}

def forward(self, image_features, text_features, logit_scale):
def forward(self, image_features, text_features, logit_scale, output_dict=False):
device = image_features.device
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
Expand Down Expand Up @@ -118,7 +118,8 @@ def forward(self, image_features, text_features, logit_scale):
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return total_loss

return {"contrastive_loss": total_loss} if output_dict else total_loss


class CoCaLoss(ClipLoss):
Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(
self.caption_loss_weight = caption_loss_weight
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)

def forward(self, image_features, text_features, logits, labels, logit_scale):
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss

Expand All @@ -157,4 +158,7 @@ def forward(self, image_features, text_features, logits, labels, logit_scale):
)
caption_loss = caption_loss * self.caption_loss_weight

return clip_loss + caption_loss
if output_dict:
return {"contrastive_loss":clip_loss, "caption_loss":caption_loss}

return clip_loss, caption_loss
Loading