diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 47f1f4199615..913499f9d3e3 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -126,8 +126,6 @@ def __init__( def _init_cpu_param_dict(self): cpu_param_dict = {} - if self.stream is None: - return cpu_param_dict for module in self.modules: for param in module.parameters(): @@ -239,24 +237,16 @@ def _offload_to_disk(self): tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) def _offload_to_memory(self): - if self.stream is not None: - if not self.record_stream: - self._torch_accelerator_module.current_stream().synchronize() + if not self.record_stream: + self._torch_accelerator_module.current_stream().synchronize() - for group_module in self.modules: - for param in group_module.parameters(): - param.data = self.cpu_param_dict[param] - for param in self.parameters: + for group_module in self.modules: + for param in group_module.parameters(): param.data = self.cpu_param_dict[param] - for buffer in self.buffers: - buffer.data = self.cpu_param_dict[buffer] - else: - for group_module in self.modules: - group_module.to(self.offload_device, non_blocking=False) - for param in self.parameters: - param.data = param.data.to(self.offload_device, non_blocking=False) - for buffer in self.buffers: - buffer.data = buffer.data.to(self.offload_device, non_blocking=False) + for param in self.parameters: + param.data = self.cpu_param_dict[param] + for buffer in self.buffers: + buffer.data = self.cpu_param_dict[buffer] @torch.compiler.disable() def onload_(self):