Skip to content

Commit

Permalink
Add EVA models (via timm backbone), torch.compile support, more (#500)
Browse files Browse the repository at this point in the history
* Add EVA models (via timm backbone), torch.compile support, pure bf16/fp16 mode, safetensors push support

* Fix optional type refinement for torchscript

* Back torchcompile changes out of factory, needs to be closer to use for various reasons

* Fix output_dict + jit regression, remove native OpenAI jit load as it's not working reliably in PyTorch 2.0, always extract state-dict, load model, re-jit (if enabled)
  • Loading branch information
rwightman authored May 12, 2023
1 parent 6ee59e1 commit 43cf086
Show file tree
Hide file tree
Showing 19 changed files with 354 additions and 153 deletions.
2 changes: 1 addition & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
Expand Down
48 changes: 32 additions & 16 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
from .tokenizer import HFTokenizer, tokenize

Expand Down Expand Up @@ -145,13 +146,8 @@ def create_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)

# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
else:
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
Expand All @@ -172,13 +168,15 @@ def create_model(
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size

is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
if is_timm_model:
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'

# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
Expand All @@ -193,6 +191,29 @@ def create_model(
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
model.to(device=device, dtype=dtype)
from .transformer import LayerNormFp32
def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
m.bias.data = m.bias.data.to(torch.float32)
model.apply(_convert_ln)
else:
model.to(device=device)
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model.to(device=device, dtype=dtype)
else:
model.to(device=device)

pretrained_loaded = False
if pretrained:
checkpoint_path = ''
Expand Down Expand Up @@ -222,20 +243,15 @@ def create_model(
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')

model.to(device=device)
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)

# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD

# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True

if jit:
model = torch.jit.script(model)
if jit:
model = torch.jit.script(model)

return model

Expand Down
59 changes: 42 additions & 17 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,23 @@ class CLIPVisionCfg:
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224

ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
n_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
n_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
output_tokens: bool = False

timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
output_tokens: bool = False


@dataclass
Expand Down Expand Up @@ -72,6 +74,15 @@ def get_cast_dtype(precision: str):
return cast_dtype


def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype


def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
Expand All @@ -95,10 +106,10 @@ def _build_vision_tower(
proj_bias=vision_cfg.timm_proj_bias,
drop=vision_cfg.timm_drop,
drop_path=vision_cfg.timm_drop_path,
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
)
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
Expand Down Expand Up @@ -228,9 +239,13 @@ def encode_text(self, text, normalize: bool = False):
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x

def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
return {
"image_features": image_features,
Expand Down Expand Up @@ -280,9 +295,13 @@ def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features

def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
return {
"image_features": image_features,
Expand All @@ -307,11 +326,17 @@ def _convert_weights(l):
if tensor is not None:
tensor.data = tensor.data.to(dtype)

for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.to(dtype)
if isinstance(l, (CLIP, TextTransformer)):
# convert text nn.Parameter projections
attr = getattr(l, "text_projection", None)
if attr is not None:
attr.data = attr.data.to(dtype)

if isinstance(l, VisionTransformer):
# convert vision nn.Parameter projections
attr = getattr(l, "proj", None)
if attr is not None:
attr.data = attr.data.to(dtype)

model.apply(_convert_weights)

Expand Down
18 changes: 18 additions & 0 deletions src/open_clip/model_configs/EVA01-g-14-plus.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva_giant_patch14_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
18 changes: 18 additions & 0 deletions src/open_clip/model_configs/EVA01-g-14.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva_giant_patch14_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
18 changes: 18 additions & 0 deletions src/open_clip/model_configs/EVA02-B-16.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_base_patch16_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"custom_text": true
}
18 changes: 18 additions & 0 deletions src/open_clip/model_configs/EVA02-E-14-plus.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_enormous_patch14_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32
},
"custom_text": true
}
18 changes: 18 additions & 0 deletions src/open_clip/model_configs/EVA02-E-14.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_enormous_patch14_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
18 changes: 18 additions & 0 deletions src/open_clip/model_configs/EVA02-L-14-336.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"timm_model_name": "eva02_large_patch14_clip_336",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
18 changes: 18 additions & 0 deletions src/open_clip/model_configs/EVA02-L-14.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_large_patch14_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
Loading

0 comments on commit 43cf086

Please sign in to comment.