Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 86 additions & 29 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,16 @@
"FluxTransformer2DModel": "flux-transformer-2d",
"SD3Transformer2DModel": "sd3-transformer-2d",
"UNet2DConditionModel": "unet-2d-condition",
"UNet3DConditionModel": "unet-3d-condition",
"T5EncoderModel": "t5-encoder",
"UMT5EncoderModel": "umt5-encoder",
"WanTransformer3DModel": "wan-transformer-3d",
}


def _get_diffusers_submodel_type(submodel):
return _DIFFUSERS_CLASS_NAME_TO_SUBMODEL_TYPE.get(submodel.__class__.__name__)


def _get_submodels_for_export_diffusion(
pipeline: "DiffusionPipeline",
) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]:
Expand All @@ -96,6 +98,7 @@ def _get_submodels_for_export_diffusion(

is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL")
is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3")
is_wan = pipeline.__class__.__name__.startswith("Wan")

# Text encoder
text_encoder = getattr(pipeline, "text_encoder", None)
Expand All @@ -105,6 +108,10 @@ def _get_submodels_for_export_diffusion(
text_encoder.text_model.config.output_hidden_states = True

text_encoder.config.export_model_type = _get_diffusers_submodel_type(text_encoder)
if text_encoder.__class__.__name__ == "UMT5EncoderModel":
orig_forward = text_encoder.forward
text_encoder.forward = lambda input_ids, attention_mask: \
orig_forward(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
models_for_export["text_encoder"] = text_encoder

# Text encoder 2
Expand Down Expand Up @@ -135,6 +142,10 @@ def _get_submodels_for_export_diffusion(
if not is_sdxl
else pipeline.text_encoder_2.config.projection_dim
)
hidden_size = getattr(pipeline.text_encoder.config, "hidden_size", None)
if unet.__class__.__name__ == "UNet3DConditionModel":
unet.config.text_encoder_projection_dim = hidden_size
unet.config.vocab_size = getattr(pipeline.text_encoder.config, "vocab_size", 4096)
unet.config.export_model_type = _get_diffusers_submodel_type(unet)
models_for_export["unet"] = unet

Expand All @@ -143,21 +154,59 @@ def _get_submodels_for_export_diffusion(
if transformer is not None:
transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
transformer.config.time_cond_proj_dim = getattr(pipeline.transformer.config, "time_cond_proj_dim", None)
transformer.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim
if is_wan is True:
# scale_factor_temporal
transformer.config.text_encoder_projection_dim = getattr(pipeline.transformer.config, "text_dim", None)
transformer.config.expand_timesteps = getattr(pipeline.config, "expand_timesteps", False)
transformer.config.vae_scale_factor_temporal = getattr(pipeline.vae.config, "scale_factor_temporal", 4)
transformer.config.vae_scale_factor_spatial = getattr(pipeline.vae.config, "scale_factor_spatial", 8)
vace_layers = getattr(pipeline.transformer.config, "vace_layers", None)
if vace_layers is not None:
transformer.config.vace_num_layers = len(vace_layers)
transformer.config.vace_in_channels = getattr(pipeline.transformer.config, "vace_in_channels", 96)
transformer.config.vocab_size = 256384
else:
transformer.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim
transformer.config.export_model_type = _get_diffusers_submodel_type(transformer)
models_for_export["transformer"] = transformer

# Transformer_2
transformer_2 = getattr(pipeline, "transformer_2", None)
if transformer_2 is not None:
transformer_2.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
transformer_2.config.time_cond_proj_dim = getattr(pipeline.transformer.config, "time_cond_proj_dim", None)
if is_wan is True:
transformer_2.config.text_encoder_projection_dim = getattr(pipeline.transformer_2.config, "text_dim", None)
transformer_2.config.expand_timesteps = getattr(pipeline.config, "expand_timesteps", False)
transformer_2.config.vae_scale_factor_temporal = getattr(pipeline.vae.config, "scale_factor_temporal", 4)
transformer_2.config.vae_scale_factor_spatial = getattr(pipeline.vae.config, "scale_factor_spatial", 8)
vace_layers = getattr(pipeline.transformer_2.config, "vace_layers", None)
if vace_layers is not None:
transformer_2.config.vace_num_layers = len(vace_layers)
transformer_2.config.vace_in_channels = getattr(pipeline.transformer_2.config, "vace_in_channels", 96)
transformer_2.config.vocab_size = 256384
else:
transformer_2.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim
transformer_2.config.export_model_type = _get_diffusers_submodel_type(transformer)
models_for_export["transformer_2"] = transformer_2

# VAE Encoder
vae_encoder = copy.deepcopy(pipeline.vae)

# we return the distribution parameters to be able to recreate it in the decoder
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
if vae_encoder.__class__.__name__ == "AutoencoderKLWan":
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample).latent_dist.mode()}
transformer.config.z_dim = getattr(vae_encoder.config, "z_dim", None)
if transformer_2 is not None:
transformer_2.config.z_dim = getattr(vae_encoder.config, "z_dim", None)
else:
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}

models_for_export["vae_encoder"] = vae_encoder

# VAE Decoder
vae_decoder = copy.deepcopy(pipeline.vae)

vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
models_for_export["vae_decoder"] = vae_decoder

return models_for_export
Expand Down Expand Up @@ -332,32 +381,40 @@ def get_diffusion_models_for_export(
models_for_export["transformer"] = (models_for_export["transformer"], transformer_export_config)

# VAE Encoder
vae_encoder = models_for_export["vae_encoder"]
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_encoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-encoder",
)
vae_encoder_export_config = vae_config_constructor(
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)
if "vae_encoder" in models_for_export:
vae_encoder = models_for_export["vae_encoder"]
encoder_model_type = "vae-encoder"
if vae_encoder.__class__.__name__ == "AutoencoderKLWan":
encoder_model_type = "vae-encoder-video"
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_encoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type=encoder_model_type,
)
vae_encoder_export_config = vae_config_constructor(
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)

# VAE Decoder
vae_decoder = models_for_export["vae_decoder"]
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_decoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-decoder",
)
vae_decoder_export_config = vae_config_constructor(
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)
if "vae_decoder" in models_for_export:
vae_decoder = models_for_export["vae_decoder"]
decoder_model_type = "vae-decoder"
if vae_decoder.__class__.__name__ == "AutoencoderKLWan":
decoder_model_type = "vae-decoder-video"
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_decoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type=decoder_model_type,
)
vae_decoder_export_config = vae_config_constructor(
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)

return models_for_export

Expand Down
2 changes: 2 additions & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
DummyVideoInputGenerator,
DummyXPathSeqInputGenerator,
DummyWanTimestepInputGenerator,
FalconDummyPastKeyValuesGenerator,
GemmaDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
Expand Down
91 changes: 91 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def wrapper(*args, **kwargs):
"feature_size": 80,
"nb_max_frames": 3000,
"audio_sequence_length": 16000,
# video
"num_frames": 2,
"video_width": 128,
"video_height": 128,
}


Expand Down Expand Up @@ -923,6 +927,93 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
)


class DummyVideoInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("hidden_states", "sample", "latent_sample")

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_frames: int = DEFAULT_DUMMY_SHAPES["num_frames"],
video_height: int = DEFAULT_DUMMY_SHAPES["video_height"],
video_width: int = DEFAULT_DUMMY_SHAPES["video_width"],
**kwargs,
):
self.task = task
self.normalized_config = normalized_config

self.in_channels = self.normalized_config.in_channels
self.latent_channels = getattr(self.normalized_config, "z_dim", None)

self.batch_size = batch_size
self.num_frames = num_frames
self.video_height = video_height
self.video_width = video_width

self.scale_factor_temporal = getattr(self.normalized_config, "scale_factor_temporal", None)
self.scale_factor_spatial = getattr(self.normalized_config, "scale_factor_spatial", None)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "latent_sample":
return self.random_float_tensor(
shape=[self.batch_size,
self.latent_channels,
1 + ((self.num_frames - 1) // self.scale_factor_temporal),
self.video_height // self.scale_factor_spatial,
self.video_width // self.scale_factor_spatial],
framework=framework,
dtype=float_dtype,
)
return self.random_float_tensor(
shape=[self.batch_size, self.in_channels, self.num_frames, self.video_height, self.video_width],
framework=framework,
dtype=float_dtype,
)

class DummyWanTimestepInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("timestep")

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_frames: int = DEFAULT_DUMMY_SHAPES["num_frames"],
video_height: int = DEFAULT_DUMMY_SHAPES["video_height"],
video_width: int = DEFAULT_DUMMY_SHAPES["video_width"],
**kwargs,
):
self.task = task
self.normalized_config = normalized_config

self.in_channels = self.normalized_config.in_channels
self.expand_timesteps = self.normalized_config.expand_timesteps
self.vae_scale_factor_temporal = self.normalized_config.vae_scale_factor_temporal
self.vae_scale_factor_spatial = self.normalized_config.vae_scale_factor_spatial

self.batch_size = batch_size
self.num_frames = num_frames
self.video_height = video_height
self.video_width = video_width


def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if self.expand_timesteps is True: #Wan2.2
num_latent_frames = (self.num_frames - 1) // self.vae_scale_factor_temporal + 1
num_latent_height = self.video_height // self.vae_scale_factor_spatial
num_latent_width = self.video_width // self.vae_scale_factor_spatial
return self.random_float_tensor(
shape=[self.batch_size, num_latent_frames * (num_latent_height // 2) * (num_latent_width //2 )],
framework=framework,
dtype=float_dtype,
)
return self.random_float_tensor(
shape=[self.batch_size],
framework=framework,
dtype=float_dtype,
)

class DummyTimestepInputGenerator(DummyInputGenerator):
"""
Generates dummy time step inputs.
Expand Down