diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 99c527ea28..d60b3e7d20 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -77,14 +77,16 @@ "FluxTransformer2DModel": "flux-transformer-2d", "SD3Transformer2DModel": "sd3-transformer-2d", "UNet2DConditionModel": "unet-2d-condition", + "UNet3DConditionModel": "unet-3d-condition", "T5EncoderModel": "t5-encoder", + "UMT5EncoderModel": "umt5-encoder", + "WanTransformer3DModel": "wan-transformer-3d", } def _get_diffusers_submodel_type(submodel): return _DIFFUSERS_CLASS_NAME_TO_SUBMODEL_TYPE.get(submodel.__class__.__name__) - def _get_submodels_for_export_diffusion( pipeline: "DiffusionPipeline", ) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]: @@ -96,6 +98,7 @@ def _get_submodels_for_export_diffusion( is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL") is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3") + is_wan = pipeline.__class__.__name__.startswith("Wan") # Text encoder text_encoder = getattr(pipeline, "text_encoder", None) @@ -105,6 +108,10 @@ def _get_submodels_for_export_diffusion( text_encoder.text_model.config.output_hidden_states = True text_encoder.config.export_model_type = _get_diffusers_submodel_type(text_encoder) + if text_encoder.__class__.__name__ == "UMT5EncoderModel": + orig_forward = text_encoder.forward + text_encoder.forward = lambda input_ids, attention_mask: \ + orig_forward(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state models_for_export["text_encoder"] = text_encoder # Text encoder 2 @@ -135,6 +142,10 @@ def _get_submodels_for_export_diffusion( if not is_sdxl else pipeline.text_encoder_2.config.projection_dim ) + hidden_size = getattr(pipeline.text_encoder.config, "hidden_size", None) + if unet.__class__.__name__ == "UNet3DConditionModel": + unet.config.text_encoder_projection_dim = hidden_size + unet.config.vocab_size = getattr(pipeline.text_encoder.config, "vocab_size", 4096) unet.config.export_model_type = _get_diffusers_submodel_type(unet) models_for_export["unet"] = unet @@ -143,21 +154,59 @@ def _get_submodels_for_export_diffusion( if transformer is not None: transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) transformer.config.time_cond_proj_dim = getattr(pipeline.transformer.config, "time_cond_proj_dim", None) - transformer.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim + if is_wan is True: + # scale_factor_temporal + transformer.config.text_encoder_projection_dim = getattr(pipeline.transformer.config, "text_dim", None) + transformer.config.expand_timesteps = getattr(pipeline.config, "expand_timesteps", False) + transformer.config.vae_scale_factor_temporal = getattr(pipeline.vae.config, "scale_factor_temporal", 4) + transformer.config.vae_scale_factor_spatial = getattr(pipeline.vae.config, "scale_factor_spatial", 8) + vace_layers = getattr(pipeline.transformer.config, "vace_layers", None) + if vace_layers is not None: + transformer.config.vace_num_layers = len(vace_layers) + transformer.config.vace_in_channels = getattr(pipeline.transformer.config, "vace_in_channels", 96) + transformer.config.vocab_size = 256384 + else: + transformer.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim transformer.config.export_model_type = _get_diffusers_submodel_type(transformer) models_for_export["transformer"] = transformer + # Transformer_2 + transformer_2 = getattr(pipeline, "transformer_2", None) + if transformer_2 is not None: + transformer_2.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) + transformer_2.config.time_cond_proj_dim = getattr(pipeline.transformer.config, "time_cond_proj_dim", None) + if is_wan is True: + transformer_2.config.text_encoder_projection_dim = getattr(pipeline.transformer_2.config, "text_dim", None) + transformer_2.config.expand_timesteps = getattr(pipeline.config, "expand_timesteps", False) + transformer_2.config.vae_scale_factor_temporal = getattr(pipeline.vae.config, "scale_factor_temporal", 4) + transformer_2.config.vae_scale_factor_spatial = getattr(pipeline.vae.config, "scale_factor_spatial", 8) + vace_layers = getattr(pipeline.transformer_2.config, "vace_layers", None) + if vace_layers is not None: + transformer_2.config.vace_num_layers = len(vace_layers) + transformer_2.config.vace_in_channels = getattr(pipeline.transformer_2.config, "vace_in_channels", 96) + transformer_2.config.vocab_size = 256384 + else: + transformer_2.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim + transformer_2.config.export_model_type = _get_diffusers_submodel_type(transformer) + models_for_export["transformer_2"] = transformer_2 + # VAE Encoder vae_encoder = copy.deepcopy(pipeline.vae) - # we return the distribution parameters to be able to recreate it in the decoder - vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters} + if vae_encoder.__class__.__name__ == "AutoencoderKLWan": + vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample).latent_dist.mode()} + transformer.config.z_dim = getattr(vae_encoder.config, "z_dim", None) + if transformer_2 is not None: + transformer_2.config.z_dim = getattr(vae_encoder.config, "z_dim", None) + else: + vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters} + models_for_export["vae_encoder"] = vae_encoder # VAE Decoder vae_decoder = copy.deepcopy(pipeline.vae) - vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) + vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) models_for_export["vae_decoder"] = vae_decoder return models_for_export @@ -332,32 +381,40 @@ def get_diffusion_models_for_export( models_for_export["transformer"] = (models_for_export["transformer"], transformer_export_config) # VAE Encoder - vae_encoder = models_for_export["vae_encoder"] - vae_config_constructor = TasksManager.get_exporter_config_constructor( - model=vae_encoder, - exporter=exporter, - library_name="diffusers", - task="semantic-segmentation", - model_type="vae-encoder", - ) - vae_encoder_export_config = vae_config_constructor( - vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype - ) - models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config) + if "vae_encoder" in models_for_export: + vae_encoder = models_for_export["vae_encoder"] + encoder_model_type = "vae-encoder" + if vae_encoder.__class__.__name__ == "AutoencoderKLWan": + encoder_model_type = "vae-encoder-video" + vae_config_constructor = TasksManager.get_exporter_config_constructor( + model=vae_encoder, + exporter=exporter, + library_name="diffusers", + task="semantic-segmentation", + model_type=encoder_model_type, + ) + vae_encoder_export_config = vae_config_constructor( + vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config) # VAE Decoder - vae_decoder = models_for_export["vae_decoder"] - vae_config_constructor = TasksManager.get_exporter_config_constructor( - model=vae_decoder, - exporter=exporter, - library_name="diffusers", - task="semantic-segmentation", - model_type="vae-decoder", - ) - vae_decoder_export_config = vae_config_constructor( - vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype - ) - models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config) + if "vae_decoder" in models_for_export: + vae_decoder = models_for_export["vae_decoder"] + decoder_model_type = "vae-decoder" + if vae_decoder.__class__.__name__ == "AutoencoderKLWan": + decoder_model_type = "vae-decoder-video" + vae_config_constructor = TasksManager.get_exporter_config_constructor( + model=vae_decoder, + exporter=exporter, + library_name="diffusers", + task="semantic-segmentation", + model_type=decoder_model_type, + ) + vae_decoder_export_config = vae_config_constructor( + vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype + ) + models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config) return models_for_export diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 5019ea47f4..4afb3bbf37 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -92,7 +92,9 @@ DummyVisionEmbeddingsGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator, DummyVisionInputGenerator, + DummyVideoInputGenerator, DummyXPathSeqInputGenerator, + DummyWanTimestepInputGenerator, FalconDummyPastKeyValuesGenerator, GemmaDummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 1e0a647c1b..0c3d7a557b 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -56,6 +56,10 @@ def wrapper(*args, **kwargs): "feature_size": 80, "nb_max_frames": 3000, "audio_sequence_length": 16000, + # video + "num_frames": 2, + "video_width": 128, + "video_height": 128, } @@ -923,6 +927,93 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ) +class DummyVideoInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("hidden_states", "sample", "latent_sample") + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_frames: int = DEFAULT_DUMMY_SHAPES["num_frames"], + video_height: int = DEFAULT_DUMMY_SHAPES["video_height"], + video_width: int = DEFAULT_DUMMY_SHAPES["video_width"], + **kwargs, + ): + self.task = task + self.normalized_config = normalized_config + + self.in_channels = self.normalized_config.in_channels + self.latent_channels = getattr(self.normalized_config, "z_dim", None) + + self.batch_size = batch_size + self.num_frames = num_frames + self.video_height = video_height + self.video_width = video_width + + self.scale_factor_temporal = getattr(self.normalized_config, "scale_factor_temporal", None) + self.scale_factor_spatial = getattr(self.normalized_config, "scale_factor_spatial", None) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "latent_sample": + return self.random_float_tensor( + shape=[self.batch_size, + self.latent_channels, + 1 + ((self.num_frames - 1) // self.scale_factor_temporal), + self.video_height // self.scale_factor_spatial, + self.video_width // self.scale_factor_spatial], + framework=framework, + dtype=float_dtype, + ) + return self.random_float_tensor( + shape=[self.batch_size, self.in_channels, self.num_frames, self.video_height, self.video_width], + framework=framework, + dtype=float_dtype, + ) + +class DummyWanTimestepInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("timestep") + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_frames: int = DEFAULT_DUMMY_SHAPES["num_frames"], + video_height: int = DEFAULT_DUMMY_SHAPES["video_height"], + video_width: int = DEFAULT_DUMMY_SHAPES["video_width"], + **kwargs, + ): + self.task = task + self.normalized_config = normalized_config + + self.in_channels = self.normalized_config.in_channels + self.expand_timesteps = self.normalized_config.expand_timesteps + self.vae_scale_factor_temporal = self.normalized_config.vae_scale_factor_temporal + self.vae_scale_factor_spatial = self.normalized_config.vae_scale_factor_spatial + + self.batch_size = batch_size + self.num_frames = num_frames + self.video_height = video_height + self.video_width = video_width + + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if self.expand_timesteps is True: #Wan2.2 + num_latent_frames = (self.num_frames - 1) // self.vae_scale_factor_temporal + 1 + num_latent_height = self.video_height // self.vae_scale_factor_spatial + num_latent_width = self.video_width // self.vae_scale_factor_spatial + return self.random_float_tensor( + shape=[self.batch_size, num_latent_frames * (num_latent_height // 2) * (num_latent_width //2 )], + framework=framework, + dtype=float_dtype, + ) + return self.random_float_tensor( + shape=[self.batch_size], + framework=framework, + dtype=float_dtype, + ) + class DummyTimestepInputGenerator(DummyInputGenerator): """ Generates dummy time step inputs.