Skip to content

Commit c1e9195

Browse files
committed
compiled_pipeline general support and split inference methods
1 parent 02705a9 commit c1e9195

File tree

3 files changed

+173
-43
lines changed

3 files changed

+173
-43
lines changed

models/turbine_models/custom_models/pipeline_base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,8 @@ def export_submodel(
670670
self.map[submodel]["export_args"]["precision"],
671671
self.map[submodel]["export_args"]["batch_size"],
672672
self.map[submodel]["export_args"]["max_length"],
673-
"tokens_to_image",
673+
"produce_img_split",
674+
unet_module_name = self.map["unet"]["module_name"],
674675
)
675676
dims = [
676677
self.map[submodel]["export_args"]["width"],
@@ -699,8 +700,8 @@ def export_submodel(
699700
return_path=True,
700701
mlir_source="str",
701702
)
702-
self.map[submodel]["vmfb"] = vmfb_path
703-
self.map[submodel]["weights"] = None
703+
self.map[submodel]["vmfb"] = [vmfb_path]
704+
self.map[submodel]["weights"] = []
704705
case _:
705706
export_args = self.map[submodel].get("export_args", {})
706707
if weights_only:
@@ -725,6 +726,11 @@ def load_map(self):
725726
if not self.map[submodel]["load"]:
726727
self.printer.print(f"Skipping load for {submodel}")
727728
continue
729+
elif self.map[submodel].get("wraps"):
730+
for wrapped in self.map[submodel]["wraps"]:
731+
self.map[submodel]["vmfb"].append(self.map[wrapped]["vmfb"])
732+
self.map[submodel]["weights"].append(self.map[wrapped]["weights"])
733+
728734
self.load_submodel(submodel)
729735

730736
def load_submodel(self, submodel):

models/turbine_models/custom_models/sd_inference/sd_pipeline.py

Lines changed: 110 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,11 @@
118118
"decomp_attn": None,
119119
},
120120
},
121-
"unetloop": {
122-
"module_name": "sdxl_compiled_pipeline",
123-
"load": False,
124-
"keywords": ["unetloop"],
125-
"wraps": ["unet", "scheduler"],
126-
"export_args": {
127-
"batch_size": 1,
128-
"height": 1024,
129-
"width": 1024,
130-
"max_length": 64,
131-
},
132-
},
133121
"fullpipeline": {
134122
"module_name": "sdxl_compiled_pipeline",
135-
"load": False,
123+
"load": True,
136124
"keywords": ["fullpipeline"],
137-
"wraps": ["text_encoder", "unet", "scheduler", "vae"],
125+
"wraps": ["unet", "scheduler", "vae"],
138126
"export_args": {
139127
"batch_size": 1,
140128
"height": 1024,
@@ -234,6 +222,7 @@ def __init__(
234222
benchmark: bool | dict[bool] = False,
235223
verbose: bool = False,
236224
batch_prompts: bool = False,
225+
compiled_pipeline: bool = False,
237226
):
238227
common_export_args = {
239228
"hf_model_name": None,
@@ -312,6 +301,7 @@ def __init__(
312301
self.scheduler = None
313302

314303
self.split_scheduler = True
304+
self.compiled_pipeline = compiled_pipeline
315305

316306
self.base_model_name = (
317307
hf_model_name
@@ -322,11 +312,6 @@ def __init__(
322312
self.is_sdxl = "xl" in self.base_model_name.lower()
323313
self.is_sd3 = "stable-diffusion-3" in self.base_model_name
324314
if self.is_sdxl:
325-
if self.split_scheduler:
326-
if self.map.get("unetloop"):
327-
self.map.pop("unetloop")
328-
if self.map.get("fullpipeline"):
329-
self.map.pop("fullpipeline")
330315
self.tokenizers = [
331316
CLIPTokenizer.from_pretrained(
332317
self.base_model_name, subfolder="tokenizer"
@@ -340,6 +325,20 @@ def __init__(
340325
self.scheduler_device = self.map["unet"]["device"]
341326
self.scheduler_driver = self.map["unet"]["driver"]
342327
self.scheduler_target = self.map["unet"]["target"]
328+
if not self.compiled_pipeline:
329+
if self.map.get("unetloop"):
330+
self.map.pop("unetloop")
331+
if self.map.get("fullpipeline"):
332+
self.map.pop("fullpipeline")
333+
elif self.compiled_pipeline:
334+
self.map["unet"]["load"] = False
335+
self.map["vae"]["load"] = False
336+
self.load_scheduler(
337+
scheduler_id,
338+
num_inference_steps,
339+
)
340+
self.map["scheduler"]["runner"].unload()
341+
self.map["scheduler"]["load"] = False
343342
elif not self.is_sd3:
344343
self.tokenizer = CLIPTokenizer.from_pretrained(
345344
self.base_model_name, subfolder="tokenizer"
@@ -381,10 +380,6 @@ def load_scheduler(
381380
scheduler_id: str,
382381
steps: int = 30,
383382
):
384-
if self.is_sd3:
385-
scheduler_device = self.mmdit.device
386-
else:
387-
scheduler_device = self.unet.device
388383
if not self.cpu_scheduling:
389384
self.map["scheduler"] = {
390385
"module_name": "compiled_scheduler",
@@ -430,7 +425,11 @@ def load_scheduler(
430425
except:
431426
print("JIT export of scheduler failed. Loading CPU scheduler.")
432427
self.cpu_scheduling = True
433-
if self.cpu_scheduling:
428+
elif self.cpu_scheduling:
429+
if self.is_sd3:
430+
scheduler_device = self.mmdit.device
431+
else:
432+
scheduler_device = self.unet.device
434433
scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id)
435434
self.scheduler = schedulers.SharkSchedulerCPUWrapper(
436435
scheduler,
@@ -615,6 +614,72 @@ def _produce_latents_sdxl(
615614
latents = self.scheduler("run_step", [noise_pred, t, latents])
616615
return latents
617616

617+
def produce_images_compiled(
618+
sample,
619+
prompt_embeds,
620+
text_embeds,
621+
guidance_scale,
622+
):
623+
pipe_inputs = [
624+
sample,
625+
prompt_embeds,
626+
text_embeds,
627+
guidance_scale,
628+
]
629+
image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
630+
631+
def prepare_sampling_inputs(
632+
self,
633+
prompt: str,
634+
negative_prompt: str = "",
635+
steps: int = 30,
636+
batch_count: int = 1,
637+
guidance_scale: float = 7.5,
638+
seed: float = -1,
639+
cpu_scheduling: bool = True,
640+
scheduler_id: str = "EulerDiscrete",
641+
return_imgs: bool = False,
642+
):
643+
needs_new_scheduler = (
644+
(steps and steps != self.num_inference_steps)
645+
or (cpu_scheduling != self.cpu_scheduling)
646+
and self.split_scheduler
647+
)
648+
if not self.scheduler and not self.compiled_pipeline:
649+
needs_new_scheduler = True
650+
651+
if guidance_scale == 0:
652+
negative_prompt = prompt
653+
prompt = ""
654+
655+
self.cpu_scheduling = cpu_scheduling
656+
if steps and needs_new_scheduler:
657+
self.num_inference_steps = steps
658+
self.load_scheduler(scheduler_id, steps)
659+
660+
pipe_start = time.time()
661+
numpy_images = []
662+
663+
samples = self.get_rand_latents(seed, batch_count)
664+
665+
# Tokenize prompt and negative prompt.
666+
if self.is_sdxl:
667+
prompt_embeds, negative_embeds = self.encode_prompts_sdxl(
668+
prompt, negative_prompt
669+
)
670+
else:
671+
prompt_embeds, negative_embeds = encode_prompt(
672+
self, prompt, negative_prompt
673+
)
674+
produce_latents_input = [
675+
samples[0],
676+
prompt_embeds,
677+
negative_embeds,
678+
steps,
679+
guidance_scale,
680+
]
681+
return produce_latents_input
682+
618683
def generate_images(
619684
self,
620685
prompt: str,
@@ -660,18 +725,26 @@ def generate_images(
660725
)
661726

662727
for i in range(batch_count):
663-
produce_latents_input = [
664-
samples[i],
665-
prompt_embeds,
666-
negative_embeds,
667-
steps,
668-
guidance_scale,
669-
]
670-
if self.is_sdxl:
671-
latents = self._produce_latents_sdxl(*produce_latents_input)
728+
if self.compiled_pipeline:
729+
image = produce_images_compiled(
730+
samples[i],
731+
prompt_embeds,
732+
negative_embeds,
733+
guidance_scale
734+
)
672735
else:
673-
latents = self._produce_latents_sd(*produce_latents_input)
674-
image = self.vae("decode", [latents])
736+
produce_latents_input = [
737+
samples[i],
738+
prompt_embeds,
739+
negative_embeds,
740+
steps,
741+
guidance_scale,
742+
]
743+
if self.is_sdxl:
744+
latents = self._produce_latents_sdxl(*produce_latents_input)
745+
else:
746+
latents = self._produce_latents_sd(*produce_latents_input)
747+
image = self.vae("decode", [latents])
675748
numpy_images.append(image)
676749
pipe_end = time.time()
677750

@@ -757,6 +830,8 @@ def numpy_to_pil_image(images):
757830
args.use_i8_punet,
758831
benchmark,
759832
args.verbose,
833+
False,
834+
args.compiled_pipeline,
760835
)
761836
sd_pipe.prepare_all()
762837
sd_pipe.load_map()

0 commit comments

Comments
 (0)