Skip to content

Commit 41e67d5

Browse files
authored
Simplify testing and dummy data generation (#147)
This uses `tfds.testing.mock_data` to simplify our training integration tests.
1 parent 38f4ba7 commit 41e67d5

File tree

6 files changed

+15
-147
lines changed

6 files changed

+15
-147
lines changed

larq_zoo/training/data.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,13 @@ class Default(ImageClassification):
4949
input_shape = Field((IMAGE_SIZE, IMAGE_SIZE, 3))
5050

5151
def input(self, data, training):
52+
image = data["image"]
53+
# tfds.testing.mock_data currently doesn't correctly handle custom decoders.
54+
# See https://github.com/tensorflow/datasets/pull/1861
55+
if image.dtype.is_integer:
56+
image = tf.image.encode_jpeg(image)
5257
return preprocess_image_bytes(
53-
data["image"], is_training=training, image_size=IMAGE_SIZE
58+
image, is_training=training, image_size=IMAGE_SIZE
5459
)
5560

5661

tests/dummy_datasets.py

-57
This file was deleted.

tests/fixtures/dummy_datasets/dummy_oxford_flowers102/2.0.0/dataset_info.json

-85
This file was deleted.

tests/fixtures/dummy_datasets/dummy_oxford_flowers102/2.0.0/image.image.json

-1
This file was deleted.

tests/train_test.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
1+
from unittest import mock
2+
13
import pytest
4+
import tensorflow_datasets as tfds
25
from click.testing import CliRunner
6+
from zookeeper.tf.dataset import TFDSDataset
37

48
from larq_zoo.training import basic_experiments
5-
from tests import dummy_datasets # noqa
69

710

811
@pytest.mark.parametrize("command", list(basic_experiments.cli.commands.keys()))
9-
def test_cli(command):
12+
@tfds.testing.mock_data(num_examples=2, data_dir="gs://tfds-data/dataset_info")
13+
@mock.patch.object(TFDSDataset, "num_examples", return_value=2)
14+
def test_cli(_, command):
1015
result = CliRunner().invoke(
1116
basic_experiments.cli,
1217
[
1318
command,
14-
"dataset=DummyOxfordFlowers",
19+
"dataset=ImageNet",
1520
"epochs=1",
1621
"batch_size=2",
1722
"validation_frequency=5",
1823
"--no-use_tensorboard",
1924
"--no-use_model_checkpointing",
2025
"--no-save_weights",
2126
],
27+
catch_exceptions=False,
2228
)
2329
assert result.exit_code == 0

0 commit comments

Comments
 (0)