Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVideoInputGenerator,
DummyVisionInputGenerator,
logging,
)
Expand Down Expand Up @@ -410,3 +412,9 @@ def post_process_exported_models(
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1]._decoder_onnx_config.is_merged = True

return models_and_onnx_configs, onnx_files_subpaths


class VideoOnnxConfig(OnnxConfig):
"""Handles video architectures."""

DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoInputGenerator, DummyTimestepInputGenerator)
Comment on lines +417 to +420
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think an abstract video onnx config is needed as it doesn't really abstract much here

164 changes: 164 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TextDecoderWithPositionIdsOnnxConfig,
TextEncoderOnnxConfig,
TextSeq2SeqOnnxConfig,
VideoOnnxConfig,
VisionOnnxConfig,
)
from optimum.exporters.onnx.input_generators import (
Expand Down Expand Up @@ -83,9 +84,11 @@
DummyTransformerTextInputGenerator,
DummyTransformerTimestepInputGenerator,
DummyTransformerVisionInputGenerator,
DummyVideoInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
DummyWanTimestepInputGenerator,
DummyXPathSeqInputGenerator,
FalconDummyPastKeyValuesGenerator,
GemmaDummyPastKeyValuesGenerator,
Expand Down Expand Up @@ -1385,6 +1388,40 @@ class SiglipVisionModelOnnxConfig(CLIPVisionModelOnnxConfig):
pass


@register_tasks_manager_onnx("unet-3d-condition", *["semantic-segmentation"], library_name="diffusers")
class UNet3DOnnxConfig(VideoOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
in_channels="in_channels",
hidden_size="text_encoder_projection_dim",
vocab_size="vocab_size",
allow_new=True,
)
DUMMY_INPUT_GENERATOR_CLASSES = (
*VideoOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES,
DummyTransformerTextInputGenerator,
)

@property
def inputs(self) -> dict[str, dict[int, str]]:
return {
"sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"},
"timestep": {}, # a scalar with no dimension
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self) -> dict[str, dict[int, str]]:
return {
"out_sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"},
}

@property
def torch_to_onnx_output_map(self) -> dict[str, str]:
return {
"sample": "out_sample",
}


@register_tasks_manager_onnx("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers")
class UNetOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
Expand Down Expand Up @@ -2852,3 +2889,130 @@ def outputs(self) -> dict[str, dict[int, str]]:
3: f"latent_width * {up_sampling_factor}",
}
}


@register_tasks_manager_onnx("umt5-encoder", *["feature-extraction"], library_name="diffusers")
class UMT5EncoderOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = version.parse("4.46.0")

@property
def inputs(self):
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self):
return {"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}


@register_tasks_manager_onnx("wan-transformer-3d", *["semantic-segmentation"], library_name="diffusers")
class WanTransformer3DOnnxConfig(VideoOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
in_channels="in_channels",
out_channels="out_channels",
hidden_size="text_dim",
z_dim="z_dim",
expand_timesteps="expand_timesteps",
scale_factor_temporal="vae_scale_factor_temporal",
scale_factor_spatial="vae_scale_factor_spatial",
vocab_size="vocab_size",
allow_new=True,
)
MIN_TRANSFORMERS_VERSION = version.parse("4.46.0")
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTextInputGenerator,
DummyWanTimestepInputGenerator,
DummyVideoInputGenerator,
)

@property
def inputs(self) -> dict[str, dict[int, str]]:
if self._normalized_config.expand_timesteps is True:
return {
"latent_sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
"timestep": {0: "batch_size", 1: "seq_len"},
}
return {
"latent_sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
"timestep": {0: "batch_size"},
}

@property
def outputs(self) -> dict[str, dict[int, str]]:
return {
"sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"},
}

def rename_ambiguous_inputs(self, inputs):
# The input name in the model signature is `x, hence the export input name is updated.
model_inputs = inputs
model_inputs["hidden_states"] = inputs["latent_sample"]
model_inputs.pop("latent_sample")

return model_inputs


@register_tasks_manager_onnx("vae-encoder-video", *["semantic-segmentation"], library_name="diffusers")
class VaeEncoderVideoOnnxConfig(VaeEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
in_channels="in_channels",
z_dim="z_dim",
scale_factor_temporal="scale_factor_temporal",
scale_factor_spatial="scale_factor_spatial",
allow_new=True,
)
MIN_TRANSFORMERS_VERSION = version.parse("4.46.0")
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoInputGenerator,)

@property
def inputs(self) -> dict[str, dict[int, str]]:
return {
"sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"},
}

@property
def outputs(self) -> dict[str, dict[int, str]]:
return {
"latent_parameters": {
0: "batch_size",
2: f"1 + ( num_frames - 1 ) // {self._normalized_config.scale_factor_temporal}",
3: f"height / {self._normalized_config.scale_factor_spatial}",
4: f"width / {self._normalized_config.scale_factor_spatial}",
}
}


@register_tasks_manager_onnx("vae-decoder-video", *["semantic-segmentation"], library_name="diffusers")
class VaeDecoderVideoOnnxConfig(VaeEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
in_channels="in_channels",
out_channels="out_channels",
z_dim="z_dim",
scale_factor_temporal="scale_factor_temporal",
scale_factor_spatial="scale_factor_spatial",
allow_new=True,
)
MIN_TRANSFORMERS_VERSION = version.parse("4.46.0")
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoInputGenerator,)

@property
def inputs(self) -> dict[str, dict[int, str]]:
return {
"latent_sample": {0: "batch_size", 2: "latent_num_frames", 3: "latent_height", 4: "latent_width"},
}

@property
def outputs(self) -> dict[str, dict[int, str]]:
return {
"sample": {
0: "batch_size",
2: f"1 + ( latent_num_frames - 1 ) * {self._normalized_config.scale_factor_temporal}",
3: f"latent_height * {self._normalized_config.scale_factor_spatial}",
4: f"latent_width * {self._normalized_config.scale_factor_spatial}",
}
}
24 changes: 24 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,30 @@ def __ior_(g, self: torch._C.Value, other: torch._C.Value) -> torch._C.Value:

torch.onnx.register_custom_op_symbolic("aten::__ior__", __ior_, 14)


@symbolic_helper.parse_args("v", "v", "v")
def upsample_nearest_exact_symbolic(g, input, output_size, scale_h=None) -> torch._C.Value:
# Compute scales from scale_h
scales = g.op("Concat", g.op("Constant", value_t=torch.tensor([1.0, 1.0], dtype=torch.float32)), scale_h, axis_i=0)
empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))

return g.op(
"Resize",
input,
empty_roi, # roi (unused for nearest)
scales,
mode_s="nearest",
coordinate_transformation_mode_s="half_pixel",
nearest_mode_s="round_prefer_floor",
)


torch.onnx.register_custom_op_symbolic(
"aten::_upsample_nearest_exact2d", # PyTorch op name
upsample_nearest_exact_symbolic, # Your symbolic function
18, # Target ONNX opset
)

if is_torch_version("<", "2.9"):
# this was fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973
from torch.onnx import JitScalarType
Expand Down
5 changes: 5 additions & 0 deletions optimum/onnxruntime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@
DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx"
DECODER_MERGED_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?merged(.*)?\.onnx"
ONNX_FILE_PATTERN = r".*\.onnx$"

# Some newer text-to-video pipelines such as Wan handles the encoder-decoder scaling at a model levels instead of pipeline level.
ENCODER_DECODER_HANDLES_SCALING_FACTOR = [
"AutoencoderKLWan",
]
Loading