Skip to content

Commit

Permalink
Add move_to_device kwarg to the optimizer's load_state_dict (bits…
Browse files Browse the repository at this point in the history
…andbytes-foundation#1344)

This makes it possible to load an optimizer checkpoint without
automatically moving the optimizer's state to the GPU.
  • Loading branch information
koute authored and matthewdouglas committed Oct 28, 2024
1 parent dc0f4c1 commit 5292aa4
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,14 @@ def fill_qmap(self):
def __setstate__(self, state):
super().__setstate__(state)

def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict, move_to_device=True):
"""Load an optimizer state.
Arguments:
state_dict (`dict`):
An optimizer state (should be returned from a call to `state_dict`) to load.
move_to_device (`bool`, defaults to `True`):
Whether to move the optimizer's state to the device.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
Expand Down Expand Up @@ -196,7 +198,8 @@ def cast(param, value):
elif isinstance(value, dict):
for k, v in value.items():
if k in self.non_castable_tensor_keys:
value[k] = v.to(param.device)
if move_to_device:
value[k] = v.to(param.device)
else:
value[k] = cast(param, v)

Expand Down

0 comments on commit 5292aa4

Please sign in to comment.