Description
🐞 Describe the Bug
Whenever the lr scale of any component is set to 0, e.g. model.base_model.transformer.mlp_lr_scale=0
, test_checkpoint
is failing with:
FAILED tests/test_checkpoint.py::test_load_pretrained_distributed_checkpoint - AssertionError: torch.Size([0]) != torch.Size([786432])
I wonder how critical is this for loading/saving checkpoints that were trained with lr scaling?
Maybe related to #256.
🔄 Steps to Reproduce
Steps to reproduce the behavior:
add e.g. model.base_model.transformer.mlp_lr_scale=0
here and run test_checkpoint
.
Same is the case when lr is set to zero using per layer lr scale from #243 and #258 (yet in this case more than just one test in test_checkpoint
fail.
Importantly, if the line self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale)
here is replaced with simple self.requires_grad = requires_grad
the test passes.
🎯 Expected Behavior
Test passes.