Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any plans to support the modified VIT arch based on the VIT-22B paper #426

Open
JianbangZ opened this issue Feb 14, 2023 · 4 comments
Open

Comments

@JianbangZ
Copy link

The changes include the QK normlizaiton, Parallel layers and etc. It would be cool to see how CLIP performs by applying those changes to VIT-L VIT-B VIT-H

@rwightman
Copy link
Collaborator

rwightman commented Feb 14, 2023

@JianbangZ

  1. QK normalization might improve stability at the larger end of the model scale, so far have managed to mitigate with bfloat16 + AMP and lowering AdamW beta2 but not perfectly for large scale models

  2. The parallel MLP + attn won't make much difference within a single device until it's (non-trivially) supported in distributed training via model/tensor parallel code. Kernels typically use the full GPU and don't execute in parallel, even if you try.

I tried torch.compile (inductor) with the parallel attn, thought I was seeing gains on the B/16 model size, and then it revered on H/14. Talked to an expert and he figures it's just the compiler behaviour (good fit vs not so great w/ some differfences due to the parallel layout), BUT there were no parallel exection optimizations in the compiler (it does not try).

@lucidrains
Copy link
Contributor

lucidrains commented Feb 15, 2023

@JianbangZ yea, i can add the qk rmsnorm this week as an option

however, there are still researchers who disagree, just fyi

also, i feel like most of the instability has been ironed out after switching to bfloat16? Ross already mentioned this

@rwightman
Copy link
Collaborator

rwightman commented Feb 15, 2023

I have a different approach for the parallel blocks (manually fusing) that looks like it'll work better. Trying in timm. My first naive approach did not yield gains as mentioned, the compiler couldn't do much with it either.

For qk norm here, it won't work with the default transformer block (it fully relies on builtin nn.MHA, and don't plan to alter that right now), it would need to be added to the custom attention block. There is a new F.scaled_dot_product_attention that's a fused kernel with flash attention (or xformers mem efficient attention) in PyTorch 2.0 that can bridge the custom impl closer to nn.MHA in performance, qk norm would work with that since it only covers the scaled dot prod

WIP https://github.com/rwightman/pytorch-image-models/blob/b6eb652924a40555d6bfcee64c7ef4c8d6e4aa9c/timm/models/vision_transformer.py#L54-L102

@lucidrains
Copy link
Contributor

lucidrains commented Feb 15, 2023

@rwightman ugh yea, i forgot about that (re: nn.mha). ok, let us leave it untouched then, until flash attention is released properly in pytorch 2.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants