Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 339 additions & 0 deletions egs/fluent_speech_commands/SLU/transducer/export-onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
#!/usr/bin/env python3
#
# Copyright 2025 Xiaomi Corporation (Author: Fangjun Kuang)

import logging
from pathlib import Path
from typing import Dict, Tuple

import onnx
import torch
from decode import get_params, get_parser, get_transducer_model
from onnxruntime.quantization import QuantType, quantize_dynamic
from torch import nn
from transducer.conformer import Conformer

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info


def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.

Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)

while len(model.metadata_props):
model.metadata_props.pop()

for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)

onnx.save(model, filename)


class OnnxEncoder(nn.Module):
"""A wrapper for Conformer"""

def __init__(self, encoder: Conformer):
"""
Args:
encoder:
A Conformer encoder.
"""
super().__init__()
self.encoder = encoder

def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Conformer.forward

Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
# Now encoder_out is of shape (N, T, joiner_dim)

return encoder_out, encoder_out_lens


def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:

- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64

and it has two outputs:

- encoder_out, a tensor of shape (N, T', C)
- encoder_out_lens, a tensor of shape (N,)

Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 23, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)

torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)

meta_data = {
"model_type": "conformer",
"version": "1",
"model_author": "k2-fsa",
"comment": "SLU_transducer",
"note": "The decoder is an LSTM with states",
}
logging.info(f"meta_data: {meta_data}")

add_meta_data(filename=encoder_filename, meta_data=meta_data)


def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.

The exported model has 3 inputs:

- y: a torch.int64 tensor of shape (N, 1)
- h: a float32 tensor of shape (num_layers, N, hidden_dim)
- c: a float32 tensor of shape (num_layers, N, hidden_dim)

and has 3 outputs:

- decoder_out: a torch.float32 tensor of shape (N, 1, C)

Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(1, 1, dtype=torch.int64)
_, (h, c) = decoder_model(y)

torch.onnx.export(
decoder_model,
(y, (h, c)),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "h", "c"],
output_names=["decoder_out", "next_h", "next_c"],
dynamic_axes={
"y": {0: "N"},
"h": {1: "N"},
"c": {1: "N"},
},
)

meta_data = {
"num_layers": h.shape[0],
"hidden_dim": h.shape[2],
}
print("decoder meta_data", meta_data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use logging.info() for consistency.

Replace the print statement with logging.info() to maintain consistency with the rest of the codebase.

-    print("decoder meta_data", meta_data)
+    logging.info(f"decoder meta_data: {meta_data}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("decoder meta_data", meta_data)
- print("decoder meta_data", meta_data)
+ logging.info(f"decoder meta_data: {meta_data}")
🤖 Prompt for AI Agents
In egs/fluent_speech_commands/SLU/transducer/export-onnx.py at line 179, replace
the print statement used for outputting "decoder meta_data" with logging.info()
to maintain consistency with the project's logging practices. Ensure logging is
properly imported and configured if not already done.

add_meta_data(filename=decoder_filename, meta_data=meta_data)


def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
hidden_dim: int,
vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.

The exported model has 2 inputs:

- encoder_out: a float32 tensor of shape (N, 1, enc_dim)
- decoder_out: a float32 tensor of shape (N, 1, dec_dim)

and has 1 output:

- logits: a torch.float32 tensor of shape (N, 1, vocab_size)

Args:
joiner_model:
The decoder model to be exported.
joiner_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
enc_out = torch.zeros(1, 1, hidden_dim, dtype=torch.float32)
dec_out = torch.zeros(1, 1, hidden_dim, dtype=torch.float32)

torch.onnx.export(
joiner_model,
(enc_out, dec_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out"],
output_names=["logits"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
},
)

meta_data = {
"vocab_size": vocab_size,
"hidden_dim": hidden_dim,
}
print("joiner meta_data", meta_data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use logging.info() for consistency.

Replace the print statement with logging.info() to maintain consistency with the rest of the codebase.

-    print("joiner meta_data", meta_data)
+    logging.info(f"joiner meta_data: {meta_data}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("joiner meta_data", meta_data)
logging.info(f"joiner meta_data: {meta_data}")
🤖 Prompt for AI Agents
In egs/fluent_speech_commands/SLU/transducer/export-onnx.py at line 230, replace
the print statement with logging.info() to maintain consistency with the rest of
the codebase. Ensure that logging is properly imported and configured if not
already done.

add_meta_data(filename=joiner_filename, meta_data=meta_data)


@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)

params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()

device = torch.device("cpu")
logging.info(f"device: {device}")
model = get_transducer_model(params)

if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
Comment on lines +250 to +256
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Move loop-invariant condition outside the loop.

The condition if start >= 0: doesn't depend on the loop variable and should be checked once outside the loop for better efficiency.

-        start = params.epoch - params.avg + 1
-        filenames = []
-        for i in range(start, params.epoch + 1):
-            if start >= 0:
-                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        start = params.epoch - params.avg + 1
+        filenames = []
+        if start >= 0:
+            for i in range(start, params.epoch + 1):
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
start = params.epoch - params.avg + 1
filenames = []
if start >= 0:
for i in range(start, params.epoch + 1):
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
🤖 Prompt for AI Agents
In egs/fluent_speech_commands/SLU/transducer/export-onnx.py around lines 250 to
256, the condition 'if start >= 0:' is inside the loop but does not depend on
the loop variable. Move this condition outside the loop to check it once before
iterating, and only run the loop to append filenames if 'start' is greater than
or equal to zero. This will improve efficiency by avoiding redundant checks in
each iteration.


model.to(device)
model.eval()
model.device = device

encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")

encoder = OnnxEncoder(
encoder=model.encoder,
)

opset_version = 13

logging.info("Exporting encoder")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")

logging.info("Exporting decoder")
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")

logging.info("Exporting joiner")
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
hidden_dim=params.hidden_dim,
vocab_size=params.vocab_size,
)
logging.info(f"Exported decoder to {joiner_filename}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect log message.

The log message incorrectly states "decoder" instead of "joiner".

-    logging.info(f"Exported decoder to {joiner_filename}")
+    logging.info(f"Exported joiner to {joiner_filename}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logging.info(f"Exported decoder to {joiner_filename}")
logging.info(f"Exported joiner to {joiner_filename}")
🤖 Prompt for AI Agents
In egs/fluent_speech_commands/SLU/transducer/export-onnx.py at line 304, the log
message incorrectly refers to "decoder" instead of "joiner". Update the
logging.info message to correctly state "Exported joiner to {joiner_filename}"
to accurately reflect the exported component.


# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

logging.info("Generate int8 quantization models")

encoder_filename_int8 = params.exp_dir / "encoder.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)

decoder_filename_int8 = params.exp_dir / "decoder.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

joiner_filename_int8 = params.exp_dir / "joiner.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
Loading