Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding whole Linear8bitLt/Linear4bit module save/load serialization #1099

Merged
merged 1 commit into from
Mar 5, 2024

Conversation

rdyro
Copy link
Contributor

@rdyro rdyro commented Feb 28, 2024

The purpose of this pull request is to allow torch.save/torch.load directly on modules containing Linear4bit and Linear8bitLt submodules.

Currently, torch.save, then torch.load on Linear8bitLt (after first forward) causes a missing field CB error in the Int8Params class. This PR makes torch aware of the CB and SCB fields in Int8Params class.

The core of this PR is

        ~~return torch.Tensor._make_subclass(cls, data, requires_grad)~~
        obj = torch.Tensor._make_subclass(cls, data, requires_grad)
        obj.CB, obj.SCB = cls.CB, cls.SCB
        return obj

in class Int8Params

I also added the torch.save -> torch.load test to the Linear4bit (this was already working) and Linear8bitLt (this is not yet working).

While saving modules directly in Pytorch with save and load is not good practice, the change to make this work is minimal and makes disk caching modules for development easier.

@younesbelkada
Copy link
Collaborator

cc @Titus-von-Koeller wdyt? might be good to have for the next release no?

@Titus-von-Koeller
Copy link
Collaborator

Yes, I agree, this is looking good and should be merged before the release. I'll review it more in depth soon.

Thanks @rdyro for the good work and taking the initiative to contribute, really appreciated 🤗

@rdyro
Copy link
Contributor Author

rdyro commented Feb 29, 2024

Thanks for the positive feedback! I really like your work with bitsandbytes.

Let me know if ideally, the new tests should extend to all Linear layers, not just Linear4bit and Linear8bitLt.

Copy link

github-actions bot commented Mar 5, 2024

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Titus-von-Koeller
Copy link
Collaborator

Dear @rdyro,

I just reviewed your proposed changes and everything really looks good! I don't think any additional tests are needed, what you did already looks good the way it is.

I also ran the transformers integration tests and everything came through clean.

Thanks so much for your contribution and if you feel like contributing more, we'd be happy to support you!

@Titus-von-Koeller Titus-von-Koeller merged commit a1c0844 into bitsandbytes-foundation:main Mar 5, 2024
9 of 10 checks passed
akx added a commit to akx/bitsandbytes that referenced this pull request Mar 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants