From 7e3b5ff1eca58000f29ace3303736e28b79fe731 Mon Sep 17 00:00:00 2001
From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Date: Wed, 30 Oct 2024 15:40:41 -0400
Subject: [PATCH] Introducte GaLoreWrappedParameter to decouple grad from param

---
 bitsandbytes/optim/adamw.py     |  6 +++---
 bitsandbytes/optim/optimizer.py | 32 ++++++++++++++++++++++----------
 2 files changed, 25 insertions(+), 13 deletions(-)

diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
index c81e8ca61..a3c6402da 100644
--- a/bitsandbytes/optim/adamw.py
+++ b/bitsandbytes/optim/adamw.py
@@ -4,7 +4,7 @@
 # LICENSE file in the root directory of this source tree.
 import torch
 
-from bitsandbytes.optim.optimizer import Optimizer2State
+from bitsandbytes.optim.optimizer import GaLoreWrappedParameter, Optimizer2State
 
 _galore_available = False
 try:
@@ -220,7 +220,6 @@ def step(self, closure=None):
                     lor_update = torch.zeros_like(
                         grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad
                     )
-                    p.grad = grad
 
                 if "state1" not in state:
                     self.init_state(group, p, gindex, pindex)
@@ -228,7 +227,8 @@ def step(self, closure=None):
                 self.prefetch_state(p)
 
                 if "rank" in group:
-                    self.update_step(group, p, gindex, pindex, return_updates=lor_update)
+                    galore_p = GaLoreWrappedParameter(p=p, grad=grad)
+                    self.update_step(group, galore_p, gindex, pindex, return_updates=lor_update)
 
                     # GaLore Projection Back
                     p.data.add_(state["projector"].project_back(lor_update))
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 3b21f09d2..2c0f77295 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -4,8 +4,9 @@
 # LICENSE file in the root directory of this source tree.
 from collections import abc as container_abcs, defaultdict
 from copy import deepcopy
+from dataclasses import dataclass
 from itertools import chain
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Union
 
 import torch
 
@@ -18,6 +19,12 @@ def __init__(self, initial_data):
             setattr(self, key, initial_data[key])
 
 
+@dataclass
+class GaLoreWrappedParameter:
+    p: torch.Tensor
+    grad: torch.Tensor
+
+
 class GlobalOptimManager:
     """
     A global optimizer manager for enabling custom optimizer configs.
@@ -497,17 +504,22 @@ def init_state(self, group, p, gindex, pindex):
     def update_step(
         self,
         group: Dict[str, Any],
-        p: torch.Tensor,
+        p: Union[torch.Tensor, GaLoreWrappedParameter],
         gindex: int,
         pindex: int,
         return_updates: Optional[torch.Tensor] = None,
     ):
-        # avoid update error from non-contiguous memory layout
-        p.data = p.data.contiguous()
-        p.grad = p.grad.contiguous()
+        if isinstance(p, GaLoreWrappedParameter):
+            # Unwrap for GaLore
+            param_to_optimize = p.p
+        else:
+            param_to_optimize = p
 
-        state = self.state[p]
-        grad = p.grad
+        state = self.state[param_to_optimize]
+
+        # avoid update error from non-contiguous memory layout
+        param_to_optimize.data = param_to_optimize.data.contiguous()
+        grad = p.grad.contiguous()
 
         config = self.get_config(gindex, pindex, group)
 
@@ -528,7 +540,7 @@ def update_step(
             F.optimizer_update_32bit(
                 self.optimizer_name,
                 grad,
-                p,
+                param_to_optimize,
                 state["state1"],
                 config["betas"][0],
                 config["eps"],
@@ -550,7 +562,7 @@ def update_step(
             F.optimizer_update_8bit(
                 self.optimizer_name,
                 grad,
-                p,
+                param_to_optimize,
                 state["state1"],
                 state["state2"],
                 config["betas"][0],
@@ -578,7 +590,7 @@ def update_step(
             F.optimizer_update_8bit_blockwise(
                 self.optimizer_name,
                 grad,
-                p,
+                param_to_optimize,
                 state["state1"],
                 state["state2"],
                 config["betas"][0],