-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
283 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import pipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#!/usr/bin/env python3 | ||
|
||
description = """ | ||
This script anonymize a kaldi/wav.scp formated dataset | ||
It takes a config file and a directory | ||
""" | ||
|
||
import os | ||
os.environ["SA_JIT_TWEAK"] = "true" | ||
import sys | ||
import time | ||
from dataclasses import dataclass | ||
import configparser | ||
import argparse | ||
import logging | ||
|
||
from multiprocessing import Process, Value | ||
from tqdm import tqdm | ||
|
||
import satools.script_utils as script_utils | ||
|
||
@dataclass | ||
class Pipeline(script_utils.ConfigParser): | ||
model: str = "large" | ||
f0_modification: str = "" | ||
target_selection_algorithm: str = "?" | ||
target_constant_spkid: str = "?" | ||
results_dir: int = "wav" # output of anonymize wavs ./data/XXXX/wav | ||
batch_size: int = 8 | ||
data_loader_nj: int = 5 | ||
new_datadir_suffix: str = "_anon" | ||
|
||
@dataclass | ||
class Cmd(script_utils.ConfigParser): | ||
device: str = "cuda" | ||
ngpu: script_utils.ngpu = 1 | ||
jobs_per_compute_device: int = 1 # number of jobs per gpus/cpus | ||
|
||
|
||
def update_progress_bar(progress, total): | ||
with tqdm(total=total) as pbar: | ||
while progress.value < total: | ||
pbar.n = progress.value | ||
pbar.refresh() | ||
time.sleep(0.5) # Adjust the sleep time as needed | ||
pbar.n = total | ||
pbar.refresh() | ||
|
||
def compute_pipeline(cfg_cmd, cfg_pipeline, directory, wavscp, progress): | ||
import satools.bin.pipeline | ||
satools.bin.pipeline.process_data(directory, cfg_pipeline.target_selection_algorithm, wavscp, cfg_pipeline, progress) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description=description) | ||
parser.add_argument("--config", default="configs/default", required=True) | ||
parser.add_argument("--directory", default="data/default", required=True) | ||
args = parser.parse_args() | ||
|
||
logging.info("Reading config") | ||
cfg_parse = configparser.ConfigParser() | ||
cfg_parse.read(args.config) | ||
cfg_parse = script_utils.vartoml(cfg_parse) | ||
|
||
cfg_cmd = Cmd().load_from_config(cfg_parse["cmd"]) | ||
cfg_pipeline = Pipeline().load_from_config(cfg_parse["pipeline"]) | ||
cfg_pipeline.device = cfg_cmd.device | ||
|
||
wavscp = script_utils.read_wav_scp(os.path.join(args.directory, "wav.scp")) | ||
|
||
wavscp_for_jobs = list(script_utils.split_dict(wavscp, len(cfg_cmd.ngpu) * cfg_cmd.jobs_per_compute_device)) | ||
progress = Value('i', 0) | ||
|
||
processes = [] | ||
index = 0 | ||
for gpu_id in cfg_cmd.ngpu: | ||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) | ||
for job_id in range(cfg_cmd.jobs_per_compute_device): | ||
p = Process(target=compute_pipeline, args=(cfg_cmd, cfg_pipeline, args.directory, wavscp_for_jobs[index], progress)) | ||
index += 1 | ||
processes.append(p) | ||
p.start() | ||
|
||
# Start a thread to update the progress bar | ||
progress_thread = Process(target=update_progress_bar, args=(progress, len(wavscp))) | ||
progress_thread.start() | ||
|
||
for p in processes: | ||
p.join() | ||
if p.exitcode != 0: | ||
print(f"Process {p.pid} exited with code {p.exitcode}. Terminating.") | ||
for proc in processes: | ||
if proc.is_alive(): | ||
proc.terminate() | ||
progress_thread.terminate() | ||
sys.exit(1) | ||
|
||
progress_thread.terminate() | ||
logging.info('Done') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
#!/usr/bin/env python3.0 | ||
# -*- coding: utf-8 -*- | ||
|
||
import os | ||
import shutil | ||
import multiprocessing | ||
from pathlib import Path | ||
import torchaudio | ||
import random | ||
import glob | ||
import logging | ||
|
||
import torch | ||
|
||
import satools.script_utils as script_utils | ||
from satools.infer_helper import load_model | ||
from satools.utils.kaldi import load_wav_from_scp | ||
|
||
def copy_data_dir(dataset_path, output_path): | ||
# Copy utt2spk wav.scp and so on, but not the directories inside (may contains clear or anonymzied *.wav) | ||
os.makedirs(output_path, exist_ok=True) | ||
for p in glob.glob(str(Path(dataset_path) / '*'), recursive=False): | ||
if os.path.isfile(p): | ||
shutil.copy(p, output_path) | ||
|
||
class Wav(): # for f0 extraction | ||
def __init__(self, w): | ||
self.wav = w | ||
|
||
class Dataset(torch.utils.data.Dataset): | ||
def __init__(self, id_wavs, get_f0_func): | ||
self.all_wavs = list(id_wavs.values()) | ||
self.all_keys = list(id_wavs.keys()) | ||
self.get_f0_func = get_f0_func | ||
|
||
def __len__(self): | ||
return len(self.all_wavs) | ||
|
||
def __getitem__(self, index): | ||
audio, freq = load_wav_from_scp(str(self.all_wavs[index])) | ||
f0 = self.get_f0_func(Wav(audio)) | ||
return {"utid": self.all_keys[index], | ||
"audio": audio, | ||
"f0": f0, | ||
"freq": freq} | ||
|
||
def collate_fn(item_list): | ||
batch_size = len(item_list) | ||
|
||
data_list_audio = [i['audio'] for i in item_list] | ||
lengths_tensor_audio = torch.tensor([i.shape[-1] for i in data_list_audio]) | ||
max_len_audio = torch.max(lengths_tensor_audio).item() | ||
output_audio = torch.zeros([batch_size, max_len_audio]) | ||
for i in range(batch_size): | ||
cur = data_list_audio[i] | ||
cur_len = data_list_audio[i].shape[-1] | ||
output_audio[i, :cur_len] = cur.squeeze() | ||
|
||
data_list_f0 = [i['f0'] for i in item_list] | ||
lengths_tensor_f0 = torch.tensor([i.shape[-1] for i in data_list_f0]) | ||
max_len_f0 = torch.max(lengths_tensor_f0).item() | ||
output_f0 = torch.zeros([batch_size, max_len_f0]) | ||
for i in range(batch_size): | ||
cur = data_list_f0[i] | ||
cur_len = data_list_f0[i].shape[-1] | ||
output_f0[i, :cur_len] = cur.squeeze() | ||
|
||
utids = [i['utid'] for i in item_list] | ||
freqs = [i['freq'] for i in item_list] | ||
return output_audio, output_f0, lengths_tensor_audio, utids, freqs | ||
|
||
def process_data(dataset_path: str, target_selection_algorithm: str, wavscp: dict, settings: dict, progress): | ||
results_dir = settings.results_dir | ||
dataset_path = Path(str(dataset_path)) | ||
output_path = Path(str(dataset_path) + settings.new_datadir_suffix) | ||
device = settings.device | ||
batch_size = settings.batch_size | ||
|
||
copy_data_dir(dataset_path, output_path) | ||
results_dir = output_path / results_dir | ||
os.makedirs(results_dir, exist_ok = True) | ||
|
||
wav_scp = dataset_path / 'wav.scp' | ||
utt2spk = dataset_path / 'utt2spk' | ||
wav_scp_out = output_path / 'wav.scp' | ||
|
||
model = load_model(settings.model) | ||
model.to(device) | ||
model.eval() | ||
possible_targets = model.spk.copy() # For spk and utt target_selection_algorithm random choice | ||
|
||
source_utt2spk = script_utils.read_wav_scp(utt2spk) | ||
out_spk2target = {} # For spk target_selection_algorithm | ||
|
||
|
||
@torch.no_grad() | ||
def process_wav(utid, freq, audio, f0, original_len): | ||
|
||
freq = freq[0] # assume all freq = in same batch (and so dataset) | ||
audio = audio.to(device) | ||
|
||
# Anonymize function | ||
model.set_f0(f0.to(device)) # CPU extracted by Dataloader (num_workers) | ||
# Batch select target spks from the available model list depending on target_selection_algorithm | ||
target_spks = [] | ||
if target_selection_algorithm == "constant": # The best way/most secure to evaluate privacy when applied to all dataset (train included) | ||
target_constant_spkid = settings.target_constant_spkid # For constant target_selection_algorithm | ||
target_spks = [target_constant_spkid]*audio.shape[0] | ||
elif target_selection_algorithm == "bad_for_evaluation": | ||
# This target selection algorithm is bad for evaluation as it does | ||
# not generate suitable training data for the ASV eval training | ||
# procedure. Use it with caution. | ||
for ut in utid: | ||
source_spk = source_utt2spk[ut] | ||
if source_spk not in out_spk2target: | ||
out_spk2target[source_spk] = random.sample(possible_targets, 2) | ||
target_spks.append(random.choice(out_spk2target[source_spk])) | ||
elif target_selection_algorithm == "random_per_utt": | ||
target_spks = [] | ||
for ut in utid: | ||
target_spks.append(random.choice(possible_targets)) | ||
elif target_selection_algorithm == "random_per_spk_uniq": | ||
for ut in utid: | ||
source_spk = source_utt2spk[ut] | ||
if source_spk not in out_spk2target: | ||
out_spk2target[source_spk] = random.choice(possible_targets) | ||
# Remove target spk: size of possible source spk to anonymize == len(possible_targets) (==247) or you need to add spk target overlap) | ||
possible_targets.remove(out_spk2target[source_spk]) | ||
target_spks.append(out_spk2target[source_spk]) | ||
elif target_selection_algorithm == "random_per_spk": | ||
for ut in utid: | ||
source_spk = source_utt2spk[ut] | ||
if source_spk not in out_spk2target: | ||
out_spk2target[source_spk] = random.choice(possible_targets) | ||
target_spks.append(out_spk2target[source_spk]) | ||
else: | ||
raise ValueError(f"{target_selection_algorithm} not implemented") | ||
# Batch conversion | ||
wav_conv = model.convert(audio, target=target_spks) | ||
wav_conv = wav_conv.cpu() | ||
|
||
def parallel_write(): | ||
for i in range(wav_conv.shape[0]): | ||
wav = wav_conv[i] | ||
if len(wav.shape) == 1: | ||
wav = wav.unsqueeze(0) # batch == 1 -> len(dst) % batch == 1 | ||
wav = wav[:, :original_len[i]] | ||
# write to buffer | ||
u = utid[i] | ||
output_file = results_dir / f'{u}.wav' | ||
torchaudio.save(str(output_file), wav, freq, encoding='PCM_S', bits_per_sample=16) | ||
p = multiprocessing.Process(target=parallel_write, args=()) | ||
p.start() | ||
return p | ||
|
||
nj = settings.data_loader_nj | ||
nj = min(nj, 18) | ||
p = None | ||
|
||
with open(wav_scp_out, 'wt', encoding='utf-8') as writer: | ||
filtered_wavs = {} | ||
for u, file in wavscp.items(): | ||
output_file = results_dir / f'{u}.wav' | ||
filtered_wavs[u] = file | ||
|
||
data_loader = torch.utils.data.DataLoader(Dataset(filtered_wavs, model.get_f0), batch_size=batch_size, num_workers=nj, collate_fn=collate_fn) | ||
for audio, f0, original_len, utid, freq in data_loader: | ||
p = process_wav(utid, freq, audio, f0, original_len) | ||
for u in utid: | ||
output_file = results_dir / f'{u}.wav' | ||
writer.writelines(f"{u} {output_file}\n") | ||
with progress.get_lock(): | ||
progress.value += batch_size | ||
if device.startswith("cuda"): | ||
torch.cuda.empty_cache() | ||
# wait for last p to write the anonymized audios | ||
if p: | ||
p.join() |