-
Notifications
You must be signed in to change notification settings - Fork 7k
Expand file tree
/
Copy pathbefore_denoise.py
More file actions
1041 lines (869 loc) · 39.8 KB
/
before_denoise.py
File metadata and controls
1041 lines (869 loc) · 39.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils.torch_utils import randn_tensor, unwrap_module
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(scheduler, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
if hasattr(scheduler, "set_begin_index"):
scheduler.set_begin_index(t_start * scheduler.order)
return timesteps, num_inference_steps - t_start
# ====================
# 1. PREPARE LATENTS
# ====================
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Prepare initial random noise for the generation process"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.latents(),
InputParam.height(),
InputParam.width(),
InputParam.num_images_per_prompt(),
InputParam.generator(),
InputParam(
name="batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam(
name="dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs, can be generated in input step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process",
),
]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
if height is not None and height % (vae_scale_factor * 2) != 0:
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
if width is not None and width % (vae_scale_factor * 2) != 0:
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
height=block_state.height,
width=block_state.width,
vae_scale_factor=components.vae_scale_factor,
)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
# we can update the height and width here since it's used to generate the initial
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width)
if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if block_state.latents is None:
block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)
self.set_block_state(state, block_state)
return components, state
class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
model_name = "qwenimage-layered"
@property
def description(self) -> str:
return "Prepare initial random noise (B, layers+1, C, H, W) for the generation process"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.latents(),
InputParam.height(),
InputParam.width(),
InputParam(
name="layers", type_hint=int, default=4, description="Number of layers to extract from the image"
),
InputParam.num_images_per_prompt(),
InputParam.generator(),
InputParam(
name="batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam(
name="dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs, can be generated in input step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process",
),
]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
if height is not None and height % (vae_scale_factor * 2) != 0:
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
if width is not None and width % (vae_scale_factor * 2) != 0:
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
height=block_state.height,
width=block_state.width,
vae_scale_factor=components.vae_scale_factor,
)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
# we can update the height and width here since it's used to generate the initial
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
shape = (batch_size, block_state.layers + 1, components.num_channels_latents, latent_height, latent_width)
if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if block_state.latents is None:
block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)
self.set_block_state(state, block_state)
return components, state
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The initial random noised, can be generated in prepare latent step.",
),
InputParam(
name="image_latents",
required=True,
type_hint=torch.Tensor,
description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
),
InputParam(
name="timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="initial_noise",
type_hint=torch.Tensor,
description="The initial random noised used for inpainting denoising.",
),
]
@staticmethod
def check_inputs(image_latents, latents):
if image_latents.shape[0] != latents.shape[0]:
raise ValueError(
f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
)
if image_latents.ndim != 3:
raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
image_latents=block_state.image_latents,
latents=block_state.latents,
)
# prepare latent timestep
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
# make copy of initial_noise
block_state.initial_noise = block_state.latents
# scale noise
block_state.latents = components.scheduler.scale_noise(
block_state.image_latents, latent_timestep, block_state.latents
)
self.set_block_state(state, block_state)
return components, state
class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
name="processed_mask_image",
required=True,
type_hint=torch.Tensor,
description="The processed mask to use for the inpainting process.",
),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="dtype", required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process."
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
block_state.mask = torch.nn.functional.interpolate(
block_state.processed_mask_image,
size=(height_latents, width_latents),
)
block_state.mask = block_state.mask.unsqueeze(2)
block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1)
block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype)
block_state.mask = components.pachifier.pack_latents(block_state.mask)
self.set_block_state(state, block_state)
return components, state
# ====================
# 2. SET TIMESTEPS
# ====================
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.num_inference_steps(),
InputParam.sigmas(),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process, used to calculate the image sequence length.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
sigmas = (
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
if block_state.sigmas is None
else block_state.sigmas
)
mu = calculate_shift(
image_seq_len=block_state.latents.shape[1],
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
base_shift=components.scheduler.config.get("base_shift", 0.5),
max_shift=components.scheduler.config.get("max_shift", 1.15),
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
scheduler=components.scheduler,
num_inference_steps=block_state.num_inference_steps,
device=device,
sigmas=sigmas,
mu=mu,
)
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
return components, state
class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
model_name = "qwenimage-layered"
@property
def description(self) -> str:
return "Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.num_inference_steps(),
InputParam.sigmas(),
InputParam("image_latents", required=True, type_hint=torch.Tensor),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="timesteps", type_hint=torch.Tensor),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
# Layered-specific mu calculation
base_seqlen = 256 * 256 / 16 / 16 # = 256
mu = (block_state.image_latents.shape[1] / base_seqlen) ** 0.5
# Default sigmas if not provided
sigmas = (
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
if block_state.sigmas is None
else block_state.sigmas
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
return components, state
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.num_inference_steps(),
InputParam.sigmas(),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process, used to calculate the image sequence length.",
),
InputParam.strength(0.9),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="timesteps",
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
sigmas = (
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
if block_state.sigmas is None
else block_state.sigmas
)
mu = calculate_shift(
image_seq_len=block_state.latents.shape[1],
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
base_shift=components.scheduler.config.get("base_shift", 0.5),
max_shift=components.scheduler.config.get("max_shift", 1.15),
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
scheduler=components.scheduler,
num_inference_steps=block_state.num_inference_steps,
device=device,
sigmas=sigmas,
mu=mu,
)
block_state.timesteps, block_state.num_inference_steps = get_timesteps(
scheduler=components.scheduler,
num_inference_steps=block_state.num_inference_steps,
strength=block_state.strength,
)
self.set_block_state(state, block_state)
return components, state
# ====================
# 3. OTHER INPUTS FOR DENOISER
# ====================
## RoPE inputs for denoiser
class QwenImageRoPEInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return (
"Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
kwargs_type="denoiser_input_fields",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.img_shapes = [
[
(
1,
block_state.height // components.vae_scale_factor // 2,
block_state.width // components.vae_scale_factor // 2,
)
]
] * block_state.batch_size
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
self.set_block_state(state, block_state)
return components, state
class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="image_height", required=True),
InputParam(name="image_width", required=True),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
kwargs_type="denoiser_input_fields",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# for edit, image size can be different from the target size (height/width)
block_state.img_shapes = [
[
(
1,
block_state.height // components.vae_scale_factor // 2,
block_state.width // components.vae_scale_factor // 2,
),
(
1,
block_state.image_height // components.vae_scale_factor // 2,
block_state.image_width // components.vae_scale_factor // 2,
),
]
] * block_state.batch_size
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
self.set_block_state(state, block_state)
return components, state
class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
model_name = "qwenimage-edit-plus"
@property
def description(self) -> str:
return (
"Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n"
"Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n"
"Should be placed after prepare_latents step."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="image_height", required=True, type_hint=List[int]),
InputParam(name="image_width", required=True, type_hint=List[int]),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
kwargs_type="denoiser_input_fields",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the image latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_scale_factor = components.vae_scale_factor
# Edit Plus: image_height and image_width are lists
block_state.img_shapes = [
[
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
*[
(1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2)
for img_height, img_width in zip(block_state.image_height, block_state.image_width)
],
]
] * block_state.batch_size
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
self.set_block_state(state, block_state)
return components, state
class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
model_name = "qwenimage-layered"
@property
def description(self) -> str:
return (
"Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="layers", default=4, description="Number of layers to extract from the image"),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
type_hint=List[List[Tuple[int, int, int]]],
kwargs_type="denoiser_input_fields",
description="The shapes of the image latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
type_hint=List[int],
kwargs_type="denoiser_input_fields",
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
type_hint=List[int],
kwargs_type="denoiser_input_fields",
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
OutputParam(
name="additional_t_cond",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="The additional t cond, used for RoPE calculation",
),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
# All shapes are the same for Layered
shape = (
1,
block_state.height // components.vae_scale_factor // 2,
block_state.width // components.vae_scale_factor // 2,
)
# layers+1 output shapes + 1 condition shape (all same)
block_state.img_shapes = [[shape] * (block_state.layers + 2)] * block_state.batch_size
# txt_seq_lens
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
block_state.additional_t_cond = torch.tensor([0] * block_state.batch_size).to(device=device, dtype=torch.long)
self.set_block_state(state, block_state)
return components, state
## ControlNet inputs for denoiser
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("controlnet", QwenImageControlNetModel),
]
@property
def description(self) -> str:
return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.control_guidance_start(),
InputParam.control_guidance_end(),
InputParam.controlnet_conditioning_scale(),
InputParam("control_image_latents", required=True),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
controlnet = unwrap_module(components.controlnet)
# control_guidance_start/control_guidance_end (align format)