Skip to content

Commit 878449e

Browse files
committed
Fix compiled pipeline
1 parent 02705a9 commit 878449e

File tree

6 files changed

+195
-59
lines changed

6 files changed

+195
-59
lines changed

models/turbine_models/custom_models/pipeline_base.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def prepare_all(
437437
vmfbs: dict = {},
438438
weights: dict = {},
439439
interactive: bool = False,
440+
num_steps: int = 20,
440441
):
441442
ready = self.is_prepared(vmfbs, weights)
442443
match ready:
@@ -463,7 +464,7 @@ def prepare_all(
463464
if not self.map[submodel].get("weights") and self.map[submodel][
464465
"export_args"
465466
].get("external_weights"):
466-
self.export_submodel(submodel, weights_only=True)
467+
self.export_submodel(submodel, weights_only=True, num_steps=num_steps)
467468
return self.prepare_all(mlirs, vmfbs, weights, interactive)
468469

469470
def is_prepared(self, vmfbs, weights):
@@ -581,6 +582,7 @@ def export_submodel(
581582
submodel: str,
582583
input_mlir: str = None,
583584
weights_only: bool = False,
585+
num_steps: int = 20,
584586
):
585587
if not os.path.exists(self.pipeline_dir):
586588
os.makedirs(self.pipeline_dir)
@@ -670,7 +672,9 @@ def export_submodel(
670672
self.map[submodel]["export_args"]["precision"],
671673
self.map[submodel]["export_args"]["batch_size"],
672674
self.map[submodel]["export_args"]["max_length"],
673-
"tokens_to_image",
675+
"produce_img_split",
676+
unet_module_name=self.map["unet"]["module_name"],
677+
num_steps=num_steps,
674678
)
675679
dims = [
676680
self.map[submodel]["export_args"]["width"],
@@ -699,8 +703,8 @@ def export_submodel(
699703
return_path=True,
700704
mlir_source="str",
701705
)
702-
self.map[submodel]["vmfb"] = vmfb_path
703-
self.map[submodel]["weights"] = None
706+
self.map[submodel]["vmfb"] = [vmfb_path]
707+
self.map[submodel]["weights"] = []
704708
case _:
705709
export_args = self.map[submodel].get("export_args", {})
706710
if weights_only:
@@ -721,10 +725,24 @@ def export_submodel(
721725

722726
# LOAD
723727
def load_map(self):
724-
for submodel in self.map.keys():
728+
# Make sure fullpipeline is imported last
729+
submodels = list(self.map.keys() - {"fullpipeline"})
730+
submodels += ["fullpipeline"] if "fullpipeline" in self.map.keys() else []
731+
for submodel in submodels:
725732
if not self.map[submodel]["load"]:
726733
self.printer.print(f"Skipping load for {submodel}")
727734
continue
735+
elif self.map[submodel].get("wraps"):
736+
vmfbs = []
737+
weights = []
738+
for wrapped in self.map[submodel]["wraps"]:
739+
vmfbs.append(self.map[wrapped]["vmfb"])
740+
if "weights" in self.map[wrapped]:
741+
weights.append(self.map[wrapped]["weights"])
742+
self.map[submodel]["vmfb"] = vmfbs + self.map[submodel]["vmfb"]
743+
self.map[submodel]["weights"] = weights + self.map[submodel]["weights"]
744+
745+
print(f"Loading {submodel}")
728746
self.load_submodel(submodel)
729747

730748
def load_submodel(self, submodel):

models/turbine_models/custom_models/sd_inference/sd_pipeline.py

Lines changed: 113 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,11 @@
118118
"decomp_attn": None,
119119
},
120120
},
121-
"unetloop": {
122-
"module_name": "sdxl_compiled_pipeline",
123-
"load": False,
124-
"keywords": ["unetloop"],
125-
"wraps": ["unet", "scheduler"],
126-
"export_args": {
127-
"batch_size": 1,
128-
"height": 1024,
129-
"width": 1024,
130-
"max_length": 64,
131-
},
132-
},
133121
"fullpipeline": {
134122
"module_name": "sdxl_compiled_pipeline",
135-
"load": False,
123+
"load": True,
136124
"keywords": ["fullpipeline"],
137-
"wraps": ["text_encoder", "unet", "scheduler", "vae"],
125+
"wraps": ["unet", "scheduler", "vae"],
138126
"export_args": {
139127
"batch_size": 1,
140128
"height": 1024,
@@ -234,6 +222,7 @@ def __init__(
234222
benchmark: bool | dict[bool] = False,
235223
verbose: bool = False,
236224
batch_prompts: bool = False,
225+
compiled_pipeline: bool = False,
237226
):
238227
common_export_args = {
239228
"hf_model_name": None,
@@ -312,6 +301,7 @@ def __init__(
312301
self.scheduler = None
313302

314303
self.split_scheduler = True
304+
self.compiled_pipeline = compiled_pipeline
315305

316306
self.base_model_name = (
317307
hf_model_name
@@ -322,11 +312,6 @@ def __init__(
322312
self.is_sdxl = "xl" in self.base_model_name.lower()
323313
self.is_sd3 = "stable-diffusion-3" in self.base_model_name
324314
if self.is_sdxl:
325-
if self.split_scheduler:
326-
if self.map.get("unetloop"):
327-
self.map.pop("unetloop")
328-
if self.map.get("fullpipeline"):
329-
self.map.pop("fullpipeline")
330315
self.tokenizers = [
331316
CLIPTokenizer.from_pretrained(
332317
self.base_model_name, subfolder="tokenizer"
@@ -340,6 +325,20 @@ def __init__(
340325
self.scheduler_device = self.map["unet"]["device"]
341326
self.scheduler_driver = self.map["unet"]["driver"]
342327
self.scheduler_target = self.map["unet"]["target"]
328+
if not self.compiled_pipeline:
329+
if self.map.get("unetloop"):
330+
self.map.pop("unetloop")
331+
if self.map.get("fullpipeline"):
332+
self.map.pop("fullpipeline")
333+
elif self.compiled_pipeline:
334+
self.map["unet"]["load"] = False
335+
self.map["vae"]["load"] = False
336+
self.load_scheduler(
337+
scheduler_id,
338+
num_inference_steps,
339+
)
340+
self.map["scheduler"]["runner"].unload()
341+
self.map["scheduler"]["load"] = False
343342
elif not self.is_sd3:
344343
self.tokenizer = CLIPTokenizer.from_pretrained(
345344
self.base_model_name, subfolder="tokenizer"
@@ -381,10 +380,6 @@ def load_scheduler(
381380
scheduler_id: str,
382381
steps: int = 30,
383382
):
384-
if self.is_sd3:
385-
scheduler_device = self.mmdit.device
386-
else:
387-
scheduler_device = self.unet.device
388383
if not self.cpu_scheduling:
389384
self.map["scheduler"] = {
390385
"module_name": "compiled_scheduler",
@@ -430,7 +425,11 @@ def load_scheduler(
430425
except:
431426
print("JIT export of scheduler failed. Loading CPU scheduler.")
432427
self.cpu_scheduling = True
433-
if self.cpu_scheduling:
428+
elif self.cpu_scheduling:
429+
if self.is_sd3:
430+
scheduler_device = self.mmdit.device
431+
else:
432+
scheduler_device = self.unet.device
434433
scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id)
435434
self.scheduler = schedulers.SharkSchedulerCPUWrapper(
436435
scheduler,
@@ -466,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt):
466465
text_input_ids_list += text_inputs.input_ids.unsqueeze(0)
467466
uncond_input_ids_list += uncond_input.input_ids.unsqueeze(0)
468467

469-
if self.compiled_pipeline:
470-
return text_input_ids_list, uncond_input_ids_list
471-
else:
472-
prompt_embeds, add_text_embeds = self.text_encoder(
473-
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
474-
)
475-
return prompt_embeds, add_text_embeds
468+
prompt_embeds, add_text_embeds = self.text_encoder(
469+
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
470+
)
471+
return prompt_embeds, add_text_embeds
476472

477473
def prepare_latents(
478474
self,
@@ -615,6 +611,75 @@ def _produce_latents_sdxl(
615611
latents = self.scheduler("run_step", [noise_pred, t, latents])
616612
return latents
617613

614+
def produce_images_compiled(
615+
self,
616+
sample,
617+
prompt_embeds,
618+
text_embeds,
619+
guidance_scale,
620+
):
621+
pipe_inputs = [
622+
sample,
623+
prompt_embeds,
624+
text_embeds,
625+
torch.as_tensor([guidance_scale], dtype=sample.dtype),
626+
]
627+
#image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
628+
image = self.map["fullpipeline"]["runner"]("produce_image_latents", pipe_inputs)
629+
return image
630+
631+
def prepare_sampling_inputs(
632+
self,
633+
prompt: str,
634+
negative_prompt: str = "",
635+
steps: int = 30,
636+
batch_count: int = 1,
637+
guidance_scale: float = 7.5,
638+
seed: float = -1,
639+
cpu_scheduling: bool = True,
640+
scheduler_id: str = "EulerDiscrete",
641+
return_imgs: bool = False,
642+
):
643+
needs_new_scheduler = (
644+
(steps and steps != self.num_inference_steps)
645+
or (cpu_scheduling != self.cpu_scheduling)
646+
and self.split_scheduler
647+
)
648+
if not self.scheduler and not self.compiled_pipeline:
649+
needs_new_scheduler = True
650+
651+
if guidance_scale == 0:
652+
negative_prompt = prompt
653+
prompt = ""
654+
655+
self.cpu_scheduling = cpu_scheduling
656+
if steps and needs_new_scheduler:
657+
self.num_inference_steps = steps
658+
self.load_scheduler(scheduler_id, steps)
659+
660+
pipe_start = time.time()
661+
numpy_images = []
662+
663+
samples = self.get_rand_latents(seed, batch_count)
664+
665+
# Tokenize prompt and negative prompt.
666+
if self.is_sdxl:
667+
prompt_embeds, negative_embeds = self.encode_prompts_sdxl(
668+
prompt, negative_prompt
669+
)
670+
else:
671+
prompt_embeds, negative_embeds = encode_prompt(
672+
self, prompt, negative_prompt
673+
)
674+
produce_latents_input = [
675+
samples[0],
676+
prompt_embeds,
677+
negative_embeds,
678+
steps,
679+
guidance_scale,
680+
]
681+
return produce_latents_input
682+
618683
def generate_images(
619684
self,
620685
prompt: str,
@@ -660,18 +725,21 @@ def generate_images(
660725
)
661726

662727
for i in range(batch_count):
663-
produce_latents_input = [
664-
samples[i],
665-
prompt_embeds,
666-
negative_embeds,
667-
steps,
668-
guidance_scale,
669-
]
670-
if self.is_sdxl:
671-
latents = self._produce_latents_sdxl(*produce_latents_input)
728+
if self.compiled_pipeline:
729+
image = self.produce_images_compiled(samples[i], prompt_embeds, negative_embeds, guidance_scale).to_host()
672730
else:
673-
latents = self._produce_latents_sd(*produce_latents_input)
674-
image = self.vae("decode", [latents])
731+
produce_latents_input = [
732+
samples[i],
733+
prompt_embeds,
734+
negative_embeds,
735+
steps,
736+
guidance_scale,
737+
]
738+
if self.is_sdxl:
739+
latents = self._produce_latents_sdxl(*produce_latents_input)
740+
else:
741+
latents = self._produce_latents_sd(*produce_latents_input)
742+
image = self.vae("decode", [latents])
675743
numpy_images.append(image)
676744
pipe_end = time.time()
677745

@@ -757,8 +825,10 @@ def numpy_to_pil_image(images):
757825
args.use_i8_punet,
758826
benchmark,
759827
args.verbose,
828+
False,
829+
args.compiled_pipeline,
760830
)
761-
sd_pipe.prepare_all()
831+
sd_pipe.prepare_all(num_steps=args.num_inference_steps)
762832
sd_pipe.load_map()
763833
sd_pipe.generate_images(
764834
args.prompt,

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@
3636
],
3737
"unet": [
3838
"--iree-flow-enable-aggressive-fusion",
39-
"--iree-flow-enable-fuse-horizontal-contractions=true",
39+
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
4040
"--iree-opt-aggressively-propagate-transposes=true",
4141
"--iree-codegen-llvmgpu-use-vector-distribution=true",
4242
],
4343
"clip": [
4444
"--iree-flow-enable-aggressive-fusion",
45-
"--iree-flow-enable-fuse-horizontal-contractions=true",
45+
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
4646
"--iree-opt-aggressively-propagate-transposes=true",
4747
],
4848
"vae": [
@@ -61,7 +61,7 @@
6161
"--iree-opt-const-eval=false",
6262
"--iree-opt-aggressively-propagate-transposes=true",
6363
"--iree-flow-enable-aggressive-fusion",
64-
"--iree-flow-enable-fuse-horizontal-contractions=true",
64+
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
6565
"--iree-codegen-gpu-native-math-precision=true",
6666
"--iree-codegen-llvmgpu-use-vector-distribution=true",
6767
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",

0 commit comments

Comments
 (0)