diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index c81e8ca61..a3c6402da 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from bitsandbytes.optim.optimizer import Optimizer2State +from bitsandbytes.optim.optimizer import GaLoreWrappedParameter, Optimizer2State _galore_available = False try: @@ -220,7 +220,6 @@ def step(self, closure=None): lor_update = torch.zeros_like( grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad ) - p.grad = grad if "state1" not in state: self.init_state(group, p, gindex, pindex) @@ -228,7 +227,8 @@ def step(self, closure=None): self.prefetch_state(p) if "rank" in group: - self.update_step(group, p, gindex, pindex, return_updates=lor_update) + galore_p = GaLoreWrappedParameter(p=p, grad=grad) + self.update_step(group, galore_p, gindex, pindex, return_updates=lor_update) # GaLore Projection Back p.data.add_(state["projector"].project_back(lor_update)) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 3b21f09d2..2c0f77295 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. from collections import abc as container_abcs, defaultdict from copy import deepcopy +from dataclasses import dataclass from itertools import chain -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import torch @@ -18,6 +19,12 @@ def __init__(self, initial_data): setattr(self, key, initial_data[key]) +@dataclass +class GaLoreWrappedParameter: + p: torch.Tensor + grad: torch.Tensor + + class GlobalOptimManager: """ A global optimizer manager for enabling custom optimizer configs. @@ -497,17 +504,22 @@ def init_state(self, group, p, gindex, pindex): def update_step( self, group: Dict[str, Any], - p: torch.Tensor, + p: Union[torch.Tensor, GaLoreWrappedParameter], gindex: int, pindex: int, return_updates: Optional[torch.Tensor] = None, ): - # avoid update error from non-contiguous memory layout - p.data = p.data.contiguous() - p.grad = p.grad.contiguous() + if isinstance(p, GaLoreWrappedParameter): + # Unwrap for GaLore + param_to_optimize = p.p + else: + param_to_optimize = p - state = self.state[p] - grad = p.grad + state = self.state[param_to_optimize] + + # avoid update error from non-contiguous memory layout + param_to_optimize.data = param_to_optimize.data.contiguous() + grad = p.grad.contiguous() config = self.get_config(gindex, pindex, group) @@ -528,7 +540,7 @@ def update_step( F.optimizer_update_32bit( self.optimizer_name, grad, - p, + param_to_optimize, state["state1"], config["betas"][0], config["eps"], @@ -550,7 +562,7 @@ def update_step( F.optimizer_update_8bit( self.optimizer_name, grad, - p, + param_to_optimize, state["state1"], state["state2"], config["betas"][0], @@ -578,7 +590,7 @@ def update_step( F.optimizer_update_8bit_blockwise( self.optimizer_name, grad, - p, + param_to_optimize, state["state1"], state["state2"], config["betas"][0],