25
25
import copy
26
26
from datetime import datetime as dt
27
27
28
- device_list = [
29
- "cpu" ,
30
- "vulkan" ,
31
- "cuda" ,
32
- "rocm" ,
33
- ]
34
-
35
- rt_device_list = [
36
- "local-task" ,
37
- "local-sync" ,
38
- "vulkan" ,
39
- "cuda" ,
40
- "rocm" ,
41
- "hip" ,
42
- ]
43
-
44
28
empty_pipe_dict = {
45
- "vae" : None ,
46
- "text_encoders" : None ,
29
+ "clip" : None ,
47
30
"mmdit" : None ,
48
31
"scheduler" : None ,
32
+ "vae" : None ,
49
33
}
50
34
51
35
EMPTY_FLAGS = {
@@ -90,24 +74,40 @@ def __init__(
90
74
self .batch_size = batch_size
91
75
self .num_inference_steps = num_inference_steps
92
76
self .devices = {}
93
- if isinstance (self . device , dict ):
77
+ if isinstance (device , dict ):
94
78
assert isinstance (iree_target_triple , dict ), "Device and target triple must be both dicts or both strings."
95
79
self .devices ["clip" ] = {
96
80
"device" : device ["clip" ],
81
+ "driver" : utils .iree_device_map (device ["clip" ]),
97
82
"target" : iree_target_triple ["clip" ]
98
83
}
99
84
self .devices ["mmdit" ] = {
100
85
"device" : device ["mmdit" ],
86
+ "driver" : utils .iree_device_map (device ["mmdit" ]),
101
87
"target" : iree_target_triple ["mmdit" ]
102
88
}
103
89
self .devices ["vae" ] = {
104
90
"device" : device ["vae" ],
91
+ "driver" : utils .iree_device_map (device ["vae" ]),
105
92
"target" : iree_target_triple ["vae" ]
106
93
}
107
94
else :
108
- self .devices ["clip" ] = device
109
- self .devices ["mmdit" ] = device
110
- self .devices ["vae" ] = device
95
+ assert isinstance (iree_target_triple , str ), "Device and target triple must be both dicts or both strings."
96
+ self .devices ["clip" ] = {
97
+ "device" : device ,
98
+ "driver" : utils .iree_device_map (device ),
99
+ "target" : iree_target_triple
100
+ }
101
+ self .devices ["mmdit" ] = {
102
+ "device" : device ,
103
+ "driver" : utils .iree_device_map (device ),
104
+ "target" : iree_target_triple
105
+ }
106
+ self .devices ["vae" ] = {
107
+ "device" : device ,
108
+ "driver" : utils .iree_device_map (device ),
109
+ "target" : iree_target_triple
110
+ }
111
111
self .iree_target_triple = iree_target_triple
112
112
self .ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS
113
113
self .attn_spec = attn_spec
@@ -176,6 +176,9 @@ def is_prepared(self, vmfbs, weights):
176
176
val = None
177
177
default_filepath = None
178
178
continue
179
+ elif key == "clip" :
180
+ val = "text_encoders"
181
+ default_filepath = os .path .join (self .pipeline_dir , val + ".vmfb" )
179
182
else :
180
183
val = vmfbs [key ]
181
184
default_filepath = os .path .join (self .pipeline_dir , key + ".vmfb" )
@@ -197,7 +200,7 @@ def is_prepared(self, vmfbs, weights):
197
200
default_name = os .path .join (
198
201
self .external_weights_dir , w_key + "." + self .external_weights
199
202
)
200
- if w_key == "text_encoders " :
203
+ if w_key == "clip " :
201
204
default_name = os .path .join (
202
205
self .external_weights_dir , f"sd3_clip_fp16.irpa"
203
206
)
@@ -287,7 +290,7 @@ def export_submodel(
287
290
if weights_only :
288
291
input_mlir = {
289
292
"vae" : None ,
290
- "text_encoders " : None ,
293
+ "clip " : None ,
291
294
"mmdit" : None ,
292
295
"scheduler" : None ,
293
296
}
@@ -366,7 +369,7 @@ def export_submodel(
366
369
)
367
370
del vae_torch
368
371
return vae_vmfb , vae_external_weight_path
369
- case "text_encoders " :
372
+ case "clip " :
370
373
_ , text_encoders_vmfb = sd3_text_encoders .export_text_encoders (
371
374
self .hf_model_name ,
372
375
None ,
@@ -380,7 +383,7 @@ def export_submodel(
380
383
self .ireec_flags ["clip" ],
381
384
exit_on_vmfb = False ,
382
385
pipeline_dir = self .pipeline_dir ,
383
- input_mlir = input_mlir ["text_encoders " ],
386
+ input_mlir = input_mlir ["clip " ],
384
387
attn_spec = self .attn_spec ,
385
388
output_batchsize = self .batch_size ,
386
389
)
@@ -392,7 +395,6 @@ def load_pipeline(
392
395
self ,
393
396
vmfbs : dict ,
394
397
weights : dict ,
395
- rt_device : str | dict [str ],
396
398
compiled_pipeline : bool = False ,
397
399
split_scheduler : bool = True ,
398
400
extra_device_args : dict = {},
@@ -401,35 +403,37 @@ def load_pipeline(
401
403
delegate = extra_device_args ["npu_delegate_path" ]
402
404
else :
403
405
delegate = None
406
+
404
407
self .runners = {}
405
408
runners = {}
406
409
load_start = time .time ()
407
410
runners ["pipe" ] = vmfbRunner (
408
- rt_device ,
411
+ self . devices [ "mmdit" ][ "driver" ] ,
409
412
vmfbs ["mmdit" ],
410
413
weights ["mmdit" ],
411
414
)
412
415
unet_loaded = time .time ()
413
416
print ("\n [LOG] MMDiT loaded in " , unet_loaded - load_start , "sec" )
414
417
415
418
runners ["scheduler" ] = sd3_schedulers .SharkSchedulerWrapper (
416
- rt_device ,
419
+ self . devices [ "mmdit" ][ "driver" ] ,
417
420
vmfbs ["scheduler" ],
418
421
)
419
422
420
423
sched_loaded = time .time ()
421
424
print ("\n [LOG] Scheduler loaded in " , sched_loaded - unet_loaded , "sec" )
422
425
runners ["vae" ] = vmfbRunner (
423
- rt_device ,
426
+ self . devices [ "vae" ][ "driver" ] ,
424
427
vmfbs ["vae" ],
425
- weights ["vae" ],
428
+ weights ["vae" ],
429
+ extra_plugin = delegate ,
426
430
)
427
431
vae_loaded = time .time ()
428
432
print ("\n [LOG] VAE Decode loaded in " , vae_loaded - sched_loaded , "sec" )
429
- runners ["text_encoders " ] = vmfbRunner (
430
- rt_device ,
431
- vmfbs ["text_encoders " ],
432
- weights ["text_encoders " ],
433
+ runners ["clip " ] = vmfbRunner (
434
+ self . devices [ "clip" ][ "driver" ] ,
435
+ vmfbs ["clip " ],
436
+ weights ["clip " ],
433
437
)
434
438
clip_loaded = time .time ()
435
439
print ("\n [LOG] Text Encoders loaded in " , clip_loaded - vae_loaded , "sec" )
@@ -500,29 +504,29 @@ def generate_images(
500
504
uncond_input_ids_list = list (uncond_input_ids_dict .values ())
501
505
text_encoders_inputs = [
502
506
ireert .asdevicearray (
503
- self .runners ["text_encoders " ].config .device , text_input_ids_list [0 ]
507
+ self .runners ["clip " ].config .device , text_input_ids_list [0 ]
504
508
),
505
509
ireert .asdevicearray (
506
- self .runners ["text_encoders " ].config .device , text_input_ids_list [1 ]
510
+ self .runners ["clip " ].config .device , text_input_ids_list [1 ]
507
511
),
508
512
ireert .asdevicearray (
509
- self .runners ["text_encoders " ].config .device , text_input_ids_list [2 ]
513
+ self .runners ["clip " ].config .device , text_input_ids_list [2 ]
510
514
),
511
515
ireert .asdevicearray (
512
- self .runners ["text_encoders " ].config .device , uncond_input_ids_list [0 ]
516
+ self .runners ["clip " ].config .device , uncond_input_ids_list [0 ]
513
517
),
514
518
ireert .asdevicearray (
515
- self .runners ["text_encoders " ].config .device , uncond_input_ids_list [1 ]
519
+ self .runners ["clip " ].config .device , uncond_input_ids_list [1 ]
516
520
),
517
521
ireert .asdevicearray (
518
- self .runners ["text_encoders " ].config .device , uncond_input_ids_list [2 ]
522
+ self .runners ["clip " ].config .device , uncond_input_ids_list [2 ]
519
523
),
520
524
]
521
525
522
526
# Tokenize prompt and negative prompt.
523
527
encode_prompts_start = time .time ()
524
528
prompt_embeds , pooled_prompt_embeds = self .runners [
525
- "text_encoders "
529
+ "clip "
526
530
].ctx .modules .compiled_text_encoder ["encode_tokens" ](* text_encoders_inputs )
527
531
encode_prompts_end = time .time ()
528
532
@@ -690,6 +694,34 @@ def run_diffusers_cpu(
690
694
mlirs = copy .deepcopy (map )
691
695
vmfbs = copy .deepcopy (map )
692
696
weights = copy .deepcopy (map )
697
+
698
+ if any (x for x in [args .clip_device , args .mmdit_device , args .vae_device ]):
699
+ assert all (
700
+ x for x in [args .clip_device , args .mmdit_device , args .vae_device ]
701
+ ), "Please specify device for all submodels or pass --device for all submodels."
702
+ assert all (
703
+ x for x in [args .clip_target , args .mmdit_target , args .vae_target ]
704
+ ), "Please specify target triple for all submodels or pass --iree_target_triple for all submodels."
705
+ args .device = "hybrid"
706
+ args .iree_target_triple = "_" .join ([args .clip_target , args .mmdit_target , args .vae_target ])
707
+ else :
708
+ args .clip_device = args .device
709
+ args .mmdit_device = args .device
710
+ args .vae_device = args .device
711
+ args .clip_target = args .iree_target_triple
712
+ args .mmdit_target = args .iree_target_triple
713
+ args .vae_target = args .iree_target_triple
714
+
715
+ devices = {
716
+ "clip" : args .clip_device ,
717
+ "mmdit" : args .mmdit_device ,
718
+ "vae" : args .vae_device ,
719
+ }
720
+ targets = {
721
+ "clip" : args .clip_target ,
722
+ "mmdit" : args .mmdit_target ,
723
+ "vae" : args .vae_target ,
724
+ }
693
725
ireec_flags = {
694
726
"clip" : args .ireec_flags + args .clip_flags ,
695
727
"mmdit" : args .ireec_flags + args .unet_flags ,
@@ -705,6 +737,7 @@ def run_diffusers_cpu(
705
737
str (args .max_length ),
706
738
args .precision ,
707
739
args .device ,
740
+ args .iree_target_triple ,
708
741
]
709
742
if args .decomp_attn :
710
743
pipe_id_list .append ("decomp" )
@@ -730,8 +763,8 @@ def run_diffusers_cpu(
730
763
args .max_length ,
731
764
args .batch_size ,
732
765
args .num_inference_steps ,
733
- args . device ,
734
- args . iree_target_triple ,
766
+ devices ,
767
+ targets ,
735
768
ireec_flags ,
736
769
args .attn_spec ,
737
770
args .decomp_attn ,
@@ -747,7 +780,7 @@ def run_diffusers_cpu(
747
780
vmfbs .pop ("scheduler" )
748
781
weights .pop ("scheduler" )
749
782
sd3_pipe .load_pipeline (
750
- vmfbs , weights , args .rt_device , args . compiled_pipeline , args .split_scheduler
783
+ vmfbs , weights , args .compiled_pipeline , args .split_scheduler
751
784
)
752
785
sd3_pipe .generate_images (
753
786
args .prompt ,
0 commit comments