Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run inference on a single wav file example? #1

Open
youssefabdelm opened this issue Sep 16, 2022 · 1 comment
Open

Run inference on a single wav file example? #1

youssefabdelm opened this issue Sep 16, 2022 · 1 comment

Comments

@youssefabdelm
Copy link

Hey! Thanks so much for making this. Is there some example code on how to run inference on a single wav file? I want to plot STOI predictions but not sure how to do this for one file.

@youssefabdelm
Copy link
Author

youssefabdelm commented Sep 20, 2022

Hi! So I managed to put something together that works on Colab but I just want to make sure I'm not doing anything stupid and messing up the results. I'd naturally want to get the most accurate results possible (use the best performing model). However I'm not sure which of the checkpoints is actually best. I tried the stoi-model (not the small one) but couldn't get it to work, was getting an error which I can share. The small seemed to work though and is what I used here.

Here is the code that seems to work for me on Google Colab:

!git clone https://github.com/vvvm23/stoi-vqcpc
%cd stoi-vqcpc
!pip install torch_audiomentations

import torch

from vqcpc import WaveVQCPC
from stoi import STOIPredictor
from data import WaveDataset
from utils import get_device

import matplotlib.pyplot as plt
import argparse
import toml
import random
from pathlib import Path
from types import SimpleNamespace



device = 'cuda'
stoi_path = '/content/stoi-vqcpc/checkpoints/stoi-gru128-small-0050000.pt'
stoi_cfg_path = '/content/stoi-vqcpc/config/vqcpc/stoi-gru128-small-kmean.toml'
stoi_cfg = SimpleNamespace(**toml.load(stoi_cfg_path))
chk = torch.load(stoi_path, map_location=device)
stoi = STOIPredictor(**stoi_cfg.model).to(device)
stoi.load_state_dict(chk['net'])
stoi.eval()

@torch.no_grad()
def get_embeddings(net, wav):
    z = net.encoder(wav).squeeze(0)
    c = net.aggregator(z.unsqueeze(0)).squeeze(0)
    return z, c
@torch.no_grad()
def get_score(net, c):
    frame_scores = net(c)
    return frame_scores



vqcpc_path = '/content/stoi-vqcpc/checkpoints/gru128-kmeans-0095000.pt'
vqcpc_cfg_path = '/content/stoi-vqcpc/config/vqcpc/vqcpc-gru128-kmean.toml'
vqcpc_cfg = SimpleNamespace(**toml.load(vqcpc_cfg_path))
vqcpc_chk = torch.load(vqcpc_path, map_location=device)
device = 'cuda'
net = WaveVQCPC(**vqcpc_cfg.vqcpc).to(device)
net.load_state_dict(vqcpc_chk['net'])
net.eval()


wav_path = '/content/test_16k_stereo.wav'
wav, rate = torchaudio.load(wav_path)


#Is this correct? It seems to work
wav = wav.unsqueeze(0)

rate = 16000


#Since I want the best performance, I'll assume it's best to go with cuda
z, c = get_embeddings(net, wav = wav.to('cuda'))
frame_scores = get_score(stoi, c)
print(f"> predicted intelligibility score: {frame_scores.mean()}")



#And now we plot latents:
import matplotlib.pyplot as plt

def plot_waveform(waveform, sample_rate, xlim=None, ylim=None, axes=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    if axes is None:
        figure, axes = plt.subplots(1, 1)
    for c in range(num_channels):
        axes.plot(time_axis, waveform[c], linewidth=1, color='blue' if c else 'red', alpha=0.5)
        axes.set_xlim(left=0, right=num_frames/sample_rate)
        axes.grid(True)


#fig, axes = plt.subplots(3, 1)
fig, axes = plt.subplots(3, 1, dpi=300)
fig.set_size_inches(18.5, 10.5)
fig.patch.set_facecolor('white')
plot_waveform(wav.squeeze().cpu(), 16_000, axes=axes[0])
axes[0].set_xlabel("time (s)")
axes[0].set_ylabel("amplitude (s)")

axes[1].plot(frame_scores.squeeze().cpu())
axes[1].set_xlim(left=0, right=frame_scores.shape[-1])
axes[1].set_xlabel("frames")
axes[1].set_ylabel("intelligibility score")

axes[2].imshow(c.squeeze().transpose(0, 1).cpu(), cmap='viridis', aspect='auto')
axes[2].set_xlabel("frames")
axes[2].set_yticks([])

fig.tight_layout()
plt.show()
fig.savefig('/content/test.png')

Would you say these models are safe / ready to use for production? I know you mentioned "This public repository is a work in progress! Results here bear no resemblance to results in the paper!"

But I have no idea to what degree it differs, and if it's still useful enough for real world scenarios or misleading.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant