Skip to content

Commit 9d641f5

Browse files
authored
feat(diffusers/pipelines): add wan pipeline (#1021)
1 parent 1a2dc78 commit 9d641f5

25 files changed

+4958
-6
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AutoencoderKLWan
13+
14+
The 3D variational autoencoder (VAE) model with KL loss used in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from mindone.diffusers import AutoencoderKLWan
20+
21+
vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", mindspore_dtype=ms.float32)
22+
```
23+
24+
::: mindone.diffusers.AutoencoderKLWan
25+
26+
::: mindone.diffusers.models.autoencoders.vae.DecoderOutput
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# WanTransformer3DModel
13+
14+
A Diffusion Transformer model for 3D video-like data was introduced in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from mindone.diffusers import WanTransformer3DModel
20+
21+
transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer", mindspore_dtype=ms.bfloat16)
22+
```
23+
24+
::: mindone.diffusers.WanTransformer3DModel
25+
26+
::: mindone.diffusers.models.modeling_outputs.Transformer2DModelOutput
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License. -->
14+
15+
# Wan
16+
17+
<div class="flex flex-wrap space-x-1">
18+
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
19+
</div>
20+
21+
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
22+
23+
<!-- TODO(aryan): update abstract once paper is out -->
24+
25+
## Generating Videos with Wan 2.1
26+
27+
We will first need to install some addtional dependencies.
28+
29+
```shell
30+
pip install -u ftfy imageio-ffmpeg imageio
31+
```
32+
33+
### Text to Video Generation
34+
35+
The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
36+
for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
37+
38+
```python
39+
from mindone.diffusers import WanPipeline
40+
from mindone.diffusers.utils import export_to_video
41+
import mindspore as ms
42+
43+
# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
44+
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
45+
46+
pipe = WanPipeline.from_pretrained(model_id, mindspore_dtype=ms.bfloat16)
47+
48+
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
49+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
50+
num_frames = 33
51+
52+
frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames)[0][0]
53+
export_to_video(frames, "wan-t2v.mp4", fps=16)
54+
```
55+
56+
!!! tip
57+
58+
You can improve the quality of the generated video by running the decoding step in full precision.
59+
60+
```python
61+
from mindone.diffusers import WanPipeline, AutoencoderKLWan
62+
from mindone.diffusers.utils import export_to_video
63+
import mindspore as ms
64+
65+
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
66+
67+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", mindspore_dtype=ms.float32)
68+
pipe = WanPipeline.from_pretrained(model_id, vae=vae, mindspore_dtype=ms.bfloat16)
69+
70+
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
71+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
72+
num_frames = 33
73+
74+
frames = pipe(prompt=prompt, num_frames=num_frames)[0][0]
75+
export_to_video(frames, "wan-t2v.mp4", fps=16)
76+
```
77+
78+
### Image to Video Generation
79+
80+
The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
81+
35GB of VRAM to run.
82+
83+
```python
84+
import mindspore as ms
85+
import numpy as np
86+
from mindone.diffusers import AutoencoderKLWan, WanImageToVideoPipeline
87+
from mindone.diffusers.utils import export_to_video, load_image
88+
from mindone.transformers import CLIPVisionModel
89+
90+
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
91+
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
92+
image_encoder = CLIPVisionModel.from_pretrained(
93+
model_id, subfolder="image_encoder", mindspore_dtype=ms.float32
94+
)
95+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", mindspore_dtype=ms.float32)
96+
pipe = WanImageToVideoPipeline.from_pretrained(
97+
model_id, vae=vae, image_encoder=image_encoder, mindspore_dtype=ms.bfloat16
98+
)
99+
100+
image = load_image(
101+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
102+
)
103+
104+
max_area = 480 * 832
105+
aspect_ratio = image.height / image.width
106+
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
107+
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
108+
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
109+
image = image.resize((width, height))
110+
111+
prompt = (
112+
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
113+
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
114+
)
115+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
116+
117+
num_frames = 33
118+
119+
output = pipe(
120+
image=image,
121+
prompt=prompt,
122+
negative_prompt=negative_prompt,
123+
height=height,
124+
width=width,
125+
num_frames=num_frames,
126+
guidance_scale=5.0,
127+
)[0][0]
128+
export_to_video(output, "wan-i2v.mp4", fps=16)
129+
```
130+
131+
### Video to Video Generation
132+
133+
```python
134+
import mindspore as ms
135+
from mindone.diffusers.utils import load_video, export_to_video
136+
from mindone.diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
137+
138+
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
139+
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
140+
vae = AutoencoderKLWan.from_pretrained(
141+
model_id, subfolder="vae", mindspore_dtype=ms.float32
142+
)
143+
pipe = WanVideoToVideoPipeline.from_pretrained(
144+
model_id, vae=vae, mindspore_dtype=ms.bfloat16
145+
)
146+
flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
147+
pipe.scheduler = UniPCMultistepScheduler.from_config(
148+
pipe.scheduler.config, flow_shift=flow_shift
149+
)
150+
151+
prompt = "A robot standing on a mountain top. The sun is setting in the background"
152+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
153+
video = load_video(
154+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
155+
)
156+
output = pipe(
157+
video=video,
158+
prompt=prompt,
159+
negative_prompt=negative_prompt,
160+
height=480,
161+
width=512,
162+
guidance_scale=7.0,
163+
strength=0.7,
164+
)[0][0]
165+
166+
export_to_video(output, "wan-v2v.mp4", fps=16)
167+
```
168+
169+
## Using Single File Loading with Wan 2.1
170+
171+
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
172+
method.
173+
174+
```python
175+
import mindspore as ms
176+
from mindone.diffusers import WanPipeline, WanTransformer3DModel
177+
178+
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
179+
transformer = WanTransformer3DModel.from_single_file(ckpt_path, mindspore_dtype=ms.bfloat16)
180+
181+
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
182+
```
183+
184+
## Recommendations for Inference
185+
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
186+
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
187+
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://mindspore-lab.github.io/mindone/latest/diffusers/api/schedulers/flow_match_euler_discrete/#mindone.diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
188+
189+
::: mindone.diffusers.WanPipeline
190+
191+
::: mindone.diffusers.WanImageToVideoPipeline
192+
193+
::: mindone.diffusers.WanVideoToVideoPipeline
194+
195+
::: mindone.diffusers.pipelines.wan.pipeline_output.WanPipelineOutput

mindone/diffusers/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"AutoencoderKLMagvit",
2828
"AutoencoderKLMochi",
2929
"AutoencoderKLTemporalDecoder",
30+
"AutoencoderKLWan",
3031
"AutoencoderOobleck",
3132
"AutoencoderTiny",
3233
"CogVideoXTransformer3DModel",
@@ -77,6 +78,7 @@
7778
"UNetSpatioTemporalConditionModel",
7879
"UVit2DModel",
7980
"VQModel",
81+
"WanTransformer3DModel",
8082
],
8183
"optimization": [
8284
"get_constant_schedule",
@@ -264,6 +266,9 @@
264266
"UniDiffuserPipeline",
265267
"UniDiffuserTextDecoder",
266268
"VideoToVideoSDPipeline",
269+
"WanImageToVideoPipeline",
270+
"WanPipeline",
271+
"WanVideoToVideoPipeline",
267272
"WuerstchenCombinedPipeline",
268273
"WuerstchenDecoderPipeline",
269274
"WuerstchenPriorPipeline",
@@ -329,6 +334,7 @@
329334
AutoencoderKLMagvit,
330335
AutoencoderKLMochi,
331336
AutoencoderKLTemporalDecoder,
337+
AutoencoderKLWan,
332338
AutoencoderOobleck,
333339
AutoencoderTiny,
334340
CogVideoXTransformer3DModel,
@@ -379,6 +385,7 @@
379385
UNetSpatioTemporalConditionModel,
380386
UVit2DModel,
381387
VQModel,
388+
WanTransformer3DModel,
382389
)
383390
from .optimization import (
384391
get_constant_schedule,
@@ -566,6 +573,9 @@
566573
UniDiffuserPipeline,
567574
UniDiffuserTextDecoder,
568575
VideoToVideoSDPipeline,
576+
WanImageToVideoPipeline,
577+
WanPipeline,
578+
WanVideoToVideoPipeline,
569579
WuerstchenCombinedPipeline,
570580
WuerstchenDecoderPipeline,
571581
WuerstchenPriorPipeline,

mindone/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def text_encoder_attn_modules(text_encoder):
7070
"HunyuanVideoLoraLoaderMixin",
7171
"SanaLoraLoaderMixin",
7272
"Lumina2LoraLoaderMixin",
73+
"WanLoraLoaderMixin",
7374
],
7475
"peft": ["PeftAdapterMixin"],
7576
"single_file": ["FromSingleFileMixin"],
@@ -93,6 +94,7 @@ def text_encoder_attn_modules(text_encoder):
9394
SD3LoraLoaderMixin,
9495
StableDiffusionLoraLoaderMixin,
9596
StableDiffusionXLLoraLoaderMixin,
97+
WanLoraLoaderMixin,
9698
)
9799
from .peft import PeftAdapterMixin
98100
from .single_file import FromSingleFileMixin

mindone/diffusers/loaders/lora_conversion_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,3 +1288,56 @@ def process_block(prefix, index, convert_norm):
12881288
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
12891289

12901290
return converted_state_dict
1291+
1292+
1293+
def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1294+
converted_state_dict = {}
1295+
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1296+
1297+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
1298+
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
1299+
1300+
for i in range(num_blocks):
1301+
# Self-attention
1302+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1303+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
1304+
f"blocks.{i}.self_attn.{o}.lora_A.weight"
1305+
)
1306+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
1307+
f"blocks.{i}.self_attn.{o}.lora_B.weight"
1308+
)
1309+
1310+
# Cross-attention
1311+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1312+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1313+
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1314+
)
1315+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1316+
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1317+
)
1318+
1319+
if is_i2v_lora:
1320+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1321+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1322+
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1323+
)
1324+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1325+
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1326+
)
1327+
1328+
# FFN
1329+
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1330+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
1331+
f"blocks.{i}.{o}.lora_A.weight"
1332+
)
1333+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
1334+
f"blocks.{i}.{o}.lora_B.weight"
1335+
)
1336+
1337+
if len(original_state_dict) > 0:
1338+
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
1339+
1340+
for key in list(converted_state_dict.keys()):
1341+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1342+
1343+
return converted_state_dict

0 commit comments

Comments
 (0)