Skip to content

ONNX export of BSRNN model #29

@sangeet2020

Description

@sangeet2020

Hi,
I am struggling with onnx export of the averaged model.

class OnnxBSRNN(nn.Module):
    """Wraps the BSRNN model for ONNX export."""
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, mix_features, enroll_embeddings):
        return self.model(mix_features, enroll_embeddings)[0]


model = get_model(config_data["model"]["tse_model"])(**config_data["model_args"]["tse_model"])

load_pretrained_model(model, args.model_checkpoint)
model.eval()

onnx_model_wrapper = OnnxBSRNN(model)
traced_model = torch.jit.trace(onnx_model_wrapper, (mix_features, enroll_embeddings))

torch.onnx.export(
    traced_model, 
    (mix_features, enroll_embeddings),
    onnx_filename,
    opset_version=17,
    input_names=["mix_features", "enroll_embeddings"],
    output_names=["output"],
    dynamic_axes={
        "mix_features": {0: "batch_size", 1: "sequence_length"},
        "enroll_embeddings": {0: "batch_size", 1: "num_frames"},
        "output": {0: "batch_size", 1: "sequence_length"},
    },
    do_constant_folding=True,
    export_params=True,
    verbose=True
)

but i consistently run into this issue:

Traceback (most recent call last):
  File "/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py", line 141, in <module>
    main()
  File "/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py", line 133, in main
    export_to_onnx(onnx_model_wrapper, (mix_features, enroll_embeddings), onnx_filepath)
  File "/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py", line 71, in export_to_onnx
    torch.onnx.export(
  File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 1596, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 1139, in _model_to_graph
    graph = _optimize_graph(
            ^^^^^^^^^^^^^^^^
  File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/utils.py", line 1940, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
    return fn(g, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/onnx/symbolic_opset17.py", line 105, in stft
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: STFT does not currently support complex types  [Caused by the value 'input.3 defined in (%input.3 : Float(*, *, device=cpu) = onnx::Reshape[allowzero=0](%input, %849), scope: OnnxBSRNN::/wesep.models.bsrnn.BSRNN::model # /mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/functional.py:649:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Reshape'.]
    (node defined in /mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/functional.py(649): stft
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/models/bsrnn.py(309): forward
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1527): _call_impl
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py(51): forward
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1527): _call_impl
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/jit/_trace.py(1065): trace_module
/mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/jit/_trace.py(798): trace
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py(69): export_to_onnx
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py(133): main
/mnt/local/sangeet/workncode/wesep/examples/librimix/tse/v2/wesep/bin/model_export_onnx.py(141): <module>
)

    Inputs:
        #0: input defined in (%input : Float(*, *, *, device=cpu) = onnx::Pad[mode="reflect"](%815, %839), scope: OnnxBSRNN::/wesep.models.bsrnn.BSRNN::model # /mnt/users/sagarst/envs/wenet/lib/python3.11/site-packages/torch/functional.py:648:0
    )  (type 'Tensor')
        #1: 849 defined in (%849 : int[] = prim::ListConstruct(%844, %848), scope: OnnxBSRNN::/wesep.models.bsrnn.BSRNN::model
    )  (type 'List[int]')
    Outputs:
        #0: input.3 defined in (%input.3 : Float(*, *, device=cpu) = onnx::Reshape[allowzero=0](%input, %849), scope: OnnxBSRNN::/wesep.models.bsrnn.BSRNN::model # /mnt/users/sagarst/envs/wenet/lib/pyth

I tried exporting with other opset versions, older Pytorch version, but encountered the same error.
I also tried torch.onnx.dynamo_export, but that led to another error.

Any help on how this can be fixed?

Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions