diff --git a/README.md b/README.md index 5e704d6..f37ec1a 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,11 @@ wav_conv = model.convert(torch.rand((1, 77040)), target="1069") asr_bn = model.get_bn(torch.rand((1, 77040))) # (ASR-BN extraction for disentangled linguistic features (best with hifigan_bn_tdnnf_wav2vec2_vq_48_v1)) ``` +## Anonymize bin +Once the install.sh script is run, (`INSTALL_KALDI=false` can be set for faster installation), you will +have access to the [`./satools/satools/bin/anonymize`](./satools/satools/bin/anonymize) bin in your path that you can use together +with a config (example: [here](./egs/vc/libritts/configs/anon_any_to_one_for_train)) to anonymize a kaldi like directory. + ## Quick JIT anonymization example This version does not rely on any dependencies using [TorchScript](https://pytorch.org/docs/stable/jit.html). diff --git a/satools/satools/bin/__init__.py b/satools/satools/bin/__init__.py new file mode 100644 index 0000000..79ba835 --- /dev/null +++ b/satools/satools/bin/__init__.py @@ -0,0 +1 @@ +from . import pipeline diff --git a/satools/satools/bin/anonymize b/satools/satools/bin/anonymize new file mode 100644 index 0000000..5fa3a60 --- /dev/null +++ b/satools/satools/bin/anonymize @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 + +description = """ + This script anonymize a kaldi/wav.scp formated dataset + It takes a config file and a directory +""" + +import os +os.environ["SA_JIT_TWEAK"] = "true" +import sys +import time +from dataclasses import dataclass +import configparser +import argparse +import logging + +from multiprocessing import Process, Value +from tqdm import tqdm + +import satools.script_utils as script_utils + +@dataclass +class Pipeline(script_utils.ConfigParser): + model: str = "large" + f0_modification: str = "" + target_selection_algorithm: str = "?" + target_constant_spkid: str = "?" + results_dir: int = "wav" # output of anonymize wavs ./data/XXXX/wav + batch_size: int = 8 + data_loader_nj: int = 5 + new_datadir_suffix: str = "_anon" + +@dataclass +class Cmd(script_utils.ConfigParser): + device: str = "cuda" + ngpu: script_utils.ngpu = 1 + jobs_per_compute_device: int = 1 # number of jobs per gpus/cpus + + +def update_progress_bar(progress, total): + with tqdm(total=total) as pbar: + while progress.value < total: + pbar.n = progress.value + pbar.refresh() + time.sleep(0.5) # Adjust the sleep time as needed + pbar.n = total + pbar.refresh() + +def compute_pipeline(cfg_cmd, cfg_pipeline, directory, wavscp, progress): + import satools.bin.pipeline + satools.bin.pipeline.process_data(directory, cfg_pipeline.target_selection_algorithm, wavscp, cfg_pipeline, progress) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=description) + parser.add_argument("--config", default="configs/default", required=True) + parser.add_argument("--directory", default="data/default", required=True) + args = parser.parse_args() + + logging.info("Reading config") + cfg_parse = configparser.ConfigParser() + cfg_parse.read(args.config) + cfg_parse = script_utils.vartoml(cfg_parse) + + cfg_cmd = Cmd().load_from_config(cfg_parse["cmd"]) + cfg_pipeline = Pipeline().load_from_config(cfg_parse["pipeline"]) + cfg_pipeline.device = cfg_cmd.device + + wavscp = script_utils.read_wav_scp(os.path.join(args.directory, "wav.scp")) + + wavscp_for_jobs = list(script_utils.split_dict(wavscp, len(cfg_cmd.ngpu) * cfg_cmd.jobs_per_compute_device)) + progress = Value('i', 0) + + processes = [] + index = 0 + for gpu_id in cfg_cmd.ngpu: + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + for job_id in range(cfg_cmd.jobs_per_compute_device): + p = Process(target=compute_pipeline, args=(cfg_cmd, cfg_pipeline, args.directory, wavscp_for_jobs[index], progress)) + index += 1 + processes.append(p) + p.start() + + # Start a thread to update the progress bar + progress_thread = Process(target=update_progress_bar, args=(progress, len(wavscp))) + progress_thread.start() + + for p in processes: + p.join() + if p.exitcode != 0: + print(f"Process {p.pid} exited with code {p.exitcode}. Terminating.") + for proc in processes: + if proc.is_alive(): + proc.terminate() + progress_thread.terminate() + sys.exit(1) + + progress_thread.terminate() + logging.info('Done') diff --git a/satools/satools/bin/pipeline.py b/satools/satools/bin/pipeline.py new file mode 100644 index 0000000..e107fcb --- /dev/null +++ b/satools/satools/bin/pipeline.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3.0 +# -*- coding: utf-8 -*- + +import os +import shutil +import multiprocessing +from pathlib import Path +import torchaudio +import random +import glob +import logging + +import torch + +import satools.script_utils as script_utils +from satools.infer_helper import load_model +from satools.utils.kaldi import load_wav_from_scp + +def copy_data_dir(dataset_path, output_path): + # Copy utt2spk wav.scp and so on, but not the directories inside (may contains clear or anonymzied *.wav) + os.makedirs(output_path, exist_ok=True) + for p in glob.glob(str(Path(dataset_path) / '*'), recursive=False): + if os.path.isfile(p): + shutil.copy(p, output_path) + +class Wav(): # for f0 extraction + def __init__(self, w): + self.wav = w + +class Dataset(torch.utils.data.Dataset): + def __init__(self, id_wavs, get_f0_func): + self.all_wavs = list(id_wavs.values()) + self.all_keys = list(id_wavs.keys()) + self.get_f0_func = get_f0_func + + def __len__(self): + return len(self.all_wavs) + + def __getitem__(self, index): + audio, freq = load_wav_from_scp(str(self.all_wavs[index])) + f0 = self.get_f0_func(Wav(audio)) + return {"utid": self.all_keys[index], + "audio": audio, + "f0": f0, + "freq": freq} + +def collate_fn(item_list): + batch_size = len(item_list) + + data_list_audio = [i['audio'] for i in item_list] + lengths_tensor_audio = torch.tensor([i.shape[-1] for i in data_list_audio]) + max_len_audio = torch.max(lengths_tensor_audio).item() + output_audio = torch.zeros([batch_size, max_len_audio]) + for i in range(batch_size): + cur = data_list_audio[i] + cur_len = data_list_audio[i].shape[-1] + output_audio[i, :cur_len] = cur.squeeze() + + data_list_f0 = [i['f0'] for i in item_list] + lengths_tensor_f0 = torch.tensor([i.shape[-1] for i in data_list_f0]) + max_len_f0 = torch.max(lengths_tensor_f0).item() + output_f0 = torch.zeros([batch_size, max_len_f0]) + for i in range(batch_size): + cur = data_list_f0[i] + cur_len = data_list_f0[i].shape[-1] + output_f0[i, :cur_len] = cur.squeeze() + + utids = [i['utid'] for i in item_list] + freqs = [i['freq'] for i in item_list] + return output_audio, output_f0, lengths_tensor_audio, utids, freqs + +def process_data(dataset_path: str, target_selection_algorithm: str, wavscp: dict, settings: dict, progress): + results_dir = settings.results_dir + dataset_path = Path(str(dataset_path)) + output_path = Path(str(dataset_path) + settings.new_datadir_suffix) + device = settings.device + batch_size = settings.batch_size + + copy_data_dir(dataset_path, output_path) + results_dir = output_path / results_dir + os.makedirs(results_dir, exist_ok = True) + + wav_scp = dataset_path / 'wav.scp' + utt2spk = dataset_path / 'utt2spk' + wav_scp_out = output_path / 'wav.scp' + + model = load_model(settings.model) + model.to(device) + model.eval() + possible_targets = model.spk.copy() # For spk and utt target_selection_algorithm random choice + + source_utt2spk = script_utils.read_wav_scp(utt2spk) + out_spk2target = {} # For spk target_selection_algorithm + + + @torch.no_grad() + def process_wav(utid, freq, audio, f0, original_len): + + freq = freq[0] # assume all freq = in same batch (and so dataset) + audio = audio.to(device) + + # Anonymize function + model.set_f0(f0.to(device)) # CPU extracted by Dataloader (num_workers) + # Batch select target spks from the available model list depending on target_selection_algorithm + target_spks = [] + if target_selection_algorithm == "constant": # The best way/most secure to evaluate privacy when applied to all dataset (train included) + target_constant_spkid = settings.target_constant_spkid # For constant target_selection_algorithm + target_spks = [target_constant_spkid]*audio.shape[0] + elif target_selection_algorithm == "bad_for_evaluation": + # This target selection algorithm is bad for evaluation as it does + # not generate suitable training data for the ASV eval training + # procedure. Use it with caution. + for ut in utid: + source_spk = source_utt2spk[ut] + if source_spk not in out_spk2target: + out_spk2target[source_spk] = random.sample(possible_targets, 2) + target_spks.append(random.choice(out_spk2target[source_spk])) + elif target_selection_algorithm == "random_per_utt": + target_spks = [] + for ut in utid: + target_spks.append(random.choice(possible_targets)) + elif target_selection_algorithm == "random_per_spk_uniq": + for ut in utid: + source_spk = source_utt2spk[ut] + if source_spk not in out_spk2target: + out_spk2target[source_spk] = random.choice(possible_targets) + # Remove target spk: size of possible source spk to anonymize == len(possible_targets) (==247) or you need to add spk target overlap) + possible_targets.remove(out_spk2target[source_spk]) + target_spks.append(out_spk2target[source_spk]) + elif target_selection_algorithm == "random_per_spk": + for ut in utid: + source_spk = source_utt2spk[ut] + if source_spk not in out_spk2target: + out_spk2target[source_spk] = random.choice(possible_targets) + target_spks.append(out_spk2target[source_spk]) + else: + raise ValueError(f"{target_selection_algorithm} not implemented") + # Batch conversion + wav_conv = model.convert(audio, target=target_spks) + wav_conv = wav_conv.cpu() + + def parallel_write(): + for i in range(wav_conv.shape[0]): + wav = wav_conv[i] + if len(wav.shape) == 1: + wav = wav.unsqueeze(0) # batch == 1 -> len(dst) % batch == 1 + wav = wav[:, :original_len[i]] + # write to buffer + u = utid[i] + output_file = results_dir / f'{u}.wav' + torchaudio.save(str(output_file), wav, freq, encoding='PCM_S', bits_per_sample=16) + p = multiprocessing.Process(target=parallel_write, args=()) + p.start() + return p + + nj = settings.data_loader_nj + nj = min(nj, 18) + p = None + + with open(wav_scp_out, 'wt', encoding='utf-8') as writer: + filtered_wavs = {} + for u, file in wavscp.items(): + output_file = results_dir / f'{u}.wav' + filtered_wavs[u] = file + + data_loader = torch.utils.data.DataLoader(Dataset(filtered_wavs, model.get_f0), batch_size=batch_size, num_workers=nj, collate_fn=collate_fn) + for audio, f0, original_len, utid, freq in data_loader: + p = process_wav(utid, freq, audio, f0, original_len) + for u in utid: + output_file = results_dir / f'{u}.wav' + writer.writelines(f"{u} {output_file}\n") + with progress.get_lock(): + progress.value += batch_size + if device.startswith("cuda"): + torch.cuda.empty_cache() + # wait for last p to write the anonymized audios + if p: + p.join()