From 1f03bebe20331447410ee21d525b22785421519e Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Mon, 7 Nov 2022 19:52:55 +0100 Subject: [PATCH] Implement simple training test. Only runs the training, no actual check except no crashes. For #198 --- Makefile | 3 +++ README.md | 3 +++ src/training/main.py | 7 +++--- src/training/params.py | 4 ++-- ...est_simple.py => test_inference_simple.py} | 0 tests/test_training_simple.py | 22 +++++++++++++++++++ 6 files changed, 34 insertions(+), 5 deletions(-) rename tests/{test_simple.py => test_inference_simple.py} (100%) create mode 100644 tests/test_training_simple.py diff --git a/Makefile b/Makefile index d4b5fc320..248e5e3b4 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,9 @@ install: ## [Local development] Upgrade pip, install requirements, install packa python -m pip install -U pip python -m pip install -e . +install-training: + python -m pip install -r requirements-training.txt + install-dev: ## [Local development] Install test requirements python -m pip install -r requirements-test.txt diff --git a/README.md b/README.md index 1e8c9269e..1bbccf662 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,9 @@ Install pip PyTorch as per https://pytorch.org/get-started/locally/ Test can be run with `make install-dev` then `make test` +`python -m pytest -x -s -v tests -k "training"` to run a specific test + + #### Other dependencies Install open_clip pacakge and remaining dependencies: diff --git a/src/training/main.py b/src/training/main.py index dd2ef9b06..5b7e04069 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -1,5 +1,6 @@ import logging import os +import sys import random from datetime import datetime @@ -38,8 +39,8 @@ def random_seed(seed=42, rank=0): random.seed(seed + rank) -def main(): - args = parse_args() +def main(args): + args = parse_args(args) if torch.cuda.is_available(): # This enables tf32 on Ampere GPUs which is only 8% slower than @@ -314,4 +315,4 @@ def copy_codebase(args): if __name__ == "__main__": - main() + main(sys.argv[1:]) diff --git a/src/training/params.py b/src/training/params.py index 29d9faf98..4c53cce67 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -10,7 +10,7 @@ def get_default_params(model_name): return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} -def parse_args(): +def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument( "--train-data", @@ -293,7 +293,7 @@ def parse_args(): parser.add_argument( "--grad-clip-norm", type=float, default=None, help="Gradient clip." ) - args = parser.parse_args() + args = parser.parse_args(args) # If some params are not passed, we use the default values based on model name. default_params = get_default_params(args.model) diff --git a/tests/test_simple.py b/tests/test_inference_simple.py similarity index 100% rename from tests/test_simple.py rename to tests/test_inference_simple.py diff --git a/tests/test_training_simple.py b/tests/test_training_simple.py new file mode 100644 index 000000000..9666a7e75 --- /dev/null +++ b/tests/test_training_simple.py @@ -0,0 +1,22 @@ + +import torch +from PIL import Image +from training.main import main +import pytest +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "" + +def test_training(): + main([ + '--save-frequency', '1', + '--zeroshot-frequency', '1', + '--dataset-type', "synthetic", + '--train-num-samples', '16', + '--warmup', '1', + '--batch-size', '4', + '--lr', '1e-3', + '--wd', '0.1', + '--epochs', '1', + '--workers', '2', + '--model', 'RN50' + ])