Skip to content

Commit cfda4de

Browse files
committed
add vad crate
1 parent 7460ae6 commit cfda4de

File tree

5 files changed

+162
-0
lines changed

5 files changed

+162
-0
lines changed

Cargo.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/vad/Cargo.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[package]
2+
name = "vad"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
serde = { workspace = true }
8+
thiserror = { workspace = true }
9+
tracing = { workspace = true }
10+
11+
ndarray = "0.16"
12+
ort = "=2.0.0-rc.9"

crates/vad/assets/model.onnx

1.72 MB
Binary file not shown.

crates/vad/src/error.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use serde::{ser::Serializer, Serialize};
2+
3+
#[derive(Debug, thiserror::Error)]
4+
pub enum Error {
5+
#[error(transparent)]
6+
OrtError(#[from] ort::Error),
7+
#[error(transparent)]
8+
ShapeError(#[from] ndarray::ShapeError),
9+
#[error("Invalid or missing output from model")]
10+
InvalidOutput,
11+
}
12+
13+
impl Serialize for Error {
14+
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
15+
where
16+
S: Serializer,
17+
{
18+
serializer.serialize_str(self.to_string().as_ref())
19+
}
20+
}

crates/vad/src/lib.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
mod error;
2+
use error::*;
3+
4+
use ndarray::{Array1, Array2, Array3, ArrayBase, Ix1, Ix3, OwnedRepr};
5+
use ort::session::{builder::GraphOptimizationLevel, Session};
6+
7+
const MODEL_BYTES: &[u8] =
8+
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/model.onnx"));
9+
10+
const SAMPLE_RATE: i64 = 16000;
11+
const CHUNK_SIZE_MS: usize = 30; // 30ms chunks for processing
12+
13+
pub struct Vad {
14+
session: Session,
15+
h_tensor: ArrayBase<OwnedRepr<f32>, Ix3>,
16+
c_tensor: ArrayBase<OwnedRepr<f32>, Ix3>,
17+
sample_rate_tensor: ArrayBase<OwnedRepr<i64>, Ix1>,
18+
}
19+
20+
impl Vad {
21+
pub fn new() -> Result<Self, crate::Error> {
22+
let session = Session::builder()?
23+
.with_optimization_level(GraphOptimizationLevel::Level3)?
24+
.with_intra_threads(4)?
25+
.commit_from_memory(MODEL_BYTES)?;
26+
27+
let h_tensor = Array3::<f32>::zeros((2, 1, 64));
28+
let c_tensor = Array3::<f32>::zeros((2, 1, 64));
29+
let sample_rate_tensor = Array1::from_vec(vec![SAMPLE_RATE]);
30+
31+
Ok(Self {
32+
session,
33+
h_tensor,
34+
c_tensor,
35+
sample_rate_tensor,
36+
})
37+
}
38+
39+
/// Process a chunk of audio samples through the model and return the speech probability
40+
fn forward(&mut self, audio_chunk: &[f32]) -> Result<f32, crate::Error> {
41+
let samples = audio_chunk.len();
42+
let audio_tensor = Array2::from_shape_vec((1, samples), audio_chunk.to_vec())?;
43+
44+
let mut result = self.session.run(ort::inputs![
45+
audio_tensor.view(),
46+
self.sample_rate_tensor.view(),
47+
self.h_tensor.view(),
48+
self.c_tensor.view()
49+
]?)?;
50+
51+
// Update internal state tensors
52+
self.h_tensor = result
53+
.get("hn")
54+
.ok_or(Error::InvalidOutput)?
55+
.try_extract_tensor::<f32>()?
56+
.to_owned()
57+
.into_shape_with_order((2, 1, 64))?;
58+
59+
self.c_tensor = result
60+
.get("cn")
61+
.ok_or(Error::InvalidOutput)?
62+
.try_extract_tensor::<f32>()?
63+
.to_owned()
64+
.into_shape_with_order((2, 1, 64))?;
65+
66+
let prob_tensor = result.remove("output").ok_or(Error::InvalidOutput)?;
67+
let prob = *prob_tensor
68+
.try_extract_tensor::<f32>()?
69+
.first()
70+
.ok_or(Error::InvalidOutput)?;
71+
72+
Ok(prob)
73+
}
74+
75+
/// For longer audio, this will process in 30ms chunks and return the maximum probability
76+
pub fn run(&mut self, audio_samples: &[f32]) -> Result<f32, crate::Error> {
77+
if audio_samples.len() < 480 {
78+
return self.forward(audio_samples);
79+
}
80+
81+
let chunk_size = (CHUNK_SIZE_MS * SAMPLE_RATE as usize) / 1000;
82+
let num_chunks = audio_samples.len() / chunk_size;
83+
84+
let mut max_prob = 0.0f32;
85+
86+
for i in 0..num_chunks {
87+
let start = i * chunk_size;
88+
let end = (start + chunk_size).min(audio_samples.len());
89+
let prob = self.forward(&audio_samples[start..end])?;
90+
max_prob = max_prob.max(prob);
91+
}
92+
93+
let remaining_start = num_chunks * chunk_size;
94+
if remaining_start < audio_samples.len() && audio_samples.len() - remaining_start >= 240 {
95+
let prob = self.forward(&audio_samples[remaining_start..])?;
96+
max_prob = max_prob.max(prob);
97+
}
98+
99+
Ok(max_prob)
100+
}
101+
102+
pub fn reset(&mut self) {
103+
self.h_tensor = Array3::<f32>::zeros((2, 1, 64));
104+
self.c_tensor = Array3::<f32>::zeros((2, 1, 64));
105+
}
106+
}
107+
108+
#[cfg(test)]
109+
mod tests {
110+
use super::*;
111+
112+
#[test]
113+
fn test_vad() {
114+
let mut vad = Vad::new().unwrap();
115+
let audio_samples = vec![0.0; 16000];
116+
let prob = vad.run(&audio_samples).unwrap();
117+
assert!(prob < 0.1);
118+
}
119+
}

0 commit comments

Comments
 (0)