Skip to content

Commit

Permalink
readme and bin
Browse files Browse the repository at this point in the history
  • Loading branch information
pchampio committed Nov 26, 2024
1 parent 88d7a54 commit 16e6c6f
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ wav_conv = model.convert(torch.rand((1, 77040)), target="1069")
asr_bn = model.get_bn(torch.rand((1, 77040))) # (ASR-BN extraction for disentangled linguistic features (best with hifigan_bn_tdnnf_wav2vec2_vq_48_v1))
```

## Anonymize bin
Once the install.sh script is run, (`INSTALL_KALDI=false` can be set for faster installation), you will
have access to the [`./satools/satools/bin/anonymize`](./satools/satools/bin/anonymize) bin in your path that you can use together
with a config (example: [here](./egs/vc/libritts/configs/anon_any_to_one_for_train)) to anonymize a kaldi like directory.

## Quick JIT anonymization example

This version does not rely on any dependencies using [TorchScript](https://pytorch.org/docs/stable/jit.html).
Expand Down
1 change: 1 addition & 0 deletions satools/satools/bin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import pipeline
99 changes: 99 additions & 0 deletions satools/satools/bin/anonymize
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python3

description = """
This script anonymize a kaldi/wav.scp formated dataset
It takes a config file and a directory
"""

import os
os.environ["SA_JIT_TWEAK"] = "true"
import sys
import time
from dataclasses import dataclass
import configparser
import argparse
import logging

from multiprocessing import Process, Value
from tqdm import tqdm

import satools.script_utils as script_utils

@dataclass
class Pipeline(script_utils.ConfigParser):
model: str = "large"
f0_modification: str = ""
target_selection_algorithm: str = "?"
target_constant_spkid: str = "?"
results_dir: int = "wav" # output of anonymize wavs ./data/XXXX/wav
batch_size: int = 8
data_loader_nj: int = 5
new_datadir_suffix: str = "_anon"

@dataclass
class Cmd(script_utils.ConfigParser):
device: str = "cuda"
ngpu: script_utils.ngpu = 1
jobs_per_compute_device: int = 1 # number of jobs per gpus/cpus


def update_progress_bar(progress, total):
with tqdm(total=total) as pbar:
while progress.value < total:
pbar.n = progress.value
pbar.refresh()
time.sleep(0.5) # Adjust the sleep time as needed
pbar.n = total
pbar.refresh()

def compute_pipeline(cfg_cmd, cfg_pipeline, directory, wavscp, progress):
import satools.bin.pipeline
satools.bin.pipeline.process_data(directory, cfg_pipeline.target_selection_algorithm, wavscp, cfg_pipeline, progress)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--config", default="configs/default", required=True)
parser.add_argument("--directory", default="data/default", required=True)
args = parser.parse_args()

logging.info("Reading config")
cfg_parse = configparser.ConfigParser()
cfg_parse.read(args.config)
cfg_parse = script_utils.vartoml(cfg_parse)

cfg_cmd = Cmd().load_from_config(cfg_parse["cmd"])
cfg_pipeline = Pipeline().load_from_config(cfg_parse["pipeline"])
cfg_pipeline.device = cfg_cmd.device

wavscp = script_utils.read_wav_scp(os.path.join(args.directory, "wav.scp"))

wavscp_for_jobs = list(script_utils.split_dict(wavscp, len(cfg_cmd.ngpu) * cfg_cmd.jobs_per_compute_device))
progress = Value('i', 0)

processes = []
index = 0
for gpu_id in cfg_cmd.ngpu:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
for job_id in range(cfg_cmd.jobs_per_compute_device):
p = Process(target=compute_pipeline, args=(cfg_cmd, cfg_pipeline, args.directory, wavscp_for_jobs[index], progress))
index += 1
processes.append(p)
p.start()

# Start a thread to update the progress bar
progress_thread = Process(target=update_progress_bar, args=(progress, len(wavscp)))
progress_thread.start()

for p in processes:
p.join()
if p.exitcode != 0:
print(f"Process {p.pid} exited with code {p.exitcode}. Terminating.")
for proc in processes:
if proc.is_alive():
proc.terminate()
progress_thread.terminate()
sys.exit(1)

progress_thread.terminate()
logging.info('Done')
178 changes: 178 additions & 0 deletions satools/satools/bin/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#!/usr/bin/env python3.0
# -*- coding: utf-8 -*-

import os
import shutil
import multiprocessing
from pathlib import Path
import torchaudio
import random
import glob
import logging

import torch

import satools.script_utils as script_utils
from satools.infer_helper import load_model
from satools.utils.kaldi import load_wav_from_scp

def copy_data_dir(dataset_path, output_path):
# Copy utt2spk wav.scp and so on, but not the directories inside (may contains clear or anonymzied *.wav)
os.makedirs(output_path, exist_ok=True)
for p in glob.glob(str(Path(dataset_path) / '*'), recursive=False):
if os.path.isfile(p):
shutil.copy(p, output_path)

class Wav(): # for f0 extraction
def __init__(self, w):
self.wav = w

class Dataset(torch.utils.data.Dataset):
def __init__(self, id_wavs, get_f0_func):
self.all_wavs = list(id_wavs.values())
self.all_keys = list(id_wavs.keys())
self.get_f0_func = get_f0_func

def __len__(self):
return len(self.all_wavs)

def __getitem__(self, index):
audio, freq = load_wav_from_scp(str(self.all_wavs[index]))
f0 = self.get_f0_func(Wav(audio))
return {"utid": self.all_keys[index],
"audio": audio,
"f0": f0,
"freq": freq}

def collate_fn(item_list):
batch_size = len(item_list)

data_list_audio = [i['audio'] for i in item_list]
lengths_tensor_audio = torch.tensor([i.shape[-1] for i in data_list_audio])
max_len_audio = torch.max(lengths_tensor_audio).item()
output_audio = torch.zeros([batch_size, max_len_audio])
for i in range(batch_size):
cur = data_list_audio[i]
cur_len = data_list_audio[i].shape[-1]
output_audio[i, :cur_len] = cur.squeeze()

data_list_f0 = [i['f0'] for i in item_list]
lengths_tensor_f0 = torch.tensor([i.shape[-1] for i in data_list_f0])
max_len_f0 = torch.max(lengths_tensor_f0).item()
output_f0 = torch.zeros([batch_size, max_len_f0])
for i in range(batch_size):
cur = data_list_f0[i]
cur_len = data_list_f0[i].shape[-1]
output_f0[i, :cur_len] = cur.squeeze()

utids = [i['utid'] for i in item_list]
freqs = [i['freq'] for i in item_list]
return output_audio, output_f0, lengths_tensor_audio, utids, freqs

def process_data(dataset_path: str, target_selection_algorithm: str, wavscp: dict, settings: dict, progress):
results_dir = settings.results_dir
dataset_path = Path(str(dataset_path))
output_path = Path(str(dataset_path) + settings.new_datadir_suffix)
device = settings.device
batch_size = settings.batch_size

copy_data_dir(dataset_path, output_path)
results_dir = output_path / results_dir
os.makedirs(results_dir, exist_ok = True)

wav_scp = dataset_path / 'wav.scp'
utt2spk = dataset_path / 'utt2spk'
wav_scp_out = output_path / 'wav.scp'

model = load_model(settings.model)
model.to(device)
model.eval()
possible_targets = model.spk.copy() # For spk and utt target_selection_algorithm random choice

source_utt2spk = script_utils.read_wav_scp(utt2spk)
out_spk2target = {} # For spk target_selection_algorithm


@torch.no_grad()
def process_wav(utid, freq, audio, f0, original_len):

freq = freq[0] # assume all freq = in same batch (and so dataset)
audio = audio.to(device)

# Anonymize function
model.set_f0(f0.to(device)) # CPU extracted by Dataloader (num_workers)
# Batch select target spks from the available model list depending on target_selection_algorithm
target_spks = []
if target_selection_algorithm == "constant": # The best way/most secure to evaluate privacy when applied to all dataset (train included)
target_constant_spkid = settings.target_constant_spkid # For constant target_selection_algorithm
target_spks = [target_constant_spkid]*audio.shape[0]
elif target_selection_algorithm == "bad_for_evaluation":
# This target selection algorithm is bad for evaluation as it does
# not generate suitable training data for the ASV eval training
# procedure. Use it with caution.
for ut in utid:
source_spk = source_utt2spk[ut]
if source_spk not in out_spk2target:
out_spk2target[source_spk] = random.sample(possible_targets, 2)
target_spks.append(random.choice(out_spk2target[source_spk]))
elif target_selection_algorithm == "random_per_utt":
target_spks = []
for ut in utid:
target_spks.append(random.choice(possible_targets))
elif target_selection_algorithm == "random_per_spk_uniq":
for ut in utid:
source_spk = source_utt2spk[ut]
if source_spk not in out_spk2target:
out_spk2target[source_spk] = random.choice(possible_targets)
# Remove target spk: size of possible source spk to anonymize == len(possible_targets) (==247) or you need to add spk target overlap)
possible_targets.remove(out_spk2target[source_spk])
target_spks.append(out_spk2target[source_spk])
elif target_selection_algorithm == "random_per_spk":
for ut in utid:
source_spk = source_utt2spk[ut]
if source_spk not in out_spk2target:
out_spk2target[source_spk] = random.choice(possible_targets)
target_spks.append(out_spk2target[source_spk])
else:
raise ValueError(f"{target_selection_algorithm} not implemented")
# Batch conversion
wav_conv = model.convert(audio, target=target_spks)
wav_conv = wav_conv.cpu()

def parallel_write():
for i in range(wav_conv.shape[0]):
wav = wav_conv[i]
if len(wav.shape) == 1:
wav = wav.unsqueeze(0) # batch == 1 -> len(dst) % batch == 1
wav = wav[:, :original_len[i]]
# write to buffer
u = utid[i]
output_file = results_dir / f'{u}.wav'
torchaudio.save(str(output_file), wav, freq, encoding='PCM_S', bits_per_sample=16)
p = multiprocessing.Process(target=parallel_write, args=())
p.start()
return p

nj = settings.data_loader_nj
nj = min(nj, 18)
p = None

with open(wav_scp_out, 'wt', encoding='utf-8') as writer:
filtered_wavs = {}
for u, file in wavscp.items():
output_file = results_dir / f'{u}.wav'
filtered_wavs[u] = file

data_loader = torch.utils.data.DataLoader(Dataset(filtered_wavs, model.get_f0), batch_size=batch_size, num_workers=nj, collate_fn=collate_fn)
for audio, f0, original_len, utid, freq in data_loader:
p = process_wav(utid, freq, audio, f0, original_len)
for u in utid:
output_file = results_dir / f'{u}.wav'
writer.writelines(f"{u} {output_file}\n")
with progress.get_lock():
progress.value += batch_size
if device.startswith("cuda"):
torch.cuda.empty_cache()
# wait for last p to write the anonymized audios
if p:
p.join()

0 comments on commit 16e6c6f

Please sign in to comment.