Skip to content

Commit

Permalink
Add C++ API for speaker embedding models. (#706)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jan 22, 2025
1 parent 2a88cf8 commit d3c953b
Show file tree
Hide file tree
Showing 20 changed files with 732 additions and 1 deletion.
109 changes: 109 additions & 0 deletions python-api-examples/compute-speaker-simiarlity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env python3
# Copyright (c) 2025 Xiaomi Corporation

"""
Please download model files from
https://github.com/k2-fsa/sherpa/releases/
E.g.
wget https://github.com/k2-fsa/sherpa/releases/download/speaker-recognition-models/3d_speaker-speech_eres2netv2_sv_zh-cn_16k-common.pt
Please download test files from
https://github.com/csukuangfj/sr-data/tree/main/test/3d-speaker
"""

import time
from typing import Tuple
import torch

import librosa
import numpy as np
import soundfile as sf

import sherpa


def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate


def create_extractor():
config = sherpa.SpeakerEmbeddingExtractorConfig(
model="./3d_speaker-speech_eres2netv2_sv_zh-cn_16k-common.pt",
)
print(config)
return sherpa.SpeakerEmbeddingExtractor(config)


def main():
extractor = create_extractor()

file1 = "./speaker1_a_cn_16k.wav"
file2 = "./speaker1_b_cn_16k.wav"
file3 = "./speaker2_a_cn_16k.wav"

samples1, sample_rate1 = load_audio(file1)
if sample_rate1 != 16000:
samples1 = librosa.resample(samples1, orig_sr=sample_rate1, target_sr=16000)
sample_rate1 = 16000

samples2, sample_rate2 = load_audio(file2)
if sample_rate2 != 16000:
samples2 = librosa.resample(samples2, orig_sr=sample_rate2, target_sr=16000)
sample_rate2 = 16000

samples3, sample_rate3 = load_audio(file3)
if sample_rate3 != 16000:
samples3 = librosa.resample(samples3, orig_sr=sample_rate3, target_sr=16000)
sample_rate3 = 16000

start = time.time()
stream1 = extractor.create_stream()
stream2 = extractor.create_stream()
stream3 = extractor.create_stream()

stream1.accept_waveform(samples1)
stream2.accept_waveform(samples2)
stream3.accept_waveform(samples3)

embeddings = extractor.compute([stream1, stream2, stream3])
# embeddings: (batch_size, dim)

x12 = torch.nn.functional.cosine_similarity(embeddings[0], embeddings[1], dim=0)
x13 = torch.nn.functional.cosine_similarity(embeddings[0], embeddings[2], dim=0)
x23 = torch.nn.functional.cosine_similarity(embeddings[1], embeddings[2], dim=0)

end = time.time()

elapsed_seconds = end - start

print(x12, x13, x23)

audio_duration = (
len(samples1) / sample_rate1
+ len(samples2) / sample_rate2
+ len(samples3) / sample_rate3
)
real_time_factor = elapsed_seconds / audio_duration
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
print(f"Audio duration in seconds: {audio_duration:.3f}")
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")


if __name__ == "__main__":
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)

torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()
1 change: 1 addition & 0 deletions python-api-examples/vad-with-sense-voice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# Copyright (c) 2025 Xiaomi Corporation

"""
Please download sense voice model from
Expand Down
8 changes: 8 additions & 0 deletions sherpa/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ set(sherpa_srcs
vad-model-config.cc
voice-activity-detector-impl.cc
voice-activity-detector.cc
#
speaker-embedding-extractor-model.cc
speaker-embedding-extractor.cc
speaker-embedding-extractor-impl.cc
)

add_library(sherpa_core ${sherpa_srcs})
Expand Down Expand Up @@ -129,6 +133,9 @@ target_include_directories(sherpa-version PRIVATE ${CMAKE_BINARY_DIR})
add_executable(sherpa-vad sherpa-vad.cc)
target_link_libraries(sherpa-vad sherpa_core)

add_executable(sherpa-compute-speaker-similarity sherpa-compute-speaker-similarity.cc)
target_link_libraries(sherpa-compute-speaker-similarity sherpa_core)

install(TARGETS
sherpa_core
DESTINATION lib
Expand All @@ -138,5 +145,6 @@ install(
TARGETS
sherpa-version
sherpa-vad
sherpa-compute-speaker-similarity
DESTINATION bin
)
93 changes: 93 additions & 0 deletions sherpa/csrc/sherpa-compute-speaker-similarity.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// sherpa/csrc/sherpa-compute-speaker-similarity.cc
//
// Copyright (c) 2025 Xiaomi Corporation

#include <chrono> // NOLINT
#include <iostream>

#include "sherpa/cpp_api/parse-options.h"
#include "sherpa/csrc/fbank-features.h"
#include "sherpa/csrc/speaker-embedding-extractor.h"

int32_t main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
This program uses a speaker embedding model to compute
similarity between two wave files.
sherpa-compute-speaker-similarity \
--model=/path/to/model.pt \
./foo.wav \
./bar.wav \
)usage";

int32_t num_threads = 1;
sherpa::ParseOptions po(kUsageMessage);
sherpa::SpeakerEmbeddingExtractorConfig config;
config.Register(&po);
po.Register("num-threads", &num_threads, "Number of threads for PyTorch");
po.Read(argc, argv);

if (po.NumArgs() != 2) {
std::cerr << "Please provide only 2 test waves\n";
exit(-1);
}

std::cerr << config.ToString() << "\n";
if (!config.Validate()) {
std::cerr << "Please check your config\n";
return -1;
}

int32_t sr = 16000;
sherpa::SpeakerEmbeddingExtractor extractor(config);

const auto begin = std::chrono::steady_clock::now();

torch::Tensor samples1 = sherpa::ReadWave(po.GetArg(1), sr).first;

auto stream1 = extractor.CreateStream();
stream1->AcceptSamples(samples1.data_ptr<float>(), samples1.numel());

torch::Tensor samples2 = sherpa::ReadWave(po.GetArg(2), sr).first;

auto stream2 = extractor.CreateStream();
stream2->AcceptSamples(samples2.data_ptr<float>(), samples2.numel());

torch::Tensor embedding1;
torch::Tensor embedding2;
if (false) {
embedding1 = extractor.Compute(stream1.get()).squeeze(0);
embedding2 = extractor.Compute(stream2.get()).squeeze(0);
} else {
std::vector<sherpa::OfflineStream *> ss{stream1.get(), stream2.get()};
auto embeddings = extractor.Compute(ss.data(), ss.size());

embedding1 = embeddings.index({0});
embedding2 = embeddings.index({1});
}

auto score =
torch::nn::functional::cosine_similarity(
embedding1, embedding2,
torch::nn::functional::CosineSimilarityFuncOptions{}.dim(0).eps(1e-6))
.item()
.toFloat();

const auto end = std::chrono::steady_clock::now();

const float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
float duration = (samples1.size(0) + samples2.size(0)) / 16000.0f;
const float rtf = elapsed_seconds / duration;

std::cout << "score: " << score << "\n";

fprintf(stderr, "Elapsed seconds: %.3f\n", elapsed_seconds);
fprintf(stderr, "Audio duration: %.3f s\n", duration);
fprintf(stderr, "Real time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds,
duration, rtf);

return 0;
}
2 changes: 1 addition & 1 deletion sherpa/csrc/sherpa-vad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This program uses a VAD models to add timestamps to a audio file
sherpa-vad \
--silero-vad-model=/path/to/model.pt \
--use-gpu=false \
--vad-use-gpu=false \
--num-threads=1 \
./foo.wav
Expand Down
98 changes: 98 additions & 0 deletions sherpa/csrc/speaker-embedding-extractor-general-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// sherpa/csrc/speaker-embedding-extractor-general-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
#define SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>

#include "sherpa/cpp_api/feature-config.h"
#include "sherpa/cpp_api/macros.h"
#include "sherpa/cpp_api/offline-stream.h"
#include "sherpa/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa/csrc/speaker-embedding-extractor-model.h"

namespace sherpa {

class SpeakerEmbeddingExtractorGeneralImpl
: public SpeakerEmbeddingExtractorImpl {
public:
explicit SpeakerEmbeddingExtractorGeneralImpl(
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {
// TODO(fangjun): make it configurable
feat_config_.fbank_opts.frame_opts.dither = 0;
feat_config_.fbank_opts.frame_opts.snip_edges = true;
feat_config_.fbank_opts.frame_opts.samp_freq = 16000;
feat_config_.fbank_opts.mel_opts.num_bins = 80;
feat_config_.normalize_samples = true;

fbank_ = std::make_unique<kaldifeat::Fbank>(feat_config_.fbank_opts);

WarmUp();
}

int32_t Dim() const override { return model_.GetModelMetadata().output_dim; }

std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(fbank_.get(), feat_config_);
}

torch::Tensor Compute(OfflineStream *s) const override {
InferenceMode no_grad;
auto features = s->GetFeatures();
features -= features.mean(0, true);
features = features.unsqueeze(0);
auto device = model_.Device();
return model_.Compute(features.to(device));
}

torch::Tensor Compute(OfflineStream **ss, int32_t n) const override {
InferenceMode no_grad;
if (n == 1) {
return Compute(ss[0]);
}

std::vector<torch::Tensor> features_vec(n);
for (int32_t i = 0; i != n; ++i) {
auto f = ss[i]->GetFeatures();
f -= f.mean(0, true);
features_vec[i] = f;
}

auto device = model_.Device();

auto features =
torch::nn::utils::rnn::pad_sequence(features_vec, true, 0).to(device);

return model_.Compute(features);
}

private:
void WarmUp() {
InferenceMode no_grad;
SHERPA_LOG(INFO) << "WarmUp begins";
auto s = CreateStream();
float sample_rate = fbank_->GetFrameOptions().samp_freq;
std::vector<float> samples(2 * sample_rate, 0);
s->AcceptSamples(samples.data(), samples.size());

auto embedding = Compute(s.get());

model_.GetModelMetadata().output_dim = embedding.size(1);

SHERPA_LOG(INFO) << "WarmUp ended";
}

private:
SpeakerEmbeddingExtractorModel model_;
std::unique_ptr<kaldifeat::Fbank> fbank_;
FeatureConfig feat_config_;
};

} // namespace sherpa

#endif // SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
17 changes: 17 additions & 0 deletions sherpa/csrc/speaker-embedding-extractor-impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// sherpa/csrc/speaker-embedding-extractor-impl.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa/csrc/speaker-embedding-extractor-impl.h"

#include "sherpa/csrc/speaker-embedding-extractor-general-impl.h"

namespace sherpa {

std::unique_ptr<SpeakerEmbeddingExtractorImpl>
SpeakerEmbeddingExtractorImpl::Create(
const SpeakerEmbeddingExtractorConfig &config) {
// supports only 3-d speaker for now
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
}

} // namespace sherpa
34 changes: 34 additions & 0 deletions sherpa/csrc/speaker-embedding-extractor-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// sherpa/csrc/speaker-embedding-extractor-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
#define SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_

#include <memory>
#include <string>
#include <vector>

#include "sherpa/csrc/speaker-embedding-extractor.h"

namespace sherpa {

class SpeakerEmbeddingExtractorImpl {
public:
virtual ~SpeakerEmbeddingExtractorImpl() = default;

static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create(
const SpeakerEmbeddingExtractorConfig &config);

virtual int32_t Dim() const = 0;

virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;

virtual torch::Tensor Compute(OfflineStream *s) const = 0;

virtual torch::Tensor Compute(OfflineStream **s, int32_t n) const = 0;
};

} // namespace sherpa

#endif // SHERPA_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
Loading

0 comments on commit d3c953b

Please sign in to comment.