118
118
"decomp_attn" : None ,
119
119
},
120
120
},
121
- "unetloop" : {
122
- "module_name" : "sdxl_compiled_pipeline" ,
123
- "load" : False ,
124
- "keywords" : ["unetloop" ],
125
- "wraps" : ["unet" , "scheduler" ],
126
- "export_args" : {
127
- "batch_size" : 1 ,
128
- "height" : 1024 ,
129
- "width" : 1024 ,
130
- "max_length" : 64 ,
131
- },
132
- },
133
121
"fullpipeline" : {
134
122
"module_name" : "sdxl_compiled_pipeline" ,
135
- "load" : False ,
123
+ "load" : True ,
136
124
"keywords" : ["fullpipeline" ],
137
- "wraps" : ["text_encoder" , " unet" , "scheduler" , "vae" ],
125
+ "wraps" : ["unet" , "scheduler" , "vae" ],
138
126
"export_args" : {
139
127
"batch_size" : 1 ,
140
128
"height" : 1024 ,
@@ -190,6 +178,7 @@ def get_sd_model_map(hf_model_name):
190
178
"stabilityai/sdxl-turbo" ,
191
179
"stabilityai/stable-diffusion-xl-base-1.0" ,
192
180
"/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16/checkpoint_pipe" ,
181
+ "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe" ,
193
182
]:
194
183
return sdxl_model_map
195
184
elif "stabilityai/stable-diffusion-3" in name :
@@ -233,6 +222,7 @@ def __init__(
233
222
benchmark : bool | dict [bool ] = False ,
234
223
verbose : bool = False ,
235
224
batch_prompts : bool = False ,
225
+ compiled_pipeline : bool = False ,
236
226
):
237
227
common_export_args = {
238
228
"hf_model_name" : None ,
@@ -243,11 +233,11 @@ def __init__(
243
233
"exit_on_vmfb" : False ,
244
234
"pipeline_dir" : pipeline_dir ,
245
235
"input_mlir" : None ,
246
- "attn_spec" : None ,
236
+ "attn_spec" : attn_spec ,
247
237
"external_weights" : None ,
248
238
"external_weight_path" : None ,
249
239
}
250
- sd_model_map = get_sd_model_map (hf_model_name )
240
+ sd_model_map = copy . deepcopy ( get_sd_model_map (hf_model_name ) )
251
241
for submodel in sd_model_map :
252
242
if "load" not in sd_model_map [submodel ]:
253
243
sd_model_map [submodel ]["load" ] = True
@@ -311,6 +301,7 @@ def __init__(
311
301
self .scheduler = None
312
302
313
303
self .split_scheduler = True
304
+ self .compiled_pipeline = compiled_pipeline
314
305
315
306
self .base_model_name = (
316
307
hf_model_name
@@ -321,11 +312,6 @@ def __init__(
321
312
self .is_sdxl = "xl" in self .base_model_name .lower ()
322
313
self .is_sd3 = "stable-diffusion-3" in self .base_model_name
323
314
if self .is_sdxl :
324
- if self .split_scheduler :
325
- if self .map .get ("unetloop" ):
326
- self .map .pop ("unetloop" )
327
- if self .map .get ("fullpipeline" ):
328
- self .map .pop ("fullpipeline" )
329
315
self .tokenizers = [
330
316
CLIPTokenizer .from_pretrained (
331
317
self .base_model_name , subfolder = "tokenizer"
@@ -339,6 +325,20 @@ def __init__(
339
325
self .scheduler_device = self .map ["unet" ]["device" ]
340
326
self .scheduler_driver = self .map ["unet" ]["driver" ]
341
327
self .scheduler_target = self .map ["unet" ]["target" ]
328
+ if not self .compiled_pipeline :
329
+ if self .map .get ("unetloop" ):
330
+ self .map .pop ("unetloop" )
331
+ if self .map .get ("fullpipeline" ):
332
+ self .map .pop ("fullpipeline" )
333
+ elif self .compiled_pipeline :
334
+ self .map ["unet" ]["load" ] = False
335
+ self .map ["vae" ]["load" ] = False
336
+ self .load_scheduler (
337
+ scheduler_id ,
338
+ num_inference_steps ,
339
+ )
340
+ self .map ["scheduler" ]["runner" ].unload ()
341
+ self .map ["scheduler" ]["load" ] = False
342
342
elif not self .is_sd3 :
343
343
self .tokenizer = CLIPTokenizer .from_pretrained (
344
344
self .base_model_name , subfolder = "tokenizer"
@@ -351,23 +351,27 @@ def __init__(
351
351
352
352
self .latents_dtype = torch_dtypes [self .latents_precision ]
353
353
self .use_i8_punet = self .use_punet = use_i8_punet
354
+ if self .use_punet :
355
+ self .setup_punet ()
356
+ else :
357
+ self .map ["unet" ]["keywords" ].append ("!punet" )
358
+ self .map ["unet" ]["function_name" ] = "run_forward"
359
+
360
+ def setup_punet (self ):
354
361
if self .use_i8_punet :
355
362
self .map ["unet" ]["export_args" ]["precision" ] = "i8"
356
- self .map ["unet" ]["export_args" ]["use_punet" ] = True
357
- self .map ["unet" ]["use_weights_for_export" ] = True
358
- self .map ["unet" ]["keywords" ].append ("punet" )
359
- self .map ["unet" ]["module_name" ] = "compiled_punet"
360
- self .map ["unet" ]["function_name" ] = "main"
361
363
self .map ["unet" ]["export_args" ]["external_weight_path" ] = (
362
364
utils .create_safe_name (self .base_model_name ) + "_punet_dataset_i8.irpa"
363
365
)
364
366
for idx , word in enumerate (self .map ["unet" ]["keywords" ]):
365
367
if word in ["fp32" , "fp16" ]:
366
368
self .map ["unet" ]["keywords" ][idx ] = "i8"
367
369
break
368
- else :
369
- self .map ["unet" ]["keywords" ].append ("!punet" )
370
- self .map ["unet" ]["function_name" ] = "run_forward"
370
+ self .map ["unet" ]["export_args" ]["use_punet" ] = True
371
+ self .map ["unet" ]["use_weights_for_export" ] = True
372
+ self .map ["unet" ]["keywords" ].append ("punet" )
373
+ self .map ["unet" ]["module_name" ] = "compiled_punet"
374
+ self .map ["unet" ]["function_name" ] = "main"
371
375
372
376
# LOAD
373
377
@@ -376,10 +380,6 @@ def load_scheduler(
376
380
scheduler_id : str ,
377
381
steps : int = 30 ,
378
382
):
379
- if self .is_sd3 :
380
- scheduler_device = self .mmdit .device
381
- else :
382
- scheduler_device = self .unet .device
383
383
if not self .cpu_scheduling :
384
384
self .map ["scheduler" ] = {
385
385
"module_name" : "compiled_scheduler" ,
@@ -425,7 +425,11 @@ def load_scheduler(
425
425
except :
426
426
print ("JIT export of scheduler failed. Loading CPU scheduler." )
427
427
self .cpu_scheduling = True
428
- if self .cpu_scheduling :
428
+ elif self .cpu_scheduling :
429
+ if self .is_sd3 :
430
+ scheduler_device = self .mmdit .device
431
+ else :
432
+ scheduler_device = self .unet .device
429
433
scheduler = schedulers .get_scheduler (self .base_model_name , scheduler_id )
430
434
self .scheduler = schedulers .SharkSchedulerCPUWrapper (
431
435
scheduler ,
@@ -461,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt):
461
465
text_input_ids_list += text_inputs .input_ids .unsqueeze (0 )
462
466
uncond_input_ids_list += uncond_input .input_ids .unsqueeze (0 )
463
467
464
- if self .compiled_pipeline :
465
- return text_input_ids_list , uncond_input_ids_list
466
- else :
467
- prompt_embeds , add_text_embeds = self .text_encoder (
468
- "encode_prompts" , [* text_input_ids_list , * uncond_input_ids_list ]
469
- )
470
- return prompt_embeds , add_text_embeds
468
+ prompt_embeds , add_text_embeds = self .text_encoder (
469
+ "encode_prompts" , [* text_input_ids_list , * uncond_input_ids_list ]
470
+ )
471
+ return prompt_embeds , add_text_embeds
471
472
472
473
def prepare_latents (
473
474
self ,
@@ -565,9 +566,11 @@ def _produce_latents_sdxl(
565
566
[guidance_scale ],
566
567
dtype = self .map ["unet" ]["np_dtype" ],
567
568
)
569
+ # Disable progress bar if we aren't in verbose mode or if we're printing
570
+ # benchmark latencies for unet.
568
571
for i , t in tqdm (
569
572
enumerate (timesteps ),
570
- disable = (self .map ["unet" ].get ("benchmark" ) and self .verbose ),
573
+ disable = (self .map ["unet" ].get ("benchmark" ) or not self .verbose ),
571
574
):
572
575
if self .cpu_scheduling :
573
576
latent_model_input , t = self .scheduler .scale_model_input (
@@ -608,6 +611,75 @@ def _produce_latents_sdxl(
608
611
latents = self .scheduler ("run_step" , [noise_pred , t , latents ])
609
612
return latents
610
613
614
+ def produce_images_compiled (
615
+ self ,
616
+ sample ,
617
+ prompt_embeds ,
618
+ text_embeds ,
619
+ guidance_scale ,
620
+ ):
621
+ pipe_inputs = [
622
+ sample ,
623
+ prompt_embeds ,
624
+ text_embeds ,
625
+ torch .as_tensor ([guidance_scale ], dtype = sample .dtype ),
626
+ ]
627
+ # image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
628
+ image = self .map ["fullpipeline" ]["runner" ]("produce_image_latents" , pipe_inputs )
629
+ return image
630
+
631
+ def prepare_sampling_inputs (
632
+ self ,
633
+ prompt : str ,
634
+ negative_prompt : str = "" ,
635
+ steps : int = 30 ,
636
+ batch_count : int = 1 ,
637
+ guidance_scale : float = 7.5 ,
638
+ seed : float = - 1 ,
639
+ cpu_scheduling : bool = True ,
640
+ scheduler_id : str = "EulerDiscrete" ,
641
+ return_imgs : bool = False ,
642
+ ):
643
+ needs_new_scheduler = (
644
+ (steps and steps != self .num_inference_steps )
645
+ or (cpu_scheduling != self .cpu_scheduling )
646
+ and self .split_scheduler
647
+ )
648
+ if not self .scheduler and not self .compiled_pipeline :
649
+ needs_new_scheduler = True
650
+
651
+ if guidance_scale == 0 :
652
+ negative_prompt = prompt
653
+ prompt = ""
654
+
655
+ self .cpu_scheduling = cpu_scheduling
656
+ if steps and needs_new_scheduler :
657
+ self .num_inference_steps = steps
658
+ self .load_scheduler (scheduler_id , steps )
659
+
660
+ pipe_start = time .time ()
661
+ numpy_images = []
662
+
663
+ samples = self .get_rand_latents (seed , batch_count )
664
+
665
+ # Tokenize prompt and negative prompt.
666
+ if self .is_sdxl :
667
+ prompt_embeds , negative_embeds = self .encode_prompts_sdxl (
668
+ prompt , negative_prompt
669
+ )
670
+ else :
671
+ prompt_embeds , negative_embeds = encode_prompt (
672
+ self , prompt , negative_prompt
673
+ )
674
+ produce_latents_input = [
675
+ samples [0 ],
676
+ prompt_embeds ,
677
+ negative_embeds ,
678
+ steps ,
679
+ guidance_scale ,
680
+ ]
681
+ return produce_latents_input
682
+
611
683
def generate_images (
612
684
self ,
613
685
prompt : str ,
@@ -653,18 +725,23 @@ def generate_images(
653
725
)
654
726
655
727
for i in range (batch_count ):
656
- produce_latents_input = [
657
- samples [i ],
658
- prompt_embeds ,
659
- negative_embeds ,
660
- steps ,
661
- guidance_scale ,
662
- ]
663
- if self .is_sdxl :
664
- latents = self ._produce_latents_sdxl (* produce_latents_input )
728
+ if self .compiled_pipeline :
729
+ image = self .produce_images_compiled (
730
+ samples [i ], prompt_embeds , negative_embeds , guidance_scale
731
+ ).to_host ()
665
732
else :
666
- latents = self ._produce_latents_sd (* produce_latents_input )
667
- image = self .vae ("decode" , [latents ])
733
+ produce_latents_input = [
734
+ samples [i ],
735
+ prompt_embeds ,
736
+ negative_embeds ,
737
+ steps ,
738
+ guidance_scale ,
739
+ ]
740
+ if self .is_sdxl :
741
+ latents = self ._produce_latents_sdxl (* produce_latents_input )
742
+ else :
743
+ latents = self ._produce_latents_sd (* produce_latents_input )
744
+ image = self .vae ("decode" , [latents ])
668
745
numpy_images .append (image )
669
746
pipe_end = time .time ()
670
747
@@ -750,8 +827,10 @@ def numpy_to_pil_image(images):
750
827
args .use_i8_punet ,
751
828
benchmark ,
752
829
args .verbose ,
830
+ False ,
831
+ args .compiled_pipeline ,
753
832
)
754
- sd_pipe .prepare_all ()
833
+ sd_pipe .prepare_all (num_steps = args . num_inference_steps )
755
834
sd_pipe .load_map ()
756
835
sd_pipe .generate_images (
757
836
args .prompt ,
0 commit comments