Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Edenzzzz's fix for min_8bit_size functionality in Optimizer base clas…
Browse files Browse the repository at this point in the history
…ses (bitsandbytes-foundation#1286)

* fix min_8bit_size invalid bug

* Apply same fix to other optimizer base class

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
2 people authored and matthewdouglas committed Oct 28, 2024
1 parent 7ca6dcd commit 90657e7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
@@ -438,7 +438,7 @@ def init_state(self, group, p, gindex, pindex):
state = self.state[p]
state["step"] = 0

if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
if dtype == torch.float32:
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
@@ -667,7 +667,7 @@ def init_state(self, group, p, gindex, pindex):
state = self.state[p]
state["step"] = 0

if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
if dtype == torch.float32:
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
if state["step"] == 0:

0 comments on commit 90657e7

Please sign in to comment.