Skip to content

Commit ff6fc8a

Browse files
authored
add amplitude_floor in vad masking (#1883)
1 parent a56f760 commit ff6fc8a

File tree

1 file changed

+19
-29
lines changed

1 file changed

+19
-29
lines changed

crates/vad-ext/src/continuous2.rs

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,47 +7,34 @@ use futures_util::Stream;
77
use hypr_audio_utils::f32_to_i16_samples;
88
use hypr_vad3::earshot::{VoiceActivityDetector, VoiceActivityProfile};
99

10-
/// Wraps a `Stream<Item = Result<Vec<f32>, E>>` and zeroes out samples that are
11-
/// classified as non-speech by a 16 kHz VAD, with a configurable hangover.
12-
///
13-
/// - Expects 16 kHz mono PCM in `[-1.0, 1.0]`.
14-
/// - Never changes the length of any chunk.
15-
/// - Fails open: on VAD errors, audio is passed through as speech.
1610
pub struct ContinuousVadMaskStream<S> {
1711
inner: S,
1812
vad: VoiceActivityDetector,
19-
/// Number of consecutive non-speech frames to keep as speech
20-
/// after speech has ended (to avoid chopping word endings).
2113
hangover_frames: usize,
2214
trailing_non_speech: usize,
2315
in_speech: bool,
24-
/// Scratch buffer for padding partial frames, reused across calls.
2516
scratch_frame: Vec<f32>,
17+
amplitude_floor: f32,
2618
}
2719

2820
impl<S> ContinuousVadMaskStream<S> {
29-
/// Construct with the default QUALITY profile and conservative hangover.
3021
pub fn new(inner: S) -> Self {
3122
Self {
3223
inner,
3324
vad: VoiceActivityDetector::new(VoiceActivityProfile::QUALITY),
3425
hangover_frames: 3,
3526
trailing_non_speech: 0,
36-
// Start in speech to be conservative: better to leak a bit of
37-
// initial noise than to truncate the first spoken frames.
3827
in_speech: true,
3928
scratch_frame: Vec::new(),
29+
amplitude_floor: 0.001,
4030
}
4131
}
4232

43-
/// Override the number of hangover frames (measured in VAD frames).
4433
pub fn with_hangover_frames(mut self, frames: usize) -> Self {
4534
self.hangover_frames = frames;
4635
self
4736
}
4837

49-
/// Optional: allow callers to change the VAD profile without changing behavior
50-
/// for existing code.
5138
pub fn with_profile(mut self, profile: VoiceActivityProfile) -> Self {
5239
self.vad = VoiceActivityDetector::new(profile);
5340
self
@@ -61,28 +48,20 @@ impl<S> ContinuousVadMaskStream<S> {
6148
let frame_size = hypr_vad3::choose_optimal_frame_size(chunk.len());
6249
debug_assert!(frame_size > 0, "VAD frame size must be > 0");
6350

64-
// `chunks_mut` yields exactly one frame when `chunk.len() <= frame_size`,
65-
// and multiple frames otherwise. No need for a special-case branch.
6651
for frame in chunk.chunks_mut(frame_size) {
6752
self.process_frame(frame, frame_size);
6853
}
6954
}
7055

71-
/// Internal state machine that applies hangover smoothing to the raw VAD
72-
/// decision. Returns the final `is_speech` value.
7356
fn smooth_vad_decision(&mut self, raw_is_speech: bool) -> bool {
7457
if raw_is_speech {
75-
// Fresh speech: reset hangover.
7658
self.in_speech = true;
7759
self.trailing_non_speech = 0;
7860
true
7961
} else if self.in_speech && self.trailing_non_speech < self.hangover_frames {
80-
// Still treat as speech for a few frames after VAD says "no"
81-
// to avoid chopping word endings or short pauses.
8262
self.trailing_non_speech += 1;
8363
true
8464
} else {
85-
// Long enough non-speech: actually treat as silence.
8665
self.in_speech = false;
8766
self.trailing_non_speech = 0;
8867
false
@@ -94,13 +73,19 @@ impl<S> ContinuousVadMaskStream<S> {
9473
return;
9574
}
9675

97-
// Convert to i16, padding only when needed. We always feed `frame_size`
98-
// samples into the VAD (either directly or via `scratch_frame`).
76+
let rms = Self::calculate_rms(frame);
77+
if rms < self.amplitude_floor {
78+
let is_speech = self.smooth_vad_decision(false);
79+
if !is_speech {
80+
frame.fill(0.0);
81+
}
82+
return;
83+
}
84+
9985
let raw_is_speech = if frame.len() == frame_size {
10086
let i16_samples = f32_to_i16_samples(frame);
10187
self.vad.predict_16khz(&i16_samples).unwrap_or(true)
10288
} else {
103-
// Partial frame at the end of a chunk: copy + pad with zeros.
10489
self.scratch_frame.clear();
10590
self.scratch_frame.extend_from_slice(frame);
10691
self.scratch_frame.resize(frame_size, 0.0);
@@ -111,11 +96,18 @@ impl<S> ContinuousVadMaskStream<S> {
11196

11297
let is_speech = self.smooth_vad_decision(raw_is_speech);
11398

114-
// Mask non-speech frames by zeroing in-place.
11599
if !is_speech {
116100
frame.fill(0.0);
117101
}
118102
}
103+
104+
fn calculate_rms(samples: &[f32]) -> f32 {
105+
if samples.is_empty() {
106+
return 0.0;
107+
}
108+
let sum_sq: f32 = samples.iter().map(|&s| s * s).sum();
109+
(sum_sq / samples.len() as f32).sqrt()
110+
}
119111
}
120112

121113
impl<S, E> Stream for ContinuousVadMaskStream<S>
@@ -138,8 +130,6 @@ where
138130
}
139131

140132
pub trait VadMaskExt: Sized {
141-
/// Wrap this stream with a VAD-based mask that zeros non-speech samples
142-
/// but never drops or reorders audio.
143133
fn mask_with_vad(self) -> ContinuousVadMaskStream<Self>;
144134
}
145135

0 commit comments

Comments
 (0)