Skip to content

[bug] test_checkpoint test not passing when any lr scale is set to 0 #265

Closed
@oleksost

Description

@oleksost

🐞 Describe the Bug

Whenever the lr scale of any component is set to 0, e.g. model.base_model.transformer.mlp_lr_scale=0, test_checkpoint is failing with:

FAILED tests/test_checkpoint.py::test_load_pretrained_distributed_checkpoint - AssertionError: torch.Size([0]) != torch.Size([786432])

I wonder how critical is this for loading/saving checkpoints that were trained with lr scaling?
Maybe related to #256.

🔄 Steps to Reproduce

Steps to reproduce the behavior:
add e.g. model.base_model.transformer.mlp_lr_scale=0 here and run test_checkpoint.

Same is the case when lr is set to zero using per layer lr scale from #243 and #258 (yet in this case more than just one test in test_checkpoint fail.
Importantly, if the line self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) here is replaced with simple self.requires_grad = requires_grad the test passes.

🎯 Expected Behavior

Test passes.

📜 Environment Information

📝 Additional Context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions