Skip to content

Commit fa1bdce

Browse files
a-r-r-o-wtolgacangozstevhliu
authored
[docs] Improve SVD pipeline docs (#7087)
* update svd docs * fix example doc string * update return type hints/docs * update type hints * Fix typos in pipeline_stable_video_diffusion.py * make style && make fix-copies * Update src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py Co-authored-by: Steven Liu <[email protected]> * Update src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py Co-authored-by: Steven Liu <[email protected]> * update based on suggestion --------- Co-authored-by: M. Tolga Cangöz <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent ca6cdc7 commit fa1bdce

File tree

1 file changed

+77
-67
lines changed

1 file changed

+77
-67
lines changed

src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py

Lines changed: 77 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,33 @@
2121
import torch
2222
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
2323

24-
from ...image_processor import VaeImageProcessor
24+
from ...image_processor import PipelineImageInput, VaeImageProcessor
2525
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
2626
from ...schedulers import EulerDiscreteScheduler
27-
from ...utils import BaseOutput, logging
27+
from ...utils import BaseOutput, logging, replace_example_docstring
2828
from ...utils.torch_utils import is_compiled_module, randn_tensor
2929
from ..pipeline_utils import DiffusionPipeline
3030

3131

3232
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3333

34+
EXAMPLE_DOC_STRING = """
35+
Examples:
36+
```py
37+
>>> from diffusers import StableVideoDiffusionPipeline
38+
>>> from diffusers.utils import load_image, export_to_video
39+
40+
>>> pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
41+
>>> pipe.to("cuda")
42+
43+
>>> image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg")
44+
>>> image = image.resize((1024, 576))
45+
46+
>>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
47+
>>> export_to_video(frames, "generated.mp4", fps=7)
48+
```
49+
"""
50+
3451

3552
def _append_dims(x, target_dims):
3653
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
@@ -41,7 +58,7 @@ def _append_dims(x, target_dims):
4158

4259

4360
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
44-
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
61+
def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"):
4562
batch_size, channels, num_frames, height, width = video.shape
4663
outputs = []
4764
for batch_idx in range(batch_size):
@@ -65,15 +82,15 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
6582
@dataclass
6683
class StableVideoDiffusionPipelineOutput(BaseOutput):
6784
r"""
68-
Output class for zero-shot text-to-video pipeline.
85+
Output class for Stable Video Diffusion pipeline.
6986
7087
Args:
71-
frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
72-
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
73-
num_channels)`.
88+
frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.FloatTensor`]):
89+
List of denoised PIL images of length `batch_size` or numpy array or torch tensor
90+
of shape `(batch_size, num_frames, height, width, num_channels)`.
7491
"""
7592

76-
frames: Union[List[PIL.Image.Image], np.ndarray]
93+
frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.FloatTensor]
7794

7895

7996
class StableVideoDiffusionPipeline(DiffusionPipeline):
@@ -119,7 +136,13 @@ def __init__(
119136
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
120137
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
121138

122-
def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
139+
def _encode_image(
140+
self,
141+
image: PipelineImageInput,
142+
device: Union[str, torch.device],
143+
num_videos_per_prompt: int,
144+
do_classifier_free_guidance: bool,
145+
) -> torch.FloatTensor:
123146
dtype = next(self.image_encoder.parameters()).dtype
124147

125148
if not isinstance(image, torch.Tensor):
@@ -164,9 +187,9 @@ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free
164187
def _encode_vae_image(
165188
self,
166189
image: torch.Tensor,
167-
device,
168-
num_videos_per_prompt,
169-
do_classifier_free_guidance,
190+
device: Union[str, torch.device],
191+
num_videos_per_prompt: int,
192+
do_classifier_free_guidance: bool,
170193
):
171194
image = image.to(device=device)
172195
image_latents = self.vae.encode(image).latent_dist.mode()
@@ -186,13 +209,13 @@ def _encode_vae_image(
186209

187210
def _get_add_time_ids(
188211
self,
189-
fps,
190-
motion_bucket_id,
191-
noise_aug_strength,
192-
dtype,
193-
batch_size,
194-
num_videos_per_prompt,
195-
do_classifier_free_guidance,
212+
fps: int,
213+
motion_bucket_id: int,
214+
noise_aug_strength: float,
215+
dtype: torch.dtype,
216+
batch_size: int,
217+
num_videos_per_prompt: int,
218+
do_classifier_free_guidance: bool,
196219
):
197220
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
198221

@@ -212,7 +235,7 @@ def _get_add_time_ids(
212235

213236
return add_time_ids
214237

215-
def decode_latents(self, latents, num_frames, decode_chunk_size=14):
238+
def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
216239
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
217240
latents = latents.flatten(0, 1)
218241

@@ -257,15 +280,15 @@ def check_inputs(self, image, height, width):
257280

258281
def prepare_latents(
259282
self,
260-
batch_size,
261-
num_frames,
262-
num_channels_latents,
263-
height,
264-
width,
265-
dtype,
266-
device,
267-
generator,
268-
latents=None,
283+
batch_size: int,
284+
num_frames: int,
285+
num_channels_latents: int,
286+
height: int,
287+
width: int,
288+
dtype: torch.dtype,
289+
device: Union[str, torch.device],
290+
generator: torch.Generator,
291+
latents: Optional[torch.FloatTensor] = None,
269292
):
270293
shape = (
271294
batch_size,
@@ -307,6 +330,7 @@ def num_timesteps(self):
307330
return self._num_timesteps
308331

309332
@torch.no_grad()
333+
@replace_example_docstring(EXAMPLE_DOC_STRING)
310334
def __call__(
311335
self,
312336
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
@@ -333,15 +357,16 @@ def __call__(
333357
334358
Args:
335359
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
336-
Image or images to guide image generation. If you provide a tensor, the expected value range is between `[0,1]`.
360+
Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, 1]`.
337361
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
338362
The height in pixels of the generated image.
339363
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
340364
The width in pixels of the generated image.
341365
num_frames (`int`, *optional*):
342-
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
366+
The number of video frames to generate. Defaults to `self.unet.config.num_frames`
367+
(14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
343368
num_inference_steps (`int`, *optional*, defaults to 25):
344-
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
369+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
345370
expense of slower inference. This parameter is modulated by `strength`.
346371
min_guidance_scale (`float`, *optional*, defaults to 1.0):
347372
The minimum guidance scale. Used for the classifier free guidance with first frame.
@@ -351,29 +376,29 @@ def __call__(
351376
Frames per second. The rate at which the generated images shall be exported to a video after generation.
352377
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
353378
motion_bucket_id (`int`, *optional*, defaults to 127):
354-
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
379+
Used for conditioning the amount of motion for the generation. The higher the number the more motion
380+
will be in the video.
355381
noise_aug_strength (`float`, *optional*, defaults to 0.02):
356382
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
357383
decode_chunk_size (`int`, *optional*):
358-
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
359-
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
360-
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
384+
The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the expense of more memory usage. By default, the decoder decodes all frames at once for maximal
385+
quality. For lower memory usage, reduce `decode_chunk_size`.
361386
num_videos_per_prompt (`int`, *optional*, defaults to 1):
362-
The number of images to generate per prompt.
387+
The number of videos to generate per prompt.
363388
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
364389
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
365390
generation deterministic.
366391
latents (`torch.FloatTensor`, *optional*):
367-
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
392+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
368393
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
369394
tensor is generated by sampling using the supplied random `generator`.
370395
output_type (`str`, *optional*, defaults to `"pil"`):
371-
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
396+
The output format of the generated image. Choose between `pil`, `np` or `pt`.
372397
callback_on_step_end (`Callable`, *optional*):
373-
A function that calls at the end of each denoising steps during the inference. The function is called
374-
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
375-
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
376-
`callback_on_step_end_tensor_inputs`.
398+
A function that is called at the end of each denoising step during inference. The function is called
399+
with the following arguments:
400+
`callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
401+
`callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
377402
callback_on_step_end_tensor_inputs (`List`, *optional*):
378403
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
379404
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
@@ -382,26 +407,12 @@ def __call__(
382407
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
383408
plain tuple.
384409
410+
Examples:
411+
385412
Returns:
386413
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
387414
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
388-
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
389-
390-
Examples:
391-
392-
```py
393-
from diffusers import StableVideoDiffusionPipeline
394-
from diffusers.utils import load_image, export_to_video
395-
396-
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
397-
pipe.to("cuda")
398-
399-
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
400-
image = image.resize((1024, 576))
401-
402-
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
403-
export_to_video(frames, "generated.mp4", fps=7)
404-
```
415+
otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`) is returned.
405416
"""
406417
# 0. Default height and width to unet
407418
height = height or self.unet.config.sample_size * self.vae_scale_factor
@@ -429,8 +440,7 @@ def __call__(
429440
# 3. Encode input image
430441
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
431442

432-
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
433-
# is why it is reduced here.
443+
# NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
434444
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
435445
fps = fps - 1
436446

@@ -471,11 +481,11 @@ def __call__(
471481
)
472482
added_time_ids = added_time_ids.to(device)
473483

474-
# 4. Prepare timesteps
484+
# 6. Prepare timesteps
475485
self.scheduler.set_timesteps(num_inference_steps, device=device)
476486
timesteps = self.scheduler.timesteps
477487

478-
# 5. Prepare latent variables
488+
# 7. Prepare latent variables
479489
num_channels_latents = self.unet.config.in_channels
480490
latents = self.prepare_latents(
481491
batch_size * num_videos_per_prompt,
@@ -489,15 +499,15 @@ def __call__(
489499
latents,
490500
)
491501

492-
# 7. Prepare guidance scale
502+
# 8. Prepare guidance scale
493503
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
494504
guidance_scale = guidance_scale.to(device, latents.dtype)
495505
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
496506
guidance_scale = _append_dims(guidance_scale, latents.ndim)
497507

498508
self._guidance_scale = guidance_scale
499509

500-
# 8. Denoising loop
510+
# 9. Denoising loop
501511
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
502512
self._num_timesteps = len(timesteps)
503513
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -506,7 +516,7 @@ def __call__(
506516
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
507517
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
508518

509-
# Concatenate image_latents over channels dimention
519+
# Concatenate image_latents over channels dimension
510520
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
511521

512522
# predict the noise residual

0 commit comments

Comments
 (0)