17
17
from turbine_models .custom_models .sd_inference import utils
18
18
from turbine_models .model_runner import vmfbRunner
19
19
from transformers import CLIPTokenizer
20
+ from diffusers import FlowMatchEulerDiscreteScheduler
20
21
21
22
from PIL import Image
22
23
import os
@@ -44,10 +45,8 @@ class SharkSD3Pipeline:
44
45
def __init__ (
45
46
self ,
46
47
hf_model_name : str ,
47
- # scheduler_id: str,
48
48
height : int ,
49
49
width : int ,
50
- shift : float ,
51
50
precision : str ,
52
51
max_length : int ,
53
52
batch_size : int ,
@@ -60,9 +59,12 @@ def __init__(
60
59
pipeline_dir : str = "./shark_vmfbs" ,
61
60
external_weights_dir : str = "./shark_weights" ,
62
61
external_weights : str = "safetensors" ,
63
- vae_decomp_attn : bool = True ,
64
- custom_vae : str = "" ,
62
+ vae_decomp_attn : bool = False ,
65
63
cpu_scheduling : bool = False ,
64
+ vae_precision : str = "fp32" ,
65
+ scheduler_id : str = None , #compatibility only, always uses EulerFlowScheduler
66
+ shift : float = 1.0 ,
67
+
66
68
):
67
69
self .hf_model_name = hf_model_name
68
70
# self.scheduler_id = scheduler_id
@@ -120,10 +122,11 @@ def __init__(
120
122
self .external_weights_dir = external_weights_dir
121
123
self .external_weights = external_weights
122
124
self .vae_decomp_attn = vae_decomp_attn
123
- self .custom_vae = custom_vae
125
+ self .custom_vae = None
124
126
self .cpu_scheduling = cpu_scheduling
125
127
self .torch_dtype = torch .float32 if self .precision == "fp32" else torch .float16
126
- self .vae_dtype = torch .float32
128
+ self .vae_precision = vae_precision if vae_precision else self .precision
129
+ self .vae_dtype = torch .float32 if vae_precision == "fp32" else torch .float16
127
130
# TODO: set this based on user-inputted guidance scale and negative prompt.
128
131
self .do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True
129
132
@@ -206,7 +209,12 @@ def is_prepared(self, vmfbs, weights):
206
209
)
207
210
if w_key == "clip" :
208
211
default_name = os .path .join (
209
- self .external_weights_dir , f"sd3_clip_fp16.irpa"
212
+ self .external_weights_dir , f"sd3_text_encoders_{ self .precision } .irpa"
213
+ )
214
+ if w_key == "mmdit" :
215
+ default_name = os .path .join (
216
+ self .external_weights_dir ,
217
+ f"sd3_mmdit_{ self .precision } ." + self .external_weights ,
210
218
)
211
219
if weights [w_key ] is None and os .path .exists (default_name ):
212
220
weights [w_key ] = os .path .join (default_name )
@@ -357,7 +365,7 @@ def export_submodel(
357
365
self .batch_size ,
358
366
self .height ,
359
367
self .width ,
360
- "fp32" ,
368
+ self . vae_precision ,
361
369
"vmfb" ,
362
370
self .external_weights ,
363
371
vae_external_weight_path ,
@@ -419,10 +427,16 @@ def load_pipeline(
419
427
unet_loaded = time .time ()
420
428
print ("\n [LOG] MMDiT loaded in " , unet_loaded - load_start , "sec" )
421
429
422
- runners ["scheduler" ] = sd3_schedulers .SharkSchedulerWrapper (
423
- self .devices ["mmdit" ]["driver" ],
424
- vmfbs ["scheduler" ],
425
- )
430
+ if not self .cpu_scheduling :
431
+ runners ["scheduler" ] = sd3_schedulers .SharkSchedulerWrapper (
432
+ self .devices ["mmdit" ]["driver" ],
433
+ vmfbs ["scheduler" ],
434
+ )
435
+ else :
436
+ print ("Using torch CPU scheduler." )
437
+ runners ["scheduler" ] = FlowMatchEulerDiscreteScheduler .from_pretrained (
438
+ self .hf_model_name , subfolder = "scheduler"
439
+ )
426
440
427
441
sched_loaded = time .time ()
428
442
print ("\n [LOG] Scheduler loaded in " , sched_loaded - unet_loaded , "sec" )
@@ -495,11 +509,12 @@ def generate_images(
495
509
)
496
510
)
497
511
498
- guidance_scale = ireert .asdevicearray (
499
- self .runners ["pipe" ].config .device ,
500
- np .asarray ([guidance_scale ]),
501
- dtype = iree_dtype ,
502
- )
512
+ if not self .cpu_scheduling :
513
+ guidance_scale = ireert .asdevicearray (
514
+ self .runners ["pipe" ].config .device ,
515
+ np .asarray ([guidance_scale ]),
516
+ dtype = iree_dtype ,
517
+ )
503
518
504
519
tokenize_start = time .time ()
505
520
text_input_ids_dict = self .tokenizer .tokenize_with_weights (prompt )
@@ -533,12 +548,23 @@ def generate_images(
533
548
"clip"
534
549
].ctx .modules .compiled_text_encoder ["encode_tokens" ](* text_encoders_inputs )
535
550
encode_prompts_end = time .time ()
551
+ if self .cpu_scheduling :
552
+ timesteps , num_inference_steps = sd3_schedulers .retrieve_timesteps (
553
+ self .runners ["scheduler" ],
554
+ num_inference_steps = self .num_inference_steps ,
555
+ timesteps = None ,
556
+ )
557
+ steps = num_inference_steps
558
+
536
559
537
560
for i in range (batch_count ):
538
561
unet_start = time .time ()
539
- sample , steps , timesteps = self .runners ["scheduler" ].initialize (samples [i ])
562
+ if not self .cpu_scheduling :
563
+ latents , steps , timesteps = self .runners ["scheduler" ].initialize (samples [i ])
564
+ else :
565
+ latents = torch .tensor (samples [i ].to_host (), dtype = self .torch_dtype )
540
566
iree_inputs = [
541
- sample ,
567
+ latents ,
542
568
ireert .asdevicearray (
543
569
self .runners ["pipe" ].config .device , prompt_embeds , dtype = iree_dtype
544
570
),
@@ -553,40 +579,71 @@ def generate_images(
553
579
# print(f"step {s}")
554
580
if self .cpu_scheduling :
555
581
step_index = s
582
+ t = timesteps [s ]
583
+ if self .do_classifier_free_guidance :
584
+ latent_model_input = torch .cat ([latents ] * 2 )
585
+ timestep = ireert .asdevicearray (
586
+ self .runners ["pipe" ].config .device ,
587
+ t .expand (latent_model_input .shape [0 ]),
588
+ dtype = iree_dtype ,
589
+ )
590
+ latent_model_input = ireert .asdevicearray (
591
+ self .runners ["pipe" ].config .device ,
592
+ latent_model_input ,
593
+ dtype = iree_dtype ,
594
+ )
556
595
else :
557
596
step_index = ireert .asdevicearray (
558
597
self .runners ["scheduler" ].runner .config .device ,
559
598
torch .tensor ([s ]),
560
599
"int64" ,
561
600
)
562
- latents , t = self .runners ["scheduler" ].prep (
563
- sample ,
564
- step_index ,
565
- timesteps ,
566
- )
601
+ latent_model_input , timestep = self .runners ["scheduler" ].prep (
602
+ latents ,
603
+ step_index ,
604
+ timesteps ,
605
+ )
606
+ t = ireert .asdevicearray (
607
+ self .runners ["scheduler" ].runner .config .device ,
608
+ timestep .to_host ()[0 ]
609
+ )
567
610
noise_pred = self .runners ["pipe" ].ctx .modules .compiled_mmdit [
568
611
"run_forward"
569
612
](
570
- latents ,
613
+ latent_model_input ,
571
614
iree_inputs [1 ],
572
615
iree_inputs [2 ],
573
- t ,
616
+ timestep ,
574
617
)
575
- sample = self .runners ["scheduler" ].step (
576
- noise_pred ,
577
- t ,
578
- sample ,
579
- guidance_scale ,
580
- step_index ,
581
- )
582
- if isinstance (sample , torch .Tensor ):
618
+ if not self .cpu_scheduling :
619
+ latents = self .runners ["scheduler" ].step (
620
+ noise_pred ,
621
+ t ,
622
+ latents ,
623
+ guidance_scale ,
624
+ step_index ,
625
+ )
626
+ else :
627
+ noise_pred = torch .tensor (noise_pred .to_host (), dtype = self .torch_dtype )
628
+ if self .do_classifier_free_guidance :
629
+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
630
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
631
+ latents = self .runners ["scheduler" ].step (
632
+ noise_pred ,
633
+ t ,
634
+ latents ,
635
+ return_dict = False ,
636
+ )[0 ]
637
+
638
+ if isinstance (latents , torch .Tensor ):
639
+ latents = latents .type (self .vae_dtype )
583
640
latents = ireert .asdevicearray (
584
641
self .runners ["vae" ].config .device ,
585
- sample ,
586
- dtype = self .vae_dtype ,
642
+ latents ,
587
643
)
588
644
else :
589
- latents = sample .astype ("float32" )
645
+ vae_numpy_dtype = np .float32 if self .vae_precision == "fp32" else np .float16
646
+ latents = latents .astype (vae_numpy_dtype )
590
647
591
648
vae_start = time .time ()
592
649
vae_out = self .runners ["vae" ].ctx .modules .compiled_vae ["decode" ](latents )
@@ -634,7 +691,7 @@ def generate_images(
634
691
out_image = Image .fromarray (image )
635
692
images .extend ([[out_image ]])
636
693
if return_imgs :
637
- return images
694
+ return images [ 0 ]
638
695
for idx_batch , image_batch in enumerate (images ):
639
696
for idx , image in enumerate (image_batch ):
640
697
img_path = (
@@ -767,7 +824,6 @@ def run_diffusers_cpu(
767
824
args .hf_model_name ,
768
825
args .height ,
769
826
args .width ,
770
- args .shift ,
771
827
args .precision ,
772
828
args .max_length ,
773
829
args .batch_size ,
@@ -779,16 +835,15 @@ def run_diffusers_cpu(
779
835
args .decomp_attn ,
780
836
args .pipeline_dir ,
781
837
args .external_weights_dir ,
782
- args .external_weights ,
783
- args .vae_decomp_attn ,
784
- custom_vae = None ,
838
+ external_weights = args .external_weights ,
839
+ vae_decomp_attn = args .vae_decomp_attn ,
785
840
cpu_scheduling = args .cpu_scheduling ,
786
841
vae_precision = args .vae_precision ,
787
842
)
788
- vmfbs , weights = sd3_pipe .check_prepared (mlirs , vmfbs , weights )
789
843
if args .cpu_scheduling :
790
844
vmfbs .pop ("scheduler" )
791
845
weights .pop ("scheduler" )
846
+ vmfbs , weights = sd3_pipe .check_prepared (mlirs , vmfbs , weights )
792
847
if args .npu_delegate_path :
793
848
extra_device_args = {"npu_delegate_path" : args .npu_delegate_path }
794
849
else :
0 commit comments