Skip to content

Commit 5a9aaa0

Browse files
authored
Merge branch 'nod-ai:ean-unify-sd' into ean-unify-sd
2 parents b45a6c5 + 618d01f commit 5a9aaa0

File tree

11 files changed

+192
-68
lines changed

11 files changed

+192
-68
lines changed

models/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
protobuf
2-
sentencepiece
32
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
43
transformers==4.37.1
54
torchsde

models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ def is_valid_file(arg):
247247
default="fp16",
248248
help="Precision of Stable Diffusion weights and graph.",
249249
)
250+
p.add_argument(
251+
"--vae_precision",
252+
type=str,
253+
default=None,
254+
help="Precision of Stable Diffusion VAE weights and graph.",
255+
)
250256
p.add_argument(
251257
"--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion"
252258
)
@@ -257,7 +263,7 @@ def is_valid_file(arg):
257263
p.add_argument(
258264
"--vae_decomp_attn",
259265
type=bool,
260-
default=True,
266+
default=False,
261267
help="Decompose attention for VAE decode only at fx graph level",
262268
)
263269
p.add_argument(
@@ -340,6 +346,12 @@ def is_valid_file(arg):
340346
action="store_true",
341347
help="Just compile attention reproducer for mmdit.",
342348
)
349+
p.add_argument(
350+
"--vae_input_path",
351+
type=str,
352+
default=None,
353+
help="Path to input latents for VAE inference numerics validation.",
354+
)
343355

344356

345357
##############################################################################

models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def export_mmdit_model(
207207
torch.empty(hidden_states_shape, dtype=dtype),
208208
torch.empty(encoder_hidden_states_shape, dtype=dtype),
209209
torch.empty(pooled_projections_shape, dtype=dtype),
210-
torch.empty(1, dtype=dtype),
210+
torch.empty(init_batch_dim, dtype=dtype),
211211
]
212212

213213
decomp_list = []

models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
154154
(batch_size, args.max_length * 2, 4096), dtype=dtype
155155
)
156156
pooled_projections = torch.randn((batch_size, 2048), dtype=dtype)
157-
timestep = torch.tensor([0], dtype=dtype)
157+
timestep = torch.tensor([0, 0], dtype=dtype)
158158

159159
turbine_output = run_mmdit_turbine(
160160
hidden_states,
@@ -180,6 +180,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
180180
timestep,
181181
args,
182182
)
183+
np.save("torch_mmdit_output.npy", torch_output.astype(np.float16))
183184
print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
184185

185186
print("\n(torch (comfy) image latents to iree image latents): ")

models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py

Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from turbine_models.custom_models.sd_inference import utils
1818
from turbine_models.model_runner import vmfbRunner
1919
from transformers import CLIPTokenizer
20+
from diffusers import FlowMatchEulerDiscreteScheduler
2021

2122
from PIL import Image
2223
import os
@@ -44,10 +45,8 @@ class SharkSD3Pipeline:
4445
def __init__(
4546
self,
4647
hf_model_name: str,
47-
# scheduler_id: str,
4848
height: int,
4949
width: int,
50-
shift: float,
5150
precision: str,
5251
max_length: int,
5352
batch_size: int,
@@ -60,9 +59,12 @@ def __init__(
6059
pipeline_dir: str = "./shark_vmfbs",
6160
external_weights_dir: str = "./shark_weights",
6261
external_weights: str = "safetensors",
63-
vae_decomp_attn: bool = True,
64-
custom_vae: str = "",
62+
vae_decomp_attn: bool = False,
6563
cpu_scheduling: bool = False,
64+
vae_precision: str = "fp32",
65+
scheduler_id: str = None, #compatibility only, always uses EulerFlowScheduler
66+
shift: float = 1.0,
67+
6668
):
6769
self.hf_model_name = hf_model_name
6870
# self.scheduler_id = scheduler_id
@@ -120,10 +122,11 @@ def __init__(
120122
self.external_weights_dir = external_weights_dir
121123
self.external_weights = external_weights
122124
self.vae_decomp_attn = vae_decomp_attn
123-
self.custom_vae = custom_vae
125+
self.custom_vae = None
124126
self.cpu_scheduling = cpu_scheduling
125127
self.torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16
126-
self.vae_dtype = torch.float32
128+
self.vae_precision = vae_precision if vae_precision else self.precision
129+
self.vae_dtype = torch.float32 if vae_precision == "fp32" else torch.float16
127130
# TODO: set this based on user-inputted guidance scale and negative prompt.
128131
self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True
129132

@@ -206,7 +209,12 @@ def is_prepared(self, vmfbs, weights):
206209
)
207210
if w_key == "clip":
208211
default_name = os.path.join(
209-
self.external_weights_dir, f"sd3_clip_fp16.irpa"
212+
self.external_weights_dir, f"sd3_text_encoders_{self.precision}.irpa"
213+
)
214+
if w_key == "mmdit":
215+
default_name = os.path.join(
216+
self.external_weights_dir,
217+
f"sd3_mmdit_{self.precision}." + self.external_weights,
210218
)
211219
if weights[w_key] is None and os.path.exists(default_name):
212220
weights[w_key] = os.path.join(default_name)
@@ -357,7 +365,7 @@ def export_submodel(
357365
self.batch_size,
358366
self.height,
359367
self.width,
360-
"fp32",
368+
self.vae_precision,
361369
"vmfb",
362370
self.external_weights,
363371
vae_external_weight_path,
@@ -419,10 +427,16 @@ def load_pipeline(
419427
unet_loaded = time.time()
420428
print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec")
421429

422-
runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
423-
self.devices["mmdit"]["driver"],
424-
vmfbs["scheduler"],
425-
)
430+
if not self.cpu_scheduling:
431+
runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
432+
self.devices["mmdit"]["driver"],
433+
vmfbs["scheduler"],
434+
)
435+
else:
436+
print("Using torch CPU scheduler.")
437+
runners["scheduler"] = FlowMatchEulerDiscreteScheduler.from_pretrained(
438+
self.hf_model_name, subfolder="scheduler"
439+
)
426440

427441
sched_loaded = time.time()
428442
print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec")
@@ -495,11 +509,12 @@ def generate_images(
495509
)
496510
)
497511

498-
guidance_scale = ireert.asdevicearray(
499-
self.runners["pipe"].config.device,
500-
np.asarray([guidance_scale]),
501-
dtype=iree_dtype,
502-
)
512+
if not self.cpu_scheduling:
513+
guidance_scale = ireert.asdevicearray(
514+
self.runners["pipe"].config.device,
515+
np.asarray([guidance_scale]),
516+
dtype=iree_dtype,
517+
)
503518

504519
tokenize_start = time.time()
505520
text_input_ids_dict = self.tokenizer.tokenize_with_weights(prompt)
@@ -533,12 +548,23 @@ def generate_images(
533548
"clip"
534549
].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs)
535550
encode_prompts_end = time.time()
551+
if self.cpu_scheduling:
552+
timesteps, num_inference_steps = sd3_schedulers.retrieve_timesteps(
553+
self.runners["scheduler"],
554+
num_inference_steps=self.num_inference_steps,
555+
timesteps=None,
556+
)
557+
steps = num_inference_steps
558+
536559

537560
for i in range(batch_count):
538561
unet_start = time.time()
539-
sample, steps, timesteps = self.runners["scheduler"].initialize(samples[i])
562+
if not self.cpu_scheduling:
563+
latents, steps, timesteps = self.runners["scheduler"].initialize(samples[i])
564+
else:
565+
latents = torch.tensor(samples[i].to_host(), dtype=self.torch_dtype)
540566
iree_inputs = [
541-
sample,
567+
latents,
542568
ireert.asdevicearray(
543569
self.runners["pipe"].config.device, prompt_embeds, dtype=iree_dtype
544570
),
@@ -553,40 +579,71 @@ def generate_images(
553579
# print(f"step {s}")
554580
if self.cpu_scheduling:
555581
step_index = s
582+
t = timesteps[s]
583+
if self.do_classifier_free_guidance:
584+
latent_model_input = torch.cat([latents] * 2)
585+
timestep = ireert.asdevicearray(
586+
self.runners["pipe"].config.device,
587+
t.expand(latent_model_input.shape[0]),
588+
dtype=iree_dtype,
589+
)
590+
latent_model_input = ireert.asdevicearray(
591+
self.runners["pipe"].config.device,
592+
latent_model_input,
593+
dtype=iree_dtype,
594+
)
556595
else:
557596
step_index = ireert.asdevicearray(
558597
self.runners["scheduler"].runner.config.device,
559598
torch.tensor([s]),
560599
"int64",
561600
)
562-
latents, t = self.runners["scheduler"].prep(
563-
sample,
564-
step_index,
565-
timesteps,
566-
)
601+
latent_model_input, timestep = self.runners["scheduler"].prep(
602+
latents,
603+
step_index,
604+
timesteps,
605+
)
606+
t = ireert.asdevicearray(
607+
self.runners["scheduler"].runner.config.device,
608+
timestep.to_host()[0]
609+
)
567610
noise_pred = self.runners["pipe"].ctx.modules.compiled_mmdit[
568611
"run_forward"
569612
](
570-
latents,
613+
latent_model_input,
571614
iree_inputs[1],
572615
iree_inputs[2],
573-
t,
616+
timestep,
574617
)
575-
sample = self.runners["scheduler"].step(
576-
noise_pred,
577-
t,
578-
sample,
579-
guidance_scale,
580-
step_index,
581-
)
582-
if isinstance(sample, torch.Tensor):
618+
if not self.cpu_scheduling:
619+
latents = self.runners["scheduler"].step(
620+
noise_pred,
621+
t,
622+
latents,
623+
guidance_scale,
624+
step_index,
625+
)
626+
else:
627+
noise_pred = torch.tensor(noise_pred.to_host(), dtype=self.torch_dtype)
628+
if self.do_classifier_free_guidance:
629+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
630+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
631+
latents = self.runners["scheduler"].step(
632+
noise_pred,
633+
t,
634+
latents,
635+
return_dict=False,
636+
)[0]
637+
638+
if isinstance(latents, torch.Tensor):
639+
latents = latents.type(self.vae_dtype)
583640
latents = ireert.asdevicearray(
584641
self.runners["vae"].config.device,
585-
sample,
586-
dtype=self.vae_dtype,
642+
latents,
587643
)
588644
else:
589-
latents = sample.astype("float32")
645+
vae_numpy_dtype = np.float32 if self.vae_precision == "fp32" else np.float16
646+
latents = latents.astype(vae_numpy_dtype)
590647

591648
vae_start = time.time()
592649
vae_out = self.runners["vae"].ctx.modules.compiled_vae["decode"](latents)
@@ -634,7 +691,7 @@ def generate_images(
634691
out_image = Image.fromarray(image)
635692
images.extend([[out_image]])
636693
if return_imgs:
637-
return images
694+
return images[0]
638695
for idx_batch, image_batch in enumerate(images):
639696
for idx, image in enumerate(image_batch):
640697
img_path = (
@@ -767,7 +824,6 @@ def run_diffusers_cpu(
767824
args.hf_model_name,
768825
args.height,
769826
args.width,
770-
args.shift,
771827
args.precision,
772828
args.max_length,
773829
args.batch_size,
@@ -779,16 +835,15 @@ def run_diffusers_cpu(
779835
args.decomp_attn,
780836
args.pipeline_dir,
781837
args.external_weights_dir,
782-
args.external_weights,
783-
args.vae_decomp_attn,
784-
custom_vae=None,
838+
external_weights=args.external_weights,
839+
vae_decomp_attn=args.vae_decomp_attn,
785840
cpu_scheduling=args.cpu_scheduling,
786841
vae_precision=args.vae_precision,
787842
)
788-
vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights)
789843
if args.cpu_scheduling:
790844
vmfbs.pop("scheduler")
791845
weights.pop("scheduler")
846+
vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights)
792847
if args.npu_delegate_path:
793848
extra_device_args = {"npu_delegate_path": args.npu_delegate_path}
794849
else:

models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import os
8+
import inspect
89
from typing import List
910

1011
import torch
12+
from typing import Any, Callable, Dict, List, Optional, Union
1113
from shark_turbine.aot import *
1214
import shark_turbine.ops.iree as ops
1315
from iree.compiler.ir import Context
@@ -75,11 +77,12 @@ def initialize(self, sample):
7577

7678
def prepare_model_input(self, sample, t, timesteps):
7779
t = timesteps[t]
78-
t = t.expand(sample.shape[0])
80+
7981
if self.do_classifier_free_guidance:
8082
latent_model_input = torch.cat([sample] * 2)
8183
else:
8284
latent_model_input = sample
85+
t = t.expand(latent_model_input.shape[0])
8386
return latent_model_input.type(self.dtype), t.type(self.dtype)
8487

8588
def step(self, noise_pred, t, sample, guidance_scale, i):
@@ -146,6 +149,42 @@ def step(self, noise_pred, t, latents, guidance_scale, i):
146149
return_dict=False,
147150
)[0]
148151

152+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
153+
# Only used for cpu scheduling.
154+
def retrieve_timesteps(
155+
scheduler,
156+
num_inference_steps: Optional[int] = None,
157+
device: Optional[Union[str, torch.device]] = None,
158+
timesteps: Optional[List[int]] = None,
159+
sigmas: Optional[List[float]] = None,
160+
**kwargs,
161+
):
162+
if timesteps is not None and sigmas is not None:
163+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
164+
if timesteps is not None:
165+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
166+
if not accepts_timesteps:
167+
raise ValueError(
168+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
169+
f" timestep schedules. Please check whether you are using the correct scheduler."
170+
)
171+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
172+
timesteps = scheduler.timesteps
173+
num_inference_steps = len(timesteps)
174+
elif sigmas is not None:
175+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
176+
if not accept_sigmas:
177+
raise ValueError(
178+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
179+
f" sigmas schedules. Please check whether you are using the correct scheduler."
180+
)
181+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
182+
timesteps = scheduler.timesteps
183+
num_inference_steps = len(timesteps)
184+
else:
185+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
186+
timesteps = scheduler.timesteps
187+
return timesteps, num_inference_steps
149188

150189
@torch.no_grad()
151190
def export_scheduler_model(

0 commit comments

Comments
 (0)