Skip to content

Commit 7754609

Browse files
committed
Fixes for multi-device (SD3)
1 parent b793686 commit 7754609

File tree

3 files changed

+144
-45
lines changed

3 files changed

+144
-45
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,48 @@ def is_valid_file(arg):
177177
help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.",
178178
)
179179

180+
p.add_argument(
181+
"--clip_device",
182+
default=None,
183+
type=str,
184+
help="Device to run CLIP on. If None, defaults to the device specified in args.device.",
185+
)
186+
187+
p.add_argument(
188+
"--mmdit_device",
189+
default=None,
190+
type=str,
191+
help="Device to run MMDiT on. If None, defaults to the device specified in args.device.",
192+
)
193+
194+
p.add_argument(
195+
"--vae_device",
196+
default=None,
197+
type=str,
198+
help="Device to run VAE on. If None, defaults to the device specified in args.device.",
199+
)
200+
201+
p.add_argument(
202+
"--clip_target",
203+
default=None,
204+
type=str,
205+
help="IREE target for CLIP compilation. If None, defaults to the target specified by --iree_target_triple.",
206+
)
207+
208+
p.add_argument(
209+
"--mmdit_target",
210+
default=None,
211+
type=str,
212+
help="IREE target for mmdit compilation. If None, defaults to the target specified by --iree_target_triple.",
213+
)
214+
215+
p.add_argument(
216+
"--vae_target",
217+
default=None,
218+
type=str,
219+
help="IREE target for vae compilation. If None, defaults to the target specified by --iree_target_triple.",
220+
)
221+
180222
##############################################################################
181223
# SD3 Modelling Options
182224
# These options are used to control model defining parameters for SD3.

models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,11 @@
2525
import copy
2626
from datetime import datetime as dt
2727

28-
device_list = [
29-
"cpu",
30-
"vulkan",
31-
"cuda",
32-
"rocm",
33-
]
34-
35-
rt_device_list = [
36-
"local-task",
37-
"local-sync",
38-
"vulkan",
39-
"cuda",
40-
"rocm",
41-
"hip",
42-
]
43-
4428
empty_pipe_dict = {
45-
"vae": None,
46-
"text_encoders": None,
29+
"clip": None,
4730
"mmdit": None,
4831
"scheduler": None,
32+
"vae": None,
4933
}
5034

5135
EMPTY_FLAGS = {
@@ -90,24 +74,40 @@ def __init__(
9074
self.batch_size = batch_size
9175
self.num_inference_steps = num_inference_steps
9276
self.devices = {}
93-
if isinstance(self.device, dict):
77+
if isinstance(device, dict):
9478
assert isinstance(iree_target_triple, dict), "Device and target triple must be both dicts or both strings."
9579
self.devices["clip"] = {
9680
"device": device["clip"],
81+
"driver": utils.iree_device_map(device["clip"]),
9782
"target": iree_target_triple["clip"]
9883
}
9984
self.devices["mmdit"] = {
10085
"device": device["mmdit"],
86+
"driver": utils.iree_device_map(device["mmdit"]),
10187
"target": iree_target_triple["mmdit"]
10288
}
10389
self.devices["vae"] = {
10490
"device": device["vae"],
91+
"driver": utils.iree_device_map(device["vae"]),
10592
"target": iree_target_triple["vae"]
10693
}
10794
else:
108-
self.devices["clip"] = device
109-
self.devices["mmdit"] = device
110-
self.devices["vae"] = device
95+
assert isinstance(iree_target_triple, str), "Device and target triple must be both dicts or both strings."
96+
self.devices["clip"] = {
97+
"device": device,
98+
"driver": utils.iree_device_map(device),
99+
"target": iree_target_triple
100+
}
101+
self.devices["mmdit"] = {
102+
"device": device,
103+
"driver": utils.iree_device_map(device),
104+
"target": iree_target_triple
105+
}
106+
self.devices["vae"] = {
107+
"device": device,
108+
"driver": utils.iree_device_map(device),
109+
"target": iree_target_triple
110+
}
111111
self.iree_target_triple = iree_target_triple
112112
self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS
113113
self.attn_spec = attn_spec
@@ -176,6 +176,9 @@ def is_prepared(self, vmfbs, weights):
176176
val = None
177177
default_filepath = None
178178
continue
179+
elif key == "clip":
180+
val = "text_encoders"
181+
default_filepath = os.path.join(self.pipeline_dir, val + ".vmfb")
179182
else:
180183
val = vmfbs[key]
181184
default_filepath = os.path.join(self.pipeline_dir, key + ".vmfb")
@@ -197,7 +200,7 @@ def is_prepared(self, vmfbs, weights):
197200
default_name = os.path.join(
198201
self.external_weights_dir, w_key + "." + self.external_weights
199202
)
200-
if w_key == "text_encoders":
203+
if w_key == "clip":
201204
default_name = os.path.join(
202205
self.external_weights_dir, f"sd3_clip_fp16.irpa"
203206
)
@@ -287,7 +290,7 @@ def export_submodel(
287290
if weights_only:
288291
input_mlir = {
289292
"vae": None,
290-
"text_encoders": None,
293+
"clip": None,
291294
"mmdit": None,
292295
"scheduler": None,
293296
}
@@ -366,7 +369,7 @@ def export_submodel(
366369
)
367370
del vae_torch
368371
return vae_vmfb, vae_external_weight_path
369-
case "text_encoders":
372+
case "clip":
370373
_, text_encoders_vmfb = sd3_text_encoders.export_text_encoders(
371374
self.hf_model_name,
372375
None,
@@ -380,7 +383,7 @@ def export_submodel(
380383
self.ireec_flags["clip"],
381384
exit_on_vmfb=False,
382385
pipeline_dir=self.pipeline_dir,
383-
input_mlir=input_mlir["text_encoders"],
386+
input_mlir=input_mlir["clip"],
384387
attn_spec=self.attn_spec,
385388
output_batchsize=self.batch_size,
386389
)
@@ -392,7 +395,6 @@ def load_pipeline(
392395
self,
393396
vmfbs: dict,
394397
weights: dict,
395-
rt_device: str | dict[str],
396398
compiled_pipeline: bool = False,
397399
split_scheduler: bool = True,
398400
extra_device_args: dict = {},
@@ -401,35 +403,37 @@ def load_pipeline(
401403
delegate = extra_device_args["npu_delegate_path"]
402404
else:
403405
delegate = None
406+
404407
self.runners = {}
405408
runners = {}
406409
load_start = time.time()
407410
runners["pipe"] = vmfbRunner(
408-
rt_device,
411+
self.devices["mmdit"]["driver"],
409412
vmfbs["mmdit"],
410413
weights["mmdit"],
411414
)
412415
unet_loaded = time.time()
413416
print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec")
414417

415418
runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
416-
rt_device,
419+
self.devices["mmdit"]["driver"],
417420
vmfbs["scheduler"],
418421
)
419422

420423
sched_loaded = time.time()
421424
print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec")
422425
runners["vae"] = vmfbRunner(
423-
rt_device,
426+
self.devices["vae"]["driver"],
424427
vmfbs["vae"],
425-
weights["vae"],
428+
weights["vae"],
429+
extra_plugin=delegate,
426430
)
427431
vae_loaded = time.time()
428432
print("\n[LOG] VAE Decode loaded in ", vae_loaded - sched_loaded, "sec")
429-
runners["text_encoders"] = vmfbRunner(
430-
rt_device,
431-
vmfbs["text_encoders"],
432-
weights["text_encoders"],
433+
runners["clip"] = vmfbRunner(
434+
self.devices["clip"]["driver"],
435+
vmfbs["clip"],
436+
weights["clip"],
433437
)
434438
clip_loaded = time.time()
435439
print("\n[LOG] Text Encoders loaded in ", clip_loaded - vae_loaded, "sec")
@@ -500,29 +504,29 @@ def generate_images(
500504
uncond_input_ids_list = list(uncond_input_ids_dict.values())
501505
text_encoders_inputs = [
502506
ireert.asdevicearray(
503-
self.runners["text_encoders"].config.device, text_input_ids_list[0]
507+
self.runners["clip"].config.device, text_input_ids_list[0]
504508
),
505509
ireert.asdevicearray(
506-
self.runners["text_encoders"].config.device, text_input_ids_list[1]
510+
self.runners["clip"].config.device, text_input_ids_list[1]
507511
),
508512
ireert.asdevicearray(
509-
self.runners["text_encoders"].config.device, text_input_ids_list[2]
513+
self.runners["clip"].config.device, text_input_ids_list[2]
510514
),
511515
ireert.asdevicearray(
512-
self.runners["text_encoders"].config.device, uncond_input_ids_list[0]
516+
self.runners["clip"].config.device, uncond_input_ids_list[0]
513517
),
514518
ireert.asdevicearray(
515-
self.runners["text_encoders"].config.device, uncond_input_ids_list[1]
519+
self.runners["clip"].config.device, uncond_input_ids_list[1]
516520
),
517521
ireert.asdevicearray(
518-
self.runners["text_encoders"].config.device, uncond_input_ids_list[2]
522+
self.runners["clip"].config.device, uncond_input_ids_list[2]
519523
),
520524
]
521525

522526
# Tokenize prompt and negative prompt.
523527
encode_prompts_start = time.time()
524528
prompt_embeds, pooled_prompt_embeds = self.runners[
525-
"text_encoders"
529+
"clip"
526530
].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs)
527531
encode_prompts_end = time.time()
528532

@@ -690,6 +694,34 @@ def run_diffusers_cpu(
690694
mlirs = copy.deepcopy(map)
691695
vmfbs = copy.deepcopy(map)
692696
weights = copy.deepcopy(map)
697+
698+
if any(x for x in [args.clip_device, args.mmdit_device, args.vae_device]):
699+
assert all(
700+
x for x in [args.clip_device, args.mmdit_device, args.vae_device]
701+
), "Please specify device for all submodels or pass --device for all submodels."
702+
assert all(
703+
x for x in [args.clip_target, args.mmdit_target, args.vae_target]
704+
), "Please specify target triple for all submodels or pass --iree_target_triple for all submodels."
705+
args.device = "hybrid"
706+
args.iree_target_triple = "_".join([args.clip_target, args.mmdit_target, args.vae_target])
707+
else:
708+
args.clip_device = args.device
709+
args.mmdit_device = args.device
710+
args.vae_device = args.device
711+
args.clip_target = args.iree_target_triple
712+
args.mmdit_target = args.iree_target_triple
713+
args.vae_target = args.iree_target_triple
714+
715+
devices = {
716+
"clip": args.clip_device,
717+
"mmdit": args.mmdit_device,
718+
"vae": args.vae_device,
719+
}
720+
targets = {
721+
"clip": args.clip_target,
722+
"mmdit": args.mmdit_target,
723+
"vae": args.vae_target,
724+
}
693725
ireec_flags = {
694726
"clip": args.ireec_flags + args.clip_flags,
695727
"mmdit": args.ireec_flags + args.unet_flags,
@@ -705,6 +737,7 @@ def run_diffusers_cpu(
705737
str(args.max_length),
706738
args.precision,
707739
args.device,
740+
args.iree_target_triple,
708741
]
709742
if args.decomp_attn:
710743
pipe_id_list.append("decomp")
@@ -730,8 +763,8 @@ def run_diffusers_cpu(
730763
args.max_length,
731764
args.batch_size,
732765
args.num_inference_steps,
733-
args.device,
734-
args.iree_target_triple,
766+
devices,
767+
targets,
735768
ireec_flags,
736769
args.attn_spec,
737770
args.decomp_attn,
@@ -747,7 +780,7 @@ def run_diffusers_cpu(
747780
vmfbs.pop("scheduler")
748781
weights.pop("scheduler")
749782
sd3_pipe.load_pipeline(
750-
vmfbs, weights, args.rt_device, args.compiled_pipeline, args.split_scheduler
783+
vmfbs, weights, args.compiled_pipeline, args.split_scheduler
751784
)
752785
sd3_pipe.generate_images(
753786
args.prompt,

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@
1212
# DPMSolverSDEScheduler,
1313
)
1414

15+
_IREE_DEVICE_MAP = {
16+
"cpu": "local-task",
17+
"cpu-task": "local-task",
18+
"cpu-sync": "local-sync",
19+
"cuda": "cuda",
20+
"vulkan": "vulkan",
21+
"metal": "metal",
22+
"rocm": "rocm",
23+
"hip": "hip",
24+
"intel-gpu": "level_zero",
25+
}
1526
# If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument.
1627
MI_flags = {
1728
"all": [
@@ -81,6 +92,19 @@
8192
],
8293
}
8394

95+
def iree_device_map(device):
96+
uri_parts = device.split("://", 2)
97+
iree_driver = (
98+
_IREE_DEVICE_MAP[uri_parts[0]]
99+
if uri_parts[0] in _IREE_DEVICE_MAP
100+
else uri_parts[0]
101+
)
102+
if len(uri_parts) == 1:
103+
return iree_driver
104+
elif "rocm" in uri_parts:
105+
return "rocm"
106+
else:
107+
return f"{iree_driver}://{uri_parts[1]}"
84108

85109
def compile_to_vmfb(
86110
module_str,

0 commit comments

Comments
 (0)