Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference testing using random data #199

Merged
merged 2 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,14 @@ make install

Install pip PyTorch as per https://pytorch.org/get-started/locally/


#### Tests

Test can be run with `make install-dev` then `make test`

`python -m pytest -x -s -v tests -k "training"` to run a specific test

When introducing new models, `python3 tests/util_test.py` can generate new output expected data.

#### Other dependencies

Expand Down
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest-xdist==2.5.0
pytest==7.0.1
transformers
timm==0.6.11
Binary file added tests/data/input/random_image_224_224.pt
Binary file not shown.
Binary file added tests/data/input/random_image_240_240.pt
Binary file not shown.
Binary file added tests/data/input/random_image_256_256.pt
Binary file not shown.
Binary file added tests/data/input/random_image_280_280.pt
Binary file not shown.
Binary file added tests/data/input/random_image_288_288.pt
Binary file not shown.
Binary file added tests/data/input/random_image_320_320.pt
Binary file not shown.
Binary file added tests/data/input/random_image_336_336.pt
Binary file not shown.
Binary file added tests/data/input/random_image_384_384.pt
Binary file not shown.
Binary file added tests/data/input/random_image_448_448.pt
Binary file not shown.
Binary file added tests/data/input/random_text.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/output/RN101_None_fp32_random_image.pt
Binary file not shown.
Binary file added tests/data/output/RN101_None_fp32_random_text.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/output/RN50_None_fp32_random_image.pt
Binary file not shown.
Binary file added tests/data/output/RN50_None_fp32_random_text.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/output/RN50x4_None_fp32_random_text.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
66 changes: 66 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

import os
import random
import pytest
import numpy
import torch
from PIL import Image
import open_clip
import util_test

os.environ['CUDA_VISIBLE_DEVICES'] = ''

# test all model with some exceptions
models_to_test = set(open_clip.list_models()).difference({
# not available with timm yet
# see https://github.com/mlfoundations/open_clip/issues/219
'timm-convnext_xlarge',
'timm-vit_medium_patch16_gap_256',
# exceeds GH runner memory limit
'ViT-G-14',
'ViT-e-14',
})

@pytest.mark.parametrize('model_name', models_to_test)
def test_inference_with_data(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How slow is this ?
Let's try to keep test running in less than 5min
Either remove redundant tests or use the (matrix) parallel feature of GH actions
(And also possibly the parallel feature of pytest)

Copy link
Contributor Author

@lopho lopho Nov 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

======================================================= test session starts =======================================================
platform linux -- Python 3.9.2, pytest-7.0.1, pluggy-1.0.0 -- /REDACTED/open_clip/.venv/bin/python
cachedir: .pytest_cache
rootdir: /REDACTED/open_clip
plugins: xdist-2.5.0, forked-1.4.0
collected 62 items                                                                                                                

tests/test_inference.py::test_inference_with_data[RN50-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN50-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[RN50-yfcc15m-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN50-cc12m-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN50-quickgelu-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN50-quickgelu-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[RN50-quickgelu-yfcc15m-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN50-quickgelu-cc12m-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN101-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN101-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[RN101-yfcc15m-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN101-quickgelu-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN101-quickgelu-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[RN101-quickgelu-yfcc15m-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN50x4-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN50x4-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[RN50x16-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN50x16-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[RN50x64-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[RN50x64-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-laion400m_e31-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-laion400m_e31-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-laion400m_e32-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-laion400m_e32-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-laion2b_e16-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-laion2b_e16-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-laion2b_s34b_b79k-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-laion2b_s34b_b79k-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-quickgelu-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-quickgelu-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-quickgelu-laion400m_e31-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-quickgelu-laion400m_e31-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-quickgelu-laion400m_e32-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-32-quickgelu-laion400m_e32-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-laion400m_e31-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-laion400m_e31-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-laion400m_e32-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-laion400m_e32-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-plus-240-laion400m_e31-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-plus-240-laion400m_e31-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-plus-240-laion400m_e32-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-B-16-plus-240-laion400m_e32-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-laion400m_e31-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-laion400m_e31-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-laion400m_e32-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-laion400m_e32-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-laion2b_s32b_b82k-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-laion2b_s32b_b82k-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-336-openai-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-L-14-336-openai-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-H-14-laion2b_s32b_b79k-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-H-14-laion2b_s32b_b79k-fp32-True] PASSED
tests/test_inference.py::test_inference_with_data[ViT-g-14-laion2b_s12b_b42k-fp32-False] PASSED
tests/test_inference.py::test_inference_with_data[ViT-g-14-laion2b_s12b_b42k-fp32-True] PASSED
tests/test_simple.py::test_inference[False] PASSED
tests/test_simple.py::test_inference[True] PASSED

======================================================== warnings summary =========================================================
tests/test_inference.py::test_inference_with_data[RN50-openai-fp32-True]
  /REDACTED/open_clip/tests/test_inference.py:38: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
  To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:600.)
    image_features = model.encode_image(prepped)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================ 62 passed, 1 warning in 374.16s (0:06:14) ============================================

0:06:14 on i7-4790K

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch size 1, single sample

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing configs in list_models(), now all unit tests (training, infer, hf, ...) take about 8:30-12:00 minutes. With setup over head 10-14 minutes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems now 3min, what changed ?

model_name,
pretrained = None,
precision = 'fp32',
jit = False,
force_quick_gelu = False,
# experimentally determined between author machine and GH runner
tolerance = torch.finfo(torch.float32).resolution * 4
):
util_test.seed_all()
model, _, preprocess_val = open_clip.create_model_and_transforms(
model_name,
pretrained = pretrained,
precision = precision,
jit = jit,
force_quick_gelu = force_quick_gelu
)
model_id = f'{model_name}_{pretrained}_{precision}'
input_dir, output_dir = util_test.get_data_dirs()
# text
input_text_path = os.path.join(input_dir, 'random_text.pt')
gt_text_path = os.path.join(output_dir, f'{model_id}_random_text.pt')
assert os.path.isfile(input_text_path), f"missing test data, expected at {input_text_path}"
assert os.path.isfile(gt_text_path), f"missing test data, expected at {gt_text_path}"
input_text = torch.load(input_text_path)
gt_text = torch.load(gt_text_path)
y_text = util_test.inference_text(model, model_name, input_text)
assert torch.allclose(y_text, gt_text, atol=tolerance), f"text output differs @ {input_text_path}"
# image
image_size = model.visual.image_size
if not isinstance(image_size, tuple):
image_size = (image_size, image_size)
input_image_path = os.path.join(input_dir, f'random_image_{image_size[0]}_{image_size[1]}.pt')
gt_image_path = os.path.join(output_dir, f'{model_id}_random_image.pt')
assert os.path.isfile(input_image_path), f"missing test data, expected at {input_image_path}"
assert os.path.isfile(gt_image_path), f"missing test data, expected at {gt_image_path}"
input_image = torch.load(input_image_path)
gt_image = torch.load(gt_image_path)
y_image = util_test.inference_image(model, preprocess_val, input_image)
assert torch.allclose(y_image, gt_image, atol=tolerance), f"image output differs @ {input_image_path}"


221 changes: 221 additions & 0 deletions tests/util_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@

import os
import random
import numpy as np
from PIL import Image
import torch
import open_clip
import argparse

os.environ['CUDA_VISIBLE_DEVICES'] = ''

def seed_all(seed = 0):
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

def inference_text(model, model_name, batches):
y = []
tokenizer = open_clip.get_tokenizer(model_name)
with torch.no_grad():
rom1504 marked this conversation as resolved.
Show resolved Hide resolved
for x in batches:
x = tokenizer(x)
y.append(model.encode_text(x))
return torch.stack(y)

def inference_image(model, preprocess_val, batches):
y = []
with torch.no_grad():
rom1504 marked this conversation as resolved.
Show resolved Hide resolved
for x in batches:
x = torch.stack([preprocess_val(img) for img in x])
y.append(model.encode_image(x))
return torch.stack(y)

def random_image_batch(batch_size, size):
h, w = size
data = np.random.randint(255, size = (batch_size, h, w, 3), dtype = np.uint8)
return [ Image.fromarray(d) for d in data ]

def random_text_batch(batch_size, min_length = 75, max_length = 75):
t = open_clip.tokenizer.SimpleTokenizer()
rom1504 marked this conversation as resolved.
Show resolved Hide resolved
# every token decoded as string, exclude SOT and EOT, replace EOW with space
token_words = [
x[1].replace('</w>', ' ')
for x in t.decoder.items()
if x[0] not in t.all_special_ids
]
# strings of randomly chosen tokens
return [
''.join(random.choices(
token_words,
k = random.randint(min_length, max_length)
))
for _ in range(batch_size)
]

def create_random_text_data(
path,
min_length = 75,
max_length = 75,
batches = 1,
batch_size = 1
):
text_batches = [
random_text_batch(batch_size, min_length, max_length)
for _ in range(batches)
]
print(f"{path}")
torch.save(text_batches, path)

def create_random_image_data(path, size, batches = 1, batch_size = 1):
image_batches = [
random_image_batch(batch_size, size)
for _ in range(batches)
]
print(f"{path}")
torch.save(image_batches, path)

def get_data_dirs(make_dir = True):
data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
input_dir = os.path.join(data_dir, 'input')
output_dir = os.path.join(data_dir, 'output')
if make_dir:
os.makedirs(input_dir, exist_ok = True)
os.makedirs(output_dir, exist_ok = True)
assert os.path.isdir(data_dir), f"data directory missing, expected at {input_dir}"
assert os.path.isdir(data_dir), f"data directory missing, expected at {output_dir}"
return input_dir, output_dir

def create_test_data_for_model(
model_name,
pretrained = None,
precision = 'fp32',
jit = False,
force_quick_gelu = False,
create_missing_input_data = True,
batches = 1,
batch_size = 1,
overwrite = False
):
model_id = f'{model_name}_{pretrained}_{precision}'
input_dir, output_dir = get_data_dirs()
output_file_text = os.path.join(output_dir, f'{model_id}_random_text.pt')
output_file_image = os.path.join(output_dir, f'{model_id}_random_image.pt')
text_exists = os.path.exists(output_file_text)
image_exists = os.path.exists(output_file_image)
if not overwrite and text_exists and image_exists:
return
seed_all()
model, _, preprocess_val = open_clip.create_model_and_transforms(
model_name,
pretrained = pretrained,
precision = precision,
jit = jit,
force_quick_gelu = force_quick_gelu
)
# text
if overwrite or not text_exists:
input_file_text = os.path.join(input_dir, 'random_text.pt')
if create_missing_input_data and not os.path.exists(input_file_text):
create_random_text_data(
input_file_text,
batches = batches,
batch_size = batch_size
)
assert os.path.isfile(input_file_text), f"missing input data, expected at {input_file_text}"
input_data_text = torch.load(input_file_text)
output_data_text = inference_text(model, model_name, input_data_text)
print(f"{output_file_text}")
torch.save(output_data_text, output_file_text)
# image
if overwrite or not image_exists:
size = model.visual.image_size
if not isinstance(size, tuple):
size = (size, size)
input_file_image = os.path.join(input_dir, f'random_image_{size[0]}_{size[1]}.pt')
if create_missing_input_data and not os.path.exists(input_file_image):
create_random_image_data(
input_file_image,
size,
batches = batches,
batch_size = batch_size
)
assert os.path.isfile(input_file_image), f"missing input data, expected at {input_file_image}"
input_data_image = torch.load(input_file_image)
output_data_image = inference_image(model, preprocess_val, input_data_image)
print(f"{output_file_image}")
torch.save(output_data_image, output_file_image)

def create_test_data(
models,
batches = 1,
batch_size = 1,
overwrite = False
):
models = set(models).difference({
# not available with timm
# see https://github.com/mlfoundations/open_clip/issues/219
'timm-convnext_xlarge',
'timm-vit_medium_patch16_gap_256'
})
for model_name in models:
create_test_data_for_model(
model_name,
batches = batches,
batch_size = batch_size,
overwrite = overwrite
)


def main(args):
rom1504 marked this conversation as resolved.
Show resolved Hide resolved
parser = argparse.ArgumentParser(description="Populate test data directory")
parser.add_argument(
"--all",
default=False,
action='store_true',
help="create test data for all models"
)
parser.add_argument(
"--model",
default=None,
type=str,
help="model to create test data for (default: None)"
)
parser.add_argument(
"--overwrite",
default=False,
action='store_true',
help="overwrite existing data"
)
parser.add_argument(
"--num_batches",
default=1,
type=int,
help="amount of data batches to create (default: 1)"
)
parser.add_argument(
"--batch_size",
default=1,
type=int,
help="test data batch size (default: 1)"
)
args = parser.parse_args(args)
if not args.all and args.model is None:
parser.print_help()
parser.exit()
models = open_clip.list_models() if args.all else [args.model]
print(f"generating test data for:\n{models}")
create_test_data(
models,
batches = args.num_batches,
batch_size = args.batch_size,
overwrite = args.overwrite
)


if __name__ == '__main__':
import sys
main(sys.argv[1:])