-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding CoCa #256
adding CoCa #256
Conversation
It would be best to see if it's possible to unify with existing code both the tower support and the losses |
Sure, I will try and reuse as much as I can, for now it is mostly copied from the coca-pytorch repo, will probably ask for some help while I move on :) |
@rom1504 I will reuse the visual_model from open_clip, however in coca-pytorch the transformer layer for the text model are different from the regular ones, feed_forward and attention are parallel, do you prefer like that or regular ones? I have no idea how much difference it makes. Even if I use the regular attention I think the current implementation doesn't allow cross attention, would you prefer a CrossAttention layer or adding the crossattention possibility to the regular attention with kwargs to the forward? |
I think let's bring options into current text model so they support coca
Thanks for working on this! |
@rom1504 I am moving forward, if you have time could you just have a look at how the cross attention and decoder are added to existing models to see if the integration is going in a reasonable direction? |
Another idea of bonus feature (not for this PR probably) : support many HF decoder for the "multimodal transformer" that got added here |
@gpucce Do you think you could give me push access to your fork? I'd love to help out but I don't want you to have to manually merge all of my suggested changes each time I make them |
Sure, I will do it as soon as I am on a computer |
Made some code review comments, most important points:
and then test test test. |
Thanks for the review @rwightman. @iejMac had raised a similar point to the second one. To coordinate with everyone (@rom1504) since the list of todos is getting longer, I am planning to address all of them, however I don't proceed too fast. If someone is working on some of them to speed up the whole process, please share with me. Otherwise I will make everything as suggested taking a bit of time, in general will start from the generative part. |
yes I think starting with implementing captioning will give us confidence that things are working |
src/open_clip/coca_model.py
Outdated
else LayerNorm | ||
) | ||
|
||
text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be self.text here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think adding this would make the state_dict of the model you have just trained incompatible, is that fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, we'll retrain
text_embs = text_embs.permute(1, 0, 2) # NLD -> LND | ||
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND | ||
|
||
for r, ca in zip(self.resblocks, self.cross_attn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you name those resblock and cross_attn ? r and ca are a bit confusing
mask.triu_(1) # zero out the lower diagonal | ||
return mask | ||
|
||
def forward(self, image_embs, text_embs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add a comment explaining the shape of those, I think it'll help
def _repeat(self, t, N): | ||
return t.reshape(1, 1, -1).repeat(N, 1, 1) | ||
|
||
def encode_text(self, text, normalize=True, return_tokens=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as for visual, can we use the text tower much more ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is a bit harder than the visual one I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is missing? don't we simply need to add that return tokens option in the text encoder too ?
Hello, I would like to pre-train coca and so, from CoCa implementation's repo, I saw this PR/branch. I am not familiar with the base code of OpenCLIP, but I think it could be a good opportunity for me to get my hands dirty. |
@AwePhD help would definitely be appreciated. See above comments for what we need You can open PRs on the branch of this PR |
I'm also want to help, and I've done many experiments with CoCa model on most public datasets, and caption generation also(but w/o HF compatibility). |
Ok. I think the best path forward is I'm going to merge this in a coca
branch in this repo, and then we all do PRs towards that branch.
When we're happy that everything is good, we merge to main
…On Tue, Dec 20, 2022, 15:16 Soonhwan-Kwon ***@***.***> wrote:
I'm also want to help, and I've done many experiments with CoCa model on
most public datasets, and caption generation also(but w/o HF compatibility).
—
Reply to this email directly, view it on GitHub
<#256 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAR437UW3HLTOXNB76J3FDTWOG5U7ANCNFSM6AAAAAASLNMBGQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Recent news is we have confirmed with Maciej that:
- this gets the same perf on zero shot classification than clip B/32
- captioning seems to work a bit
We can probably share this model to help testing
We need to proper AR sampling implementation to know whether captioning is
working well.
…On Tue, Dec 20, 2022, 16:37 Romain Beaumont ***@***.***> wrote:
Ok. I think the best path forward is I'm going to merge this in a coca
branch in this repo, and then we all do PRs towards that branch.
When we're happy that everything is good, we merge to main
On Tue, Dec 20, 2022, 15:16 Soonhwan-Kwon ***@***.***>
wrote:
> I'm also want to help, and I've done many experiments with CoCa model on
> most public datasets, and caption generation also(but w/o HF compatibility).
>
> —
> Reply to this email directly, view it on GitHub
> <#256 (comment)>,
> or unsubscribe
> <https://github.com/notifications/unsubscribe-auth/AAR437UW3HLTOXNB76J3FDTWOG5U7ANCNFSM6AAAAAASLNMBGQ>
> .
> You are receiving this because you were mentioned.Message ID:
> ***@***.***>
>
|
Nice that it gets same performance! If you can wait a moment before merging in a few moments I should be able to simplify the logic for the visual part, while for the text one more things would need changing. |
And if you can somehow share the model that would be very useful |
@@ -160,19 +155,31 @@ def encode_image(self, images, normalize=True, return_tokens=False): | |||
def _repeat(self, t, N): | |||
return t.reshape(1, 1, -1).repeat(N, 1, 1) | |||
|
|||
def _build_cls_mask(self, text, cast_dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should be the impact of this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think right now the cls token at the end can attend to pad tokens in the sequence, this should not be possible with this extra mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
@rom1504 the visual part should be simpler now, can make the new branch now, I will still work on the generative part as soon as I have time |
@@ -465,6 +465,9 @@ def forward(self, x: torch.Tensor): | |||
x = self.transformer(x) | |||
x = x.permute(1, 0, 2) # LND -> NLD | |||
|
|||
if output_tokens: | |||
return x | |||
|
|||
if self.global_average_pool: | |||
x = x.mean(dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if this can be done after the ln post
if yes, then it will make it possible to do the ln post only here and not in coca
x = x.permute(1, 0, 2) # NLD -> LND | ||
x = self.visual.transformer(x) | ||
x = x.permute(1, 0, 2) # LND -> NLD | ||
x = self.visual(images, output_tokens=True) | ||
x = self.visual.ln_post(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this ln post call makes a big assumption on the API of the visual encoder
@gpucce I merged into coca branch. I excluded the 2 last commits you added today to avoid discrepancies with our trained model. All comments mentioned here stay valid |
Please refer to #308 for follow ups |
The PR idea is to add the CoCa model as implemented in https://github.com/lucidrains/CoCa-pytorch, using existing parts as much as possible.
Ideally adding possibilty to choose between custom and non custom Attention implementation as is done for CLIP.