Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ hf-hub = "0.3.2"
symphonia = { version = "0.5.4", features = ["aac", "isomp4", "opt-simd"] }
rand = "0.8.5"
rubato = "0.15.0"
whisper-rs = { git = "https://github.com/tazz4843/whisper-rs.git", rev = "e0597486400ec436669e6ee3d8cc94b3859355f5", features = [
"tracing_backend",
] } # pin revision to avoid breaking changes

# Log
log = { workspace = true }
Expand Down Expand Up @@ -107,8 +110,18 @@ tracing-subscriber = "0.3.16"


[features]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = [
"candle/metal",
"candle-nn/metal",
"candle-transformers/metal",
"whisper-rs/metal",
]
cuda = [
"candle/cuda",
"candle-nn/cuda",
"candle-transformers/cuda",
"whisper-rs/cuda",
]
mkl = ["candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

[[bin]]
Expand Down
4 changes: 4 additions & 0 deletions screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub enum AudioTranscriptionEngine {
WhisperDistilLargeV3,
#[default]
WhisperLargeV3Turbo,
WhisperLargeV3TurboQuantized,
WhisperLargeV3,
}

Expand All @@ -40,6 +41,9 @@ impl fmt::Display for AudioTranscriptionEngine {
AudioTranscriptionEngine::WhisperTiny => write!(f, "WhisperTiny"),
AudioTranscriptionEngine::WhisperDistilLargeV3 => write!(f, "WhisperLarge"),
AudioTranscriptionEngine::WhisperLargeV3Turbo => write!(f, "WhisperLargeV3Turbo"),
AudioTranscriptionEngine::WhisperLargeV3TurboQuantized => {
write!(f, "WhisperLargeV3TurboQuantized")
}
AudioTranscriptionEngine::WhisperLargeV3 => write!(f, "WhisperLargeV3"),
}
}
Expand Down
21 changes: 21 additions & 0 deletions screenpipe-audio/src/multilingual.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use crate::whisper::{token_id, Model};
use anyhow::anyhow;
use candle::IndexOp;
use candle::{Result, Tensor, D};
use candle_transformers::models::whisper::SOT_TOKEN;
use clap::ValueEnum;
use log::debug;
use screenpipe_core::Language;
use tokenizers::Tokenizer;
use whisper_rs::get_lang_str_full;

pub const LANGUAGES: [(&str, &str); 99] = [

Check warning on line 12 in screenpipe-audio/src/multilingual.rs

View workflow job for this annotation

GitHub Actions / test-ubuntu

constant `LANGUAGES` is never used

Check warning on line 12 in screenpipe-audio/src/multilingual.rs

View workflow job for this annotation

GitHub Actions / test-macos

constant `LANGUAGES` is never used

Check warning on line 12 in screenpipe-audio/src/multilingual.rs

View workflow job for this annotation

GitHub Actions / test-linux

constant `LANGUAGES` is never used
("en", "english"),
("zh", "chinese"),
("de", "german"),
Expand Down Expand Up @@ -161,3 +164,21 @@
debug!("detected language: {:?}", probabilities[0].0);
Ok(language)
}

pub fn get_lang_token(tokens: Vec<f32>, languages: Vec<Language>) -> anyhow::Result<i32> {
if languages.is_empty() {
return Ok(tokens[0] as i32);
}
for token in tokens {
let token = token as i32;
if let Some(lang) = get_lang_str_full(token) {
let l =
Language::from_str(lang, true).map_err(|_| anyhow!("language token not found"))?;
if languages.contains(&l) {
return Ok(token);
}
}
}

Err(anyhow::anyhow!("Language not identified"))
}
142 changes: 77 additions & 65 deletions screenpipe-audio/src/stt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
use crate::pyannote::models::{get_or_download_model, PyannoteModel};
use crate::pyannote::segment::SpeechSegment;
pub use crate::segments::prepare_segments;
use crate::whisper::download_quantized_whisper;
use crate::{
pyannote::{embedding::EmbeddingExtractor, identify::EmbeddingManager},
vad_engine::{SileroVad, VadEngine, VadEngineEnum, VadSensitivity, WebRtcVad},
whisper::{process_with_whisper, WhisperModel},

Check warning on line 10 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-ubuntu

unused import: `WhisperModel`

Check warning on line 10 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-macos

unused import: `WhisperModel`

Check warning on line 10 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-linux

unused import: `WhisperModel`

Check warning on line 10 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-windows

unused import: `WhisperModel`
AudioDevice, AudioTranscriptionEngine,
};
use crate::{resample, DeviceControl};
Expand All @@ -25,81 +26,79 @@
time::{SystemTime, UNIX_EPOCH},
};
use tokio::sync::Mutex;
use whisper_rs::{
FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperState,

Check warning on line 30 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-ubuntu

unused imports: `FullParams`, `SamplingStrategy`, and `WhisperState`

Check warning on line 30 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-macos

unused imports: `FullParams`, `SamplingStrategy`, and `WhisperState`

Check warning on line 30 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-linux

unused imports: `FullParams`, `SamplingStrategy`, and `WhisperState`

Check warning on line 30 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-windows

unused imports: `FullParams`, `SamplingStrategy`, and `WhisperState`
};

pub fn stt_sync(
pub async fn stt_sync(
audio: &[f32],
sample_rate: u32,
device: &str,
whisper_model: &mut WhisperModel,
whisper_model: &WhisperContext,
audio_transcription_engine: Arc<AudioTranscriptionEngine>,
deepgram_api_key: Option<String>,
languages: Vec<Language>,
) -> Result<String> {
let mut whisper_model = whisper_model.clone();
let mut whisper_model = whisper_model;

Check warning on line 42 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-ubuntu

variable does not need to be mutable

Check warning on line 42 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-macos

variable does not need to be mutable

Check warning on line 42 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-linux

variable does not need to be mutable

Check warning on line 42 in screenpipe-audio/src/stt.rs

View workflow job for this annotation

GitHub Actions / test-windows

variable does not need to be mutable
let audio = audio.to_vec();

let device = device.to_string();
let handle = std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();

rt.block_on(stt(
&audio,
sample_rate,
&device,
&mut whisper_model,
audio_transcription_engine,
deepgram_api_key,
languages,
))
});

handle.join().unwrap()
stt(
&audio,
sample_rate,
&device,
whisper_model,
audio_transcription_engine,
deepgram_api_key,
languages,
)
.await
}

#[allow(clippy::too_many_arguments)]
pub async fn stt(
audio: &[f32],
sample_rate: u32,
device: &str,
whisper_model: &mut WhisperModel,
whisper_model: &WhisperContext,
audio_transcription_engine: Arc<AudioTranscriptionEngine>,
deepgram_api_key: Option<String>,
languages: Vec<Language>,
) -> Result<String> {
let model = &whisper_model.model;

debug!("Loading mel filters");
let mel_bytes = match model.config().num_mel_bins {
80 => include_bytes!("../models/whisper/melfilters.bytes").as_slice(),
128 => include_bytes!("../models/whisper/melfilters128.bytes").as_slice(),
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
};
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);

let transcription: Result<String> = if audio_transcription_engine
== AudioTranscriptionEngine::Deepgram.into()
{
// Deepgram implementation
let api_key = deepgram_api_key.unwrap_or_default();

match transcribe_with_deepgram(&api_key, audio, device, sample_rate, languages.clone())
.await
{
Ok(transcription) => Ok(transcription),
Err(e) => {
error!(
"device: {}, deepgram transcription failed, falling back to Whisper: {:?}",
device, e
);
// Fallback to Whisper
process_with_whisper(&mut *whisper_model, audio, &mel_filters, languages.clone())
// let model = &whisper_model.model;

// debug!("Loading mel filters");
// let mel_bytes = match model.config().num_mel_bins {
// 80 => include_bytes!("../models/whisper/melfilters.bytes").as_slice(),
// 128 => include_bytes!("../models/whisper/melfilters128.bytes").as_slice(),
// nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
// };
// let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
// <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters)

let transcription: Result<String> =
if audio_transcription_engine == AudioTranscriptionEngine::Deepgram.into() {
// Deepgram implementation
let api_key = deepgram_api_key.unwrap_or_default();

match transcribe_with_deepgram(&api_key, audio, device, sample_rate, languages.clone())
.await
{
Ok(transcription) => Ok(transcription),
Err(e) => {
error!(
"device: {}, deepgram transcription failed, falling back to Whisper: {:?}",
device, e
);
// Fallback to Whisper
process_with_whisper(whisper_model, audio, languages.clone())
}
}
}
} else {
// Existing Whisper implementation
process_with_whisper(&mut *whisper_model, audio, &mel_filters, languages)
};
} else {
// Existing Whisper implementation
process_with_whisper(whisper_model, audio, languages)
};

transcription
}
Expand Down Expand Up @@ -162,7 +161,18 @@
crossbeam::channel::Receiver<TranscriptionResult>,
Arc<AtomicBool>, // Shutdown flag
)> {
let mut whisper_model = WhisperModel::new(&audio_transcription_engine)?;
// let mut whisper_model = WhisperModel::new(&audio_transcription_engine)?;
whisper_rs::install_logging_hooks();
let mut context_param = WhisperContextParameters::default();
context_param.dtw_parameters.mode = whisper_rs::DtwMode::ModelPreset {
model_preset: whisper_rs::DtwModelPreset::LargeV3Turbo,
};
context_param.use_gpu(true);

let quantized_path = download_quantized_whisper()?;
let ctx = WhisperContext::new_with_params(&quantized_path.to_string_lossy(), context_param)
.expect("failed to load model");

let (input_sender, input_receiver): (
crossbeam::channel::Sender<AudioInput>,
crossbeam::channel::Receiver<AudioInput>,
Expand Down Expand Up @@ -273,16 +283,16 @@
{
let timestamp = timestamp + segment.start.round() as u64;
autoreleasepool(|| {
run_stt(segment, audio.device.clone(), &mut whisper_model, audio_transcription_engine.clone(), deepgram_api_key.clone(), languages.clone(), path, timestamp)
})
}
#[cfg(not(target_os = "macos"))]
{
unreachable!("This code should not be reached on non-macOS platforms")
}
} else {
run_stt(segment, audio.device.clone(), &mut whisper_model, audio_transcription_engine.clone(), deepgram_api_key.clone(), languages.clone(), path, timestamp)
};
run_stt(segment, audio.device.clone(), &ctx, audio_transcription_engine.clone(), deepgram_api_key.clone(), languages.clone(), path, timestamp)
}).await
}
#[cfg(not(target_os = "macos"))]
{
unreachable!("This code should not be reached on non-macOS platforms")
}
} else {
run_stt(segment, audio.device.clone(), &ctx, audio_transcription_engine.clone(), deepgram_api_key.clone(), languages.clone(), path, timestamp).await
};

if output_sender.send(transcription_result).is_err() {
break;
Expand All @@ -306,10 +316,10 @@
}

#[allow(clippy::too_many_arguments)]
pub fn run_stt(
pub async fn run_stt(
segment: SpeechSegment,
device: Arc<AudioDevice>,
whisper_model: &mut WhisperModel,
whisper_model: &WhisperContext,
audio_transcription_engine: Arc<AudioTranscriptionEngine>,
deepgram_api_key: Option<String>,
languages: Vec<Language>,
Expand All @@ -326,7 +336,9 @@
audio_transcription_engine.clone(),
deepgram_api_key.clone(),
languages.clone(),
) {
)
.await
{
Ok(transcription) => TranscriptionResult {
input: AudioInput {
data: Arc::new(audio),
Expand Down
14 changes: 14 additions & 0 deletions screenpipe-audio/src/whisper/model.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::PathBuf;

use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
Expand Down Expand Up @@ -111,3 +113,15 @@ impl Model {
}
}
}

pub fn download_quantized_whisper() -> Result<PathBuf> {
let api = Api::new()?;
let repo = Repo::with_revision(
"ggerganov/whisper.cpp".to_string(),
RepoType::Model,
"main".to_string(),
);
let api_repo = api.repo(repo);
let model = api_repo.get("ggml-large-v3-turbo-q8_0.bin")?;
Ok(model)
}
Loading
Loading