Hi,
I am struggling with onnx export of the averaged model.
class OnnxBSRNN(nn.Module):
"""Wraps the BSRNN model for ONNX export."""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, mix_features, enroll_embeddings):
return self.model(mix_features, enroll_embeddings)[0]
model = get_model(config_data["model"]["tse_model"])(**config_data["model_args"]["tse_model"])
load_pretrained_model(model, args.model_checkpoint)
model.eval()
onnx_model_wrapper = OnnxBSRNN(model)
traced_model = torch.jit.trace(onnx_model_wrapper, (mix_features, enroll_embeddings))
torch.onnx.export(
traced_model,
(mix_features, enroll_embeddings),
onnx_filename,
opset_version=17,
input_names=["mix_features", "enroll_embeddings"],
output_names=["output"],
dynamic_axes={
"mix_features": {0: "batch_size", 1: "sequence_length"},
"enroll_embeddings": {0: "batch_size", 1: "num_frames"},
"output": {0: "batch_size", 1: "sequence_length"},
},
do_constant_folding=True,
export_params=True,
verbose=True
)
but i consistently run into this issue:
Traceback (most recent call last):
File "/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py", line 141, in <module>
main()
File "/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py", line 133, in main
export_to_onnx(onnx_model_wrapper, (mix_features, enroll_embeddings), onnx_filepath)
File "/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py", line 71, in export_to_onnx
torch.onnx.export(
File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 1596, in _export
graph, params_dict, torch_out = _model_to_graph(
^^^^^^^^^^^^^^^^
File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 1139, in _model_to_graph
graph = _optimize_graph(
^^^^^^^^^^^^^^^^
File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 1940, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
return fn(g, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/symbolic_opset17.py", line 105, in stft
raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: STFT does not currently support complex types [Caused by the value 'input.3 defined in (%input.3 : Float(*, *, device=cpu) = onnx::Reshape[allowzero=0](%input, %849), scope: OnnxBSRNN::/wesep.models.bsrnn.BSRNN::model # /mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/functional.py:649:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Reshape'.]
(node defined in /mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/functional.py(649): stft
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/models/bsrnn.py(309): forward
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1527): _call_impl
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py(51): forward
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1527): _call_impl
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/jit/_trace.py(1065): trace_module
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/jit/_trace.py(798): trace
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py(69): export_to_onnx
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py(133): main
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py(141): <module>
)
Inputs:
#0: input defined in (%input : Float(*, *, *, device=cpu) = onnx::Pad[mode="reflect"](%815, %839), scope: OnnxBSRNN::/wesep.models.bsrnn.BSRNN::model # /mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/functional.py:648:0
) (type 'Tensor')
#1: 849 defined in (%849 : int[] = prim::ListConstruct(%844, %848), scope: OnnxBSRNN::/wesep.models.bsrnn.BSRNN::model
) (type 'List[int]')
Outputs:
#0: input.3 defined in (%input.3 : Float(*, *, device=cpu) = onnx::Reshape[allowzero=0](%input, %849), scope: OnnxBSRNN::/wesep.models.bsrnn.BSRNN::model # /mnt/users/sagarst/envs/wenet/lib/pyth
I tried exporting with other opset versions, older Pytorch version, but encountered the same error.
I also tried torch.onnx.dynamo_export, but that led to another error.
Any help on how this can be fixed?
Thank you
Hi,
I am struggling with onnx export of the averaged model.
but i consistently run into this issue:
I tried exporting with other opset versions, older Pytorch version, but encountered the same error.
I also tried
torch.onnx.dynamo_export, but that led to another error.Any help on how this can be fixed?
Thank you