@@ -539,6 +539,7 @@ def interrupt(self):
539539 def __call__ (
540540 self ,
541541 image : PipelineImageInput ,
542+ last_image : Optional [PipelineImageInput ] = None ,
542543 prompt : Union [str , List [str ]] = None ,
543544 prompt_2 : Union [str , List [str ]] = None ,
544545 negative_prompt : Union [str , List [str ]] = None ,
@@ -554,6 +555,7 @@ def __call__(
554555 num_videos_per_prompt : Optional [int ] = 1 ,
555556 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
556557 image_latents : Optional [torch .Tensor ] = None ,
558+ last_image_latents : Optional [torch .Tensor ] = None ,
557559 prompt_embeds : Optional [torch .Tensor ] = None ,
558560 pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
559561 prompt_attention_mask : Optional [torch .Tensor ] = None ,
@@ -574,6 +576,11 @@ def __call__(
574576 The call function to the pipeline for generation.
575577
576578 Args:
579+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
580+ The image to be used as the starting point for the video generation.
581+ last_image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`, *optional*):
582+ The optional last image to be used as the ending point for the video generation. This is useful for
583+ generating transitions between two images.
577584 prompt (`str` or `List[str]`, *optional*):
578585 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
579586 instead.
@@ -616,7 +623,9 @@ def __call__(
616623 A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
617624 generation deterministic.
618625 image_latents (`torch.Tensor`, *optional*):
619- Pre-encoded image latents.
626+ Pre-encoded image latents. If not provided, the image will be encoded using the VAE.
627+ last_image_latents (`torch.Tensor`, *optional*):
628+ Pre-encoded last image latents. If not provided, the last image will be encoded using the VAE.
620629 prompt_embeds (`torch.Tensor`, *optional*):
621630 Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
622631 provided, text embeddings are generated from the `prompt` input argument.
@@ -734,6 +743,12 @@ def __call__(
734743 # 4. Prepare image
735744 image = self .video_processor .preprocess (image , height , width )
736745 image_embeds = self .encode_image (image , device = device ).to (transformer_dtype )
746+ if last_image is not None :
747+ # Credits: https://github.com/lllyasviel/FramePack/pull/167
748+ # Users can modify the weighting strategy applied here
749+ last_image = self .video_processor .preprocess (last_image , height , width )
750+ last_image_embeds = self .encode_image (last_image , device = device ).to (transformer_dtype )
751+ last_image_embeds = (image_embeds + last_image_embeds ) / 2
737752
738753 # 5. Prepare latent variables
739754 num_channels_latents = self .transformer .config .in_channels
@@ -757,6 +772,10 @@ def __call__(
757772 image_latents = self .prepare_image_latents (
758773 image , dtype = torch .float32 , device = device , generator = generator , latents = image_latents
759774 )
775+ if last_image is not None :
776+ last_image_latents = self .prepare_image_latents (
777+ last_image , dtype = torch .float32 , device = device , generator = generator
778+ )
760779
761780 latent_paddings = list (reversed (range (num_latent_sections )))
762781 if num_latent_sections > 4 :
@@ -767,7 +786,8 @@ def __call__(
767786
768787 # 7. Denoising loop
769788 for k in range (num_latent_sections ):
770- is_last_section = latent_paddings [k ] == 0
789+ is_first_section = k == 0
790+ is_last_section = k == num_latent_sections - 1
771791 latent_padding_size = latent_paddings [k ] * latent_window_size
772792
773793 indices = torch .arange (0 , sum ([1 , latent_padding_size , latent_window_size , * history_sizes ]))
@@ -786,6 +806,8 @@ def __call__(
786806 latents_postfix , latents_history_2x , latents_history_4x = history_latents [
787807 :, :, : sum (history_sizes )
788808 ].split (history_sizes , dim = 2 )
809+ if last_image is not None and is_first_section :
810+ latents_postfix = last_image_latents
789811 latents_clean = torch .cat ([latents_prefix , latents_postfix ], dim = 2 )
790812
791813 latents = self .prepare_latents (
0 commit comments