From 4c6793a8af37091d465184d6cc3552785c4a6f17 Mon Sep 17 00:00:00 2001 From: Jan Bujak <j@exia.io> Date: Sat, 31 Aug 2024 10:58:06 +0000 Subject: [PATCH] Add `move_to_device` kwarg to the optimizer's `load_state_dict` This makes it possible to load an optimizer checkpoint without automatically moving the optimizer's state to the GPU. --- bitsandbytes/optim/optimizer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index e9c857d49..f6fcd171d 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -153,12 +153,14 @@ def fill_qmap(self): def __setstate__(self, state): super().__setstate__(state) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict, move_to_device=True): """Load an optimizer state. Arguments: state_dict (`dict`): An optimizer state (should be returned from a call to `state_dict`) to load. + move_to_device (`bool`, defaults to `True`): + Whether to move the optimizer's state to the device. """ # deepcopy, to be consistent with module API state_dict = deepcopy(state_dict) @@ -195,7 +197,8 @@ def cast(param, value): elif isinstance(value, dict): for k, v in value.items(): if k in self.non_castable_tensor_keys: - value[k] = v.to(param.device) + if move_to_device: + value[k] = v.to(param.device) else: value[k] = cast(param, v)