Skip to content

Commit

Permalink
Add/remove some model configs. Add profiler. Add support for layer_sc…
Browse files Browse the repository at this point in the history
…ale. No bias on timm model projection by default. Fix null pretrained arg handling.
  • Loading branch information
rwightman committed Nov 3, 2022
1 parent c4190d2 commit 90a890f
Show file tree
Hide file tree
Showing 19 changed files with 340 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained
from .factory import list_models, add_model_config, load_checkpoint
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
Expand Down
15 changes: 11 additions & 4 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def add_model_config(path):
_rescan_model_configs()


def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None


def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
Expand Down Expand Up @@ -98,7 +105,7 @@ def create_model(
if isinstance(device, str):
device = torch.device(device)

if pretrained.lower() == 'openai':
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
Expand All @@ -108,9 +115,9 @@ def create_model(
cache_dir=cache_dir,
)
else:
if model_name in _MODEL_CONFIGS:
logging.info(f'Loading {model_name} model config.')
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
Expand Down
7 changes: 6 additions & 1 deletion src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ class CLIPVisionCfg:
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection


@dataclass
Expand All @@ -40,6 +42,7 @@ class CLIPTextCfg:
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value


def get_cast_dtype(precision: str):
Expand Down Expand Up @@ -71,6 +74,7 @@ def _build_vision_tower(
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
embed_dim=embed_dim,
image_size=vision_cfg.image_size
)
Expand All @@ -94,6 +98,7 @@ def _build_vision_tower(
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
Expand All @@ -120,6 +125,7 @@ def _build_text_tower(
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
Expand Down Expand Up @@ -176,7 +182,6 @@ def encode_text(self, text, normalize: bool = False):
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

return F.normalize(x, dim=-1) if normalize else x

def forward(self, image, text):
Expand Down
17 changes: 17 additions & 0 deletions src/open_clip/model_configs/ViT-M-16-alt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16,
"ls_init_value": 1e-4
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-M-16.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-M-32-alt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-M-32.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-S-16-alt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 256,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 256,
"heads": 4,
"layers": 10
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-S-16.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-S-32.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_base_patch16_224",
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
{
"embed_dim": 1024,
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "resnet50d",
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "abs_attn",
"timm_proj": "",
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"width": 768,
"heads": 12,
"layers": 12
}
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "resnetblur50",
"timm_model_name": "convnext_xlarge",
"timm_model_pretrained": false,
"timm_pool": "abs_attn",
"timm_proj": "",
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
"width": 1024,
"heads": 16,
"layers": 16
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_small_patch16_224",
"timm_model_name": "vit_medium_patch16_gap_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_base_patch32_224",
"timm_model_name": "vit_relpos_medium_patch16_cls_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
Expand Down
5 changes: 3 additions & 2 deletions src/open_clip/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
image_size=224,
pool='avg',
proj='linear',
proj_bias=False,
drop=0.,
pretrained=False):
super().__init__()
Expand Down Expand Up @@ -62,9 +63,9 @@ def __init__(
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))

self.head = nn.Sequential(head_layers)

Expand Down
Loading

0 comments on commit 90a890f

Please sign in to comment.