Skip to content

Commit

Permalink
Update make_grid
Browse files Browse the repository at this point in the history
This is the version from this comment: NVlabs#141 (comment)
  • Loading branch information
mintar committed Jul 8, 2021
1 parent 08c7a35 commit 621e984
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions src/dope/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@

import torch

irange = range


def make_grid(tensor, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0):
"""
Make a grid of images.
normalize=False, range_=None, scale_each=False, pad_value=0):
"""Make a grid of images.
Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
or a list of images all of the same size.
Expand All @@ -24,6 +20,8 @@ def make_grid(tensor, nrow=8, padding=2,
scale_each (bool, optional): If True, scale each image in the batch of
images separately rather than the (min, max) over all images.
pad_value (float, optional): Value for the padded pixels.
Example:
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
Expand All @@ -43,23 +41,23 @@ def make_grid(tensor, nrow=8, padding=2,
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)

if normalize == True:
if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if range is not None:
assert isinstance(range, tuple), \
if range_ is not None:
assert isinstance(range_, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"

def norm_ip(img, min, max):
img.clamp_(min=min, max=max)
img.add_(-min).div_(max - min + 1e-5)

def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
def norm_range(t, range_):
if range_ is not None:
norm_ip(t, range_[0], range_[1])
else:
norm_ip(t, float(t.min()), float(t.max()))

if scale_each == True:
if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, range)
else:
Expand All @@ -75,8 +73,8 @@ def norm_range(t, range):
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_(pad_value)
k = 0
for y in irange(ymaps):
for x in irange(xmaps):
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
grid.narrow(1, y * height + padding, height - padding) \
Expand Down

0 comments on commit 621e984

Please sign in to comment.