Skip to content

Commit 0070763

Browse files
committed
add last_image support; credits: lllyasviel/FramePack#167
1 parent ccf593e commit 0070763

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def interrupt(self):
539539
def __call__(
540540
self,
541541
image: PipelineImageInput,
542+
last_image: Optional[PipelineImageInput] = None,
542543
prompt: Union[str, List[str]] = None,
543544
prompt_2: Union[str, List[str]] = None,
544545
negative_prompt: Union[str, List[str]] = None,
@@ -554,6 +555,7 @@ def __call__(
554555
num_videos_per_prompt: Optional[int] = 1,
555556
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
556557
image_latents: Optional[torch.Tensor] = None,
558+
last_image_latents: Optional[torch.Tensor] = None,
557559
prompt_embeds: Optional[torch.Tensor] = None,
558560
pooled_prompt_embeds: Optional[torch.Tensor] = None,
559561
prompt_attention_mask: Optional[torch.Tensor] = None,
@@ -574,6 +576,11 @@ def __call__(
574576
The call function to the pipeline for generation.
575577
576578
Args:
579+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
580+
The image to be used as the starting point for the video generation.
581+
last_image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`, *optional*):
582+
The optional last image to be used as the ending point for the video generation. This is useful for
583+
generating transitions between two images.
577584
prompt (`str` or `List[str]`, *optional*):
578585
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
579586
instead.
@@ -616,7 +623,9 @@ def __call__(
616623
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
617624
generation deterministic.
618625
image_latents (`torch.Tensor`, *optional*):
619-
Pre-encoded image latents.
626+
Pre-encoded image latents. If not provided, the image will be encoded using the VAE.
627+
last_image_latents (`torch.Tensor`, *optional*):
628+
Pre-encoded last image latents. If not provided, the last image will be encoded using the VAE.
620629
prompt_embeds (`torch.Tensor`, *optional*):
621630
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
622631
provided, text embeddings are generated from the `prompt` input argument.
@@ -734,6 +743,12 @@ def __call__(
734743
# 4. Prepare image
735744
image = self.video_processor.preprocess(image, height, width)
736745
image_embeds = self.encode_image(image, device=device).to(transformer_dtype)
746+
if last_image is not None:
747+
# Credits: https://github.com/lllyasviel/FramePack/pull/167
748+
# Users can modify the weighting strategy applied here
749+
last_image = self.video_processor.preprocess(last_image, height, width)
750+
last_image_embeds = self.encode_image(last_image, device=device).to(transformer_dtype)
751+
last_image_embeds = (image_embeds + last_image_embeds) / 2
737752

738753
# 5. Prepare latent variables
739754
num_channels_latents = self.transformer.config.in_channels
@@ -757,6 +772,10 @@ def __call__(
757772
image_latents = self.prepare_image_latents(
758773
image, dtype=torch.float32, device=device, generator=generator, latents=image_latents
759774
)
775+
if last_image is not None:
776+
last_image_latents = self.prepare_image_latents(
777+
last_image, dtype=torch.float32, device=device, generator=generator
778+
)
760779

761780
latent_paddings = list(reversed(range(num_latent_sections)))
762781
if num_latent_sections > 4:
@@ -767,7 +786,8 @@ def __call__(
767786

768787
# 7. Denoising loop
769788
for k in range(num_latent_sections):
770-
is_last_section = latent_paddings[k] == 0
789+
is_first_section = k == 0
790+
is_last_section = k == num_latent_sections - 1
771791
latent_padding_size = latent_paddings[k] * latent_window_size
772792

773793
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes]))
@@ -786,6 +806,8 @@ def __call__(
786806
latents_postfix, latents_history_2x, latents_history_4x = history_latents[
787807
:, :, : sum(history_sizes)
788808
].split(history_sizes, dim=2)
809+
if last_image is not None and is_first_section:
810+
latents_postfix = last_image_latents
789811
latents_clean = torch.cat([latents_prefix, latents_postfix], dim=2)
790812

791813
latents = self.prepare_latents(

0 commit comments

Comments
 (0)