Skip to content

Commit

Permalink
Merge branch 'main' into distributed_validation
Browse files Browse the repository at this point in the history
  • Loading branch information
rom1504 authored Nov 7, 2022
2 parents 84b2dc7 + 84617b0 commit 03839c5
Show file tree
Hide file tree
Showing 34 changed files with 1,458 additions and 617 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ This repository is focused on training CLIP models. To fine-tune a *trained* zer

## Data

To download datasets as webdataset, we recommend [img2dataset](https://github.com/rom1504/img2dataset)

### Conceptual Captions

Expand Down Expand Up @@ -164,6 +165,10 @@ the the logit matrix. Using a naïve all-gather scheme, space complexity will be
`--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one
numerical results as the naïve method.

#### Epochs

For larger datasets (eg Laion2B), we recommend setting --train-num-samples to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with --dataset-resampled to do sampling with replacement. This allows having frequent checkpoints to evaluate more often.

#### Single-Node

We make use of `torchrun` to launch distributed jobs. The following launches a
Expand Down Expand Up @@ -384,7 +389,7 @@ Below are checkpoints of models trained on YFCC-15M, along with their zero-shot

We offer a simple model interface to instantiate both pre-trained and untrained models.

NOTE: Many existing checkpoints use the QuickGELU activation from the original OpenAI models. This activation is actually less efficient that native torch.nn.GELU in recent versions of PyTorch. The model defaults are now nn.GELU, so one should use model definitions with `-quickgelu` postfix for the OpenCLIP pretrained weights. All OpenAI pretrained weights will always default to QuickGELU. One can also use the non `-quickgelu` model definitions with pretrained weights using QuickGELU but there will be an accuracy drop, for fine-tune that will likely vanish for longer runs.
NOTE: Many existing checkpoints use the QuickGELU activation from the original OpenAI models. This activation is actually less efficient than native torch.nn.GELU in recent versions of PyTorch. The model defaults are now nn.GELU, so one should use model definitions with `-quickgelu` postfix for the OpenCLIP pretrained weights. All OpenAI pretrained weights will always default to QuickGELU. One can also use the non `-quickgelu` model definitions with pretrained weights using QuickGELU but there will be an accuracy drop, for fine-tune that will likely vanish for longer runs.

Future trained models will use nn.GELU.

Expand Down
8 changes: 5 additions & 3 deletions src/open_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import list_models, create_model, create_model_and_transforms, add_model_config
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .tokenizer import SimpleTokenizer, tokenize
from .transform import image_transform
154 changes: 120 additions & 34 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, convert_weights_to_fp16, resize_pos_embed
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .openai import load_openai_model
from .pretrained import get_pretrained_cfg, download_pretrained
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
from .transform import image_transform


Expand Down Expand Up @@ -48,6 +49,26 @@ def _rescan_model_configs():
_rescan_model_configs() # initial populate of model config registry


def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())


def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()


def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None


def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
Expand All @@ -61,33 +82,42 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):

def load_checkpoint(model, checkpoint_path, strict=True):
state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys


def create_model(
model_name: str,
pretrained: str = '',
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
pretrained_image: bool = False,
cache_dir: Optional[str] = None,
):
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
if isinstance(device, str):
device = torch.device(device)

if pretrained.lower() == 'openai':
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(model_name, device=device, jit=jit, cache_dir=cache_dir)
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
if precision == "amp" or precision == "fp32":
model = model.float()
model = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
else:
if model_name in _MODEL_CONFIGS:
logging.info(f'Loading {model_name} model config.')
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
Expand All @@ -103,7 +133,13 @@ def create_model(
else:
assert False, 'pretrained image towers currently only supported for timm models'

model = CLIP(**model_cfg)
cast_dtype = get_cast_dtype(precision)
custom_text = model_cfg.pop('custom_text', False) or force_custom_text

if custom_text:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

pretrained_cfg = {}
if pretrained:
Expand All @@ -118,13 +154,15 @@ def create_model(
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)

model.to(device=device)
if precision == "fp16":
assert device.type != 'cpu'
convert_weights_to_fp16(model)
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)

# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
Expand All @@ -138,38 +176,86 @@ def create_model(

def create_model_and_transforms(
model_name: str,
pretrained: str = '',
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
pretrained_image: bool = False,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
):
model = create_model(
model_name, pretrained, precision, device, jit,
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
pretrained_image=pretrained_image,
cache_dir=cache_dir)
cache_dir=cache_dir,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std)
preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)

return model, preprocess_train, preprocess_val


def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def create_model_from_pretrained(
model_name: str,
pretrained: str,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
):
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
raise RuntimeError(
f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
f' Use open_clip.list_pretrained() to find one.')

model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
cache_dir=cache_dir,
)

if not return_transform:
return model

def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)

return model, preprocess
Loading

0 comments on commit 03839c5

Please sign in to comment.