Skip to content

CoreMLTools PyTorch converter fails on rotary-attention einsum (BeatThis model) with rank/perm mismatch #2644

@tillt

Description

@tillt

🐞Describing the bug

Converting the BeatThis model (rotary attention, ISMIR 2024) from PyTorch to CoreML fails during einsum lowering. The converter throws a rank/perm mismatch (perm length 3 != rank(x) 1). PyTorch inference works fine; this appears to be a coremltools converter bug for this einsum pattern.

Source: https://github.com/CPJKU/beat_this

Stack Trace

ERROR - converting 'einsum' op (located at: 'base/frontend/blocks/0/partial/attnF/x.13'):

Traceback (most recent call last):
  ...
  File ".../coremltools/converters/mil/frontend/_utils.py", line 573, in solve_generic_einsum
    parsed_vectors, vars = solve_diagonal_einsum(parsed_vectors, vars)
  File ".../coremltools/converters/mil/frontend/_utils.py", line 495, in solve_diagonal_einsum_one_step
    x = mb.transpose(x=x, perm=perm)
  File ".../coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py", line 988, in type_inference
    raise ValueError(msg.format(len(perm), self.x.rank))
ValueError: perm should have the same length as rank(x): 3 != 1

To Reproduce

import coremltools as ct
import torch
from beat_this.model.beat_tracker import BeatThis
from beat_this.utils import replace_state_dict_key

checkpoint = torch.load("beat_this/beat_this-final0.ckpt", map_location="cpu", weights_only=True)
hparams = {
    k: v for k, v in checkpoint["hyper_parameters"].items()
    if k in set(BeatThis.__init__.__code__.co_varnames)
}
model = BeatThis(**hparams)
state_dict = replace_state_dict_key(checkpoint["state_dict"], "model.", "")
model.load_state_dict(state_dict)
model.eval()

example = torch.randn(1, 1500, 128)
traced = torch.jit.trace(model, example, check_trace=False)

ct.convert(
    traced,
    inputs=[ct.TensorType(name="mel_spectrogram", shape=example.shape)],
)
  • If the model conversion succeeds, but there is a numerical mismatch in predictions, please include the code used for comparisons.

System environment (please complete the following information):

  • coremltools version: 9.0
  • macOS 15.7.3 (24G407) (arm64)
  • torch 2.2.0

Additional context

This happens in the rotary‑attention einsum pattern used by BeatThis. ONNX export with opset ≥12 works, but legacy ONNX conversion fails due to Slice inputs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    PyTorch (traced)bugUnexpected behaviour that should be corrected (type)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions