Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TextTextCLIP #323

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
3367f27
add texttext-clip
Dec 21, 2022
5089d57
fix loss
Dec 21, 2022
d414da5
change main.py
Dec 26, 2022
949834e
add arguments
Dec 26, 2022
479aa07
add factory.py test
Dec 27, 2022
8fcc3aa
test main.py
Jan 4, 2023
db81bbb
Merge branch 'main' into main
lingjzhu Jan 4, 2023
e68f4f8
fix main.py
Jan 5, 2023
fef6721
Merge branch 'main' of https://github.com/lingjzhu/open_clip
Jan 5, 2023
f6eba4f
rename variables
Jan 5, 2023
5d331d9
rename variables
Jan 5, 2023
9d620f5
Merge branch 'mlfoundations:main' into main
lingjzhu Jan 9, 2023
4a00ea0
add hf datasets
Jan 24, 2023
588e8ba
fix Siamese network
Jan 26, 2023
1e0d6aa
fix some typos
lingjzhu Jan 30, 2023
516176c
fix typos
Jan 30, 2023
2241ce9
Merge branch 'main' into main
lingjzhu Jan 31, 2023
71d46ea
resolve conflicts
Jan 31, 2023
73ab4d7
resolve conflicts
Jan 31, 2023
d210534
resolve conflicts
Jan 31, 2023
26d677b
resolve conflicts
Jan 31, 2023
a3029f4
resolve conflicts
Jan 31, 2023
633f53f
resolve conflicts in loss.py
Jan 31, 2023
634709e
resolve conflicts in loss.py
Jan 31, 2023
a9710b1
add output_dict
Feb 6, 2023
4084147
add webdataset loader
Feb 19, 2023
d430974
Merge branch 'main' into main
lingjzhu Mar 21, 2023
9167976
Update loss.py
lingjzhu Mar 21, 2023
579a591
add sts evaluation code
Mar 23, 2023
646d517
fix dependencies
Mar 23, 2023
9e5ed73
add weighted mean pooling for decoder models
Apr 2, 2023
2ede8fc
fix tokenizers
Apr 2, 2023
5415590
enable freezing all weights but biases
Apr 2, 2023
4f89a44
fixed a typo
Apr 2, 2023
3dddc3f
add contriever training
Apr 30, 2023
ab6c0b5
fix import
May 10, 2023
edef673
fix import
May 10, 2023
a495a8a
fix agumentation script
May 21, 2023
f7c7f19
add MTEB evaluation
May 22, 2023
330d4d9
MTEB benchmark
Jun 20, 2023
0fa9e22
add script example
Jun 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
from .model import CLIP, CustomTextCLIP, TextTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
Expand Down
90 changes: 57 additions & 33 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
from .model import CLIP, CustomTextCLIP, TextTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
Expand Down Expand Up @@ -41,7 +41,7 @@ def _rescan_model_configs():
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')) or all(a in model_cfg for a in ('embed_dim', 'tower_a_cfg', 'tower_b_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg

_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
Expand Down Expand Up @@ -72,7 +72,11 @@ def get_model_config(model_name):

def get_tokenizer(model_name):
config = get_model_config(model_name)
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
if 'text_cfg' in config.keys():
key = 'text_cfg'
elif 'tower_a_cfg' in config.keys():
key = 'tower_a_cfg'
tokenizer = HFTokenizer(config[key]['hf_tokenizer_name']) if 'hf_tokenizer_name' in config[key] else tokenize
return tokenizer


Expand Down Expand Up @@ -109,6 +113,7 @@ def create_model(
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
text_to_text: Optional[bool] = False,
):
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
if isinstance(device, str):
Expand Down Expand Up @@ -148,13 +153,22 @@ def create_model(

cast_dtype = get_cast_dtype(precision)
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or ('hf_model_name' in model_cfg.get('text_cfg', {}))

if custom_text:
if 'hf_model_name' in model_cfg.get('text_cfg', {}):
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)

# switch to TextTextCLIP
if text_to_text:
if 'hf_model_name' in model_cfg.get('tower_a_cfg', {}):
model_cfg['tower_a_cfg']['hf_model_pretrained'] = pretrained_hf
if 'hf_model_name' in model_cfg.get('tower_b_cfg', {}):
model_cfg['tower_b_cfg']['hf_model_pretrained'] = pretrained_hf

model = TextTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
if custom_text:
if 'hf_model_name' in model_cfg.get('text_cfg', {}):
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

pretrained_cfg = {}
if pretrained:
Expand All @@ -179,9 +193,10 @@ def create_model(
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)

if not text_to_text:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about checking if model.visual exists instead ?

# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD

if jit:
model = torch.jit.script(model)
Expand All @@ -203,6 +218,7 @@ def create_model_and_transforms(
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
text_to_text: Optional[bool] = False,
):
model = create_model(
model_name,
Expand All @@ -216,22 +232,27 @@ def create_model_and_transforms(
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
text_to_text=text_to_text,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
if not text_to_text:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about checking if model.visual exists instead ?

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
else:
preprocess_val = None
preprocess_train = None

return model, preprocess_train, preprocess_val

Expand All @@ -248,6 +269,7 @@ def create_model_from_pretrained(
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
text_to_text: Optional[bool] = False,
):
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
raise RuntimeError(
Expand All @@ -263,18 +285,20 @@ def create_model_from_pretrained(
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
cache_dir=cache_dir,
text_to_text=text_to_text,
)

if not return_transform:
return model

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
if not text_to_text:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about checking if model.visual exists instead ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we merge TextTextCLIP to CustomCLIP to get a more general model, then model.visual might not exist at all?

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)

return model, preprocess
72 changes: 36 additions & 36 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


def gather_features(
image_features,
text_features,
features_a,
features_b,
local_loss=False,
gather_with_grad=False,
rank=0,
Expand All @@ -28,38 +28,38 @@ def gather_features(
if use_horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
all_features_a = hvd.allgather(features_a)
all_features_b = hvd.allgather(features_b)
else:
with torch.no_grad():
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
all_features_a = hvd.allgather(features_a)
all_features_b = hvd.allgather(features_b)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
gathered_features_a = list(all_features_a.chunk(world_size, dim=0))
gathered_features_b = list(all_features_b.chunk(world_size, dim=0))
gathered_features_a[rank] = features_a
gathered_features_b[rank] = features_b
all_features_a = torch.cat(gathered_features_a, dim=0)
all_features_b = torch.cat(gathered_features_b, dim=0)
else:
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
all_features_a = torch.cat(torch.distributed.nn.all_gather(features_a), dim=0)
all_features_b = torch.cat(torch.distributed.nn.all_gather(features_b), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
gathered_features_a = [torch.zeros_like(features_a) for _ in range(world_size)]
gathered_features_b = [torch.zeros_like(features_b) for _ in range(world_size)]
dist.all_gather(gathered_features_a, features_a)
dist.all_gather(gathered_features_b, features_b)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
gathered_features_a[rank] = features_a
gathered_features_b[rank] = features_b
all_features_a = torch.cat(gathered_features_a, dim=0)
all_features_b = torch.cat(gathered_features_b, dim=0)

return all_image_features, all_text_features
return all_features_a, all_features_b


class ClipLoss(nn.Module):
Expand All @@ -85,25 +85,25 @@ def __init__(
self.prev_num_logits = 0
self.labels = {}

def forward(self, image_features, text_features, logit_scale):
device = image_features.device
def forward(self, features_a, features_b, logit_scale):
device = features_a.device
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
all_features_a, all_features_b = gather_features(
features_a, features_b,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)

if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
logits_per_feature_a = logit_scale * features_a @ all_features_b.T
logits_per_feature_b = logit_scale * features_b @ all_features_a.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
logits_per_feature_a = logit_scale * all_features_a @ all_features_b.T
logits_per_feature_b = logits_per_feature_a.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
logits_per_feature_a = logit_scale * features_a @ features_b.T
logits_per_feature_b = logit_scale * features_b @ features_a.T

# calculated ground-truth and cache if enabled
num_logits = logits_per_image.shape[0]
num_logits = logits_per_feature_a.shape[0]
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
Expand All @@ -115,7 +115,7 @@ def forward(self, image_features, text_features, logit_scale):
labels = self.labels[device]

total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
F.cross_entropy(logits_per_feature_a, labels) +
F.cross_entropy(logits_per_feature_b, labels)
) / 2
return total_loss
41 changes: 41 additions & 0 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,47 @@ def forward(self, image, text):
return image_features, text_features, self.logit_scale.exp()



class TextTextCLIP(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is completely duplicated from above class, I wonder if we could reconcile it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could use this one as the general model? It's not specific about modality. We can refer to image_features as features_a and text_features as features_b.

def __init__(
self,
embed_dim: int,
tower_a_cfg: CLIPTextCfg,
tower_b_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.tower_a = _build_text_tower(embed_dim, tower_a_cfg, quick_gelu, cast_dtype)
self.tower_b = _build_text_tower(embed_dim, tower_b_cfg, quick_gelu, cast_dtype)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

def lock_tower_a(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.tower_a.lock(unlocked_layers, freeze_layer_norm)

def lock_tower_b(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.tower_b.lock(unlocked_layers, freeze_layer_norm)

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.tower_a.set_grad_checkpointing(enable)
self.tower_b.set_grad_checkpointing(enable)

def encode_text_a(self, text, normalize: bool = False):
features = self.tower_a(text)
return F.normalize(features, dim=-1) if normalize else features

def encode_text_b(self, text, normalize: bool = False):
features = self.tower_b(text)
return F.normalize(features, dim=-1) if normalize else features

def forward(self, text_a, text_b):
features_a = self.encode_text_a(text_a, normalize=True)
features_b = self.encode_text_b(text_b, normalize=True)
return features_a, features_b, self.logit_scale.exp()



def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""

Expand Down
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/roberta-roberta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 512,
"quick_gelu": true,
"tower_a_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"proj": "mlp",
"pooler_type": "mean_pooler"
},
"tower_b_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"proj": "mlp",
"pooler_type": "mean_pooler"
}
}
Loading