diff --git a/setup.py b/setup.py index a6cd932..754339c 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'x-clip', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.7.1', + version = '0.7.2', license='MIT', description = 'X-CLIP', author = 'Phil Wang', diff --git a/x_clip/visual_ssl.py b/x_clip/visual_ssl.py index 5115c30..ae9e376 100644 --- a/x_clip/visual_ssl.py +++ b/x_clip/visual_ssl.py @@ -21,12 +21,16 @@ def forward(self, x): return x return self.fn(x) -def get_default_aug(image_size, is_rgb = True): +def get_default_aug(image_size, channels = 3): + is_rgb = channels == 3 + is_greyscale = channels == 1 + rgb_or_greyscale = is_rgb or is_greyscale + return torch.nn.Sequential( RandomApply( T.ColorJitter(0.8, 0.8, 0.8, 0.2), p = 0.3 - ), + ) if rgb_or_greyscale else nn.Identity(), T.RandomGrayscale(p = 0.2) if is_rgb else nn.Identity(), T.RandomHorizontalFlip(), RandomApply( @@ -217,7 +221,7 @@ def __init__( # default SimCLR augmentation - self.augment1 = default(augment_fn, get_default_aug(image_size, is_rgb = channels == 3)) + self.augment1 = default(augment_fn, get_default_aug(image_size, channels)) self.augment2 = default(augment_fn2, self.augment1) self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) @@ -272,7 +276,7 @@ def __init__( ): super().__init__() self.net = NetWrapper(net, project_dim, layer = hidden_layer) - self.augment = default(augment_fn, get_default_aug(image_size, is_rgb = channels == 3)) + self.augment = default(augment_fn, get_default_aug(image_size, channels)) self.augment_both = augment_both self.temperature = temperature