Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text Tower Refactor #185

Merged
merged 28 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4ce3926
POC interface for create fn with mandatory pretrained arg and inferen…
rwightman Sep 17, 2022
1c8844b
Text Tower Support
iejMac Sep 25, 2022
1e348c8
add model.py changes
iejMac Sep 25, 2022
213e218
modified_resnet.py
iejMac Sep 25, 2022
b4df97a
update visual transformer to main version
iejMac Sep 25, 2022
9cf309e
init_params
iejMac Sep 25, 2022
f454a18
update to main
iejMac Sep 25, 2022
f474a2a
import math
iejMac Sep 25, 2022
65a14a9
remove print
iejMac Sep 25, 2022
b3a3718
convert state dict
iejMac Sep 25, 2022
9955855
comment above conver_state_dict
iejMac Sep 25, 2022
2415600
update to main
iejMac Sep 25, 2022
dc9fc01
comment in factory
Sep 25, 2022
15daeb7
Add a note on how to do smaller epochs. Fix #135
rom1504 Sep 24, 2022
a618008
Recommend img2dataset in readme, fix #148
rom1504 Sep 24, 2022
75a009b
filter examples with no images, in addition to those with no captions…
mehdidc Sep 26, 2022
aa71712
Add jit=True to check we don't break torchscript
rom1504 Sep 26, 2022
c849dee
Test both jit True and False
rom1504 Sep 26, 2022
044c30d
Merge branch 'text_tower' of https://github.com/iejMac/open_clip into…
rwightman Sep 29, 2022
2c3d86e
Refactor custom text tower into separate model w/ code re-use reduced.
rwightman Sep 29, 2022
d79c5c2
Remove text tower conversion for OpenAI weight loads
rwightman Oct 31, 2022
3e76bca
Merge remote-tracking branch 'origin/from_pretrained' into text_tower…
rwightman Oct 31, 2022
c4190d2
Fixing float16/bfloat16 (pure) modes, adding flag to force use of cus…
rwightman Nov 1, 2022
90a890f
Add/remove some model configs. Add profiler. Add support for layer_sc…
rwightman Nov 3, 2022
0fd8534
Remove save that was for openai checkpoint tests
rwightman Nov 3, 2022
7e5546d
Fix grad checkpoing for timm models (bug). Change grad clipping arg n…
rwightman Nov 4, 2022
e5a92e2
Tweak profile script, fix a bug for resnet models, add G/e/S-32-alt c…
rwightman Nov 4, 2022
c119c01
Bump version to 2.1.0
rwightman Nov 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained
from .factory import list_models, add_model_config, load_checkpoint
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
Expand Down
15 changes: 11 additions & 4 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def add_model_config(path):
_rescan_model_configs()


def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None


def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
Expand Down Expand Up @@ -98,7 +105,7 @@ def create_model(
if isinstance(device, str):
device = torch.device(device)

if pretrained.lower() == 'openai':
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
Expand All @@ -108,9 +115,9 @@ def create_model(
cache_dir=cache_dir,
)
else:
if model_name in _MODEL_CONFIGS:
logging.info(f'Loading {model_name} model config.')
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
Expand Down
7 changes: 6 additions & 1 deletion src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ class CLIPVisionCfg:
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection


@dataclass
Expand All @@ -40,6 +42,7 @@ class CLIPTextCfg:
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value


def get_cast_dtype(precision: str):
Expand Down Expand Up @@ -71,6 +74,7 @@ def _build_vision_tower(
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
embed_dim=embed_dim,
image_size=vision_cfg.image_size
)
Expand All @@ -94,6 +98,7 @@ def _build_vision_tower(
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
Expand All @@ -120,6 +125,7 @@ def _build_text_tower(
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
Expand Down Expand Up @@ -176,7 +182,6 @@ def encode_text(self, text, normalize: bool = False):
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

return F.normalize(x, dim=-1) if normalize else x

def forward(self, image, text):
Expand Down
17 changes: 17 additions & 0 deletions src/open_clip/model_configs/ViT-M-16-alt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16,
"ls_init_value": 1e-4
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-M-16.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-M-32-alt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-M-32.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-S-16-alt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 256,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 256,
"heads": 4,
"layers": 10
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-S-16.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/ViT-S-32.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_base_patch16_224",
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
{
"embed_dim": 1024,
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "resnet50d",
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "abs_attn",
"timm_proj": "",
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"width": 768,
"heads": 12,
"layers": 12
}
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "resnetblur50",
"timm_model_name": "convnext_xlarge",
"timm_model_pretrained": false,
"timm_pool": "abs_attn",
"timm_proj": "",
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
"width": 1024,
"heads": 16,
"layers": 16
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_small_patch16_224",
"timm_model_name": "vit_medium_patch16_gap_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_base_patch32_224",
"timm_model_name": "vit_relpos_medium_patch16_cls_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
Expand Down
5 changes: 3 additions & 2 deletions src/open_clip/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
image_size=224,
pool='avg',
proj='linear',
proj_bias=False,
drop=0.,
pretrained=False):
super().__init__()
Expand Down Expand Up @@ -62,9 +63,9 @@ def __init__(
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))

self.head = nn.Sequential(head_layers)

Expand Down
Loading