-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support layer decay and different lr for text/visual encoder #268
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: quansun <[email protected]>
Just a follow-up. Is anyone taking a look? |
Hi @Quan-Sun, thanks for the PR! Can you give more context on why this is helpful? Have you observed better results with either of these changes? |
Signed-off-by: Sun Quan <[email protected]>
Hello Gabriel, thanks for your reply! This PR is for layer decay and different lr for text/visual encoder. Learning rate layer decay is a common trick when we train a big model loading pre-trained weights. I think text and visual encoders are different due to natural differences between image and text, so different learning rates and values of learning rate layer decay should be applied. |
p.s. bsz can achieve 57k when using grad checkpoint & deepspeed fp16 & zero-stage-1 & local loss |
Signed-off-by: Sun Quan <[email protected]>
@Quan-Sun @gabrielilharco I think the changes are reasonable, I use layer decay extensively in timm fine-tuning these days, and I feel it'd be useful here, especially when initializing one or both of the towers w/ pretrained weights. Also, I was thinking of pushing optimizer creation into a factory method as well since I wanted to try some other optimizers at some point and that'd make it a bit cleaner. Question (for Quan-Sun), why are the assigners in main? they're just created and passed to the optim factory, wouldn't it be cleaner to create them in the factory since they're not needed in main? |
src/training/main.py
Outdated
|
||
if visual_ld < 1.0: | ||
visual_num_layers = model_without_ddp.visual.get_num_layers() | ||
assigner_visual = LayerDecayValueAssigner(list(visual_ld ** (visual_num_layers + 1 - i) for i in range(visual_num_layers + 2))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider using a more descriptive variable name here, explaining what is being assigned (e.g. lr_assigner_visual
). Also is there any reason why the exponent logic is not inside the Assigner class? I.e. the constructor could take in the layer decay and number of layers, and compute the values accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on that, plus getting this out of main
src/training/main.py
Outdated
else: | ||
assigner_visual = None | ||
|
||
if text_ld < 1.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be != 1.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for visual_ld
above
src/training/optim.py
Outdated
"text.token_embedding", | ||
"text.transformer.embeddings.word_embeddings", | ||
"text.transformer.embeddings.position_embeddings", | ||
"text.transformer.embeddings.token_type_embeddings" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this work with all architectures we support? E.g. text models from HF
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gabrielilharco no that's something else that needs to be addressed , as implemented it only works for built-in vit + text transformer, needs to at least detect if it will work and warn that it cannot be used for other models...
It'd be very useful for pretrained timm and HF text models, timm has functions that can calculate the layer decay but needs to be called if a timm model is used (that can be a diff PR), and not sure if HF has any built-in support to calculate layer-decay (discriminative LR) in a general way...
src/training/optim.py
Outdated
parameters = get_all_parameters(args, model, assigner_visual, assigner_text) | ||
|
||
optimizer_args = dict( | ||
betas=(args.beta1, args.beta2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extra tab?
Thanks for the context @Quan-Sun @rwightman. I agree with @rwightman re. the assigners, and also left some other minor comments there. I'll test it from my side after the changes. Thanks! |
FWIW, timm's LD impl is here https://github.com/rwightman/pytorch-image-models/blob/e98c93264cde1657b188f974dc928b9d73303b18/timm/optim/optim_factory.py#L92-L153 ... all models have a fn that returns the group metadata, and the grouper fn can be used, so that would be basis for apply LD for vision tower if timm tower is used |
@gabrielilharco @rwightman Thanks for your comments. I will work on these changes ASAP. |
…ide the Assigner class Signed-off-by: Quan Sun <[email protected]>
Thanks for the update @Quan-Sun. IIUC this would still only work for built-in vit + text transformers, is this right? As Ross pointed out, we should at least detect if this is not the case and warn users for models that are not supported |
Hi @gabrielilharco. You are right. get_num_layer_for_transformer(...) is not flexible. It should warn users if the models are not supported. Do you think we can have a white list here? For example, white_list = ["visual.blocks", "visual.transformer.resblocks", "text.transformer.resblocks", "text.transformer.encoder.layer"], then detecting if the model_param_name is in this white_list. |
Yes, that should work. If we're being very conservative we could also whitelist specific |
add a white_list then detecting if the model_param_name is in this white_list
Thanks for the update @Quan-Sun! Could you check if “Allow edits from maintainers.” is checked on your side? I want to do some small changes before merging |
Hi @gabrielilharco. have checked "Allow edits from maintainers." on my side. Please let me know if anything was missed. |
Signed-off-by: quansun [email protected]
Add new features: