Skip to content

Commit 121a409

Browse files
committed
Fix compiled pipeline
1 parent c1e9195 commit 121a409

File tree

6 files changed

+55
-45
lines changed

6 files changed

+55
-45
lines changed

models/turbine_models/custom_models/pipeline_base.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def prepare_all(
437437
vmfbs: dict = {},
438438
weights: dict = {},
439439
interactive: bool = False,
440+
num_steps: int = 20,
440441
):
441442
ready = self.is_prepared(vmfbs, weights)
442443
match ready:
@@ -463,7 +464,9 @@ def prepare_all(
463464
if not self.map[submodel].get("weights") and self.map[submodel][
464465
"export_args"
465466
].get("external_weights"):
466-
self.export_submodel(submodel, weights_only=True)
467+
self.export_submodel(
468+
submodel, weights_only=True, num_steps=num_steps
469+
)
467470
return self.prepare_all(mlirs, vmfbs, weights, interactive)
468471

469472
def is_prepared(self, vmfbs, weights):
@@ -581,6 +584,7 @@ def export_submodel(
581584
submodel: str,
582585
input_mlir: str = None,
583586
weights_only: bool = False,
587+
num_steps: int = 20,
584588
):
585589
if not os.path.exists(self.pipeline_dir):
586590
os.makedirs(self.pipeline_dir)
@@ -671,7 +675,8 @@ def export_submodel(
671675
self.map[submodel]["export_args"]["batch_size"],
672676
self.map[submodel]["export_args"]["max_length"],
673677
"produce_img_split",
674-
unet_module_name = self.map["unet"]["module_name"],
678+
unet_module_name=self.map["unet"]["module_name"],
679+
num_steps=num_steps,
675680
)
676681
dims = [
677682
self.map[submodel]["export_args"]["width"],
@@ -722,15 +727,24 @@ def export_submodel(
722727

723728
# LOAD
724729
def load_map(self):
725-
for submodel in self.map.keys():
730+
# Make sure fullpipeline is imported last
731+
submodels = list(self.map.keys() - {"fullpipeline"})
732+
submodels += ["fullpipeline"] if "fullpipeline" in self.map.keys() else []
733+
for submodel in submodels:
726734
if not self.map[submodel]["load"]:
727735
self.printer.print(f"Skipping load for {submodel}")
728736
continue
729737
elif self.map[submodel].get("wraps"):
738+
vmfbs = []
739+
weights = []
730740
for wrapped in self.map[submodel]["wraps"]:
731-
self.map[submodel]["vmfb"].append(self.map[wrapped]["vmfb"])
732-
self.map[submodel]["weights"].append(self.map[wrapped]["weights"])
741+
vmfbs.append(self.map[wrapped]["vmfb"])
742+
if "weights" in self.map[wrapped]:
743+
weights.append(self.map[wrapped]["weights"])
744+
self.map[submodel]["vmfb"] = vmfbs + self.map[submodel]["vmfb"]
745+
self.map[submodel]["weights"] = weights + self.map[submodel]["weights"]
733746

747+
print(f"Loading {submodel}")
734748
self.load_submodel(submodel)
735749

736750
def load_submodel(self, submodel):

models/turbine_models/custom_models/sd_inference/sd_pipeline.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -465,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt):
465465
text_input_ids_list += text_inputs.input_ids.unsqueeze(0)
466466
uncond_input_ids_list += uncond_input.input_ids.unsqueeze(0)
467467

468-
if self.compiled_pipeline:
469-
return text_input_ids_list, uncond_input_ids_list
470-
else:
471-
prompt_embeds, add_text_embeds = self.text_encoder(
472-
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
473-
)
474-
return prompt_embeds, add_text_embeds
468+
prompt_embeds, add_text_embeds = self.text_encoder(
469+
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
470+
)
471+
return prompt_embeds, add_text_embeds
475472

476473
def prepare_latents(
477474
self,
@@ -615,6 +612,7 @@ def _produce_latents_sdxl(
615612
return latents
616613

617614
def produce_images_compiled(
615+
self,
618616
sample,
619617
prompt_embeds,
620618
text_embeds,
@@ -624,9 +622,11 @@ def produce_images_compiled(
624622
sample,
625623
prompt_embeds,
626624
text_embeds,
627-
guidance_scale,
625+
torch.as_tensor([guidance_scale], dtype=sample.dtype),
628626
]
629-
image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
627+
# image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
628+
image = self.map["fullpipeline"]["runner"]("produce_image_latents", pipe_inputs)
629+
return image
630630

631631
def prepare_sampling_inputs(
632632
self,
@@ -726,12 +726,9 @@ def generate_images(
726726

727727
for i in range(batch_count):
728728
if self.compiled_pipeline:
729-
image = produce_images_compiled(
730-
samples[i],
731-
prompt_embeds,
732-
negative_embeds,
733-
guidance_scale
734-
)
729+
image = self.produce_images_compiled(
730+
samples[i], prompt_embeds, negative_embeds, guidance_scale
731+
).to_host()
735732
else:
736733
produce_latents_input = [
737734
samples[i],
@@ -833,7 +830,7 @@ def numpy_to_pil_image(images):
833830
False,
834831
args.compiled_pipeline,
835832
)
836-
sd_pipe.prepare_all()
833+
sd_pipe.prepare_all(num_steps=args.num_inference_steps)
837834
sd_pipe.load_map()
838835
sd_pipe.generate_images(
839836
args.prompt,

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@
3636
],
3737
"unet": [
3838
"--iree-flow-enable-aggressive-fusion",
39-
"--iree-flow-enable-fuse-horizontal-contractions=true",
39+
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
4040
"--iree-opt-aggressively-propagate-transposes=true",
4141
"--iree-codegen-llvmgpu-use-vector-distribution=true",
4242
],
4343
"clip": [
4444
"--iree-flow-enable-aggressive-fusion",
45-
"--iree-flow-enable-fuse-horizontal-contractions=true",
45+
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
4646
"--iree-opt-aggressively-propagate-transposes=true",
4747
],
4848
"vae": [
@@ -61,7 +61,7 @@
6161
"--iree-opt-const-eval=false",
6262
"--iree-opt-aggressively-propagate-transposes=true",
6363
"--iree-flow-enable-aggressive-fusion",
64-
"--iree-flow-enable-fuse-horizontal-contractions=true",
64+
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
6565
"--iree-codegen-gpu-native-math-precision=true",
6666
"--iree-codegen-llvmgpu-use-vector-distribution=true",
6767
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",

models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
%res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) {{
3939
%step_64 = arith.index_cast %arg0 : index to i64
4040
%this_step = tensor.from_elements %step_64 : tensor<1xi64>
41-
%inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}>
41+
%inner = func.call @compiled_scheduled_unet.run_forward(%arg, %this_step, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}>
4242
scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}>
4343
}}
4444
return %res : tensor<{batch_size}x4x{lw}x{lh}x{precision}>
@@ -48,28 +48,27 @@
4848

4949
produce_img_split = r"""
5050
module @sdxl_compiled_pipeline {{
51-
func.func private @{scheduler_module}.run_initialize(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],f16>, !torch.vtensor<[{num_steps}],f32>) attributes {{torch.assume_strict_symbolic_shapes}}
52-
func.func private @{scheduler_module}.run_scale(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[{num_steps}],f32>) -> (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>) attributes {{torch.assume_strict_symbolic_shapes}}
53-
func.func private @{scheduler_module}.run_step(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, %arg1: !torch.vtensor<[1],{precision}>, %arg2: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
54-
func.func private @{unet_module}.{unet_function}(%arg0: !torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, %arg1: !torch.vtensor<[1],{precision}>, %arg2: !torch.vtensor<[{bd},{max_length},2048],{precision}>, %arg3: !torch.vtensor<[{bd},1280],{precision}>, %arg4: !torch.vtensor<[{bd},6],{precision}>, %arg5: !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
55-
func.func private @{vae_module}.decode(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},3,{height},{width}],{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
56-
57-
func.func @produce_image_latents(%sample: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, %p_embeds: !torch.vtensor<[{bd},{max_length},2048],{precision}>, %t_embeds: !torch.vtensor<[{bd},1280],{precision}>, %guidance_scale: !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},3,{height},{width}],{precision}> {{
58-
%noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{num_steps}],f32>)
51+
func.func private @{scheduler_module}.run_initialize(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1xf16>, tensor<{num_steps}xf32>) attributes {{torch.assume_strict_symbolic_shapes}}
52+
func.func private @{scheduler_module}.run_scale(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1xi64>, %arg2: tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) attributes {{torch.assume_strict_symbolic_shapes}}
53+
func.func private @{scheduler_module}.run_step(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
54+
func.func private @{unet_module}.{unet_function}(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{bd}x{max_length}x2048x{precision}>, %arg3: tensor<{bd}x1280x{precision}>, %arg4: tensor<{bd}x6x{precision}>, %arg5: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
55+
func.func private @{vae_module}.decode(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
56+
57+
func.func @produce_image_latents(%sample: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %p_embeds: tensor<{bd}x{max_length}x2048x{precision}>, %t_embeds: tensor<{bd}x1280x{precision}>, %guidance_scale: tensor<1x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> {{
58+
%noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<{num_steps}xf32>)
5959
%c0 = arith.constant 0 : index
6060
%c1 = arith.constant 1 : index
6161
%n_steps = arith.constant {num_steps} : index
62-
%res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) {{
62+
%res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) {{
6363
%step_64 = arith.index_cast %arg0 : index to i64
6464
%this_step = tensor.from_elements %step_64 : tensor<1xi64>
65-
%step_torch = torch_c.from_builtin_tensor %this_step : tensor<1xi64> -> !torch.vtensor<[1],si64>
66-
%scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %step_torch, %timesteps) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],si64>, !torch.vtensor<[{num_steps}],f32>) -> (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>)
67-
%inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{bd},{max_length},2048],{precision}>, !torch.vtensor<[{bd},1280],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
68-
%pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
69-
scf.yield %pred : !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
65+
%scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %this_step, %timesteps) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1xi64>, tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>)
66+
%inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}>
67+
%pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}>
68+
scf.yield %pred : tensor<{batch_size}x4x{lh}x{lw}x{precision}>
7069
}}
71-
%image = func.call @{vae_module}.decode(%res): (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},3,{height},{width}],{precision}>
72-
return %image : !torch.vtensor<[{batch_size},3,{height},{width}],{precision}>
70+
%image = func.call @{vae_module}.decode(%res): (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}>
71+
return %image : tensor<{batch_size}x3x{height}x{width}x{precision}>
7372
}}
7473
}}
7574
"""
@@ -128,4 +127,4 @@ def get_pipeline_ir(
128127
scheduler_module=scheduler_module_name,
129128
vae_module=vae_module_name,
130129
num_steps=num_steps,
131-
)
130+
)

models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def export_submodel(
480480
self.hf_model_name,
481481
None,
482482
self.max_length,
483+
self.batch_size,
483484
self.precision,
484485
"vmfb",
485486
self.external_weights,
@@ -494,7 +495,6 @@ def export_submodel(
494495
input_mlir=input_mlir["prompt_encoder"],
495496
attn_spec=self.attn_spec,
496497
weights_only=weights_only,
497-
batchsize=self.batch_size,
498498
batch_input=self.batch_prompt_input,
499499
)
500500
return prompt_encoder_vmfb, prompt_encoder_external_weight_path

models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def export_prompt_encoder(
231231
)
232232

233233
if weights_only:
234-
return None, external_weight_path
234+
return external_weight_path
235235

236236
class CompiledClip(CompiledModule):
237237
if external_weights:
@@ -277,7 +277,7 @@ def encode_prompts_turbo(
277277
module_str = str(module)
278278

279279
if compile_to != "vmfb":
280-
return module_str
280+
return module_str, None
281281
else:
282282
vmfb_path = utils.compile_to_vmfb(
283283
module_str,

0 commit comments

Comments
 (0)