|
| 1 | +mod error; |
| 2 | +use error::*; |
| 3 | + |
| 4 | +use ndarray::{Array1, Array2, Array3, ArrayBase, Ix1, Ix3, OwnedRepr}; |
| 5 | +use ort::session::{builder::GraphOptimizationLevel, Session}; |
| 6 | + |
| 7 | +const MODEL_BYTES: &[u8] = |
| 8 | + include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/model.onnx")); |
| 9 | + |
| 10 | +const SAMPLE_RATE: i64 = 16000; |
| 11 | +const CHUNK_SIZE_MS: usize = 30; // 30ms chunks for processing |
| 12 | + |
| 13 | +pub struct Vad { |
| 14 | + session: Session, |
| 15 | + h_tensor: ArrayBase<OwnedRepr<f32>, Ix3>, |
| 16 | + c_tensor: ArrayBase<OwnedRepr<f32>, Ix3>, |
| 17 | + sample_rate_tensor: ArrayBase<OwnedRepr<i64>, Ix1>, |
| 18 | +} |
| 19 | + |
| 20 | +impl Vad { |
| 21 | + pub fn new() -> Result<Self, crate::Error> { |
| 22 | + let session = Session::builder()? |
| 23 | + .with_optimization_level(GraphOptimizationLevel::Level3)? |
| 24 | + .with_intra_threads(4)? |
| 25 | + .commit_from_memory(MODEL_BYTES)?; |
| 26 | + |
| 27 | + let h_tensor = Array3::<f32>::zeros((2, 1, 64)); |
| 28 | + let c_tensor = Array3::<f32>::zeros((2, 1, 64)); |
| 29 | + let sample_rate_tensor = Array1::from_vec(vec![SAMPLE_RATE]); |
| 30 | + |
| 31 | + Ok(Self { |
| 32 | + session, |
| 33 | + h_tensor, |
| 34 | + c_tensor, |
| 35 | + sample_rate_tensor, |
| 36 | + }) |
| 37 | + } |
| 38 | + |
| 39 | + /// Process a chunk of audio samples through the model and return the speech probability |
| 40 | + fn forward(&mut self, audio_chunk: &[f32]) -> Result<f32, crate::Error> { |
| 41 | + let samples = audio_chunk.len(); |
| 42 | + let audio_tensor = Array2::from_shape_vec((1, samples), audio_chunk.to_vec())?; |
| 43 | + |
| 44 | + let mut result = self.session.run(ort::inputs![ |
| 45 | + audio_tensor.view(), |
| 46 | + self.sample_rate_tensor.view(), |
| 47 | + self.h_tensor.view(), |
| 48 | + self.c_tensor.view() |
| 49 | + ]?)?; |
| 50 | + |
| 51 | + // Update internal state tensors |
| 52 | + self.h_tensor = result |
| 53 | + .get("hn") |
| 54 | + .ok_or(Error::InvalidOutput)? |
| 55 | + .try_extract_tensor::<f32>()? |
| 56 | + .to_owned() |
| 57 | + .into_shape_with_order((2, 1, 64))?; |
| 58 | + |
| 59 | + self.c_tensor = result |
| 60 | + .get("cn") |
| 61 | + .ok_or(Error::InvalidOutput)? |
| 62 | + .try_extract_tensor::<f32>()? |
| 63 | + .to_owned() |
| 64 | + .into_shape_with_order((2, 1, 64))?; |
| 65 | + |
| 66 | + let prob_tensor = result.remove("output").ok_or(Error::InvalidOutput)?; |
| 67 | + let prob = *prob_tensor |
| 68 | + .try_extract_tensor::<f32>()? |
| 69 | + .first() |
| 70 | + .ok_or(Error::InvalidOutput)?; |
| 71 | + |
| 72 | + Ok(prob) |
| 73 | + } |
| 74 | + |
| 75 | + /// For longer audio, this will process in 30ms chunks and return the maximum probability |
| 76 | + pub fn run(&mut self, audio_samples: &[f32]) -> Result<f32, crate::Error> { |
| 77 | + if audio_samples.len() < 480 { |
| 78 | + return self.forward(audio_samples); |
| 79 | + } |
| 80 | + |
| 81 | + let chunk_size = (CHUNK_SIZE_MS * SAMPLE_RATE as usize) / 1000; |
| 82 | + let num_chunks = audio_samples.len() / chunk_size; |
| 83 | + |
| 84 | + let mut max_prob = 0.0f32; |
| 85 | + |
| 86 | + for i in 0..num_chunks { |
| 87 | + let start = i * chunk_size; |
| 88 | + let end = (start + chunk_size).min(audio_samples.len()); |
| 89 | + let prob = self.forward(&audio_samples[start..end])?; |
| 90 | + max_prob = max_prob.max(prob); |
| 91 | + } |
| 92 | + |
| 93 | + let remaining_start = num_chunks * chunk_size; |
| 94 | + if remaining_start < audio_samples.len() && audio_samples.len() - remaining_start >= 240 { |
| 95 | + let prob = self.forward(&audio_samples[remaining_start..])?; |
| 96 | + max_prob = max_prob.max(prob); |
| 97 | + } |
| 98 | + |
| 99 | + Ok(max_prob) |
| 100 | + } |
| 101 | + |
| 102 | + pub fn reset(&mut self) { |
| 103 | + self.h_tensor = Array3::<f32>::zeros((2, 1, 64)); |
| 104 | + self.c_tensor = Array3::<f32>::zeros((2, 1, 64)); |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +#[cfg(test)] |
| 109 | +mod tests { |
| 110 | + use super::*; |
| 111 | + |
| 112 | + #[test] |
| 113 | + fn test_vad() { |
| 114 | + let mut vad = Vad::new().unwrap(); |
| 115 | + let audio_samples = vec![0.0; 16000]; |
| 116 | + let prob = vad.run(&audio_samples).unwrap(); |
| 117 | + assert!(prob < 0.1); |
| 118 | + } |
| 119 | +} |
0 commit comments