Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding CoCa #256

Closed
wants to merge 123 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
1189487
initial setup
gpucce Nov 25, 2022
91d01fa
add coca loss
gpucce Nov 27, 2022
efb6540
remove loss from the model
gpucce Nov 27, 2022
669a3a0
fix loss
gpucce Nov 27, 2022
f081dc4
add underscores
gpucce Nov 27, 2022
0b1c895
name changes
gpucce Nov 27, 2022
27369b6
Merge remote-tracking branch 'upstream/main' into add_coca
gpucce Nov 27, 2022
d518dd0
add cross attention to Residual and CustomResidual
gpucce Nov 29, 2022
11bf57c
fix if
gpucce Nov 29, 2022
f3dedf6
ädd transformer 'decoder'
gpucce Nov 29, 2022
50c4726
minor fix
gpucce Nov 29, 2022
1e41d83
looks better
gpucce Nov 29, 2022
0d91609
initlize coca model structure
gpucce Nov 29, 2022
50e0cbe
clean
gpucce Nov 29, 2022
93b4236
typo and format
gpucce Nov 29, 2022
97e3c0f
checkpoint signature
gpucce Nov 29, 2022
6ae6f8c
adjust multimodal decoder and add CoCaTransformer
gpucce Nov 30, 2022
0975dfe
keep older logic
gpucce Dec 1, 2022
f2265ec
remove chunk
gpucce Dec 1, 2022
9d47f0e
typo
gpucce Dec 1, 2022
6a101ec
fix
gpucce Dec 1, 2022
e259851
make chunk dim explicit
gpucce Dec 1, 2022
7fff61d
adjust cfg names
gpucce Dec 1, 2022
abd132d
add attentionalpooling
gpucce Dec 1, 2022
452d7d2
add attentional pooling to coca
gpucce Dec 1, 2022
43ce18f
small change
gpucce Dec 1, 2022
3f0f012
add cocatransformer variants and AttentionPooling
gpucce Dec 2, 2022
3e745ec
remoive older attention pooler
gpucce Dec 2, 2022
4f4d3b7
adapt embed text to coca text transformer
gpucce Dec 2, 2022
42539aa
rm coca layers
gpucce Dec 2, 2022
914a570
rename and remove useless CoCa models
gpucce Dec 3, 2022
6215d4a
make attentionpooler pooler only
gpucce Dec 3, 2022
b97db74
refactor for one transformer only
gpucce Dec 5, 2022
d89f018
coca forward works
Dec 5, 2022
9a8c15d
separatae context and n_queries
Dec 5, 2022
c8b9236
add inital coca_base config
Dec 5, 2022
d0f995a
remove config
Dec 5, 2022
5260774
small loss change
Dec 5, 2022
7a2b84e
init training file
Dec 5, 2022
3ef1d17
make variable order right
gpucce Dec 5, 2022
86f47bb
remove print
gpucce Dec 5, 2022
c6834b5
uniform names
gpucce Dec 5, 2022
7489c68
renaming
gpucce Dec 5, 2022
59503df
add coca funcs to init
gpucce Dec 5, 2022
504febd
add coca config and exclude from testing
gpucce Dec 5, 2022
72a7e96
add and comment simple test (no trained model)
gpucce Dec 5, 2022
d8a94be
add L2 norm
Dec 6, 2022
d250eac
make L2 same as in clip
Dec 6, 2022
8d9dfa6
remove unused temperature
Dec 6, 2022
1f2578c
type
Dec 6, 2022
d8ff1bd
clean
Dec 6, 2022
fa24047
fix config
Dec 6, 2022
f61f9d5
make rename and move cfg
Dec 6, 2022
4b76187
rename
Dec 6, 2022
b8777fe
temptative add coca to factory
Dec 6, 2022
42aa408
fix config
Dec 6, 2022
1044f36
update config
Dec 6, 2022
dab7d7d
embed contrastive cls token in model
Dec 7, 2022
d0ae683
remove unused arg
Dec 7, 2022
5a40804
import create_loss
Dec 7, 2022
6789438
make factory accept coca
Dec 7, 2022
60865ef
make caption loss distributed
Dec 7, 2022
ac617bf
make loss customizable
Dec 7, 2022
b9c2b25
pass loss trhough training_epoch
Dec 7, 2022
ccfd1e4
add coca specific params to params
Dec 7, 2022
c1556d4
removed decoder unused parameters
Dec 7, 2022
68d608a
remove unused attributes
Dec 7, 2022
59d4db4
adjust coca_config
Dec 7, 2022
4ee12e1
Merge remote-tracking branch 'upstream/main' into add_coca
Dec 7, 2022
732f15f
fix config and remove unused parameters
Dec 7, 2022
17072c6
remove comment
Dec 7, 2022
74d5e37
remove more comments
gpucce Dec 7, 2022
578aadf
rename attention pooler
Dec 8, 2022
08f43a3
rename TransformerDecoder
Dec 8, 2022
812a8bb
make AttentionalPooler clearer
Dec 8, 2022
f69f4e0
add local loss logic to cocaloss
Dec 8, 2022
3c02aa5
only create loss if train in data
Dec 8, 2022
979cef4
remove wrong file
Dec 8, 2022
2ec204b
fix attentional pooler call
Dec 8, 2022
29c7dfa
not ready for testing
Dec 8, 2022
5a4126b
really not ready for testing
Dec 8, 2022
6e49474
eof lien
Dec 8, 2022
288ddf3
Merge remote-tracking branch 'upstream/main' into add_coca
Dec 9, 2022
599d448
uniform names
Dec 9, 2022
d7953da
add possible generative loss to evaluate
Dec 9, 2022
e2042d4
change _build function names
Dec 9, 2022
15c69f8
remove wrong import
Dec 9, 2022
c219381
remove local_loss from captioning loss
Dec 9, 2022
360408e
Merge branch 'main' into add_coca
rom1504 Dec 9, 2022
5c77e4d
indexing error
Dec 9, 2022
3f095a6
finish renaming
Dec 9, 2022
60f35f3
adjust configs
Dec 9, 2022
a53f477
add training test for coca
rom1504 Dec 9, 2022
b3f3d68
simplify captioning loss
Dec 9, 2022
8eb4772
remove hf
Dec 9, 2022
cf0f857
fix evaluate and loss
Dec 9, 2022
d547017
remove print
Dec 9, 2022
75be611
move projection
Dec 9, 2022
356fb7d
add coca vit 32 config
Dec 9, 2022
8008f25
test on new config
Dec 9, 2022
5b54a4b
adjust coca_base config
Dec 9, 2022
292fa6e
Merge branch 'main' into add_coca
rom1504 Dec 10, 2022
720dabf
remove coca from test_inference
gpucce Dec 10, 2022
bcb82c4
maybe fix regression test
gpucce Dec 10, 2022
d0f4947
make logits and labels contiguous
gpucce Dec 10, 2022
39f20e6
simpler logic
gpucce Dec 10, 2022
2dde78d
make contiguous after transpose
gpucce Dec 10, 2022
de4c063
last test
gpucce Dec 10, 2022
00aa464
try fix loss
Dec 12, 2022
b7bea09
Merge remote-tracking branch 'upstream/main' into add_coca
Dec 12, 2022
27bfc7d
CoCa PR: loss fix + rename file
Dec 17, 2022
e694999
wait for feedback on this
Dec 17, 2022
5427b0a
cleanup
Dec 17, 2022
cc6d13f
Merge pull request #1 from iejMac/add_coca2
gpucce Dec 17, 2022
abd7849
CoCa PR: add set_grad_checkpointing + fix checkpoint API
Dec 17, 2022
19300ad
Merge pull request #2 from iejMac/grad_checkpointing
gpucce Dec 17, 2022
919f5a0
CoCa PR: fix eval (which uses encode_x instead of forward)
Dec 18, 2022
5b29ec0
move making space for CLS token into encode_text
Dec 18, 2022
752de0a
rever zs changes + fix
Dec 18, 2022
1360fcd
Merge pull request #3 from iejMac/eval_fix
gpucce Dec 18, 2022
cd91d32
Merge remote-tracking branch 'upstream/main' into add_coca
gpucce Dec 18, 2022
64c33d8
add cls mask for pad ids
Dec 20, 2022
17813eb
simplify encode image
Dec 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,7 @@ def set_grad_checkpointing(self, enable=True):
self.multimodal_decoder.grad_checkpointing = enable

def encode_image(self, images, normalize=True, return_tokens=False):
gpucce marked this conversation as resolved.
Show resolved Hide resolved
x = self.visual.conv1(images) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.visual.class_embedding.to(x.dtype)
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.visual.positional_embedding.to(x.dtype)
x = self.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.visual(images, output_tokens=True)
x = self.visual.ln_post(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this ln post call makes a big assumption on the API of the visual encoder


if self.visual.proj is not None:
Expand Down
5 changes: 4 additions & 1 deletion src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def init_parameters(self):
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, output_tokens = False):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
Expand All @@ -465,6 +465,9 @@ def forward(self, x: torch.Tensor):
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

if output_tokens:
return x

if self.global_average_pool:
x = x.mean(dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this can be done after the ln post

if yes, then it will make it possible to do the ln post only here and not in coca

else:
Expand Down