Skip to content

Commit

Permalink
make sure augmentations in visual ssl can work with non rgb or greysc…
Browse files Browse the repository at this point in the history
…ale images
  • Loading branch information
lucidrains committed Aug 3, 2022
1 parent 5cd3038 commit d9c2c52
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'x-clip',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.7.1',
version = '0.7.2',
license='MIT',
description = 'X-CLIP',
author = 'Phil Wang',
Expand Down
12 changes: 8 additions & 4 deletions x_clip/visual_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ def forward(self, x):
return x
return self.fn(x)

def get_default_aug(image_size, is_rgb = True):
def get_default_aug(image_size, channels = 3):
is_rgb = channels == 3
is_greyscale = channels == 1
rgb_or_greyscale = is_rgb or is_greyscale

return torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
) if rgb_or_greyscale else nn.Identity(),
T.RandomGrayscale(p = 0.2) if is_rgb else nn.Identity(),
T.RandomHorizontalFlip(),
RandomApply(
Expand Down Expand Up @@ -217,7 +221,7 @@ def __init__(

# default SimCLR augmentation

self.augment1 = default(augment_fn, get_default_aug(image_size, is_rgb = channels == 3))
self.augment1 = default(augment_fn, get_default_aug(image_size, channels))
self.augment2 = default(augment_fn2, self.augment1)

self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
Expand Down Expand Up @@ -272,7 +276,7 @@ def __init__(
):
super().__init__()
self.net = NetWrapper(net, project_dim, layer = hidden_layer)
self.augment = default(augment_fn, get_default_aug(image_size, is_rgb = channels == 3))
self.augment = default(augment_fn, get_default_aug(image_size, channels))
self.augment_both = augment_both
self.temperature = temperature

Expand Down

0 comments on commit d9c2c52

Please sign in to comment.