-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_anonymization.py
71 lines (59 loc) · 3.04 KB
/
run_anonymization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# We need to set CUDA_VISIBLE_DEVICES before we import Pytorch, so we will read all arguments directly on startup
from argparse import ArgumentParser
import os
parser = ArgumentParser()
parser.add_argument('--config', default='anon_config.yaml')
parser.add_argument('--gpu_ids', default='0')
parser.add_argument('--force_compute', default=False, type=bool)
args = parser.parse_args()
if 'CUDA_VISIBLE_DEVICES' not in os.environ: # do not overwrite previously set devices
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids
else: # CUDA_VISIBLE_DEVICES more important than the gpu_ids arg
args.gpu_ids = ",".join([ str(i) for i, _ in enumerate(os.environ['CUDA_VISIBLE_DEVICES'].split(","))])
from pathlib import Path
import subprocess
import sys
import torch
from utils import parse_yaml, get_datasets, check_dependencies, setup_logger
logger = setup_logger(__name__)
def shell_run(cmd):
if subprocess.run(['bash', cmd]).returncode != 0:
logger.error(f'Failed to bash execute: {cmd}')
sys.exit(1)
if __name__ == '__main__':
config = parse_yaml(Path(args.config))
datasets = get_datasets(config)
gpus = args.gpu_ids.split(',')
devices = []
if torch.cuda.is_available():
for gpu in gpus:
devices.append(torch.device(f'cuda:{gpu}'))
else:
devices.append(torch.device('cpu'))
if config['pipeline'] == "mcadams":
from anonymization.pipelines.mcadams import McAdamsPipeline as pipeline
elif config['pipeline'] == "sttts":
shell_run('anonymization/pipelines/sttts/install.sh')
check_dependencies('anonymization/pipelines/sttts/requirements.txt')
if "download_precomputed_intermediate_repr" in config and config["download_precomputed_intermediate_repr"]:
shell_run('anonymization/pipelines/sttts/download_precomputed_intermediate_repr.sh')
from anonymization.pipelines.sttts import STTTSPipeline as pipeline
elif config['pipeline'] == "nac":
shell_run('anonymization/pipelines/nac/install.sh')
sys.path.append(os.path.join(os.path.dirname(__file__), 'anonymization/modules/nac/coqui_tts/'))
if devices[0] == torch.device('cpu'):
from anonymization.pipelines.nac.nac_pipeline import NACPipeline as pipeline
else:
from anonymization.pipelines.nac.nac_pipeline_accelerate import NACPipeline as pipeline
elif config['pipeline'] == "asrbn":
shell_run('anonymization/pipelines/asrbn/install.sh')
check_dependencies('anonymization/pipelines/asrbn/requirements.txt')
from anonymization.pipelines.asrbn import ASRBNPipeline as pipeline
elif config['pipeline'] == "template":
from anonymization.pipelines.template import TemplatePipeline as pipeline
else:
raise ValueError(f"Pipeline {config['pipeline']} not defined/imported")
logger.info(f'Running pipeline: {config["pipeline"]}')
p = pipeline(config=config, force_compute=args.force_compute, devices=devices)
p.run_anonymization_pipeline(datasets)