118
118
"decomp_attn" : None ,
119
119
},
120
120
},
121
- "unetloop" : {
122
- "module_name" : "sdxl_compiled_pipeline" ,
123
- "load" : False ,
124
- "keywords" : ["unetloop" ],
125
- "wraps" : ["unet" , "scheduler" ],
126
- "export_args" : {
127
- "batch_size" : 1 ,
128
- "height" : 1024 ,
129
- "width" : 1024 ,
130
- "max_length" : 64 ,
131
- },
132
- },
133
121
"fullpipeline" : {
134
122
"module_name" : "sdxl_compiled_pipeline" ,
135
- "load" : False ,
123
+ "load" : True ,
136
124
"keywords" : ["fullpipeline" ],
137
- "wraps" : ["text_encoder" , " unet" , "scheduler" , "vae" ],
125
+ "wraps" : ["unet" , "scheduler" , "vae" ],
138
126
"export_args" : {
139
127
"batch_size" : 1 ,
140
128
"height" : 1024 ,
@@ -234,6 +222,7 @@ def __init__(
234
222
benchmark : bool | dict [bool ] = False ,
235
223
verbose : bool = False ,
236
224
batch_prompts : bool = False ,
225
+ compiled_pipeline : bool = False ,
237
226
):
238
227
common_export_args = {
239
228
"hf_model_name" : None ,
@@ -312,6 +301,7 @@ def __init__(
312
301
self .scheduler = None
313
302
314
303
self .split_scheduler = True
304
+ self .compiled_pipeline = compiled_pipeline
315
305
316
306
self .base_model_name = (
317
307
hf_model_name
@@ -322,11 +312,6 @@ def __init__(
322
312
self .is_sdxl = "xl" in self .base_model_name .lower ()
323
313
self .is_sd3 = "stable-diffusion-3" in self .base_model_name
324
314
if self .is_sdxl :
325
- if self .split_scheduler :
326
- if self .map .get ("unetloop" ):
327
- self .map .pop ("unetloop" )
328
- if self .map .get ("fullpipeline" ):
329
- self .map .pop ("fullpipeline" )
330
315
self .tokenizers = [
331
316
CLIPTokenizer .from_pretrained (
332
317
self .base_model_name , subfolder = "tokenizer"
@@ -340,6 +325,20 @@ def __init__(
340
325
self .scheduler_device = self .map ["unet" ]["device" ]
341
326
self .scheduler_driver = self .map ["unet" ]["driver" ]
342
327
self .scheduler_target = self .map ["unet" ]["target" ]
328
+ if not self .compiled_pipeline :
329
+ if self .map .get ("unetloop" ):
330
+ self .map .pop ("unetloop" )
331
+ if self .map .get ("fullpipeline" ):
332
+ self .map .pop ("fullpipeline" )
333
+ elif self .compiled_pipeline :
334
+ self .map ["unet" ]["load" ] = False
335
+ self .map ["vae" ]["load" ] = False
336
+ self .load_scheduler (
337
+ scheduler_id ,
338
+ num_inference_steps ,
339
+ )
340
+ self .map ["scheduler" ]["runner" ].unload ()
341
+ self .map ["scheduler" ]["load" ] = False
343
342
elif not self .is_sd3 :
344
343
self .tokenizer = CLIPTokenizer .from_pretrained (
345
344
self .base_model_name , subfolder = "tokenizer"
@@ -381,10 +380,6 @@ def load_scheduler(
381
380
scheduler_id : str ,
382
381
steps : int = 30 ,
383
382
):
384
- if self .is_sd3 :
385
- scheduler_device = self .mmdit .device
386
- else :
387
- scheduler_device = self .unet .device
388
383
if not self .cpu_scheduling :
389
384
self .map ["scheduler" ] = {
390
385
"module_name" : "compiled_scheduler" ,
@@ -430,7 +425,11 @@ def load_scheduler(
430
425
except :
431
426
print ("JIT export of scheduler failed. Loading CPU scheduler." )
432
427
self .cpu_scheduling = True
433
- if self .cpu_scheduling :
428
+ elif self .cpu_scheduling :
429
+ if self .is_sd3 :
430
+ scheduler_device = self .mmdit .device
431
+ else :
432
+ scheduler_device = self .unet .device
434
433
scheduler = schedulers .get_scheduler (self .base_model_name , scheduler_id )
435
434
self .scheduler = schedulers .SharkSchedulerCPUWrapper (
436
435
scheduler ,
@@ -466,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt):
466
465
text_input_ids_list += text_inputs .input_ids .unsqueeze (0 )
467
466
uncond_input_ids_list += uncond_input .input_ids .unsqueeze (0 )
468
467
469
- if self .compiled_pipeline :
470
- return text_input_ids_list , uncond_input_ids_list
471
- else :
472
- prompt_embeds , add_text_embeds = self .text_encoder (
473
- "encode_prompts" , [* text_input_ids_list , * uncond_input_ids_list ]
474
- )
475
- return prompt_embeds , add_text_embeds
468
+ prompt_embeds , add_text_embeds = self .text_encoder (
469
+ "encode_prompts" , [* text_input_ids_list , * uncond_input_ids_list ]
470
+ )
471
+ return prompt_embeds , add_text_embeds
476
472
477
473
def prepare_latents (
478
474
self ,
@@ -615,6 +611,75 @@ def _produce_latents_sdxl(
615
611
latents = self .scheduler ("run_step" , [noise_pred , t , latents ])
616
612
return latents
617
613
614
+ def produce_images_compiled (
615
+ self ,
616
+ sample ,
617
+ prompt_embeds ,
618
+ text_embeds ,
619
+ guidance_scale ,
620
+ ):
621
+ pipe_inputs = [
622
+ sample ,
623
+ prompt_embeds ,
624
+ text_embeds ,
625
+ torch .as_tensor ([guidance_scale ], dtype = sample .dtype ),
626
+ ]
627
+ #image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
628
+ image = self .map ["fullpipeline" ]["runner" ]("produce_image_latents" , pipe_inputs )
629
+ return image
630
+
631
+ def prepare_sampling_inputs (
632
+ self ,
633
+ prompt : str ,
634
+ negative_prompt : str = "" ,
635
+ steps : int = 30 ,
636
+ batch_count : int = 1 ,
637
+ guidance_scale : float = 7.5 ,
638
+ seed : float = - 1 ,
639
+ cpu_scheduling : bool = True ,
640
+ scheduler_id : str = "EulerDiscrete" ,
641
+ return_imgs : bool = False ,
642
+ ):
643
+ needs_new_scheduler = (
644
+ (steps and steps != self .num_inference_steps )
645
+ or (cpu_scheduling != self .cpu_scheduling )
646
+ and self .split_scheduler
647
+ )
648
+ if not self .scheduler and not self .compiled_pipeline :
649
+ needs_new_scheduler = True
650
+
651
+ if guidance_scale == 0 :
652
+ negative_prompt = prompt
653
+ prompt = ""
654
+
655
+ self .cpu_scheduling = cpu_scheduling
656
+ if steps and needs_new_scheduler :
657
+ self .num_inference_steps = steps
658
+ self .load_scheduler (scheduler_id , steps )
659
+
660
+ pipe_start = time .time ()
661
+ numpy_images = []
662
+
663
+ samples = self .get_rand_latents (seed , batch_count )
664
+
665
+ # Tokenize prompt and negative prompt.
666
+ if self .is_sdxl :
667
+ prompt_embeds , negative_embeds = self .encode_prompts_sdxl (
668
+ prompt , negative_prompt
669
+ )
670
+ else :
671
+ prompt_embeds , negative_embeds = encode_prompt (
672
+ self , prompt , negative_prompt
673
+ )
674
+ produce_latents_input = [
675
+ samples [0 ],
676
+ prompt_embeds ,
677
+ negative_embeds ,
678
+ steps ,
679
+ guidance_scale ,
680
+ ]
681
+ return produce_latents_input
682
+
618
683
def generate_images (
619
684
self ,
620
685
prompt : str ,
@@ -660,18 +725,21 @@ def generate_images(
660
725
)
661
726
662
727
for i in range (batch_count ):
663
- produce_latents_input = [
664
- samples [i ],
665
- prompt_embeds ,
666
- negative_embeds ,
667
- steps ,
668
- guidance_scale ,
669
- ]
670
- if self .is_sdxl :
671
- latents = self ._produce_latents_sdxl (* produce_latents_input )
728
+ if self .compiled_pipeline :
729
+ image = self .produce_images_compiled (samples [i ], prompt_embeds , negative_embeds , guidance_scale ).to_host ()
672
730
else :
673
- latents = self ._produce_latents_sd (* produce_latents_input )
674
- image = self .vae ("decode" , [latents ])
731
+ produce_latents_input = [
732
+ samples [i ],
733
+ prompt_embeds ,
734
+ negative_embeds ,
735
+ steps ,
736
+ guidance_scale ,
737
+ ]
738
+ if self .is_sdxl :
739
+ latents = self ._produce_latents_sdxl (* produce_latents_input )
740
+ else :
741
+ latents = self ._produce_latents_sd (* produce_latents_input )
742
+ image = self .vae ("decode" , [latents ])
675
743
numpy_images .append (image )
676
744
pipe_end = time .time ()
677
745
@@ -757,8 +825,10 @@ def numpy_to_pil_image(images):
757
825
args .use_i8_punet ,
758
826
benchmark ,
759
827
args .verbose ,
828
+ False ,
829
+ args .compiled_pipeline ,
760
830
)
761
- sd_pipe .prepare_all ()
831
+ sd_pipe .prepare_all (num_steps = args . num_inference_steps )
762
832
sd_pipe .load_map ()
763
833
sd_pipe .generate_images (
764
834
args .prompt ,
0 commit comments