-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add C++ API for speaker embedding models. (#706)
- Loading branch information
1 parent
2a88cf8
commit d3c953b
Showing
20 changed files
with
732 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) 2025 Xiaomi Corporation | ||
|
||
""" | ||
Please download model files from | ||
https://github.com/k2-fsa/sherpa/releases/ | ||
E.g. | ||
wget https://github.com/k2-fsa/sherpa/releases/download/speaker-recognition-models/3d_speaker-speech_eres2netv2_sv_zh-cn_16k-common.pt | ||
Please download test files from | ||
https://github.com/csukuangfj/sr-data/tree/main/test/3d-speaker | ||
""" | ||
|
||
import time | ||
from typing import Tuple | ||
import torch | ||
|
||
import librosa | ||
import numpy as np | ||
import soundfile as sf | ||
|
||
import sherpa | ||
|
||
|
||
def load_audio(filename: str) -> Tuple[np.ndarray, int]: | ||
data, sample_rate = sf.read( | ||
filename, | ||
always_2d=True, | ||
dtype="float32", | ||
) | ||
data = data[:, 0] # use only the first channel | ||
samples = np.ascontiguousarray(data) | ||
return samples, sample_rate | ||
|
||
|
||
def create_extractor(): | ||
config = sherpa.SpeakerEmbeddingExtractorConfig( | ||
model="./3d_speaker-speech_eres2netv2_sv_zh-cn_16k-common.pt", | ||
) | ||
print(config) | ||
return sherpa.SpeakerEmbeddingExtractor(config) | ||
|
||
|
||
def main(): | ||
extractor = create_extractor() | ||
|
||
file1 = "./speaker1_a_cn_16k.wav" | ||
file2 = "./speaker1_b_cn_16k.wav" | ||
file3 = "./speaker2_a_cn_16k.wav" | ||
|
||
samples1, sample_rate1 = load_audio(file1) | ||
if sample_rate1 != 16000: | ||
samples1 = librosa.resample(samples1, orig_sr=sample_rate1, target_sr=16000) | ||
sample_rate1 = 16000 | ||
|
||
samples2, sample_rate2 = load_audio(file2) | ||
if sample_rate2 != 16000: | ||
samples2 = librosa.resample(samples2, orig_sr=sample_rate2, target_sr=16000) | ||
sample_rate2 = 16000 | ||
|
||
samples3, sample_rate3 = load_audio(file3) | ||
if sample_rate3 != 16000: | ||
samples3 = librosa.resample(samples3, orig_sr=sample_rate3, target_sr=16000) | ||
sample_rate3 = 16000 | ||
|
||
start = time.time() | ||
stream1 = extractor.create_stream() | ||
stream2 = extractor.create_stream() | ||
stream3 = extractor.create_stream() | ||
|
||
stream1.accept_waveform(samples1) | ||
stream2.accept_waveform(samples2) | ||
stream3.accept_waveform(samples3) | ||
|
||
embeddings = extractor.compute([stream1, stream2, stream3]) | ||
# embeddings: (batch_size, dim) | ||
|
||
x12 = torch.nn.functional.cosine_similarity(embeddings[0], embeddings[1], dim=0) | ||
x13 = torch.nn.functional.cosine_similarity(embeddings[0], embeddings[2], dim=0) | ||
x23 = torch.nn.functional.cosine_similarity(embeddings[1], embeddings[2], dim=0) | ||
|
||
end = time.time() | ||
|
||
elapsed_seconds = end - start | ||
|
||
print(x12, x13, x23) | ||
|
||
audio_duration = ( | ||
len(samples1) / sample_rate1 | ||
+ len(samples2) / sample_rate2 | ||
+ len(samples3) / sample_rate3 | ||
) | ||
real_time_factor = elapsed_seconds / audio_duration | ||
print(f"Elapsed seconds: {elapsed_seconds:.3f}") | ||
print(f"Audio duration in seconds: {audio_duration:.3f}") | ||
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") | ||
|
||
|
||
if __name__ == "__main__": | ||
torch._C._jit_set_profiling_executor(False) | ||
torch._C._jit_set_profiling_mode(False) | ||
torch._C._set_graph_executor_optimize(False) | ||
|
||
torch.set_num_threads(1) | ||
torch.set_num_interop_threads(1) | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) 2025 Xiaomi Corporation | ||
|
||
""" | ||
Please download sense voice model from | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// sherpa/csrc/sherpa-compute-speaker-similarity.cc | ||
// | ||
// Copyright (c) 2025 Xiaomi Corporation | ||
|
||
#include <chrono> // NOLINT | ||
#include <iostream> | ||
|
||
#include "sherpa/cpp_api/parse-options.h" | ||
#include "sherpa/csrc/fbank-features.h" | ||
#include "sherpa/csrc/speaker-embedding-extractor.h" | ||
|
||
int32_t main(int32_t argc, char *argv[]) { | ||
const char *kUsageMessage = R"usage( | ||
This program uses a speaker embedding model to compute | ||
similarity between two wave files. | ||
sherpa-compute-speaker-similarity \ | ||
--model=/path/to/model.pt \ | ||
./foo.wav \ | ||
./bar.wav \ | ||
)usage"; | ||
|
||
int32_t num_threads = 1; | ||
sherpa::ParseOptions po(kUsageMessage); | ||
sherpa::SpeakerEmbeddingExtractorConfig config; | ||
config.Register(&po); | ||
po.Register("num-threads", &num_threads, "Number of threads for PyTorch"); | ||
po.Read(argc, argv); | ||
|
||
if (po.NumArgs() != 2) { | ||
std::cerr << "Please provide only 2 test waves\n"; | ||
exit(-1); | ||
} | ||
|
||
std::cerr << config.ToString() << "\n"; | ||
if (!config.Validate()) { | ||
std::cerr << "Please check your config\n"; | ||
return -1; | ||
} | ||
|
||
int32_t sr = 16000; | ||
sherpa::SpeakerEmbeddingExtractor extractor(config); | ||
|
||
const auto begin = std::chrono::steady_clock::now(); | ||
|
||
torch::Tensor samples1 = sherpa::ReadWave(po.GetArg(1), sr).first; | ||
|
||
auto stream1 = extractor.CreateStream(); | ||
stream1->AcceptSamples(samples1.data_ptr<float>(), samples1.numel()); | ||
|
||
torch::Tensor samples2 = sherpa::ReadWave(po.GetArg(2), sr).first; | ||
|
||
auto stream2 = extractor.CreateStream(); | ||
stream2->AcceptSamples(samples2.data_ptr<float>(), samples2.numel()); | ||
|
||
torch::Tensor embedding1; | ||
torch::Tensor embedding2; | ||
if (false) { | ||
embedding1 = extractor.Compute(stream1.get()).squeeze(0); | ||
embedding2 = extractor.Compute(stream2.get()).squeeze(0); | ||
} else { | ||
std::vector<sherpa::OfflineStream *> ss{stream1.get(), stream2.get()}; | ||
auto embeddings = extractor.Compute(ss.data(), ss.size()); | ||
|
||
embedding1 = embeddings.index({0}); | ||
embedding2 = embeddings.index({1}); | ||
} | ||
|
||
auto score = | ||
torch::nn::functional::cosine_similarity( | ||
embedding1, embedding2, | ||
torch::nn::functional::CosineSimilarityFuncOptions{}.dim(0).eps(1e-6)) | ||
.item() | ||
.toFloat(); | ||
|
||
const auto end = std::chrono::steady_clock::now(); | ||
|
||
const float elapsed_seconds = | ||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
.count() / | ||
1000.; | ||
float duration = (samples1.size(0) + samples2.size(0)) / 16000.0f; | ||
const float rtf = elapsed_seconds / duration; | ||
|
||
std::cout << "score: " << score << "\n"; | ||
|
||
fprintf(stderr, "Elapsed seconds: %.3f\n", elapsed_seconds); | ||
fprintf(stderr, "Audio duration: %.3f s\n", duration); | ||
fprintf(stderr, "Real time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds, | ||
duration, rtf); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// sherpa/csrc/speaker-embedding-extractor-general-impl.h | ||
// | ||
// Copyright (c) 2025 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ | ||
#define SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ | ||
#include <algorithm> | ||
#include <memory> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "sherpa/cpp_api/feature-config.h" | ||
#include "sherpa/cpp_api/macros.h" | ||
#include "sherpa/cpp_api/offline-stream.h" | ||
#include "sherpa/csrc/speaker-embedding-extractor-impl.h" | ||
#include "sherpa/csrc/speaker-embedding-extractor-model.h" | ||
|
||
namespace sherpa { | ||
|
||
class SpeakerEmbeddingExtractorGeneralImpl | ||
: public SpeakerEmbeddingExtractorImpl { | ||
public: | ||
explicit SpeakerEmbeddingExtractorGeneralImpl( | ||
const SpeakerEmbeddingExtractorConfig &config) | ||
: model_(config) { | ||
// TODO(fangjun): make it configurable | ||
feat_config_.fbank_opts.frame_opts.dither = 0; | ||
feat_config_.fbank_opts.frame_opts.snip_edges = true; | ||
feat_config_.fbank_opts.frame_opts.samp_freq = 16000; | ||
feat_config_.fbank_opts.mel_opts.num_bins = 80; | ||
feat_config_.normalize_samples = true; | ||
|
||
fbank_ = std::make_unique<kaldifeat::Fbank>(feat_config_.fbank_opts); | ||
|
||
WarmUp(); | ||
} | ||
|
||
int32_t Dim() const override { return model_.GetModelMetadata().output_dim; } | ||
|
||
std::unique_ptr<OfflineStream> CreateStream() const override { | ||
return std::make_unique<OfflineStream>(fbank_.get(), feat_config_); | ||
} | ||
|
||
torch::Tensor Compute(OfflineStream *s) const override { | ||
InferenceMode no_grad; | ||
auto features = s->GetFeatures(); | ||
features -= features.mean(0, true); | ||
features = features.unsqueeze(0); | ||
auto device = model_.Device(); | ||
return model_.Compute(features.to(device)); | ||
} | ||
|
||
torch::Tensor Compute(OfflineStream **ss, int32_t n) const override { | ||
InferenceMode no_grad; | ||
if (n == 1) { | ||
return Compute(ss[0]); | ||
} | ||
|
||
std::vector<torch::Tensor> features_vec(n); | ||
for (int32_t i = 0; i != n; ++i) { | ||
auto f = ss[i]->GetFeatures(); | ||
f -= f.mean(0, true); | ||
features_vec[i] = f; | ||
} | ||
|
||
auto device = model_.Device(); | ||
|
||
auto features = | ||
torch::nn::utils::rnn::pad_sequence(features_vec, true, 0).to(device); | ||
|
||
return model_.Compute(features); | ||
} | ||
|
||
private: | ||
void WarmUp() { | ||
InferenceMode no_grad; | ||
SHERPA_LOG(INFO) << "WarmUp begins"; | ||
auto s = CreateStream(); | ||
float sample_rate = fbank_->GetFrameOptions().samp_freq; | ||
std::vector<float> samples(2 * sample_rate, 0); | ||
s->AcceptSamples(samples.data(), samples.size()); | ||
|
||
auto embedding = Compute(s.get()); | ||
|
||
model_.GetModelMetadata().output_dim = embedding.size(1); | ||
|
||
SHERPA_LOG(INFO) << "WarmUp ended"; | ||
} | ||
|
||
private: | ||
SpeakerEmbeddingExtractorModel model_; | ||
std::unique_ptr<kaldifeat::Fbank> fbank_; | ||
FeatureConfig feat_config_; | ||
}; | ||
|
||
} // namespace sherpa | ||
|
||
#endif // SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// sherpa/csrc/speaker-embedding-extractor-impl.cc | ||
// | ||
// Copyright (c) 2025 Xiaomi Corporation | ||
#include "sherpa/csrc/speaker-embedding-extractor-impl.h" | ||
|
||
#include "sherpa/csrc/speaker-embedding-extractor-general-impl.h" | ||
|
||
namespace sherpa { | ||
|
||
std::unique_ptr<SpeakerEmbeddingExtractorImpl> | ||
SpeakerEmbeddingExtractorImpl::Create( | ||
const SpeakerEmbeddingExtractorConfig &config) { | ||
// supports only 3-d speaker for now | ||
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config); | ||
} | ||
|
||
} // namespace sherpa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// sherpa/csrc/speaker-embedding-extractor-impl.h | ||
// | ||
// Copyright (c) 2025 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ | ||
#define SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "sherpa/csrc/speaker-embedding-extractor.h" | ||
|
||
namespace sherpa { | ||
|
||
class SpeakerEmbeddingExtractorImpl { | ||
public: | ||
virtual ~SpeakerEmbeddingExtractorImpl() = default; | ||
|
||
static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create( | ||
const SpeakerEmbeddingExtractorConfig &config); | ||
|
||
virtual int32_t Dim() const = 0; | ||
|
||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; | ||
|
||
virtual torch::Tensor Compute(OfflineStream *s) const = 0; | ||
|
||
virtual torch::Tensor Compute(OfflineStream **s, int32_t n) const = 0; | ||
}; | ||
|
||
} // namespace sherpa | ||
|
||
#endif // SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ |
Oops, something went wrong.