Skip to content

Commit

Permalink
Merge pull request #106 from mlfoundations/resize_resample_more
Browse files Browse the repository at this point in the history
PosEmbed resizing, Image pre-processing flexibility, WDS resampled shard option, worker epoch synch improvements
  • Loading branch information
rwightman authored Jun 28, 2022
2 parents fb58d4b + 9078293 commit c933765
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 46 deletions.
18 changes: 14 additions & 4 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple

import torch

from .model import CLIP, convert_weights_to_fp16
from .model import CLIP, convert_weights_to_fp16, resize_pos_embed
from .openai import load_openai_model
from .pretrained import get_pretrained_url, download_pretrained
from .transform import image_transform
Expand Down Expand Up @@ -57,6 +58,13 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):
return state_dict


def load_checkpoint(model, checkpoint_path, strict=True):
state_dict = load_state_dict(checkpoint_path)
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys


def create_model(
model_name: str,
pretrained: str = '',
Expand Down Expand Up @@ -105,7 +113,7 @@ def create_model(

if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
model.load_state_dict(load_state_dict(checkpoint_path))
load_checkpoint(model, checkpoint_path)
else:
logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
Expand All @@ -129,13 +137,15 @@ def create_model_and_transforms(
jit: bool = False,
force_quick_gelu: bool = False,
pretrained_image: bool = False,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
):
model = create_model(
model_name, pretrained, precision, device, jit,
force_quick_gelu=force_quick_gelu,
pretrained_image=pretrained_image)
preprocess_train = image_transform(model.visual.image_size, is_train=True)
preprocess_val = image_transform(model.visual.image_size, is_train=False)
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=mean, std=std)
preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=mean, std=std)
return model, preprocess_train, preprocess_val


Expand Down
11 changes: 9 additions & 2 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import torch
import torch.distributed.nn
from torch import distributed as dist, nn as nn
import torch.nn as nn
from torch.nn import functional as F

try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False

try:
import horovod.torch as hvd
except ImportError:
Expand All @@ -18,6 +24,7 @@ def gather_features(
world_size=1,
use_horovod=False
):
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
if use_horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
Expand Down
44 changes: 40 additions & 4 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""

from collections import OrderedDict
from dataclasses import dataclass
import logging
import math
from typing import Tuple, Union, Callable, Optional

import numpy as np
Expand All @@ -14,7 +15,7 @@
from torch.utils.checkpoint import checkpoint

from .timm_model import TimmModel
from .utils import freeze_batch_norm_2d
from .utils import freeze_batch_norm_2d, to_2tuple


class Bottleneck(nn.Module):
Expand Down Expand Up @@ -255,13 +256,15 @@ def __init__(
self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float,
output_dim: int, act_layer: Callable = nn.GELU):
super().__init__()
self.image_size = image_size
self.image_size = to_2tuple(image_size)
self.patch_size = to_2tuple(patch_size)
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((image_size // patch_size) ** 2 + 1, width))
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
self.ln_pre = LayerNorm(width)

self.transformer = Transformer(width, layers, heads, mlp_ratio, act_layer=act_layer)
Expand Down Expand Up @@ -556,3 +559,36 @@ def trace_model(model, batch_size=256, device=torch.device('cpu')):
))
model.visual.image_size = image_size
return model


def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('visual.positional_embedding', None)
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return

if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))

logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict['visual.positional_embedding'] = new_pos_embed
59 changes: 54 additions & 5 deletions src/open_clip/transform.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,58 @@
from typing import Optional, Sequence, Tuple

import torch
import torch.nn as nn
import torchvision.transforms.functional as F


from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop


class ResizeMaxSize(nn.Module):

def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == 'min' else min
self.fill = fill

def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
return img


def _convert_to_rgb(image):
return image.convert('RGB')


def image_transform(
image_size: int,
is_train: bool,
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
):
mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean
std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]

normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose([
Expand All @@ -21,10 +62,18 @@ def image_transform(
normalize,
])
else:
return Compose([
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
if resize_longest_max:
transforms = [
ResizeMaxSize(image_size, fill=fill_color)
]
else:
transforms = [
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
]
transforms.extend([
_convert_to_rgb,
ToTensor(),
normalize,
])
return Compose(transforms)
21 changes: 20 additions & 1 deletion src/open_clip/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from itertools import repeat
import collections.abc

from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d

Expand Down Expand Up @@ -38,4 +41,20 @@ def freeze_batch_norm_2d(module, module_match={}, name=''):
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
return res


# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
Loading

0 comments on commit c933765

Please sign in to comment.