Skip to content

Commit

Permalink
Added dependency check at beginning of each run script
Browse files Browse the repository at this point in the history
  • Loading branch information
Sarina Meyer committed Feb 15, 2024
1 parent 90006b7 commit 846abdc
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 3 deletions.
3 changes: 2 additions & 1 deletion run_anonymization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import torch

from anonymization.pipelines.sttts_pipeline import STTTSPipeline
from utils import parse_yaml, get_datasets
from utils import parse_yaml, get_datasets, check_dependencies

PIPELINES = {
'sttts': STTTSPipeline
}

if __name__ == '__main__':
check_dependencies('requirements.txt')
parser = ArgumentParser()
parser.add_argument('--config', default='anon_config.yaml')
parser.add_argument('--gpu_ids', default='0')
Expand Down
3 changes: 2 additions & 1 deletion run_anonymization_dsp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from pathlib import Path
from argparse import ArgumentParser
from anonymization.pipelines.dsp_pipeline import DSPPipeline
from utils import parse_yaml, get_datasets
from utils import parse_yaml, get_datasets, check_dependencies

PIPELINES = {
'dsp': DSPPipeline
}

if __name__ == '__main__':
check_dependencies('requirements.txt')
parser = ArgumentParser()
parser.add_argument('--config', default='anon_config.yaml')
args = parser.parse_args()
Expand Down
3 changes: 2 additions & 1 deletion run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from evaluation import evaluate_asv, train_asv_eval, evaluate_asr, train_asr_eval, evaluate_gvd
from utils import (parse_yaml, scan_checkpoint, combine_asr_data, get_datasets,
prepare_evaluation_data, get_anon_wav_scps, save_yaml)
prepare_evaluation_data, get_anon_wav_scps, save_yaml, check_dependencies)

def get_evaluation_steps(params):
eval_steps = {}
Expand Down Expand Up @@ -125,6 +125,7 @@ def save_result_summary(out_dir, results_dict, config):


if __name__ == '__main__':
check_dependencies('requirements.txt')
multiprocessing.set_start_method("fork",force=True)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s- %(levelname)s - %(message)s')

Expand Down
1 change: 1 addition & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
find_asv_model_checkpoint, scan_checkpoint)
from .prepare_results_in_kaldi_format import (prepare_evaluation_data, combine_asr_data,
split_vctk_into_common_and_diverse, get_anon_wav_scps)
from .dependencies import check_dependencies
34 changes: 34 additions & 0 deletions utils/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from importlib.metadata import version, PackageNotFoundError
from packaging.requirements import Requirement
from packaging.version import parse

def check_dependencies(requirements_file):
missing_dependencies = []
nonmatching_versions = []

with open(requirements_file) as f:
for line in f:
line = line.strip()
if len(line) == 0:
continue
requirement = Requirement(line)

try:
installed_version = version(requirement.name)
except PackageNotFoundError:
missing_dependencies.append(line)

if not parse(installed_version) in requirement.specifier:
nonmatching_versions.append((requirement, installed_version))

error_msg = ''
if missing_dependencies:
error_msg += f'Missing dependencies: {" ".join(missing_dependencies)}.\n'
if nonmatching_versions:
error_msg += f'The following packages are installed with a version that does not match the requirement:\n'
for req, installed_version in nonmatching_versions:
error_msg += f'Package: {req.name}, installed: {installed_version}, required: {str(req.specifier)}\n'

if len(error_msg) > 0:
raise ModuleNotFoundError(f'{error_msg}--Make sure to install {requirements_file} to run this code!--')

0 comments on commit 846abdc

Please sign in to comment.