From 0c2252a304849ea31b7d619137c55d75f09efe37 Mon Sep 17 00:00:00 2001 From: TheDudeFromCI Date: Sun, 21 Feb 2021 10:28:20 +0000 Subject: [PATCH] Added image folder format conversion to dataset tool. Signed-off-by: TheDudeFromCI --- dataset_tool.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/dataset_tool.py b/dataset_tool.py index c59e62928..b8b265c73 100755 --- a/dataset_tool.py +++ b/dataset_tool.py @@ -49,7 +49,7 @@ def is_image_ext(fname: Union[str, Path]) -> bool: #---------------------------------------------------------------------------- -def open_image_folder(source_dir, *, max_images: Optional[int]): +def open_image_folder(source_dir, *, max_images: Optional[int], img_format: str): input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] # Load labels. @@ -69,7 +69,10 @@ def iterate_images(): for idx, fname in enumerate(input_images): arch_fname = os.path.relpath(fname, source_dir) arch_fname = arch_fname.replace('\\', '/') - img = np.array(PIL.Image.open(fname)) + img = PIL.Image.open(fname) + if img_format != 'keep': + img = img.convert(img_format) + img = np.array(img) yield dict(img=img, label=labels.get(arch_fname)) if idx >= max_idx-1: break @@ -77,7 +80,7 @@ def iterate_images(): #---------------------------------------------------------------------------- -def open_image_zip(source, *, max_images: Optional[int]): +def open_image_zip(source, *, max_images: Optional[int], img_format: str): with zipfile.ZipFile(source, mode='r') as z: input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] @@ -98,6 +101,8 @@ def iterate_images(): for idx, fname in enumerate(input_images): with z.open(fname, 'r') as file: img = PIL.Image.open(file) # type: ignore + if img_format != 'keep': + img = img.convert(img_format) img = np.array(img) yield dict(img=img, label=labels.get(fname)) if idx >= max_idx-1: @@ -249,19 +254,19 @@ def center_crop_wide(width, height, img): #---------------------------------------------------------------------------- -def open_dataset(source, *, max_images: Optional[int]): +def open_dataset(source, *, max_images: Optional[int], img_format: str): if os.path.isdir(source): if source.rstrip('/').endswith('_lmdb'): return open_lmdb(source, max_images=max_images) else: - return open_image_folder(source, max_images=max_images) + return open_image_folder(source, max_images=max_images, img_format=img_format) elif os.path.isfile(source): if os.path.basename(source) == 'cifar-10-python.tar.gz': return open_cifar10(source, max_images=max_images) elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': return open_mnist(source, max_images=max_images) elif file_ext(source) == 'zip': - return open_image_zip(source, max_images=max_images) + return open_image_zip(source, max_images=max_images, img_format=img_format) else: assert False, 'unknown archive type' else: @@ -310,6 +315,7 @@ def folder_write_bytes(fname: str, data: Union[bytes, str]): @click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide'])) @click.option('--width', help='Output width', type=int) @click.option('--height', help='Output height', type=int) +@click.option('--img-format', help='Forces images to be loaded as a specific file format.', type=click.Choice(['keep', 'L', 'RGB']), default='keep', show_default=True) def convert_dataset( ctx: click.Context, source: str, @@ -318,7 +324,8 @@ def convert_dataset( transform: Optional[str], resize_filter: str, width: Optional[int], - height: Optional[int] + height: Optional[int], + img_format: str ): """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. @@ -377,6 +384,11 @@ def convert_dataset( \b python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ --transform=center-crop-wide --width 512 --height=384 + + For custom image folders and image zip files, the image format can be force converted to + a specific format on load by using --img-format=L for grayscale or --img-format=RGB + for full color images. Defaults to --img-format=keep which keeps the current image + color format. """ PIL.Image.init() # type: ignore @@ -384,7 +396,7 @@ def convert_dataset( if dest == '': ctx.fail('--dest output filename or directory must not be an empty string') - num_files, input_iter = open_dataset(source, max_images=max_images) + num_files, input_iter = open_dataset(source, max_images=max_images, img_format=img_format) archive_root_dir, save_bytes, close_dest = open_dest(dest) transform_image = make_transform(transform, width, height, resize_filter)