Skip to content

Commit 1d71f8c

Browse files
committed
Fix compiled pipeline
1 parent 5eb013d commit 1d71f8c

File tree

6 files changed

+239
-81
lines changed

6 files changed

+239
-81
lines changed

models/turbine_models/custom_models/pipeline_base.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def prepare_all(
437437
vmfbs: dict = {},
438438
weights: dict = {},
439439
interactive: bool = False,
440+
num_steps: int = 20,
440441
):
441442
ready = self.is_prepared(vmfbs, weights)
442443
match ready:
@@ -463,7 +464,9 @@ def prepare_all(
463464
if not self.map[submodel].get("weights") and self.map[submodel][
464465
"export_args"
465466
].get("external_weights"):
466-
self.export_submodel(submodel, weights_only=True)
467+
self.export_submodel(
468+
submodel, weights_only=True, num_steps=num_steps
469+
)
467470
return self.prepare_all(mlirs, vmfbs, weights, interactive)
468471

469472
def is_prepared(self, vmfbs, weights):
@@ -581,6 +584,7 @@ def export_submodel(
581584
submodel: str,
582585
input_mlir: str = None,
583586
weights_only: bool = False,
587+
num_steps: int = 20,
584588
):
585589
if not os.path.exists(self.pipeline_dir):
586590
os.makedirs(self.pipeline_dir)
@@ -670,7 +674,9 @@ def export_submodel(
670674
self.map[submodel]["export_args"]["precision"],
671675
self.map[submodel]["export_args"]["batch_size"],
672676
self.map[submodel]["export_args"]["max_length"],
673-
"tokens_to_image",
677+
"produce_img_split",
678+
unet_module_name=self.map["unet"]["module_name"],
679+
num_steps=num_steps,
674680
)
675681
dims = [
676682
self.map[submodel]["export_args"]["width"],
@@ -699,8 +705,8 @@ def export_submodel(
699705
return_path=True,
700706
mlir_source="str",
701707
)
702-
self.map[submodel]["vmfb"] = vmfb_path
703-
self.map[submodel]["weights"] = None
708+
self.map[submodel]["vmfb"] = [vmfb_path]
709+
self.map[submodel]["weights"] = []
704710
case _:
705711
export_args = self.map[submodel].get("export_args", {})
706712
if weights_only:
@@ -721,10 +727,24 @@ def export_submodel(
721727

722728
# LOAD
723729
def load_map(self):
724-
for submodel in self.map.keys():
730+
# Make sure fullpipeline is imported last
731+
submodels = list(self.map.keys() - {"fullpipeline"})
732+
submodels += ["fullpipeline"] if "fullpipeline" in self.map.keys() else []
733+
for submodel in submodels:
725734
if not self.map[submodel]["load"]:
726-
self.printer.print("Skipping load for ", submodel)
735+
self.printer.print(f"Skipping load for {submodel}")
727736
continue
737+
elif self.map[submodel].get("wraps"):
738+
vmfbs = []
739+
weights = []
740+
for wrapped in self.map[submodel]["wraps"]:
741+
vmfbs.append(self.map[wrapped]["vmfb"])
742+
if "weights" in self.map[wrapped]:
743+
weights.append(self.map[wrapped]["weights"])
744+
self.map[submodel]["vmfb"] = vmfbs + self.map[submodel]["vmfb"]
745+
self.map[submodel]["weights"] = weights + self.map[submodel]["weights"]
746+
747+
print(f"Loading {submodel}")
728748
self.load_submodel(submodel)
729749

730750
def load_submodel(self, submodel):
@@ -751,6 +771,10 @@ def load_submodel(self, submodel):
751771

752772
def unload_submodel(self, submodel):
753773
self.map[submodel]["runner"].unload()
774+
self.map[submodel]["vmfb"] = None
775+
self.map[submodel]["mlir"] = None
776+
self.map[submodel]["weights"] = None
777+
self.map[submodel]["export_args"]["input_mlir"] = None
754778
setattr(self, submodel, None)
755779

756780

models/turbine_models/custom_models/sd_inference/sd_pipeline.py

Lines changed: 133 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,11 @@
118118
"decomp_attn": None,
119119
},
120120
},
121-
"unetloop": {
122-
"module_name": "sdxl_compiled_pipeline",
123-
"load": False,
124-
"keywords": ["unetloop"],
125-
"wraps": ["unet", "scheduler"],
126-
"export_args": {
127-
"batch_size": 1,
128-
"height": 1024,
129-
"width": 1024,
130-
"max_length": 64,
131-
},
132-
},
133121
"fullpipeline": {
134122
"module_name": "sdxl_compiled_pipeline",
135-
"load": False,
123+
"load": True,
136124
"keywords": ["fullpipeline"],
137-
"wraps": ["text_encoder", "unet", "scheduler", "vae"],
125+
"wraps": ["unet", "scheduler", "vae"],
138126
"export_args": {
139127
"batch_size": 1,
140128
"height": 1024,
@@ -190,6 +178,7 @@ def get_sd_model_map(hf_model_name):
190178
"stabilityai/sdxl-turbo",
191179
"stabilityai/stable-diffusion-xl-base-1.0",
192180
"/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16/checkpoint_pipe",
181+
"/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe",
193182
]:
194183
return sdxl_model_map
195184
elif "stabilityai/stable-diffusion-3" in name:
@@ -233,6 +222,7 @@ def __init__(
233222
benchmark: bool | dict[bool] = False,
234223
verbose: bool = False,
235224
batch_prompts: bool = False,
225+
compiled_pipeline: bool = False,
236226
):
237227
common_export_args = {
238228
"hf_model_name": None,
@@ -243,11 +233,11 @@ def __init__(
243233
"exit_on_vmfb": False,
244234
"pipeline_dir": pipeline_dir,
245235
"input_mlir": None,
246-
"attn_spec": None,
236+
"attn_spec": attn_spec,
247237
"external_weights": None,
248238
"external_weight_path": None,
249239
}
250-
sd_model_map = get_sd_model_map(hf_model_name)
240+
sd_model_map = copy.deepcopy(get_sd_model_map(hf_model_name))
251241
for submodel in sd_model_map:
252242
if "load" not in sd_model_map[submodel]:
253243
sd_model_map[submodel]["load"] = True
@@ -311,6 +301,7 @@ def __init__(
311301
self.scheduler = None
312302

313303
self.split_scheduler = True
304+
self.compiled_pipeline = compiled_pipeline
314305

315306
self.base_model_name = (
316307
hf_model_name
@@ -321,11 +312,6 @@ def __init__(
321312
self.is_sdxl = "xl" in self.base_model_name.lower()
322313
self.is_sd3 = "stable-diffusion-3" in self.base_model_name
323314
if self.is_sdxl:
324-
if self.split_scheduler:
325-
if self.map.get("unetloop"):
326-
self.map.pop("unetloop")
327-
if self.map.get("fullpipeline"):
328-
self.map.pop("fullpipeline")
329315
self.tokenizers = [
330316
CLIPTokenizer.from_pretrained(
331317
self.base_model_name, subfolder="tokenizer"
@@ -339,6 +325,20 @@ def __init__(
339325
self.scheduler_device = self.map["unet"]["device"]
340326
self.scheduler_driver = self.map["unet"]["driver"]
341327
self.scheduler_target = self.map["unet"]["target"]
328+
if not self.compiled_pipeline:
329+
if self.map.get("unetloop"):
330+
self.map.pop("unetloop")
331+
if self.map.get("fullpipeline"):
332+
self.map.pop("fullpipeline")
333+
elif self.compiled_pipeline:
334+
self.map["unet"]["load"] = False
335+
self.map["vae"]["load"] = False
336+
self.load_scheduler(
337+
scheduler_id,
338+
num_inference_steps,
339+
)
340+
self.map["scheduler"]["runner"].unload()
341+
self.map["scheduler"]["load"] = False
342342
elif not self.is_sd3:
343343
self.tokenizer = CLIPTokenizer.from_pretrained(
344344
self.base_model_name, subfolder="tokenizer"
@@ -351,23 +351,27 @@ def __init__(
351351

352352
self.latents_dtype = torch_dtypes[self.latents_precision]
353353
self.use_i8_punet = self.use_punet = use_i8_punet
354+
if self.use_punet:
355+
self.setup_punet()
356+
else:
357+
self.map["unet"]["keywords"].append("!punet")
358+
self.map["unet"]["function_name"] = "run_forward"
359+
360+
def setup_punet(self):
354361
if self.use_i8_punet:
355362
self.map["unet"]["export_args"]["precision"] = "i8"
356-
self.map["unet"]["export_args"]["use_punet"] = True
357-
self.map["unet"]["use_weights_for_export"] = True
358-
self.map["unet"]["keywords"].append("punet")
359-
self.map["unet"]["module_name"] = "compiled_punet"
360-
self.map["unet"]["function_name"] = "main"
361363
self.map["unet"]["export_args"]["external_weight_path"] = (
362364
utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa"
363365
)
364366
for idx, word in enumerate(self.map["unet"]["keywords"]):
365367
if word in ["fp32", "fp16"]:
366368
self.map["unet"]["keywords"][idx] = "i8"
367369
break
368-
else:
369-
self.map["unet"]["keywords"].append("!punet")
370-
self.map["unet"]["function_name"] = "run_forward"
370+
self.map["unet"]["export_args"]["use_punet"] = True
371+
self.map["unet"]["use_weights_for_export"] = True
372+
self.map["unet"]["keywords"].append("punet")
373+
self.map["unet"]["module_name"] = "compiled_punet"
374+
self.map["unet"]["function_name"] = "main"
371375

372376
# LOAD
373377

@@ -376,10 +380,6 @@ def load_scheduler(
376380
scheduler_id: str,
377381
steps: int = 30,
378382
):
379-
if self.is_sd3:
380-
scheduler_device = self.mmdit.device
381-
else:
382-
scheduler_device = self.unet.device
383383
if not self.cpu_scheduling:
384384
self.map["scheduler"] = {
385385
"module_name": "compiled_scheduler",
@@ -425,7 +425,11 @@ def load_scheduler(
425425
except:
426426
print("JIT export of scheduler failed. Loading CPU scheduler.")
427427
self.cpu_scheduling = True
428-
if self.cpu_scheduling:
428+
elif self.cpu_scheduling:
429+
if self.is_sd3:
430+
scheduler_device = self.mmdit.device
431+
else:
432+
scheduler_device = self.unet.device
429433
scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id)
430434
self.scheduler = schedulers.SharkSchedulerCPUWrapper(
431435
scheduler,
@@ -461,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt):
461465
text_input_ids_list += text_inputs.input_ids.unsqueeze(0)
462466
uncond_input_ids_list += uncond_input.input_ids.unsqueeze(0)
463467

464-
if self.compiled_pipeline:
465-
return text_input_ids_list, uncond_input_ids_list
466-
else:
467-
prompt_embeds, add_text_embeds = self.text_encoder(
468-
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
469-
)
470-
return prompt_embeds, add_text_embeds
468+
prompt_embeds, add_text_embeds = self.text_encoder(
469+
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
470+
)
471+
return prompt_embeds, add_text_embeds
471472

472473
def prepare_latents(
473474
self,
@@ -565,9 +566,11 @@ def _produce_latents_sdxl(
565566
[guidance_scale],
566567
dtype=self.map["unet"]["np_dtype"],
567568
)
569+
# Disable progress bar if we aren't in verbose mode or if we're printing
570+
# benchmark latencies for unet.
568571
for i, t in tqdm(
569572
enumerate(timesteps),
570-
disable=(self.map["unet"].get("benchmark") and self.verbose),
573+
disable=(self.map["unet"].get("benchmark") or not self.verbose),
571574
):
572575
if self.cpu_scheduling:
573576
latent_model_input, t = self.scheduler.scale_model_input(
@@ -608,6 +611,75 @@ def _produce_latents_sdxl(
608611
latents = self.scheduler("run_step", [noise_pred, t, latents])
609612
return latents
610613

614+
def produce_images_compiled(
615+
self,
616+
sample,
617+
prompt_embeds,
618+
text_embeds,
619+
guidance_scale,
620+
):
621+
pipe_inputs = [
622+
sample,
623+
prompt_embeds,
624+
text_embeds,
625+
torch.as_tensor([guidance_scale], dtype=sample.dtype),
626+
]
627+
# image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
628+
image = self.map["fullpipeline"]["runner"]("produce_image_latents", pipe_inputs)
629+
return image
630+
631+
def prepare_sampling_inputs(
632+
self,
633+
prompt: str,
634+
negative_prompt: str = "",
635+
steps: int = 30,
636+
batch_count: int = 1,
637+
guidance_scale: float = 7.5,
638+
seed: float = -1,
639+
cpu_scheduling: bool = True,
640+
scheduler_id: str = "EulerDiscrete",
641+
return_imgs: bool = False,
642+
):
643+
needs_new_scheduler = (
644+
(steps and steps != self.num_inference_steps)
645+
or (cpu_scheduling != self.cpu_scheduling)
646+
and self.split_scheduler
647+
)
648+
if not self.scheduler and not self.compiled_pipeline:
649+
needs_new_scheduler = True
650+
651+
if guidance_scale == 0:
652+
negative_prompt = prompt
653+
prompt = ""
654+
655+
self.cpu_scheduling = cpu_scheduling
656+
if steps and needs_new_scheduler:
657+
self.num_inference_steps = steps
658+
self.load_scheduler(scheduler_id, steps)
659+
660+
pipe_start = time.time()
661+
numpy_images = []
662+
663+
samples = self.get_rand_latents(seed, batch_count)
664+
665+
# Tokenize prompt and negative prompt.
666+
if self.is_sdxl:
667+
prompt_embeds, negative_embeds = self.encode_prompts_sdxl(
668+
prompt, negative_prompt
669+
)
670+
else:
671+
prompt_embeds, negative_embeds = encode_prompt(
672+
self, prompt, negative_prompt
673+
)
674+
produce_latents_input = [
675+
samples[0],
676+
prompt_embeds,
677+
negative_embeds,
678+
steps,
679+
guidance_scale,
680+
]
681+
return produce_latents_input
682+
611683
def generate_images(
612684
self,
613685
prompt: str,
@@ -653,18 +725,23 @@ def generate_images(
653725
)
654726

655727
for i in range(batch_count):
656-
produce_latents_input = [
657-
samples[i],
658-
prompt_embeds,
659-
negative_embeds,
660-
steps,
661-
guidance_scale,
662-
]
663-
if self.is_sdxl:
664-
latents = self._produce_latents_sdxl(*produce_latents_input)
728+
if self.compiled_pipeline:
729+
image = self.produce_images_compiled(
730+
samples[i], prompt_embeds, negative_embeds, guidance_scale
731+
).to_host()
665732
else:
666-
latents = self._produce_latents_sd(*produce_latents_input)
667-
image = self.vae("decode", [latents])
733+
produce_latents_input = [
734+
samples[i],
735+
prompt_embeds,
736+
negative_embeds,
737+
steps,
738+
guidance_scale,
739+
]
740+
if self.is_sdxl:
741+
latents = self._produce_latents_sdxl(*produce_latents_input)
742+
else:
743+
latents = self._produce_latents_sd(*produce_latents_input)
744+
image = self.vae("decode", [latents])
668745
numpy_images.append(image)
669746
pipe_end = time.time()
670747

@@ -750,8 +827,10 @@ def numpy_to_pil_image(images):
750827
args.use_i8_punet,
751828
benchmark,
752829
args.verbose,
830+
False,
831+
args.compiled_pipeline,
753832
)
754-
sd_pipe.prepare_all()
833+
sd_pipe.prepare_all(num_steps=args.num_inference_steps)
755834
sd_pipe.load_map()
756835
sd_pipe.generate_images(
757836
args.prompt,

0 commit comments

Comments
 (0)