118
118
"decomp_attn" : None ,
119
119
},
120
120
},
121
- "unetloop" : {
122
- "module_name" : "sdxl_compiled_pipeline" ,
123
- "load" : False ,
124
- "keywords" : ["unetloop" ],
125
- "wraps" : ["unet" , "scheduler" ],
126
- "export_args" : {
127
- "batch_size" : 1 ,
128
- "height" : 1024 ,
129
- "width" : 1024 ,
130
- "max_length" : 64 ,
131
- },
132
- },
133
121
"fullpipeline" : {
134
122
"module_name" : "sdxl_compiled_pipeline" ,
135
- "load" : False ,
123
+ "load" : True ,
136
124
"keywords" : ["fullpipeline" ],
137
- "wraps" : ["text_encoder" , " unet" , "scheduler" , "vae" ],
125
+ "wraps" : ["unet" , "scheduler" , "vae" ],
138
126
"export_args" : {
139
127
"batch_size" : 1 ,
140
128
"height" : 1024 ,
@@ -234,6 +222,7 @@ def __init__(
234
222
benchmark : bool | dict [bool ] = False ,
235
223
verbose : bool = False ,
236
224
batch_prompts : bool = False ,
225
+ compiled_pipeline : bool = False ,
237
226
):
238
227
common_export_args = {
239
228
"hf_model_name" : None ,
@@ -312,6 +301,7 @@ def __init__(
312
301
self .scheduler = None
313
302
314
303
self .split_scheduler = True
304
+ self .compiled_pipeline = compiled_pipeline
315
305
316
306
self .base_model_name = (
317
307
hf_model_name
@@ -322,11 +312,6 @@ def __init__(
322
312
self .is_sdxl = "xl" in self .base_model_name .lower ()
323
313
self .is_sd3 = "stable-diffusion-3" in self .base_model_name
324
314
if self .is_sdxl :
325
- if self .split_scheduler :
326
- if self .map .get ("unetloop" ):
327
- self .map .pop ("unetloop" )
328
- if self .map .get ("fullpipeline" ):
329
- self .map .pop ("fullpipeline" )
330
315
self .tokenizers = [
331
316
CLIPTokenizer .from_pretrained (
332
317
self .base_model_name , subfolder = "tokenizer"
@@ -340,6 +325,20 @@ def __init__(
340
325
self .scheduler_device = self .map ["unet" ]["device" ]
341
326
self .scheduler_driver = self .map ["unet" ]["driver" ]
342
327
self .scheduler_target = self .map ["unet" ]["target" ]
328
+ if not self .compiled_pipeline :
329
+ if self .map .get ("unetloop" ):
330
+ self .map .pop ("unetloop" )
331
+ if self .map .get ("fullpipeline" ):
332
+ self .map .pop ("fullpipeline" )
333
+ elif self .compiled_pipeline :
334
+ self .map ["unet" ]["load" ] = False
335
+ self .map ["vae" ]["load" ] = False
336
+ self .load_scheduler (
337
+ scheduler_id ,
338
+ num_inference_steps ,
339
+ )
340
+ self .map ["scheduler" ]["runner" ].unload ()
341
+ self .map ["scheduler" ]["load" ] = False
343
342
elif not self .is_sd3 :
344
343
self .tokenizer = CLIPTokenizer .from_pretrained (
345
344
self .base_model_name , subfolder = "tokenizer"
@@ -381,10 +380,6 @@ def load_scheduler(
381
380
scheduler_id : str ,
382
381
steps : int = 30 ,
383
382
):
384
- if self .is_sd3 :
385
- scheduler_device = self .mmdit .device
386
- else :
387
- scheduler_device = self .unet .device
388
383
if not self .cpu_scheduling :
389
384
self .map ["scheduler" ] = {
390
385
"module_name" : "compiled_scheduler" ,
@@ -430,7 +425,11 @@ def load_scheduler(
430
425
except :
431
426
print ("JIT export of scheduler failed. Loading CPU scheduler." )
432
427
self .cpu_scheduling = True
433
- if self .cpu_scheduling :
428
+ elif self .cpu_scheduling :
429
+ if self .is_sd3 :
430
+ scheduler_device = self .mmdit .device
431
+ else :
432
+ scheduler_device = self .unet .device
434
433
scheduler = schedulers .get_scheduler (self .base_model_name , scheduler_id )
435
434
self .scheduler = schedulers .SharkSchedulerCPUWrapper (
436
435
scheduler ,
@@ -615,6 +614,72 @@ def _produce_latents_sdxl(
615
614
latents = self .scheduler ("run_step" , [noise_pred , t , latents ])
616
615
return latents
617
616
617
+ def produce_images_compiled (
618
+ sample ,
619
+ prompt_embeds ,
620
+ text_embeds ,
621
+ guidance_scale ,
622
+ ):
623
+ pipe_inputs = [
624
+ sample ,
625
+ prompt_embeds ,
626
+ text_embeds ,
627
+ guidance_scale ,
628
+ ]
629
+ image = self .compiled_pipeline ("produce_img_latents" , pipe_inputs )
630
+
631
+ def prepare_sampling_inputs (
632
+ self ,
633
+ prompt : str ,
634
+ negative_prompt : str = "" ,
635
+ steps : int = 30 ,
636
+ batch_count : int = 1 ,
637
+ guidance_scale : float = 7.5 ,
638
+ seed : float = - 1 ,
639
+ cpu_scheduling : bool = True ,
640
+ scheduler_id : str = "EulerDiscrete" ,
641
+ return_imgs : bool = False ,
642
+ ):
643
+ needs_new_scheduler = (
644
+ (steps and steps != self .num_inference_steps )
645
+ or (cpu_scheduling != self .cpu_scheduling )
646
+ and self .split_scheduler
647
+ )
648
+ if not self .scheduler and not self .compiled_pipeline :
649
+ needs_new_scheduler = True
650
+
651
+ if guidance_scale == 0 :
652
+ negative_prompt = prompt
653
+ prompt = ""
654
+
655
+ self .cpu_scheduling = cpu_scheduling
656
+ if steps and needs_new_scheduler :
657
+ self .num_inference_steps = steps
658
+ self .load_scheduler (scheduler_id , steps )
659
+
660
+ pipe_start = time .time ()
661
+ numpy_images = []
662
+
663
+ samples = self .get_rand_latents (seed , batch_count )
664
+
665
+ # Tokenize prompt and negative prompt.
666
+ if self .is_sdxl :
667
+ prompt_embeds , negative_embeds = self .encode_prompts_sdxl (
668
+ prompt , negative_prompt
669
+ )
670
+ else :
671
+ prompt_embeds , negative_embeds = encode_prompt (
672
+ self , prompt , negative_prompt
673
+ )
674
+ produce_latents_input = [
675
+ samples [0 ],
676
+ prompt_embeds ,
677
+ negative_embeds ,
678
+ steps ,
679
+ guidance_scale ,
680
+ ]
681
+ return produce_latents_input
682
+
618
683
def generate_images (
619
684
self ,
620
685
prompt : str ,
@@ -660,18 +725,26 @@ def generate_images(
660
725
)
661
726
662
727
for i in range (batch_count ):
663
- produce_latents_input = [
664
- samples [i ],
665
- prompt_embeds ,
666
- negative_embeds ,
667
- steps ,
668
- guidance_scale ,
669
- ]
670
- if self .is_sdxl :
671
- latents = self ._produce_latents_sdxl (* produce_latents_input )
728
+ if self .compiled_pipeline :
729
+ image = produce_images_compiled (
730
+ samples [i ],
731
+ prompt_embeds ,
732
+ negative_embeds ,
733
+ guidance_scale
734
+ )
672
735
else :
673
- latents = self ._produce_latents_sd (* produce_latents_input )
674
- image = self .vae ("decode" , [latents ])
736
+ produce_latents_input = [
737
+ samples [i ],
738
+ prompt_embeds ,
739
+ negative_embeds ,
740
+ steps ,
741
+ guidance_scale ,
742
+ ]
743
+ if self .is_sdxl :
744
+ latents = self ._produce_latents_sdxl (* produce_latents_input )
745
+ else :
746
+ latents = self ._produce_latents_sd (* produce_latents_input )
747
+ image = self .vae ("decode" , [latents ])
675
748
numpy_images .append (image )
676
749
pipe_end = time .time ()
677
750
@@ -757,6 +830,8 @@ def numpy_to_pil_image(images):
757
830
args .use_i8_punet ,
758
831
benchmark ,
759
832
args .verbose ,
833
+ False ,
834
+ args .compiled_pipeline ,
760
835
)
761
836
sd_pipe .prepare_all ()
762
837
sd_pipe .load_map ()
0 commit comments