-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run inference on a single wav file example? #1
Comments
Hi! So I managed to put something together that works on Colab but I just want to make sure I'm not doing anything stupid and messing up the results. I'd naturally want to get the most accurate results possible (use the best performing model). However I'm not sure which of the checkpoints is actually best. I tried the stoi-model (not the small one) but couldn't get it to work, was getting an error which I can share. The small seemed to work though and is what I used here. Here is the code that seems to work for me on Google Colab: !git clone https://github.com/vvvm23/stoi-vqcpc
%cd stoi-vqcpc
!pip install torch_audiomentations
import torch
from vqcpc import WaveVQCPC
from stoi import STOIPredictor
from data import WaveDataset
from utils import get_device
import matplotlib.pyplot as plt
import argparse
import toml
import random
from pathlib import Path
from types import SimpleNamespace
device = 'cuda'
stoi_path = '/content/stoi-vqcpc/checkpoints/stoi-gru128-small-0050000.pt'
stoi_cfg_path = '/content/stoi-vqcpc/config/vqcpc/stoi-gru128-small-kmean.toml'
stoi_cfg = SimpleNamespace(**toml.load(stoi_cfg_path))
chk = torch.load(stoi_path, map_location=device)
stoi = STOIPredictor(**stoi_cfg.model).to(device)
stoi.load_state_dict(chk['net'])
stoi.eval()
@torch.no_grad()
def get_embeddings(net, wav):
z = net.encoder(wav).squeeze(0)
c = net.aggregator(z.unsqueeze(0)).squeeze(0)
return z, c
@torch.no_grad()
def get_score(net, c):
frame_scores = net(c)
return frame_scores
vqcpc_path = '/content/stoi-vqcpc/checkpoints/gru128-kmeans-0095000.pt'
vqcpc_cfg_path = '/content/stoi-vqcpc/config/vqcpc/vqcpc-gru128-kmean.toml'
vqcpc_cfg = SimpleNamespace(**toml.load(vqcpc_cfg_path))
vqcpc_chk = torch.load(vqcpc_path, map_location=device)
device = 'cuda'
net = WaveVQCPC(**vqcpc_cfg.vqcpc).to(device)
net.load_state_dict(vqcpc_chk['net'])
net.eval()
wav_path = '/content/test_16k_stereo.wav'
wav, rate = torchaudio.load(wav_path)
#Is this correct? It seems to work
wav = wav.unsqueeze(0)
rate = 16000
#Since I want the best performance, I'll assume it's best to go with cuda
z, c = get_embeddings(net, wav = wav.to('cuda'))
frame_scores = get_score(stoi, c)
print(f"> predicted intelligibility score: {frame_scores.mean()}")
#And now we plot latents:
import matplotlib.pyplot as plt
def plot_waveform(waveform, sample_rate, xlim=None, ylim=None, axes=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
if axes is None:
figure, axes = plt.subplots(1, 1)
for c in range(num_channels):
axes.plot(time_axis, waveform[c], linewidth=1, color='blue' if c else 'red', alpha=0.5)
axes.set_xlim(left=0, right=num_frames/sample_rate)
axes.grid(True)
#fig, axes = plt.subplots(3, 1)
fig, axes = plt.subplots(3, 1, dpi=300)
fig.set_size_inches(18.5, 10.5)
fig.patch.set_facecolor('white')
plot_waveform(wav.squeeze().cpu(), 16_000, axes=axes[0])
axes[0].set_xlabel("time (s)")
axes[0].set_ylabel("amplitude (s)")
axes[1].plot(frame_scores.squeeze().cpu())
axes[1].set_xlim(left=0, right=frame_scores.shape[-1])
axes[1].set_xlabel("frames")
axes[1].set_ylabel("intelligibility score")
axes[2].imshow(c.squeeze().transpose(0, 1).cpu(), cmap='viridis', aspect='auto')
axes[2].set_xlabel("frames")
axes[2].set_yticks([])
fig.tight_layout()
plt.show()
fig.savefig('/content/test.png') Would you say these models are safe / ready to use for production? I know you mentioned "This public repository is a work in progress! Results here bear no resemblance to results in the paper!" But I have no idea to what degree it differs, and if it's still useful enough for real world scenarios or misleading. |
Hey! Thanks so much for making this. Is there some example code on how to run inference on a single wav file? I want to plot STOI predictions but not sure how to do this for one file.
The text was updated successfully, but these errors were encountered: