Skip to content

Commit

Permalink
Merge pull request #185 from mlfoundations/text_tower_refactor
Browse files Browse the repository at this point in the history
Text Tower Refactor
  • Loading branch information
rwightman authored Nov 4, 2022
2 parents b4cf926 + c119c01 commit 55174d7
Show file tree
Hide file tree
Showing 30 changed files with 1,302 additions and 603 deletions.
8 changes: 5 additions & 3 deletions src/open_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import list_models, create_model, create_model_and_transforms, add_model_config
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .tokenizer import SimpleTokenizer, tokenize
from .transform import image_transform
154 changes: 120 additions & 34 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, convert_weights_to_fp16, resize_pos_embed
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .openai import load_openai_model
from .pretrained import get_pretrained_cfg, download_pretrained
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
from .transform import image_transform


Expand Down Expand Up @@ -48,6 +49,26 @@ def _rescan_model_configs():
_rescan_model_configs() # initial populate of model config registry


def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())


def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()


def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None


def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
Expand All @@ -61,33 +82,42 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):

def load_checkpoint(model, checkpoint_path, strict=True):
state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys


def create_model(
model_name: str,
pretrained: str = '',
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
pretrained_image: bool = False,
cache_dir: Optional[str] = None,
):
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
if isinstance(device, str):
device = torch.device(device)

if pretrained.lower() == 'openai':
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(model_name, device=device, jit=jit, cache_dir=cache_dir)
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
if precision == "amp" or precision == "fp32":
model = model.float()
model = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
else:
if model_name in _MODEL_CONFIGS:
logging.info(f'Loading {model_name} model config.')
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
Expand All @@ -103,7 +133,13 @@ def create_model(
else:
assert False, 'pretrained image towers currently only supported for timm models'

model = CLIP(**model_cfg)
cast_dtype = get_cast_dtype(precision)
custom_text = model_cfg.pop('custom_text', False) or force_custom_text

if custom_text:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

pretrained_cfg = {}
if pretrained:
Expand All @@ -118,13 +154,15 @@ def create_model(
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)

model.to(device=device)
if precision == "fp16":
assert device.type != 'cpu'
convert_weights_to_fp16(model)
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)

# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
Expand All @@ -138,38 +176,86 @@ def create_model(

def create_model_and_transforms(
model_name: str,
pretrained: str = '',
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
pretrained_image: bool = False,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
):
model = create_model(
model_name, pretrained, precision, device, jit,
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
pretrained_image=pretrained_image,
cache_dir=cache_dir)
cache_dir=cache_dir,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std)
preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)

return model, preprocess_train, preprocess_val


def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def create_model_from_pretrained(
model_name: str,
pretrained: str,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
):
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
raise RuntimeError(
f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
f' Use open_clip.list_pretrained() to find one.')

model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
cache_dir=cache_dir,
)

if not return_transform:
return model

def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)

return model, preprocess
Loading

0 comments on commit 55174d7

Please sign in to comment.