From 2a88cf896607af03fca8aaa07018f94833b2148e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 22 Jan 2025 12:57:52 +0800 Subject: [PATCH] Support specifying number of threads in sherpa-vad (#705) --- sherpa/csrc/sherpa-vad.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sherpa/csrc/sherpa-vad.cc b/sherpa/csrc/sherpa-vad.cc index ad9e57a4..ce7244fb 100644 --- a/sherpa/csrc/sherpa-vad.cc +++ b/sherpa/csrc/sherpa-vad.cc @@ -8,6 +8,7 @@ #include "sherpa/cpp_api/parse-options.h" #include "sherpa/csrc/fbank-features.h" #include "sherpa/csrc/voice-activity-detector.h" +#include "torch/torch.h" int32_t main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( @@ -17,13 +18,16 @@ This program uses a VAD models to add timestamps to a audio file sherpa-vad \ --silero-vad-model=/path/to/model.pt \ --use-gpu=false \ + --num-threads=1 \ ./foo.wav )usage"; + int32_t num_threads = 1; sherpa::ParseOptions po(kUsageMessage); sherpa::VoiceActivityDetectorConfig config; config.Register(&po); + po.Register("num-threads", &num_threads, "Number of threads for PyTorch"); po.Read(argc, argv); if (po.NumArgs() != 1) { @@ -34,6 +38,9 @@ sherpa-vad \ std::cerr << config.ToString() << "\n"; config.Validate(); + torch::set_num_threads(num_threads); + torch::set_num_interop_threads(num_threads); + sherpa::VoiceActivityDetector vad(config); torch::Tensor samples = sherpa::ReadWave(po.GetArg(1), 16000).first; @@ -55,6 +62,7 @@ sherpa-vad \ fprintf(stderr, "%.3f -- %.3f\n", s.start, s.end); } + fprintf(stderr, "Number of threads: %d\n", num_threads); fprintf(stderr, "Elapsed seconds: %.3f\n", elapsed_seconds); fprintf(stderr, "Audio duration: %.3f s\n", duration); fprintf(stderr, "Real time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds,