Skip to content

Commit ccf593e

Browse files
committed
fix for model cpu offloading'
1 parent 4221738 commit ccf593e

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
206206
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
207207
"""
208208

209-
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
209+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
210210
_callback_tensor_inputs = ["latents", "prompt_embeds"]
211211

212212
def __init__(
@@ -386,12 +386,13 @@ def encode_prompt(
386386
def encode_image(
387387
self, image: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
388388
):
389+
device = device or self._execution_device
389390
image = (image + 1) / 2.0 # [-1, 1] -> [0, 1]
390391
image = self.feature_extractor(images=image, return_tensors="pt", do_rescale=False).to(
391-
device=self.image_encoder.device, dtype=self.image_encoder.dtype
392+
device=device, dtype=self.image_encoder.dtype
392393
)
393394
image_embeds = self.image_encoder(**image).last_hidden_state
394-
return image_embeds.to(device=device, dtype=dtype)
395+
return image_embeds.to(dtype=dtype)
395396

396397
def check_inputs(
397398
self,
@@ -477,8 +478,9 @@ def prepare_image_latents(
477478
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
478479
latents: Optional[torch.Tensor] = None,
479480
) -> torch.Tensor:
481+
device = device or self._execution_device
480482
if latents is None:
481-
image = image.unsqueeze(2).to(device=self.vae.device, dtype=self.vae.dtype)
483+
image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype)
482484
latents = self.vae.encode(image).latent_dist.sample(generator=generator)
483485
latents = latents * self.vae.config.scaling_factor
484486
return latents.to(device=device, dtype=dtype)

0 commit comments

Comments
 (0)