diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index 55b720f40..6ee4b8dd9 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -1,9 +1,11 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .factory import list_models, create_model, create_model_and_transforms, add_model_config +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained +from .factory import list_models, add_model_config, get_model_config, load_checkpoint from .loss import ClipLoss -from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype from .openai import load_openai_model, list_openai_models -from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained from .tokenizer import SimpleTokenizer, tokenize from .transform import image_transform diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 029616c1e..bf9c9588b 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -5,14 +5,15 @@ import re from copy import deepcopy from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .model import CLIP, convert_weights_to_fp16, resize_pos_embed +from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + resize_pos_embed, get_cast_dtype from .openai import load_openai_model -from .pretrained import get_pretrained_cfg, download_pretrained +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model from .transform import image_transform @@ -48,6 +49,26 @@ def _rescan_model_configs(): _rescan_model_configs() # initial populate of model config registry +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + def load_state_dict(checkpoint_path: str, map_location='cpu'): checkpoint = torch.load(checkpoint_path, map_location=map_location) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: @@ -61,6 +82,9 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): def load_checkpoint(model, checkpoint_path, strict=True): state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) resize_pos_embed(state_dict, model) incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys @@ -68,26 +92,32 @@ def load_checkpoint(model, checkpoint_path, strict=True): def create_model( model_name: str, - pretrained: str = '', + pretrained: Optional[str] = None, precision: str = 'fp32', - device: torch.device = torch.device('cpu'), + device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, + force_custom_text: bool = False, pretrained_image: bool = False, cache_dir: Optional[str] = None, ): model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + if isinstance(device, str): + device = torch.device(device) - if pretrained.lower() == 'openai': + if pretrained and pretrained.lower() == 'openai': logging.info(f'Loading pretrained {model_name} from OpenAI.') - model = load_openai_model(model_name, device=device, jit=jit, cache_dir=cache_dir) - # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 - if precision == "amp" or precision == "fp32": - model = model.float() + model = load_openai_model( + model_name, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + ) else: - if model_name in _MODEL_CONFIGS: - logging.info(f'Loading {model_name} model config.') - model_cfg = deepcopy(_MODEL_CONFIGS[model_name]) + model_cfg = get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') else: logging.error(f'Model config for {model_name} not found; available models {list_models()}.') raise RuntimeError(f'Model config for {model_name} not found.') @@ -103,7 +133,13 @@ def create_model( else: assert False, 'pretrained image towers currently only supported for timm models' - model = CLIP(**model_cfg) + cast_dtype = get_cast_dtype(precision) + custom_text = model_cfg.pop('custom_text', False) or force_custom_text + + if custom_text: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) pretrained_cfg = {} if pretrained: @@ -118,13 +154,15 @@ def create_model( logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') load_checkpoint(model, checkpoint_path) else: - logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.') - raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.') + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) model.to(device=device) - if precision == "fp16": - assert device.type != 'cpu' - convert_weights_to_fp16(model) + if precision in ("fp16", "bf16"): + convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) # set image / mean metadata from pretrained_cfg if available, or use default model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN @@ -138,38 +176,86 @@ def create_model( def create_model_and_transforms( model_name: str, - pretrained: str = '', + pretrained: Optional[str] = None, precision: str = 'fp32', - device: torch.device = torch.device('cpu'), + device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, + force_custom_text: bool = False, pretrained_image: bool = False, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, cache_dir: Optional[str] = None, ): model = create_model( - model_name, pretrained, precision, device, jit, + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, pretrained_image=pretrained_image, - cache_dir=cache_dir) + cache_dir=cache_dir, + ) image_mean = image_mean or getattr(model.visual, 'image_mean', None) image_std = image_std or getattr(model.visual, 'image_std', None) - preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std) - preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std + ) return model, preprocess_train, preprocess_val -def list_models(): - """ enumerate available model architectures based on config files """ - return list(_MODEL_CONFIGS.keys()) +def create_model_from_pretrained( + model_name: str, + pretrained: str, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, +): + if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained): + raise RuntimeError( + f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.' + f' Use open_clip.list_pretrained() to find one.') + + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + cache_dir=cache_dir, + ) + if not return_transform: + return model -def add_model_config(path): - """ add model config path or file and update registry """ - if not isinstance(path, Path): - path = Path(path) - _MODEL_CONFIG_PATHS.append(path) - _rescan_model_configs() + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std + ) + + return model, preprocess diff --git a/src/open_clip/model.py b/src/open_clip/model.py index ee74372ba..d5eb34461 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -2,11 +2,10 @@ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ -from collections import OrderedDict from dataclasses import dataclass import logging import math -from typing import Tuple, Union, Callable, Optional +from typing import Optional, Tuple, Union import numpy as np import torch @@ -14,402 +13,10 @@ from torch import nn from torch.utils.checkpoint import checkpoint +from .modified_resnet import ModifiedResNet from .timm_model import TimmModel -from .utils import freeze_batch_norm_2d, to_2tuple - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.act1 = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.act2 = nn.ReLU(inplace=True) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.act3 = nn.ReLU(inplace=True) - - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential(OrderedDict([ - ("-1", nn.AvgPool2d(stride)), - ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), - ("1", nn.BatchNorm2d(planes * self.expansion)) - ])) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.act1(self.bn1(self.conv1(x))) - out = self.act2(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.act3(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, key=x, value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False - ) - - return x[0] - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, image_size=224, width=64): - super().__init__() - self.output_dim = output_dim - self.image_size = image_size - - # the 3-layer stem - self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(width // 2) - self.act1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(width // 2) - self.act2 = nn.ReLU(inplace=True) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.act3 = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(2) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) - - self.init_parameters() - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def init_parameters(self): - if self.attnpool is not None: - std = self.attnpool.c_proj.in_features ** -0.5 - nn.init.normal_(self.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert unlocked_groups == 0, 'partial locking not currently supported for this model' - for param in self.parameters(): - param.requires_grad = False - if freeze_bn_stats: - freeze_batch_norm_2d(self) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - # FIXME support for non-transformer - pass - - def stem(self, x): - x = self.act1(self.bn1(self.conv1(x))) - x = self.act2(self.bn2(self.conv2(x))) - x = self.act3(self.bn3(self.conv3(x))) - x = self.avgpool(x) - return x - - def forward(self, x): - x = self.stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - - return x - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - return x.to(orig_type) - - -class QuickGELU(nn.Module): - # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class Attention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - qkv_bias=True, - scaled_cosine=False, - scale_heads=False, - logit_scale_max=math.log(1. / 0.01), - attn_drop=0., - proj_drop=0. - ): - super().__init__() - self.scaled_cosine = scaled_cosine - self.scale_heads = scale_heads - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.logit_scale_max = logit_scale_max - - # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original - self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) - if qkv_bias: - self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) - else: - self.in_proj_bias = None - - if self.scaled_cosine: - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) - else: - self.logit_scale = None - self.attn_drop = nn.Dropout(attn_drop) - if self.scale_heads: - self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) - else: - self.head_scale = None - self.out_proj = nn.Linear(dim, dim) - self.out_drop = nn.Dropout(proj_drop) - - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): - L, N, C = x.shape - q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) - q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - - if self.logit_scale is not None: - attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) - logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() - attn = attn.view(N, self.num_heads, L, L) * logit_scale - attn = attn.view(-1, L, L) - else: - q = q * self.scale - attn = torch.bmm(q, k.transpose(-1, -2)) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, float("-inf")) - attn_mask = new_attn_mask - attn += attn_mask - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = torch.bmm(attn, v) - if self.head_scale is not None: - x = x.view(N, self.num_heads, L, C) * self.head_scale - x = x.view(-1, L, C) - x = x.transpose(0, 1).reshape(L, N, C) - x = self.out_proj(x) - x = self.out_drop(x) - return x - - -class ResidualAttentionBlock(nn.Module): - def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - scale_cosine_attn: bool = False, - scale_heads: bool = False, - scale_attn: bool = False, - scale_fc: bool = False, - ): - super().__init__() - - self.ln_1 = LayerNorm(d_model) - # FIXME torchscript issues need to be resolved for custom attention - # if scale_cosine_attn or scale_heads: - # self.attn = Attention( - # d_model, n_head, - # scaled_cosine=scale_cosine_attn, - # scale_heads=scale_heads, - # ) - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity() - - self.ln_2 = LayerNorm(d_model) - mlp_width = int(d_model * mlp_ratio) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, mlp_width)), - ('ln', LayerNorm(mlp_width) if scale_fc else nn.Identity()), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)) - ])) - - def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] - # FIXME torchscript issues need resolving for custom attention option to work - # if self.use_torch_attn: - # return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] - # else: - # return self.attn(x, attn_mask=attn_mask) - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ln_attn(self.attention(self.ln_1(x), attn_mask=attn_mask)) - x = x + self.mlp(self.ln_2(x)) - return x - - -class Transformer(nn.Module): - def __init__(self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU): - super().__init__() - self.width = width - self.layers = layers - self.grad_checkpointing = False - - self.resblocks = nn.ModuleList([ - ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer) - for _ in range(layers) - ]) - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - for r in self.resblocks: - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - return x - - -class VisualTransformer(nn.Module): - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - output_dim: int, - act_layer: Callable = nn.GELU - ): - super().__init__() - self.image_size = to_2tuple(image_size) - self.patch_size = to_2tuple(patch_size) - self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) - self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) - - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) - self.ln_pre = LayerNorm(width) - - self.transformer = Transformer(width, layers, heads, mlp_ratio, act_layer=act_layer) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert unlocked_groups == 0, 'partial locking not currently supported for this model' - for param in self.parameters(): - param.requires_grad = False - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.transformer.grad_checkpointing = enable - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), - x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - x = x @ self.proj - - return x +from .transformer import LayerNormFp32, QuickGELU, Attention, VisionTransformer, TextTransformer +from .utils import to_2tuple @dataclass @@ -420,10 +27,12 @@ class CLIPVisionCfg: mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value timm_model_name: str = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection @dataclass @@ -433,6 +42,96 @@ class CLIPTextCfg: width: int = 512 heads: int = 8 layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + embed_dim=embed_dim, + image_size=vision_cfg.image_size + ) + act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else nn.LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else nn.LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return text class CLIP(nn.Module): @@ -442,97 +141,21 @@ def __init__( vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, ): super().__init__() - if isinstance(vision_cfg, dict): - vision_cfg = CLIPVisionCfg(**vision_cfg) - if isinstance(text_cfg, dict): - text_cfg = CLIPTextCfg(**text_cfg) - - self.context_length = text_cfg.context_length - - # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more - # memory efficient in recent PyTorch releases (>= 1.10). - # NOTE: timm models always use native GELU regardless of quick_gelu flag. - act_layer = QuickGELU if quick_gelu else nn.GELU - - if vision_cfg.timm_model_name: - self.visual = TimmModel( - vision_cfg.timm_model_name, - pretrained=vision_cfg.timm_model_pretrained, - pool=vision_cfg.timm_pool, - proj=vision_cfg.timm_proj, - embed_dim=embed_dim, - image_size=vision_cfg.image_size - ) - act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models - elif isinstance(vision_cfg.layers, (tuple, list)): - vision_heads = vision_cfg.width * 32 // vision_cfg.head_width - self.visual = ModifiedResNet( - layers=vision_cfg.layers, - output_dim=embed_dim, - heads=vision_heads, - image_size=vision_cfg.image_size, - width=vision_cfg.width - ) - else: - vision_heads = vision_cfg.width // vision_cfg.head_width - self.visual = VisualTransformer( - image_size=vision_cfg.image_size, - patch_size=vision_cfg.patch_size, - width=vision_cfg.width, - layers=vision_cfg.layers, - heads=vision_heads, - mlp_ratio=vision_cfg.mlp_ratio, - output_dim=embed_dim, - act_layer=act_layer, - ) - - self.transformer = Transformer( - width=text_cfg.width, - layers=text_cfg.layers, - heads=text_cfg.heads, - act_layer=act_layer, - ) + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) - self.vocab_size = text_cfg.vocab_size - self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, text_cfg.width)) - self.ln_final = LayerNorm(text_cfg.width) + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) - self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim)) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) - - self.init_parameters() - - def init_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) - - if hasattr(self.visual, 'init_parameters'): - self.visual.init_parameters() - - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 @@ -543,63 +166,118 @@ def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable - def encode_image(self, image): - return self.visual(image) + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() - def encode_text(self, text): - x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - x = x + self.positional_embedding + x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=self.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) - - # x.shape = [batch_size, n_ctx, transformer.width] + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x + return F.normalize(x, dim=-1) if normalize else x def forward(self, image, text): - if image is None: - return self.encode_text(text) - elif text is None: - return self.encode_image(image) - image_features = self.encode_image(image) - image_features = F.normalize(image_features, dim=-1) + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - text_features = self.encode_text(text) - text_features = F.normalize(text_features, dim=-1) + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) return image_features, text_features, self.logit_scale.exp() -def convert_weights_to_fp16(model: nn.Module): - """Convert applicable model parameters to fp16""" +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" - def _convert_weights_to_fp16(l): + def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() + l.weight.data = l.weight.data.to(dtype) if l.bias is not None: - l.bias.data = l.bias.data.half() + l.bias.data = l.bias.data.to(dtype) if isinstance(l, (nn.MultiheadAttention, Attention)): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: - tensor.data = tensor.data.half() + tensor.data = tensor.data.to(dtype) for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: - attr.data = attr.data.half() - - model.apply(_convert_weights_to_fp16) - - -def build_model_from_openai_state_dict(state_dict: dict): + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): vit = "visual.proj" in state_dict if vit: @@ -643,13 +321,14 @@ def build_model_from_openai_state_dict(state_dict: dict): embed_dim, vision_cfg=vision_cfg, text_cfg=text_cfg, - quick_gelu=True, # OpenAI models were trained with QuickGELU + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) - convert_weights_to_fp16(model) + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() diff --git a/src/open_clip/model_configs/ViT-G-14.json b/src/open_clip/model_configs/ViT-G-14.json new file mode 100644 index 000000000..2cfba479a --- /dev/null +++ b/src/open_clip/model_configs/ViT-G-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-M-16-alt.json b/src/open_clip/model_configs/ViT-M-16-alt.json new file mode 100644 index 000000000..1a317aad8 --- /dev/null +++ b/src/open_clip/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-M-16.json b/src/open_clip/model_configs/ViT-M-16.json new file mode 100644 index 000000000..f2f3225a4 --- /dev/null +++ b/src/open_clip/model_configs/ViT-M-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-M-32-alt.json b/src/open_clip/model_configs/ViT-M-32-alt.json new file mode 100644 index 000000000..fd222aeac --- /dev/null +++ b/src/open_clip/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-M-32.json b/src/open_clip/model_configs/ViT-M-32.json new file mode 100644 index 000000000..4f7186428 --- /dev/null +++ b/src/open_clip/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-S-16-alt.json b/src/open_clip/model_configs/ViT-S-16-alt.json new file mode 100644 index 000000000..a8c056555 --- /dev/null +++ b/src/open_clip/model_configs/ViT-S-16-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-S-16.json b/src/open_clip/model_configs/ViT-S-16.json new file mode 100644 index 000000000..1d8504e59 --- /dev/null +++ b/src/open_clip/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-S-32-alt.json b/src/open_clip/model_configs/ViT-S-32-alt.json new file mode 100644 index 000000000..e1dfdec98 --- /dev/null +++ b/src/open_clip/model_configs/ViT-S-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-S-32.json b/src/open_clip/model_configs/ViT-S-32.json new file mode 100644 index 000000000..9b8b4191b --- /dev/null +++ b/src/open_clip/model_configs/ViT-S-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-e-14.json b/src/open_clip/model_configs/ViT-e-14.json new file mode 100644 index 000000000..91a0fe14d --- /dev/null +++ b/src/open_clip/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/timm-vit_base_patch16_224.json b/src/open_clip/model_configs/timm-convnext_base.json similarity index 86% rename from src/open_clip/model_configs/timm-vit_base_patch16_224.json rename to src/open_clip/model_configs/timm-convnext_base.json index 39341ce1b..4de9aa8a3 100644 --- a/src/open_clip/model_configs/timm-vit_base_patch16_224.json +++ b/src/open_clip/model_configs/timm-convnext_base.json @@ -1,7 +1,7 @@ { "embed_dim": 512, "vision_cfg": { - "timm_model_name": "vit_base_patch16_224", + "timm_model_name": "convnext_base", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", diff --git a/src/open_clip/model_configs/timm-resnet50d.json b/src/open_clip/model_configs/timm-convnext_large.json similarity index 54% rename from src/open_clip/model_configs/timm-resnet50d.json rename to src/open_clip/model_configs/timm-convnext_large.json index 7bb0957cd..72341b9a7 100644 --- a/src/open_clip/model_configs/timm-resnet50d.json +++ b/src/open_clip/model_configs/timm-convnext_large.json @@ -1,17 +1,17 @@ { - "embed_dim": 1024, + "embed_dim": 768, "vision_cfg": { - "timm_model_name": "resnet50d", + "timm_model_name": "convnext_large", "timm_model_pretrained": false, - "timm_pool": "abs_attn", - "timm_proj": "", + "timm_pool": "", + "timm_proj": "linear", "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, - "width": 512, - "heads": 8, + "width": 768, + "heads": 12, "layers": 12 } -} +} \ No newline at end of file diff --git a/src/open_clip/model_configs/timm-resnetblur50.json b/src/open_clip/model_configs/timm-convnext_xlarge.json similarity index 54% rename from src/open_clip/model_configs/timm-resnetblur50.json rename to src/open_clip/model_configs/timm-convnext_xlarge.json index 05d0b209a..5186dca08 100644 --- a/src/open_clip/model_configs/timm-resnetblur50.json +++ b/src/open_clip/model_configs/timm-convnext_xlarge.json @@ -1,17 +1,17 @@ { "embed_dim": 1024, "vision_cfg": { - "timm_model_name": "resnetblur50", + "timm_model_name": "convnext_xlarge", "timm_model_pretrained": false, - "timm_pool": "abs_attn", - "timm_proj": "", + "timm_pool": "", + "timm_proj": "linear", "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 + "width": 1024, + "heads": 16, + "layers": 16 } -} +} \ No newline at end of file diff --git a/src/open_clip/model_configs/timm-vit_small_patch16_224.json b/src/open_clip/model_configs/timm-vit_medium_patch16_gap_256.json similarity index 84% rename from src/open_clip/model_configs/timm-vit_small_patch16_224.json rename to src/open_clip/model_configs/timm-vit_medium_patch16_gap_256.json index 45863ab38..df511dad2 100644 --- a/src/open_clip/model_configs/timm-vit_small_patch16_224.json +++ b/src/open_clip/model_configs/timm-vit_medium_patch16_gap_256.json @@ -1,7 +1,7 @@ { "embed_dim": 512, "vision_cfg": { - "timm_model_name": "vit_small_patch16_224", + "timm_model_name": "vit_medium_patch16_gap_256", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", diff --git a/src/open_clip/model_configs/timm-vit_base_patch32_224.json b/src/open_clip/model_configs/timm-vit_relpos_medium_patch16_cls_224.json similarity index 83% rename from src/open_clip/model_configs/timm-vit_base_patch32_224.json rename to src/open_clip/model_configs/timm-vit_relpos_medium_patch16_cls_224.json index 39b845271..ed217b202 100644 --- a/src/open_clip/model_configs/timm-vit_base_patch32_224.json +++ b/src/open_clip/model_configs/timm-vit_relpos_medium_patch16_cls_224.json @@ -1,7 +1,7 @@ { "embed_dim": 512, "vision_cfg": { - "timm_model_name": "vit_base_patch32_224", + "timm_model_name": "vit_relpos_medium_patch16_cls_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", diff --git a/src/open_clip/modified_resnet.py b/src/open_clip/modified_resnet.py new file mode 100644 index 000000000..be07f2197 --- /dev/null +++ b/src/open_clip/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from open_clip.utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/src/open_clip/openai.py b/src/open_clip/openai.py index 156a4c0c5..cc4e13e87 100644 --- a/src/open_clip/openai.py +++ b/src/open_clip/openai.py @@ -5,26 +5,27 @@ import os import warnings -from typing import Union, List +from typing import List, Optional, Union import torch -from .model import build_model_from_openai_state_dict -from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained_from_url +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url __all__ = ["list_openai_models", "load_openai_model"] def list_openai_models() -> List[str]: """Returns the names of available CLIP models""" - return list_pretrained_tag_models('openai') + return list_pretrained_models_by_tag('openai') def load_openai_model( name: str, - device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", - jit=True, - cache_dir=None, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, ): """Load a CLIP model @@ -32,6 +33,8 @@ def load_openai_model( ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. device : Union[str, torch.device] The device to put the loaded model jit : bool @@ -46,6 +49,11 @@ def load_openai_model( preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + if get_pretrained_url(name, 'openai'): model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) elif os.path.isfile(name): @@ -65,14 +73,21 @@ def load_openai_model( state_dict = torch.load(model_path, map_location="cpu") if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) try: - model = build_model_from_openai_state_dict(state_dict or model.state_dict()).to(device) + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) except KeyError: sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} - model = build_model_from_openai_state_dict(sd).to(device) + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) - if str(device) == "cpu": + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith('amp') or precision == 'fp32': model.float() + elif precision == 'bf16': + convert_weights_to_lp(model, dtype=torch.bfloat16) + return model # patch the device names @@ -97,8 +112,8 @@ def patch_device(module): patch_device(model.encode_image) patch_device(model.encode_text) - # patch dtype to float32 on CPU - if str(device) == "cpu": + # patch dtype to float32 (typically for CPU) + if precision == 'fp32': float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node() diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index d69d1bf87..4f1e03581 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -168,7 +168,7 @@ def list_pretrained(as_str: bool = False): return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] -def list_pretrained_tag_models(tag: str): +def list_pretrained_models_by_tag(tag: str): """ return all models having the specified pretrain tag """ models = [] for k in _PRETRAINED.keys(): @@ -177,7 +177,7 @@ def list_pretrained_tag_models(tag: str): return models -def list_pretrained_model_tags(model: str): +def list_pretrained_tags_by_model(model: str): """ return all pretrain tags for the specified model architecture """ tags = [] if model in _PRETRAINED: diff --git a/src/open_clip/timm_model.py b/src/open_clip/timm_model.py index 071dd148c..637bce42c 100644 --- a/src/open_clip/timm_model.py +++ b/src/open_clip/timm_model.py @@ -2,8 +2,10 @@ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. """ +import logging from collections import OrderedDict +import torch import torch.nn as nn try: @@ -29,6 +31,7 @@ def __init__( image_size=224, pool='avg', proj='linear', + proj_bias=False, drop=0., pretrained=False): super().__init__() @@ -62,9 +65,9 @@ def __init__( # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used if proj == 'linear': head_layers['drop'] = nn.Dropout(drop) - head_layers['proj'] = nn.Linear(prev_chs, embed_dim) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) elif proj == 'mlp': - head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) self.head = nn.Sequential(head_layers) @@ -100,6 +103,13 @@ def lock(self, unlocked_groups=0, freeze_bn_stats=False): gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} freeze_batch_norm_2d(self.trunk, gmodules) + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + def forward(self, x): x = self.trunk(x) x = self.head(x) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py new file mode 100644 index 000000000..0471a3e93 --- /dev/null +++ b/src/open_clip/transformer.py @@ -0,0 +1,395 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class TextTransformer(nn.Module): + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x diff --git a/src/open_clip/version.py b/src/open_clip/version.py index 668c3446e..a33997dd1 100644 --- a/src/open_clip/version.py +++ b/src/open_clip/version.py @@ -1 +1 @@ -__version__ = '2.0.2' +__version__ = '2.1.0' diff --git a/src/training/main.py b/src/training/main.py index 3da59020f..2f369082b 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -101,7 +101,6 @@ def main(): if args.copy_codebase: copy_codebase(args) - assert args.precision in ['amp', 'amp_bfloat16', 'fp16', 'fp32'] if args.precision == 'fp16': logging.warning( 'It is recommended to use AMP mixed-precision instead of FP16. ' @@ -126,6 +125,7 @@ def main(): device=device, jit=args.torchscript, force_quick_gelu=args.force_quick_gelu, + force_custom_text=args.force_custom_text, pretrained_image=args.pretrained_image, image_mean=args.image_mean, image_std=args.image_std, diff --git a/src/training/params.py b/src/training/params.py index 41dfb08c5..a5b2a94dc 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -147,7 +147,7 @@ def parse_args(): ) parser.add_argument( "--precision", - choices=["amp", "amp_bfloat16", "fp16", "fp32"], + choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], default="amp", help="Floating point precision." ) @@ -217,6 +217,12 @@ def parse_args(): action='store_true', help="Force use of QuickGELU activation for non-OpenAI transformer models.", ) + parser.add_argument( + "--force-custom-text", + default=False, + action='store_true', + help="Force use of CustomTextCLIP model (separate text-tower).", + ) parser.add_argument( "--torchscript", default=False, @@ -285,7 +291,7 @@ def parse_args(): "--seed", type=int, default=0, help="Default random seed." ) parser.add_argument( - "--norm_gradient_clip", type=float, default=None, help="Gradient clip." + "--grad-clip-norm", type=float, default=None, help="Gradient clip." ) args = parser.parse_args() diff --git a/src/training/precision.py b/src/training/precision.py index 3801b1078..a63b92256 100644 --- a/src/training/precision.py +++ b/src/training/precision.py @@ -1,11 +1,12 @@ import torch from contextlib import suppress -# amp_bfloat16 is more stable than amp float16 for clip training + def get_autocast(precision): if precision == 'amp': return torch.cuda.amp.autocast - elif precision == 'amp_bfloat16': + elif precision == 'amp_bfloat16' or precision == 'amp_bf16': + # amp_bfloat16 is more stable than amp float16 for clip training return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) else: - return suppress \ No newline at end of file + return suppress diff --git a/src/training/profile.py b/src/training/profile.py new file mode 100644 index 000000000..392d0cbd5 --- /dev/null +++ b/src/training/profile.py @@ -0,0 +1,155 @@ +import argparse + +import torch +import open_clip +import pandas as pd +from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis + + +parser = argparse.ArgumentParser(description='OpenCLIP Profiler') + +# benchmark specific args +parser.add_argument('--model', metavar='NAME', default='', + help='model(s) to profile') +parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', + help='Output csv file for results') + + +def profile_fvcore( + model, + image_input_size=(3, 224, 224), + text_input_size=(77,), + batch_size=1, + detailed=False, + force_cpu=False +): + if force_cpu: + model = model.to('cpu') + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) + example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) + fca = FlopCountAnalysis(model, (example_image_input, example_text_input)) + aca = ActivationCountAnalysis(model, (example_image_input, example_text_input)) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total(), aca.total() + + +def profile_fvcore_text( + model, + text_input_size=(77,), + batch_size=1, + detailed=False, + force_cpu=False +): + if force_cpu: + model = model.to('cpu') + device = next(model.parameters()).device + example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) + fca = FlopCountAnalysis(model, example_input) + aca = ActivationCountAnalysis(model, example_input) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total(), aca.total() + + +def profile_fvcore_image( + model, + image_input_size=(3, 224, 224), + batch_size=1, + detailed=False, + force_cpu=False +): + if force_cpu: + model = model.to('cpu') + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) + fca = FlopCountAnalysis(model, example_input) + aca = ActivationCountAnalysis(model, example_input) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total(), aca.total() + + +def count_params(model): + return sum([m.numel() for m in model.parameters()]) + + +def profile_model(model_name): + model = open_clip.create_model(model_name, force_custom_text=True) + model.eval() + if torch.cuda.is_available(): + model = model.cuda() + + if isinstance(model.visual.image_size, (tuple, list)): + image_input_size = (3,) + tuple(model.visual.image_size[-2:]) + else: + image_input_size = (3, model.visual.image_size, model.visual.image_size) + text_input_size = (77,) + + results = {} + results['model'] = model_name + results['image_size'] = image_input_size[1] + + model_cfg = open_clip.get_model_config(model_name) + if model_cfg: + vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) + text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) + results['image_width'] = int(vision_cfg.width) + results['text_width'] = int(text_cfg.width) + results['embed_dim'] = int(model_cfg['embed_dim']) + else: + results['image_width'] = 0 + results['text_width'] = 0 + results['embed_dim'] = 0 + + retries = 2 + while retries: + retries -= 1 + try: + macs, acts = profile_fvcore( + model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries) + + image_macs, image_acts = profile_fvcore_image( + model.visual, image_input_size=image_input_size, force_cpu=not retries) + + text_macs, text_acts = profile_fvcore_text( + model.text, text_input_size=text_input_size, force_cpu=not retries) + + results['gmacs'] = round(macs / 1e9, 2) + results['macts'] = round(acts / 1e6, 2) + results['mparams'] = round(count_params(model) / 1e6, 2) + results['image_gmacs'] = round(image_macs / 1e9, 2) + results['image_macts'] = round(image_acts / 1e6, 2) + results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) + results['text_gmacs'] = round(text_macs / 1e9, 2) + results['text_macts'] = round(text_acts / 1e6, 2) + results['text_mparams'] = round(count_params(model.text) / 1e6, 2) + except RuntimeError as e: + pass + return results + + +def main(): + args = parser.parse_args() + + # FIXME accept a text file name to allow lists of models in txt/csv + parsed_model = args.model.split(',') + + results = [] + for m in parsed_model: + row = profile_model(m) + results.append(row) + + df = pd.DataFrame(results, columns=results[0].keys()) + df = df.sort_values('gmacs') + print(df) + if args.results_file: + df.to_csv(args.results_file, index=False) + + +if __name__ == '__main__': + main() diff --git a/src/training/train.py b/src/training/train.py index 028fa31d6..b8924427c 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -13,7 +13,7 @@ except ImportError: wandb = None -from open_clip import ClipLoss +from open_clip import ClipLoss, get_cast_dtype from .distributed import is_master from .zero_shot import zero_shot_eval from .precision import get_autocast @@ -47,6 +47,7 @@ def unwrap_model(model): def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None): device = torch.device(args.device) autocast = get_autocast(args.precision) + cast_dtype = get_cast_dtype(args.precision) model.train() loss = ClipLoss( @@ -71,7 +72,7 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w scheduler(step) images, texts = batch - images = images.to(device=device, non_blocking=True) + images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) data_time_m.update(time.time() - end) @@ -86,20 +87,20 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w if args.horovod: optimizer.synchronize() scaler.unscale_(optimizer) - if args.norm_gradient_clip is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) with optimizer.skip_synchronize(): scaler.step(optimizer) else: - if args.norm_gradient_clip is not None: + if args.grad_clip_norm is not None: scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) scaler.step(optimizer) scaler.update() else: total_loss.backward() - if args.norm_gradient_clip is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) optimizer.step() # Note: we clamp to 4.6052 = ln(100), as in the original paper. @@ -161,8 +162,8 @@ def evaluate(model, data, epoch, args, tb_writer=None): metrics.update(zero_shot_metrics) autocast = get_autocast(args.precision) + cast_dtype = get_cast_dtype(args.precision) - if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): dataloader = data['val'].dataloader num_samples = 0 @@ -175,7 +176,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): with torch.no_grad(): for i, batch in enumerate(dataloader): images, texts = batch - images = images.to(device=device, non_blocking=True) + images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) with autocast(): diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 4bfe7b861..04fe3e3a3 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from tqdm import tqdm -from open_clip import tokenize +from open_clip import tokenize, get_cast_dtype from .precision import get_autocast from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template @@ -34,10 +34,13 @@ def accuracy(output, target, topk=(1,)): def run(model, classifier, dataloader, args): autocast = get_autocast(args.precision) + cast_dtype = get_cast_dtype(args.precision) with torch.no_grad(): top1, top5, n = 0., 0., 0. for images, target in tqdm(dataloader, unit_scale=args.batch_size): images = images.to(args.device) + if cast_dtype is not None: + images = images.to(dtype=cast_dtype) target = target.to(args.device) with autocast():