Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6 bottleneck in mel segment calculation #7

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from functools import lru_cache
from typing import Union
from typing import Union, List

import multiprocessing as mp
import ffmpeg
import numpy as np
import torch
Expand Down Expand Up @@ -69,7 +70,6 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)

return array


Expand All @@ -88,8 +88,15 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)

def audio_load_helper(audio):
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
return audio


def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor, List[str], List[np.ndarray], List[torch.Tensor]], n_mels: int = N_MELS):
"""
Compute the log-Mel spectrogram of

Expand All @@ -106,10 +113,15 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if type(audio) == list:
with mp.Pool() as p:
audio_files = p.map(audio_load_helper, audio)
return [log_mel_spectrogram(audio_file, n_mels) for audio_file in audio_files]
else:
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
Expand Down
10 changes: 5 additions & 5 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def __init__(self, model: "Whisper", options: DecodingOptions):

# logit filters: applies various rules to suppress or penalize certain tokens
self.decoder = []
self.logit_filters = [[]]*len(self.initial_tokens)
self.logit_filters = [[] for _ in range(len(self.initial_tokens))]
for i in range(len(self.initial_tokens)):
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
Expand Down Expand Up @@ -677,7 +677,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
no_speech_probs = [np.nan]*n_batch

try:
for i in range(self.sample_len):
Expand Down Expand Up @@ -816,19 +816,19 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")

return [
DecodingResult(
audio_features=features,
language=language,
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_speech_prob=no_speech_prob,
no_speech_prob=no_speech_prob[i],
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
for i, (text, language, tokens, features, avg_logprob, no_speech_prob) in enumerate(zip(*fields))
]


Expand Down
4 changes: 2 additions & 2 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def batch_transcribe(
if dtype == torch.float32:
decode_options["fp16"] = False

mels = [log_mel_spectrogram(audio_file) for audio_file in audio]
mels = log_mel_spectrogram(audio)
segments = [pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) for mel in mels]

if decode_options.get("language", None) is None:
Expand Down Expand Up @@ -475,7 +475,7 @@ def check_cursors(seekers: List[int], num_frames: List[int]) -> bool:
if no_speech_threshold is not None:
for i,result in enumerate(results):
# no voice activity check
should_skip = result.no_speech_prob[i] > no_speech_threshold
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
Expand Down