From 4c6793a8af37091d465184d6cc3552785c4a6f17 Mon Sep 17 00:00:00 2001
From: Jan Bujak <j@exia.io>
Date: Sat, 31 Aug 2024 10:58:06 +0000
Subject: [PATCH] Add `move_to_device` kwarg to the optimizer's
 `load_state_dict`

This makes it possible to load an optimizer checkpoint without
automatically moving the optimizer's state to the GPU.
---
 bitsandbytes/optim/optimizer.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index e9c857d49..f6fcd171d 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -153,12 +153,14 @@ def fill_qmap(self):
     def __setstate__(self, state):
         super().__setstate__(state)
 
-    def load_state_dict(self, state_dict):
+    def load_state_dict(self, state_dict, move_to_device=True):
         """Load an optimizer state.
 
         Arguments:
             state_dict (`dict`):
                 An optimizer state (should be returned from a call to `state_dict`) to load.
+            move_to_device (`bool`, defaults to `True`):
+                Whether to move the optimizer's state to the device.
         """
         # deepcopy, to be consistent with module API
         state_dict = deepcopy(state_dict)
@@ -195,7 +197,8 @@ def cast(param, value):
             elif isinstance(value, dict):
                 for k, v in value.items():
                     if k in self.non_castable_tensor_keys:
-                        value[k] = v.to(param.device)
+                        if move_to_device:
+                            value[k] = v.to(param.device)
                     else:
                         value[k] = cast(param, v)