From 1d258338ed1020b4db6c455f5caf62849bea2db8 Mon Sep 17 00:00:00 2001 From: Tero Karras Date: Wed, 3 Feb 2021 15:02:19 +0200 Subject: [PATCH] Add support for MNIST --- dataset_tool.py | 55 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/dataset_tool.py b/dataset_tool.py index 7b1671234..d663211ff 100755 --- a/dataset_tool.py +++ b/dataset_tool.py @@ -13,6 +13,7 @@ import pickle import sys import tarfile +import gzip import zipfile from pathlib import Path from typing import Callable, Optional, Tuple, Union @@ -165,6 +166,36 @@ def iterate_images(): #---------------------------------------------------------------------------- +def open_mnist(images_gz: str, *, max_images: Optional[int]): + labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') + assert labels_gz != images_gz + images = [] + labels = [] + + with gzip.open(images_gz, 'rb') as f: + images = np.frombuffer(f.read(), np.uint8, offset=16) + with gzip.open(labels_gz, 'rb') as f: + labels = np.frombuffer(f.read(), np.uint8, offset=8) + + images = images.reshape(-1, 28, 28) + images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) + assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (60000,) and labels.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + + max_idx = maybe_min(len(images), max_images) + + def iterate_images(): + for idx, img in enumerate(images): + yield dict(img=img, label=int(labels[idx])) + if idx >= max_idx-1: + break + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + def make_transform( transform: Optional[str], output_width: Optional[int], @@ -225,10 +256,11 @@ def open_dataset(source, *, max_images: Optional[int]): else: return open_image_folder(source, max_images=max_images) elif os.path.isfile(source): - if source.endswith('cifar-10-python.tar.gz'): + if os.path.basename(source) == 'cifar-10-python.tar.gz': return open_cifar10(source, max_images=max_images) - ext = file_ext(source) - if ext == 'zip': + elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': + return open_mnist(source, max_images=max_images) + elif file_ext(source) == 'zip': return open_image_zip(source, max_images=max_images) else: assert False, 'unknown archive type' @@ -293,17 +325,18 @@ def convert_dataset( The input dataset format is guessed from the --source argument: \b - --source *_lmdb/ - Load LSUN dataset - --source cifar-10-python.tar.gz - Load CIFAR-10 dataset - --source path/ - Recursively load all images from path/ - --source dataset.zip - Recursively load all images from dataset.zip + --source *_lmdb/ Load LSUN dataset + --source cifar-10-python.tar.gz Load CIFAR-10 dataset + --source train-images-idx3-ubyte.gz Load MNIST dataset + --source path/ Recursively load all images from path/ + --source dataset.zip Recursively load all images from dataset.zip - The output dataset format can be either an image folder or a zip archive. Specifying - the output format and path: + The output dataset format can be either an image folder or a zip archive. + Specifying the output format and path: \b - --dest /path/to/dir - Save output files under /path/to/dir - --dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive + --dest /path/to/dir Save output files under /path/to/dir + --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip Images within the dataset archive will be stored as uncompressed PNG.