-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding CoCa #256
adding CoCa #256
Changes from 1 commit
1189487
91d01fa
efb6540
669a3a0
f081dc4
0b1c895
27369b6
d518dd0
11bf57c
f3dedf6
50c4726
1e41d83
0d91609
50e0cbe
93b4236
97e3c0f
6ae6f8c
0975dfe
f2265ec
9d47f0e
6a101ec
e259851
7fff61d
abd132d
452d7d2
43ce18f
3f0f012
3e745ec
4f4d3b7
42539aa
914a570
6215d4a
b97db74
d89f018
9a8c15d
c8b9236
d0f995a
5260774
7a2b84e
3ef1d17
86f47bb
c6834b5
7489c68
59503df
504febd
72a7e96
d8a94be
d250eac
8d9dfa6
1f2578c
d8ff1bd
fa24047
f61f9d5
4b76187
b8777fe
42aa408
1044f36
dab7d7d
d0ae683
5a40804
6789438
60865ef
ac617bf
b9c2b25
ccfd1e4
c1556d4
68d608a
59d4db4
4ee12e1
732f15f
17072c6
74d5e37
578aadf
08f43a3
812a8bb
f69f4e0
3c02aa5
979cef4
2ec204b
29c7dfa
5a4126b
6e49474
288ddf3
599d448
d7953da
e2042d4
15c69f8
c219381
360408e
5c77e4d
3f095a6
60f35f3
a53f477
b3f3d68
8eb4772
cf0f857
d547017
75be611
356fb7d
8008f25
5b54a4b
292fa6e
720dabf
bcb82c4
d0f4947
39f20e6
2dde78d
de4c063
00aa464
b7bea09
27bfc7d
e694999
5427b0a
cc6d13f
abd7849
19300ad
919f5a0
5b29ec0
752de0a
1360fcd
cd91d32
64c33d8
17813eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,24 +30,22 @@ def _build_input_dependent_text_tower( | |
multimodal_cfg: MultimodalCfg, | ||
quick_gelu: bool = False, | ||
cast_dtype: Optional[torch.dtype] = None, | ||
multimodal:bool = True | ||
multimodal: bool = True, | ||
): | ||
|
||
if not multimodal: | ||
return _build_text_tower( | ||
embed_dim=embed_dim, | ||
text_cfg=multimodal_cfg, | ||
quick_gelu=quick_gelu, | ||
cast_dtype=cast_dtype | ||
cast_dtype=cast_dtype, | ||
) | ||
|
||
if isinstance(multimodal_cfg, dict): | ||
multimodal_cfg = MultimodalCfg(**multimodal_cfg) | ||
|
||
act_layer = QuickGELU if quick_gelu else nn.GELU | ||
norm_layer = ( | ||
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm | ||
) | ||
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm | ||
|
||
text = MultimodalTransformer( | ||
context_length=multimodal_cfg.context_length, | ||
|
@@ -76,14 +74,11 @@ def __init__( | |
): | ||
super().__init__() | ||
|
||
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm | ||
|
||
norm_layer = ( | ||
LayerNormFp32 | ||
if cast_dtype in (torch.float16, torch.bfloat16) | ||
else LayerNorm | ||
text = _build_input_dependent_text_tower( | ||
embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False | ||
) | ||
|
||
text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False) | ||
self.transformer = text.transformer | ||
self.vocab_size = text.vocab_size | ||
self.token_embedding = text.token_embedding | ||
|
@@ -92,10 +87,9 @@ def __init__( | |
self.text_projection = text.text_projection | ||
self.register_buffer("attn_mask", text.attn_mask, persistent=False) | ||
|
||
self.heads = text_cfg["heads"] | ||
self.cls_token = nn.Parameter(torch.randn(embed_dim)) | ||
self.visual = _build_vision_tower( | ||
embed_dim, vision_cfg, quick_gelu, cast_dtype | ||
) | ||
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) | ||
|
||
self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower( | ||
embed_dim, multimodal_cfg, quick_gelu, cast_dtype | ||
|
@@ -107,7 +101,9 @@ def __init__( | |
|
||
self.img_attn_pool_norm = norm_layer(embed_dim) | ||
|
||
self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width | ||
self.dim_latents = ( | ||
multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width | ||
) | ||
self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False) | ||
|
||
self.to_logits = nn.Sequential( | ||
|
@@ -118,6 +114,7 @@ def __init__( | |
self.to_logits[-1].weight = self.token_embedding.weight | ||
|
||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | ||
self.pad_id = 0 | ||
|
||
@torch.jit.ignore | ||
def set_grad_checkpointing(self, enable=True): | ||
|
@@ -132,9 +129,7 @@ def encode_image(self, images, normalize=True, return_tokens=False): | |
x = torch.cat( | ||
[ | ||
self.visual.class_embedding.to(x.dtype) | ||
+ torch.zeros( | ||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device | ||
), | ||
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), | ||
x, | ||
], | ||
dim=1, | ||
|
@@ -160,19 +155,31 @@ def encode_image(self, images, normalize=True, return_tokens=False): | |
def _repeat(self, t, N): | ||
return t.reshape(1, 1, -1).repeat(N, 1, 1) | ||
|
||
def _build_cls_mask(self, text, cast_dtype): | ||
cls_mask = (text != self.pad_id).unsqueeze(1) | ||
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) | ||
additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) | ||
additive_mask.fill_(0) | ||
additive_mask.masked_fill_(~cls_mask, float("-inf")) | ||
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) | ||
return additive_mask | ||
|
||
def encode_text(self, text, normalize=True, return_tokens=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as for visual, can we use the text tower much more ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one is a bit harder than the visual one I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is missing? don't we simply need to add that return tokens option in the text encoder too ? |
||
text = text[:, :-1] # make space for CLS token | ||
text = text[:, :-1] # make space for CLS token | ||
cast_dtype = self.transformer.get_cast_dtype() | ||
|
||
# cls_mask = (text!=self.pad_id).unsqueeze(1) | ||
# attn_mask = F.pad(cls_mask, (0, 1, text.shape[1], 0), value=True) | ||
attn_mask = self.attn_mask[None, :].expand( | ||
text.shape[0] * self.heads, *self.attn_mask.shape | ||
) | ||
cls_mask = self._build_cls_mask(text, cast_dtype) | ||
# attn_mask = F.pad(self.attn_mask, (0, 1, 0, 1), value=0.0) | ||
|
||
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] | ||
x = torch.cat([x, self._repeat(self.cls_token, x.shape[0])], dim=1) | ||
x = x + self.positional_embedding.to(cast_dtype) | ||
x = x.permute(1, 0, 2) # NLD -> LND | ||
x = self.transformer(x, attn_mask=self.attn_mask) | ||
x = self.transformer(x, attn_mask=attn_mask + cls_mask) | ||
x = x.permute(1, 0, 2) # LND -> NLD | ||
|
||
# x.shape = [batch_size, n_ctx, transformer.width] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should be the impact of this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think right now the cls token at the end can attend to pad tokens in the sequence, this should not be possible with this extra mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!