Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text Tower Refactor #185

Merged
merged 28 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4ce3926
POC interface for create fn with mandatory pretrained arg and inferen…
rwightman Sep 17, 2022
1c8844b
Text Tower Support
iejMac Sep 25, 2022
1e348c8
add model.py changes
iejMac Sep 25, 2022
213e218
modified_resnet.py
iejMac Sep 25, 2022
b4df97a
update visual transformer to main version
iejMac Sep 25, 2022
9cf309e
init_params
iejMac Sep 25, 2022
f454a18
update to main
iejMac Sep 25, 2022
f474a2a
import math
iejMac Sep 25, 2022
65a14a9
remove print
iejMac Sep 25, 2022
b3a3718
convert state dict
iejMac Sep 25, 2022
9955855
comment above conver_state_dict
iejMac Sep 25, 2022
2415600
update to main
iejMac Sep 25, 2022
dc9fc01
comment in factory
Sep 25, 2022
15daeb7
Add a note on how to do smaller epochs. Fix #135
rom1504 Sep 24, 2022
a618008
Recommend img2dataset in readme, fix #148
rom1504 Sep 24, 2022
75a009b
filter examples with no images, in addition to those with no captions…
mehdidc Sep 26, 2022
aa71712
Add jit=True to check we don't break torchscript
rom1504 Sep 26, 2022
c849dee
Test both jit True and False
rom1504 Sep 26, 2022
044c30d
Merge branch 'text_tower' of https://github.com/iejMac/open_clip into…
rwightman Sep 29, 2022
2c3d86e
Refactor custom text tower into separate model w/ code re-use reduced.
rwightman Sep 29, 2022
d79c5c2
Remove text tower conversion for OpenAI weight loads
rwightman Oct 31, 2022
3e76bca
Merge remote-tracking branch 'origin/from_pretrained' into text_tower…
rwightman Oct 31, 2022
c4190d2
Fixing float16/bfloat16 (pure) modes, adding flag to force use of cus…
rwightman Nov 1, 2022
90a890f
Add/remove some model configs. Add profiler. Add support for layer_sc…
rwightman Nov 3, 2022
0fd8534
Remove save that was for openai checkpoint tests
rwightman Nov 3, 2022
7e5546d
Fix grad checkpoing for timm models (bug). Change grad clipping arg n…
rwightman Nov 4, 2022
e5a92e2
Tweak profile script, fix a bug for resnet models, add G/e/S-32-alt c…
rwightman Nov 4, 2022
c119c01
Bump version to 2.1.0
rwightman Nov 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/open_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import list_models, create_model, create_model_and_transforms, add_model_config
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .tokenizer import SimpleTokenizer, tokenize
from .transform import image_transform
154 changes: 120 additions & 34 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, convert_weights_to_fp16, resize_pos_embed
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .openai import load_openai_model
from .pretrained import get_pretrained_cfg, download_pretrained
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
from .transform import image_transform


Expand Down Expand Up @@ -48,6 +49,26 @@ def _rescan_model_configs():
_rescan_model_configs() # initial populate of model config registry


def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())


def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()


def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None


def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
Expand All @@ -61,33 +82,42 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):

def load_checkpoint(model, checkpoint_path, strict=True):
state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
rom1504 marked this conversation as resolved.
Show resolved Hide resolved
state_dict = convert_to_custom_text_state_dict(state_dict)
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys


def create_model(
model_name: str,
pretrained: str = '',
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
pretrained_image: bool = False,
cache_dir: Optional[str] = None,
):
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
if isinstance(device, str):
device = torch.device(device)

if pretrained.lower() == 'openai':
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(model_name, device=device, jit=jit, cache_dir=cache_dir)
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
if precision == "amp" or precision == "fp32":
model = model.float()
model = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
else:
if model_name in _MODEL_CONFIGS:
logging.info(f'Loading {model_name} model config.')
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
Expand All @@ -103,7 +133,13 @@ def create_model(
else:
assert False, 'pretrained image towers currently only supported for timm models'

model = CLIP(**model_cfg)
cast_dtype = get_cast_dtype(precision)
custom_text = model_cfg.pop('custom_text', False) or force_custom_text

if custom_text:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

pretrained_cfg = {}
if pretrained:
Expand All @@ -118,13 +154,15 @@ def create_model(
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)

model.to(device=device)
if precision == "fp16":
assert device.type != 'cpu'
convert_weights_to_fp16(model)
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)

# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
Expand All @@ -138,38 +176,86 @@ def create_model(

def create_model_and_transforms(
model_name: str,
pretrained: str = '',
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
pretrained_image: bool = False,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
):
model = create_model(
model_name, pretrained, precision, device, jit,
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
pretrained_image=pretrained_image,
cache_dir=cache_dir)
cache_dir=cache_dir,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std)
preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)

return model, preprocess_train, preprocess_val


def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def create_model_from_pretrained(
model_name: str,
pretrained: str,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
):
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
raise RuntimeError(
f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
f' Use open_clip.list_pretrained() to find one.')

model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
cache_dir=cache_dir,
)

if not return_transform:
return model

def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)

return model, preprocess
Loading