Skip to content

Commit

Permalink
Implement simple training test. (#203)
Browse files Browse the repository at this point in the history
Only runs the training, no actual check except no crashes.

For #198
  • Loading branch information
rom1504 authored Nov 7, 2022
1 parent 7ab8dfb commit 933e1a9
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
source .env/bin/activate
make install
make install-dev
make install-training
- name: Unit tests
run: |
source .env/bin/activate
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ install: ## [Local development] Upgrade pip, install requirements, install packa
python -m pip install -U pip
python -m pip install -e .

install-training:
python -m pip install -r requirements-training.txt

install-dev: ## [Local development] Install test requirements
python -m pip install -r requirements-test.txt

Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ Install pip PyTorch as per https://pytorch.org/get-started/locally/

Test can be run with `make install-dev` then `make test`

`python -m pytest -x -s -v tests -k "training"` to run a specific test


#### Other dependencies

Install open_clip pacakge and remaining dependencies:
Expand Down
7 changes: 4 additions & 3 deletions src/training/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import sys
import random
from datetime import datetime

Expand Down Expand Up @@ -38,8 +39,8 @@ def random_seed(seed=42, rank=0):
random.seed(seed + rank)


def main():
args = parse_args()
def main(args):
args = parse_args(args)

if torch.cuda.is_available():
# This enables tf32 on Ampere GPUs which is only 8% slower than
Expand Down Expand Up @@ -314,4 +315,4 @@ def copy_codebase(args):


if __name__ == "__main__":
main()
main(sys.argv[1:])
4 changes: 2 additions & 2 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def get_default_params(model_name):
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}


def parse_args():
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--train-data",
Expand Down Expand Up @@ -293,7 +293,7 @@ def parse_args():
parser.add_argument(
"--grad-clip-norm", type=float, default=None, help="Gradient clip."
)
args = parser.parse_args()
args = parser.parse_args(args)

# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.model)
Expand Down
File renamed without changes.
22 changes: 22 additions & 0 deletions tests/test_training_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

import torch
from PIL import Image
from training.main import main
import pytest
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

def test_training():
main([
'--save-frequency', '1',
'--zeroshot-frequency', '1',
'--dataset-type', "synthetic",
'--train-num-samples', '16',
'--warmup', '1',
'--batch-size', '4',
'--lr', '1e-3',
'--wd', '0.1',
'--epochs', '1',
'--workers', '2',
'--model', 'RN50'
])

0 comments on commit 933e1a9

Please sign in to comment.