Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added image format conversion to dataset tool #46

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def is_image_ext(fname: Union[str, Path]) -> bool:

#----------------------------------------------------------------------------

def open_image_folder(source_dir, *, max_images: Optional[int]):
def open_image_folder(source_dir, *, max_images: Optional[int], img_format: str):
input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]

# Load labels.
Expand All @@ -69,15 +69,18 @@ def iterate_images():
for idx, fname in enumerate(input_images):
arch_fname = os.path.relpath(fname, source_dir)
arch_fname = arch_fname.replace('\\', '/')
img = np.array(PIL.Image.open(fname))
img = PIL.Image.open(fname)
if img_format != 'keep':
img = img.convert(img_format)
img = np.array(img)
yield dict(img=img, label=labels.get(arch_fname))
if idx >= max_idx-1:
break
return max_idx, iterate_images()

#----------------------------------------------------------------------------

def open_image_zip(source, *, max_images: Optional[int]):
def open_image_zip(source, *, max_images: Optional[int], img_format: str):
with zipfile.ZipFile(source, mode='r') as z:
input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]

Expand All @@ -98,6 +101,8 @@ def iterate_images():
for idx, fname in enumerate(input_images):
with z.open(fname, 'r') as file:
img = PIL.Image.open(file) # type: ignore
if img_format != 'keep':
img = img.convert(img_format)
img = np.array(img)
yield dict(img=img, label=labels.get(fname))
if idx >= max_idx-1:
Expand Down Expand Up @@ -249,19 +254,19 @@ def center_crop_wide(width, height, img):

#----------------------------------------------------------------------------

def open_dataset(source, *, max_images: Optional[int]):
def open_dataset(source, *, max_images: Optional[int], img_format: str):
if os.path.isdir(source):
if source.rstrip('/').endswith('_lmdb'):
return open_lmdb(source, max_images=max_images)
else:
return open_image_folder(source, max_images=max_images)
return open_image_folder(source, max_images=max_images, img_format=img_format)
elif os.path.isfile(source):
if os.path.basename(source) == 'cifar-10-python.tar.gz':
return open_cifar10(source, max_images=max_images)
elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
return open_mnist(source, max_images=max_images)
elif file_ext(source) == 'zip':
return open_image_zip(source, max_images=max_images)
return open_image_zip(source, max_images=max_images, img_format=img_format)
else:
assert False, 'unknown archive type'
else:
Expand Down Expand Up @@ -310,6 +315,7 @@ def folder_write_bytes(fname: str, data: Union[bytes, str]):
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
@click.option('--width', help='Output width', type=int)
@click.option('--height', help='Output height', type=int)
@click.option('--img-format', help='Forces images to be loaded as a specific file format.', type=click.Choice(['keep', 'L', 'RGB']), default='keep', show_default=True)
def convert_dataset(
ctx: click.Context,
source: str,
Expand All @@ -318,7 +324,8 @@ def convert_dataset(
transform: Optional[str],
resize_filter: str,
width: Optional[int],
height: Optional[int]
height: Optional[int],
img_format: str
):
"""Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.

Expand Down Expand Up @@ -377,14 +384,19 @@ def convert_dataset(
\b
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
--transform=center-crop-wide --width 512 --height=384

For custom image folders and image zip files, the image format can be force converted to
a specific format on load by using --img-format=L for grayscale or --img-format=RGB
for full color images. Defaults to --img-format=keep which keeps the current image
color format.
"""

PIL.Image.init() # type: ignore

if dest == '':
ctx.fail('--dest output filename or directory must not be an empty string')

num_files, input_iter = open_dataset(source, max_images=max_images)
num_files, input_iter = open_dataset(source, max_images=max_images, img_format=img_format)
archive_root_dir, save_bytes, close_dest = open_dest(dest)

transform_image = make_transform(transform, width, height, resize_filter)
Expand Down