Skip to content

Commit

Permalink
Implement simple training test.
Browse files Browse the repository at this point in the history
Only runs the training, no actual check except no crashes.

For #198
  • Loading branch information
rom1504 committed Nov 7, 2022
1 parent 7ab8dfb commit 1f03beb
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 5 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ install: ## [Local development] Upgrade pip, install requirements, install packa
python -m pip install -U pip
python -m pip install -e .

install-training:
python -m pip install -r requirements-training.txt

install-dev: ## [Local development] Install test requirements
python -m pip install -r requirements-test.txt

Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ Install pip PyTorch as per https://pytorch.org/get-started/locally/

Test can be run with `make install-dev` then `make test`

`python -m pytest -x -s -v tests -k "training"` to run a specific test


#### Other dependencies

Install open_clip pacakge and remaining dependencies:
Expand Down
7 changes: 4 additions & 3 deletions src/training/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import sys
import random
from datetime import datetime

Expand Down Expand Up @@ -38,8 +39,8 @@ def random_seed(seed=42, rank=0):
random.seed(seed + rank)


def main():
args = parse_args()
def main(args):
args = parse_args(args)

if torch.cuda.is_available():
# This enables tf32 on Ampere GPUs which is only 8% slower than
Expand Down Expand Up @@ -314,4 +315,4 @@ def copy_codebase(args):


if __name__ == "__main__":
main()
main(sys.argv[1:])
4 changes: 2 additions & 2 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def get_default_params(model_name):
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}


def parse_args():
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--train-data",
Expand Down Expand Up @@ -293,7 +293,7 @@ def parse_args():
parser.add_argument(
"--grad-clip-norm", type=float, default=None, help="Gradient clip."
)
args = parser.parse_args()
args = parser.parse_args(args)

# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.model)
Expand Down
File renamed without changes.
22 changes: 22 additions & 0 deletions tests/test_training_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

import torch
from PIL import Image
from training.main import main
import pytest
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

def test_training():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'RN50'
])

0 comments on commit 1f03beb

Please sign in to comment.