diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index 73c73f390..fdb1199b8 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -4,7 +4,7 @@ from .factory import list_models, add_model_config, get_model_config, load_checkpoint from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ - convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 14011f934..ac8596eab 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -15,7 +15,8 @@ from .coca_model import CoCa from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .openai import load_openai_model -from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ + list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform, AugmentationCfg from .tokenizer import HFTokenizer, tokenize @@ -145,13 +146,8 @@ def create_model( model_name, precision=precision, device=device, - jit=jit, cache_dir=cache_dir, ) - - # to always output dict even if it is clip - if output_dict and hasattr(model, "output_dict"): - model.output_dict = True else: model_cfg = model_cfg or get_model_config(model_name) if model_cfg is not None: @@ -172,13 +168,15 @@ def create_model( # override model config's image size model_cfg["vision_cfg"]["image_size"] = force_image_size + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) if pretrained_image: - if 'timm_model_name' in model_cfg.get('vision_cfg', {}): + if is_timm_model: # pretrained weight loading for timm models set via vision_cfg model_cfg['vision_cfg']['timm_model_pretrained'] = True else: assert False, 'pretrained image towers currently only supported for timm models' + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes cast_dtype = get_cast_dtype(precision) is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model @@ -193,6 +191,29 @@ def create_model( else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. + model.to(device=device, dtype=dtype) + from .transformer import LayerNormFp32 + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) + else: + model.to(device=device) + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + pretrained_loaded = False if pretrained: checkpoint_path = '' @@ -222,20 +243,15 @@ def create_model( raise RuntimeError( f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') - model.to(device=device) - if precision in ("fp16", "bf16"): - convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) - # set image / mean metadata from pretrained_cfg if available, or use default model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD - # to always output dict even if it is clip - if output_dict and hasattr(model, "output_dict"): - model.output_dict = True + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True - if jit: - model = torch.jit.script(model) + if jit: + model = torch.jit.script(model) return model diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 8a628f9bc..f85b68ba2 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -28,13 +28,16 @@ class CLIPVisionCfg: mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results - input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) - attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer - n_queries: int = 256 # n_queries for attentional pooler - attn_pooler_heads: int = 8 # n heads for attentional_pooling + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + output_tokens: bool = False + timm_model_name: str = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') @@ -42,7 +45,6 @@ class CLIPVisionCfg: timm_proj_bias: bool = False # enable bias final projection timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth - output_tokens: bool = False @dataclass @@ -72,6 +74,15 @@ def get_cast_dtype(precision: str): return cast_dtype +def get_input_dtype(precision: str): + input_dtype = None + if precision in ('bf16', 'pure_bf16'): + input_dtype = torch.bfloat16 + elif precision in ('fp16', 'pure_fp16'): + input_dtype = torch.float16 + return input_dtype + + def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, @@ -95,10 +106,10 @@ def _build_vision_tower( proj_bias=vision_cfg.timm_proj_bias, drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, + patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, embed_dim=embed_dim, image_size=vision_cfg.image_size, ) - act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width visual = ModifiedResNet( @@ -228,9 +239,13 @@ def encode_text(self, text, normalize: bool = False): x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return F.normalize(x, dim=-1) if normalize else x - def forward(self, image, text): - image_features = self.encode_image(image, normalize=True) - text_features = self.encode_text(text, normalize=True) + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: return { "image_features": image_features, @@ -280,9 +295,13 @@ def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features - def forward(self, image, text): - image_features = self.encode_image(image, normalize=True) - text_features = self.encode_text(text, normalize=True) + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: return { "image_features": image_features, @@ -307,11 +326,17 @@ def _convert_weights(l): if tensor is not None: tensor.data = tensor.data.to(dtype) - for name in ["text_projection", "proj"]: - if hasattr(l, name): - attr = getattr(l, name) - if attr is not None: - attr.data = attr.data.to(dtype) + if isinstance(l, (CLIP, TextTransformer)): + # convert text nn.Parameter projections + attr = getattr(l, "text_projection", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + if isinstance(l, VisionTransformer): + # convert vision nn.Parameter projections + attr = getattr(l, "proj", None) + if attr is not None: + attr.data = attr.data.to(dtype) model.apply(_convert_weights) diff --git a/src/open_clip/model_configs/EVA01-g-14-plus.json b/src/open_clip/model_configs/EVA01-g-14-plus.json new file mode 100644 index 000000000..73f46a71e --- /dev/null +++ b/src/open_clip/model_configs/EVA01-g-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/EVA01-g-14.json b/src/open_clip/model_configs/EVA01-g-14.json new file mode 100644 index 000000000..9d0e80f29 --- /dev/null +++ b/src/open_clip/model_configs/EVA01-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/EVA02-B-16.json b/src/open_clip/model_configs/EVA02-B-16.json new file mode 100644 index 000000000..3f9235728 --- /dev/null +++ b/src/open_clip/model_configs/EVA02-B-16.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_base_patch16_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/EVA02-E-14-plus.json b/src/open_clip/model_configs/EVA02-E-14-plus.json new file mode 100644 index 000000000..e250c2a40 --- /dev/null +++ b/src/open_clip/model_configs/EVA02-E-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/EVA02-E-14.json b/src/open_clip/model_configs/EVA02-E-14.json new file mode 100644 index 000000000..4b6648e25 --- /dev/null +++ b/src/open_clip/model_configs/EVA02-E-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/EVA02-L-14-336.json b/src/open_clip/model_configs/EVA02-L-14-336.json new file mode 100644 index 000000000..2bb07f3c0 --- /dev/null +++ b/src/open_clip/model_configs/EVA02-L-14-336.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "timm_model_name": "eva02_large_patch14_clip_336", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/EVA02-L-14.json b/src/open_clip/model_configs/EVA02-L-14.json new file mode 100644 index 000000000..b4c7f377b --- /dev/null +++ b/src/open_clip/model_configs/EVA02-L-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_large_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/openai.py b/src/open_clip/openai.py index cc4e13e87..6c2c02352 100644 --- a/src/open_clip/openai.py +++ b/src/open_clip/openai.py @@ -9,6 +9,7 @@ import torch +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url @@ -24,7 +25,6 @@ def load_openai_model( name: str, precision: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, - jit: bool = True, cache_dir: Optional[str] = None, ): """Load a CLIP model @@ -37,8 +37,6 @@ def load_openai_model( Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. device : Union[str, torch.device] The device to put the loaded model - jit : bool - Whether to load the optimized JIT model (default) or more hackable non-JIT model. cache_dir : Optional[str] The directory to cache the downloaded model weights @@ -63,82 +61,30 @@ def load_openai_model( try: # loading JIT archive - model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = None except RuntimeError: # loading saved state dict - if jit: - warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") - jit = False state_dict = torch.load(model_path, map_location="cpu") - if not jit: - # Build a non-jit model from the OpenAI jitted model state dict - cast_dtype = get_cast_dtype(precision) - try: - model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) - except KeyError: - sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} - model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) - - # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use - model = model.to(device) - if precision.startswith('amp') or precision == 'fp32': - model.float() - elif precision == 'bf16': - convert_weights_to_lp(model, dtype=torch.bfloat16) - - return model - - # patch the device names - device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) - device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] - - def patch_device(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("prim::Constant"): - if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): - node.copyAttributes(device_node) - - model.apply(patch_device) - patch_device(model.encode_image) - patch_device(model.encode_text) - - # patch dtype to float32 (typically for CPU) - if precision == 'fp32': - float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) - float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] - float_node = float_input.node() - - def patch_float(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("aten::to"): - inputs = list(node.inputs()) - for i in [1, 2]: # dtype can be the second or third argument to aten::to() - if inputs[i].node()["value"] == 5: - inputs[i].node().copyAttributes(float_node) - - model.apply(patch_float) - patch_float(model.encode_image) - patch_float(model.encode_text) + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + # FIXME support pure fp16/bf16 precision modes + if precision != 'fp16': model.float() + if precision == 'bf16': + # for bf16, convert back to low-precision + convert_weights_to_lp(model, dtype=torch.bfloat16) - # ensure image_size attr available at consistent location for both jit and non-jit - model.visual.image_size = model.input_resolution.item() + # add mean / std attributes for consistency with OpenCLIP models + model.visual.image_mean = OPENAI_DATASET_MEAN + model.visual.image_std = OPENAI_DATASET_STD return model diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index a747933a6..1465a2325 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -118,12 +118,6 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), - # laion400m_32k=_pcfg( - # url="", - # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - # laion400m_64k=_pcfg( - # url="", - # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), # DataComp-L models datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), @@ -258,6 +252,34 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "convnext_xxlarge": _convnext_xxlarge, "coca_ViT-B-32": _coca_VITB32, "coca_ViT-L-14": _coca_VITL14, + "EVA01-g-14": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt + laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), + ), + "EVA01-g-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt + merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), + ), + "EVA02-B-16": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt + merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), + ), + "EVA02-L-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt + merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), + ), + "EVA02-L-14-336": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt + merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), + ), + "EVA02-E-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt + laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), + ), + "EVA02-E-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt + laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), + ) } diff --git a/src/open_clip/push_to_hf_hub.py b/src/open_clip/push_to_hf_hub.py index 188fc4ea5..464dd41aa 100644 --- a/src/open_clip/push_to_hf_hub.py +++ b/src/open_clip/push_to_hf_hub.py @@ -1,8 +1,9 @@ import argparse import json +import os from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple, Union import torch @@ -14,15 +15,26 @@ hf_hub_url, repo_type_and_id_from_hf_id, upload_folder, + list_repo_files, ) from huggingface_hub.utils import EntryNotFoundError _has_hf_hub = True except ImportError: _has_hf_hub = False +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + from .factory import create_model_from_pretrained, get_model_config, get_tokenizer from .tokenizer import HFTokenizer +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version +HF_CONFIG_NAME = 'open_clip_config.json' def save_config_for_hf( model, @@ -47,14 +59,21 @@ def save_for_hf( tokenizer: HFTokenizer, model_config: dict, save_directory: str, - weights_filename='open_clip_pytorch_model.bin', - config_filename='open_clip_config.json', + safe_serialization: Union[bool, Literal["both"]] = False, + skip_weights : bool = False, ): + config_filename = HF_CONFIG_NAME + save_directory = Path(save_directory) save_directory.mkdir(exist_ok=True, parents=True) - weights_path = save_directory / weights_filename - torch.save(model.state_dict(), weights_path) + if not skip_weights: + tensors = model.state_dict() + if safe_serialization is True or safe_serialization == "both": + assert _has_safetensors, "`pip install safetensors` to use .safetensors" + safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) + if safe_serialization is False or safe_serialization == "both": + torch.save(tensors, save_directory / HF_WEIGHTS_NAME) tokenizer.save_pretrained(save_directory) @@ -73,6 +92,7 @@ def push_to_hf_hub( private: bool = False, create_pr: bool = False, model_card: Optional[dict] = None, + safe_serialization: Union[bool, Literal["both"]] = False, ): if not isinstance(tokenizer, HFTokenizer): # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 @@ -86,7 +106,15 @@ def push_to_hf_hub( _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) repo_id = f"{repo_owner}/{repo_name}" - # Check if README file already exist in repo + # Check if repo already exists and determine what needs updating + repo_exists = False + repo_files = {} + try: + repo_files = set(list_repo_files(repo_id)) + repo_exists = True + except Exception as e: + print('Repo does not exist', e) + try: get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) has_readme = True @@ -101,6 +129,7 @@ def push_to_hf_hub( tokenizer=tokenizer, model_config=model_config, save_directory=tmpdir, + safe_serialization=safe_serialization, ) # Add readme if it does not exist @@ -125,6 +154,7 @@ def push_pretrained_to_hf_hub( model_name, pretrained: str, repo_id: str, + precision: str = 'fp32', image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, commit_message: str = 'Add model', @@ -137,6 +167,7 @@ def push_pretrained_to_hf_hub( model, preprocess_eval = create_model_from_pretrained( model_name, pretrained=pretrained, + precision=precision, image_mean=image_mean, image_std=image_std, ) @@ -157,13 +188,15 @@ def push_pretrained_to_hf_hub( private=private, create_pr=create_pr, model_card=model_card, + safe_serialization='both', ) def generate_readme(model_card: dict, model_name: str): readme_text = "---\n" - readme_text += "tags:\n- zero-shot-image-classification\n- clip\n" + readme_text += "tags:\n- clip\n" readme_text += "library_name: open_clip\n" + readme_text += "pipeline_tag: zero-shot-image-classification\n" readme_text += f"license: {model_card.get('license', 'mit')}\n" if 'details' in model_card and 'Dataset' in model_card['details']: readme_text += 'datasets:\n' @@ -220,6 +253,9 @@ def generate_readme(model_card: dict, model_name: str): "--repo-id", type=str, help="Destination HF Hub repo-id ie 'organization/model_id'.", ) + parser.add_argument( + "--precision", type=str, default='fp32', + ) parser.add_argument( '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override default image mean value of dataset') @@ -236,6 +272,7 @@ def generate_readme(model_card: dict, model_name: str): args.model, args.pretrained, args.repo_id, + precision=args.precision, image_mean=args.image_mean, # override image mean/std if trained w/ non defaults image_std=args.image_std, ) diff --git a/src/open_clip/timm_model.py b/src/open_clip/timm_model.py index dc71a693f..3d3f595d6 100644 --- a/src/open_clip/timm_model.py +++ b/src/open_clip/timm_model.py @@ -27,7 +27,6 @@ class TimmModel(nn.Module): """ timm model adapter - # FIXME this adapter is a work in progress, may change in ways that break weight compat """ def __init__( @@ -40,38 +39,59 @@ def __init__( proj_bias=False, drop=0., drop_path=None, + patch_drop=None, pretrained=False, ): super().__init__() if timm is None: raise RuntimeError("Please `pip install timm` to use timm models.") - self.image_size = to_2tuple(image_size) + + # setup kwargs that may not be common across all models timm_kwargs = {} if drop_path is not None: timm_kwargs['drop_path_rate'] = drop_path - self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs) - feat_size = self.trunk.default_cfg.get('pool_size', None) - feature_ndim = 1 if not feat_size else 2 - if pool in ('abs_attn', 'rot_attn'): - assert feature_ndim == 2 - # if attn pooling used, remove both classifier and default pool - self.trunk.reset_classifier(0, global_pool='') + if patch_drop is not None: + timm_kwargs['patch_drop_rate'] = patch_drop + + custom_pool = pool in ('abs_attn', 'rot_attn') + if not proj and not custom_pool: + # use network classifier head as projection if no proj specified and no custom pooling used + self.trunk = timm.create_model( + model_name, + num_classes=embed_dim, + global_pool=pool, + pretrained=pretrained, + **timm_kwargs, + ) + prev_chs = embed_dim else: - # reset global pool if pool config set, otherwise leave as network default - reset_kwargs = dict(global_pool=pool) if pool else {} - self.trunk.reset_classifier(0, **reset_kwargs) - prev_chs = self.trunk.num_features + self.trunk = timm.create_model( + model_name, + pretrained=pretrained, + **timm_kwargs, + ) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if custom_pool: + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features head_layers = OrderedDict() + + # Add custom pooling to head if pool == 'abs_attn': head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) prev_chs = embed_dim elif pool == 'rot_attn': head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) prev_chs = embed_dim - else: - assert proj, 'projection layer needed if non-attention pooling is used.' # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used if proj == 'linear': @@ -79,6 +99,8 @@ def __init__( head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) elif proj == 'mlp': head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + else: + assert not proj, f'Unknown projection type {proj}.' self.head = nn.Sequential(head_layers) diff --git a/src/training/main.py b/src/training/main.py index 5f7db9f41..2929d0121 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -396,6 +396,10 @@ def main(args): wandb.save(params_file) logging.debug('Finished loading wandb.') + if args.torchcompile: + logging.info('Compiling model...') + model = torch.compile(model) + if 'train' not in data: # If using int8, convert to inference mode. if args.use_bnb_linear is not None: diff --git a/src/training/params.py b/src/training/params.py index 816a8e1cb..31c841791 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -188,7 +188,7 @@ def parse_args(args): ) parser.add_argument( "--precision", - choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], + choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "pure_bf16", "pure_fp16", "fp32"], default="amp", help="Floating point precision." ) @@ -281,6 +281,12 @@ def parse_args(args): action='store_true', help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", ) + parser.add_argument( + "--torchcompile", + default=False, + action='store_true', + help="torch.compile() the model, requires pytorch 2.0 or later.", + ) parser.add_argument( "--trace", default=False, diff --git a/src/training/train.py b/src/training/train.py index e0a140f9c..e93d9d370 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -14,7 +14,7 @@ except ImportError: wandb = None -from open_clip import get_cast_dtype, CLIP, CustomTextCLIP +from open_clip import get_input_dtype, CLIP, CustomTextCLIP from .distributed import is_master from .zero_shot import zero_shot_eval from .precision import get_autocast @@ -62,7 +62,7 @@ def backward(total_loss, scaler): def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): device = torch.device(args.device) autocast = get_autocast(args.precision) - cast_dtype = get_cast_dtype(args.precision) + input_dtype = get_input_dtype(args.precision) model.train() @@ -89,7 +89,7 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist scheduler(step) images, texts = batch - images = images.to(device=device, dtype=cast_dtype, non_blocking=True) + images = images.to(device=device, dtype=input_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) data_time_m.update(time.time() - end) @@ -244,7 +244,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): metrics.update(zero_shot_metrics) autocast = get_autocast(args.precision) - cast_dtype = get_cast_dtype(args.precision) + input_dtype = get_input_dtype(args.precision) if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): dataloader = data['val'].dataloader @@ -259,7 +259,7 @@ def evaluate(model, data, epoch, args, tb_writer=None): with torch.no_grad(): for i, batch in enumerate(dataloader): images, texts = batch - images = images.to(device=device, dtype=cast_dtype, non_blocking=True) + images = images.to(device=device, dtype=input_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) with autocast(): diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 045ec3e10..8265b424b 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from tqdm import tqdm -from open_clip import get_cast_dtype, get_tokenizer, build_zero_shot_classifier, \ +from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES from .precision import get_autocast @@ -17,20 +17,18 @@ def accuracy(output, target, topk=(1,)): def run(model, classifier, dataloader, args): autocast = get_autocast(args.precision) - cast_dtype = get_cast_dtype(args.precision) + input_dtype = get_input_dtype(args.precision) with torch.no_grad(): top1, top5, n = 0., 0., 0. for images, target in tqdm(dataloader, unit_scale=args.batch_size): - images = images.to(args.device) - if cast_dtype is not None: - images = images.to(dtype=cast_dtype) + images = images.to(device=args.device, dtype=input_dtype) target = target.to(args.device) with autocast(): # predict - image_features = model.encode_image(images) - image_features = F.normalize(image_features, dim=-1) + output = model(image=images) + image_features = output['image_features'] if isinstance(output, dict) else output[0] logits = 100. * image_features @ classifier # measure accuracy diff --git a/tests/util_test.py b/tests/util_test.py index d09b09a2d..53380ddb1 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -209,10 +209,9 @@ def __init__(self, model, model_name, output_dict=True) -> None: def forward(self, image, text): x = self.model(image, text) - if self.output_dict: - out = self.head(x["image_features"]) - else: - out = self.head(x[0]) + x = x['image_features'] if self.output_dict else x[0] + assert x is not None # remove Optional[], type refinement for torchscript + out = self.head(x) return {"test_output": out} def main(args):