diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
index f4660588e..c81e8ca61 100644
--- a/bitsandbytes/optim/adamw.py
+++ b/bitsandbytes/optim/adamw.py
@@ -220,7 +220,7 @@ def step(self, closure=None):
                     lor_update = torch.zeros_like(
                         grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad
                     )
-                    lor_update.grad = grad
+                    p.grad = grad
 
                 if "state1" not in state:
                     self.init_state(group, p, gindex, pindex)
@@ -228,7 +228,7 @@ def step(self, closure=None):
                 self.prefetch_state(p)
 
                 if "rank" in group:
-                    self.update_step(group, lor_update, gindex, pindex, return_updates=lor_update)
+                    self.update_step(group, p, gindex, pindex, return_updates=lor_update)
 
                     # GaLore Projection Back
                     p.data.add_(state["projector"].project_back(lor_update))