Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 240 additions & 0 deletions scripts/benchmark_decoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
#!/usr/bin/env python
import argparse
import json
import time
from pathlib import Path

import numpy as np
import torch
from scipy.io import wavfile

import torchcrepe


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--audio',
type=Path,
help='Audio file for real-audio benchmarks')
parser.add_argument('--repeat-audio', type=int, default=1,
help='Repeat audio N times to create a longer benchmark clip')
parser.add_argument('--device', default='cpu',
help='Inference device, e.g. cpu or mps')
parser.add_argument('--model', default='full', choices=('tiny', 'full'))
parser.add_argument('--batch-size', type=int, default=None)
parser.add_argument('--frames', type=int, nargs='+', default=[512, 2048])
parser.add_argument('--warmup', type=int, default=2)
parser.add_argument('--iters', type=int, default=5)
parser.add_argument('--json', action='store_true')
return parser.parse_args()


def synchronize(device):
if isinstance(device, str) and device.startswith('cuda') and torch.cuda.is_available():
torch.cuda.synchronize(device)
elif device == 'mps' and torch.backends.mps.is_available():
torch.mps.synchronize()


def load_audio(path):
sample_rate, audio = wavfile.read(path)
if audio.ndim == 1:
audio = audio[None, :]
else:
audio = audio.T
if np.issubdtype(audio.dtype, np.integer):
max_value = max(np.iinfo(audio.dtype).max, 1)
audio = audio.astype(np.float32) / max_value
else:
audio = audio.astype(np.float32, copy=False)
return torch.from_numpy(audio), int(sample_rate)


def time_call(fn, *, warmup, iters, device):
for _ in range(warmup):
fn()
synchronize(device)
durations_ms = []
for _ in range(iters):
synchronize(device)
start = time.perf_counter()
fn()
synchronize(device)
durations_ms.append((time.perf_counter() - start) * 1000.0)
durations_ms.sort()
return durations_ms[len(durations_ms) // 2]


def benchmark_decoder_core(args):
results = []
generator = torch.Generator().manual_seed(0)
for frames in args.frames:
probabilities = torch.rand((1, torchcrepe.PITCH_BINS, frames),
generator=generator)
legacy_bins, _ = torchcrepe.decode.viterbi_legacy(probabilities, dither=False)
fast_bins, _ = torchcrepe.decode.viterbi_banded_fast(probabilities, dither=False)
results.append({
'frames': frames,
'parity_ok': bool(torch.equal(legacy_bins, fast_bins)),
'viterbi_legacy_ms': time_call(
lambda: torchcrepe.decode.viterbi_legacy(probabilities, dither=False),
warmup=args.warmup,
iters=args.iters,
device='cpu',
),
'viterbi_banded_fast_ms': time_call(
lambda: torchcrepe.decode.viterbi_banded_fast(probabilities, dither=False),
warmup=args.warmup,
iters=args.iters,
device='cpu',
),
})
return results


def benchmark_real_audio(args):
if args.audio is None:
raise ValueError('--audio is required for real-audio benchmarks')
audio_path = args.audio
audio, sample_rate = load_audio(str(audio_path))
if args.repeat_audio > 1:
audio = audio.repeat(1, args.repeat_audio)

infer_device = args.device

def predict(decoder_name):
return torchcrepe.predict(
audio,
sample_rate,
model=args.model,
decoder=torchcrepe.decode.get_decoder(decoder_name),
batch_size=args.batch_size,
device=infer_device,
)

preprocessed = list(torchcrepe.preprocess(audio,
sample_rate,
None,
args.batch_size,
infer_device,
True))
probabilities = []
with torch.no_grad():
for frames in preprocessed:
batch = torchcrepe.infer(frames, args.model, infer_device, embed=False)
batch = batch.reshape(audio.size(0), -1, torchcrepe.PITCH_BINS).transpose(1, 2)
probabilities.append(batch.detach().to('cpu'))
probabilities = torch.cat(probabilities, dim=2)

legacy_bins, _ = torchcrepe.decode.viterbi_legacy(probabilities, dither=False)
fast_bins, _ = torchcrepe.decode.viterbi_banded_fast(probabilities, dither=False)

results = {
'audio': str(audio_path),
'repeat_audio': args.repeat_audio,
'frames': int(probabilities.shape[-1]),
'decoder_parity_ok': bool(torch.equal(legacy_bins, fast_bins)),
'decode_cached_viterbi_legacy_ms': time_call(
lambda: torchcrepe.decode.viterbi_legacy(probabilities, dither=False),
warmup=args.warmup,
iters=args.iters,
device='cpu',
),
'decode_cached_viterbi_banded_fast_ms': time_call(
lambda: torchcrepe.decode.viterbi_banded_fast(probabilities, dither=False),
warmup=args.warmup,
iters=args.iters,
device='cpu',
),
'predict_viterbi_legacy_ms': time_call(
lambda: predict('viterbi_legacy'),
warmup=args.warmup,
iters=args.iters,
device=infer_device,
),
'predict_viterbi_banded_fast_ms': time_call(
lambda: predict('viterbi_banded_fast'),
warmup=args.warmup,
iters=args.iters,
device=infer_device,
),
'predict_viterbi_ms': time_call(
lambda: predict('viterbi'),
warmup=args.warmup,
iters=args.iters,
device=infer_device,
),
}
return results


def render_markdown(core_results, real_audio_result, args):
lines = [
'## Torchcrepe Decoder Benchmark',
'',
f'**Device:** `{args.device}`',
f'**Model:** `{args.model}`',
f'**Batch size:** `{args.batch_size}`',
'',
'---',
'',
'## Synthetic Decoder Core',
'',
'| Frames | Legacy | Fast | Speedup | Parity |',
'|------:|------:|------:|------:|:------:|',
]
for result in core_results:
speedup = result['viterbi_legacy_ms'] / result['viterbi_banded_fast_ms']
lines.append(
f"| {result['frames']} | **{result['viterbi_legacy_ms']:.3f} ms** | "
f"**{result['viterbi_banded_fast_ms']:.3f} ms** | "
f"`{speedup:.2f}x` | {'✅' if result['parity_ok'] else '❌'} |"
)

lines.extend([
'',
'---',
'',
'## Real Audio Cached Decoder',
'',
'| Decoder | Time |',
'|:--|--:|',
f"| `viterbi_legacy` | **{real_audio_result['decode_cached_viterbi_legacy_ms']:.3f} ms** |",
f"| `viterbi_banded_fast` | **{real_audio_result['decode_cached_viterbi_banded_fast_ms']:.3f} ms** |",
'',
'---',
'',
'## Real Audio Predict Path',
'',
f"**Audio:** `{real_audio_result['audio']}`",
f"**Frames:** `{real_audio_result['frames']}`",
f"**Repeated:** `{real_audio_result['repeat_audio']}`",
'',
'| Decoder | Time |',
'|:--|--:|',
f"| `viterbi_legacy` | **{real_audio_result['predict_viterbi_legacy_ms']:.3f} ms** |",
f"| `viterbi_banded_fast` | **{real_audio_result['predict_viterbi_banded_fast_ms']:.3f} ms** |",
f"| `viterbi` | **{real_audio_result['predict_viterbi_ms']:.3f} ms** |",
'',
f"> {'✅' if real_audio_result['decoder_parity_ok'] else '❌'} "
f"Decoder parity on cached model outputs: `{real_audio_result['decoder_parity_ok']}`",
])
return '\n'.join(lines)


def main():
args = parse_args()
core_results = benchmark_decoder_core(args)
real_audio_result = benchmark_real_audio(args)
payload = {
'core': core_results,
'real_audio': real_audio_result,
}
if args.json:
print(json.dumps(payload, indent=2))
return
print(render_markdown(core_results, real_audio_result, args))


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Test package marker for unittest discovery.
18 changes: 18 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from torchcrepe import __main__
import torchcrepe


def test_parse_args_accepts_decoder_variants():
args = __main__.parse_args([
'--audio_files', 'in.wav',
'--output_files', 'out.pt',
'--decoder', 'viterbi_banded_fast',
])
assert args.decoder == 'viterbi_banded_fast'


def test_get_decoder_returns_expected_function():
assert (
torchcrepe.decode.get_decoder('viterbi_legacy')
is torchcrepe.decode.viterbi_legacy
)
26 changes: 25 additions & 1 deletion tests/test_decode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
import torchcrepe
import pytest


###############################################################################
Expand All @@ -13,3 +13,27 @@ def test_weighted_argmax_decode():
"""Tests that weighted argmax decode works without CUDA assertion error"""
fake_logits = torch.rand(8, 360, 128, device="cuda")
decoded = torchcrepe.decode.weighted_argmax(fake_logits)


def test_viterbi_banded_fast_matches_legacy():
generator = torch.Generator().manual_seed(0)
for frames in (16, 128, 1024):
probabilities = torch.rand((2, torchcrepe.PITCH_BINS, frames),
generator=generator)
legacy_bins, legacy_pitch = torchcrepe.decode.viterbi_legacy(
probabilities, dither=False)
fast_bins, fast_pitch = torchcrepe.decode.viterbi_banded_fast(
probabilities, dither=False)
assert torch.equal(legacy_bins, fast_bins)
assert torch.allclose(legacy_pitch, fast_pitch)


def test_viterbi_alias_preserves_legacy_path():
generator = torch.Generator().manual_seed(1)
probabilities = torch.rand((1, torchcrepe.PITCH_BINS, 256),
generator=generator)
legacy_bins, legacy_pitch = torchcrepe.decode.viterbi_legacy(
probabilities, dither=False)
bins, pitch = torchcrepe.decode.viterbi(probabilities, dither=False)
assert torch.equal(legacy_bins, bins)
assert torch.allclose(legacy_pitch, pitch)
19 changes: 7 additions & 12 deletions torchcrepe/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
###############################################################################


def parse_args():
def parse_args(argv=None):
"""Parse command-line arguments"""
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -65,8 +65,8 @@ def parse_args():
parser.add_argument(
'--decoder',
default='viterbi',
help='The decoder to use. One of "argmax", "viterbi", or ' +
'"weighted_argmax"')
choices=sorted(torchcrepe.decode.DECODERS),
help='The decoder to use')
parser.add_argument(
'--batch_size',
type=int,
Expand All @@ -80,7 +80,7 @@ def parse_args():
action='store_true',
help='Whether to pad the audio')

return parser.parse_args()
return parser.parse_args(argv)


def make_parent_directory(file):
Expand Down Expand Up @@ -113,12 +113,7 @@ def main():
device = 'cpu' if args.gpu is None else f'cuda:{args.gpu}'

# Get decoder
if args.decoder == 'argmax':
decoder = torchcrepe.decode.argmax
elif args.decoder == 'weighted_argmax':
decoder = torchcrepe.decode.weighted_argmax
elif args.decoder == 'viterbi':
decoder = torchcrepe.decode.viterbi
decoder = torchcrepe.decode.get_decoder(args.decoder)

# Infer pitch or embedding and save to disk
if args.embed:
Expand All @@ -144,5 +139,5 @@ def main():
not args.no_pad)


# Run module entry point
main()
if __name__ == '__main__':
main()
Loading