From 90657e7bac03edfde445d0da64690619a28169ef Mon Sep 17 00:00:00 2001 From: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Mon, 22 Jul 2024 16:36:31 +0200 Subject: [PATCH] Edenzzzz's fix for min_8bit_size functionality in Optimizer base classes (#1286) * fix min_8bit_size invalid bug * Apply same fix to other optimizer base class --------- Co-authored-by: Edenzzzz --- bitsandbytes/optim/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index e3830507c..2d50cd9b7 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -438,7 +438,7 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32: state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: @@ -667,7 +667,7 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32: state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: