diff --git a/src/dope/utils.py b/src/dope/utils.py index 7aa75179..a3ccf5da 100644 --- a/src/dope/utils.py +++ b/src/dope/utils.py @@ -2,14 +2,10 @@ import torch -irange = range - def make_grid(tensor, nrow=8, padding=2, - normalize=False, range=None, scale_each=False, pad_value=0): - """ - Make a grid of images. - + normalize=False, range_=None, scale_each=False, pad_value=0): + """Make a grid of images. Args: tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) or a list of images all of the same size. @@ -24,6 +20,8 @@ def make_grid(tensor, nrow=8, padding=2, scale_each (bool, optional): If True, scale each image in the batch of images separately rather than the (min, max) over all images. pad_value (float, optional): Value for the padded pixels. + Example: + See this notebook `here `_ """ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): @@ -43,23 +41,23 @@ def make_grid(tensor, nrow=8, padding=2, if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images tensor = torch.cat((tensor, tensor, tensor), 1) - if normalize == True: + if normalize is True: tensor = tensor.clone() # avoid modifying tensor in-place - if range is not None: - assert isinstance(range, tuple), \ + if range_ is not None: + assert isinstance(range_, tuple), \ "range has to be a tuple (min, max) if specified. min and max are numbers" def norm_ip(img, min, max): img.clamp_(min=min, max=max) img.add_(-min).div_(max - min + 1e-5) - def norm_range(t, range): - if range is not None: - norm_ip(t, range[0], range[1]) + def norm_range(t, range_): + if range_ is not None: + norm_ip(t, range_[0], range_[1]) else: norm_ip(t, float(t.min()), float(t.max())) - if scale_each == True: + if scale_each is True: for t in tensor: # loop over mini-batch dimension norm_range(t, range) else: @@ -75,8 +73,8 @@ def norm_range(t, range): height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_(pad_value) k = 0 - for y in irange(ymaps): - for x in irange(xmaps): + for y in range(ymaps): + for x in range(xmaps): if k >= nmaps: break grid.narrow(1, y * height + padding, height - padding) \