-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TextTextCLIP #323
base: main
Are you sure you want to change the base?
Add TextTextCLIP #323
Changes from 5 commits
3367f27
5089d57
d414da5
949834e
479aa07
8fcc3aa
db81bbb
e68f4f8
fef6721
f6eba4f
5d331d9
9d620f5
4a00ea0
588e8ba
1e0d6aa
516176c
2241ce9
71d46ea
73ab4d7
d210534
26d677b
a3029f4
633f53f
634709e
a9710b1
4084147
d430974
9167976
579a591
646d517
9e5ed73
2ede8fc
5415590
4f89a44
3dddc3f
ab6c0b5
edef673
a495a8a
f7c7f19
330d4d9
0fa9e22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
import torch | ||
|
||
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD | ||
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ | ||
from .model import CLIP, CustomTextCLIP, TextTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ | ||
resize_pos_embed, get_cast_dtype | ||
from .openai import load_openai_model | ||
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model | ||
|
@@ -41,7 +41,7 @@ def _rescan_model_configs(): | |
for cf in config_files: | ||
with open(cf, 'r') as f: | ||
model_cfg = json.load(f) | ||
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): | ||
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')) or all(a in model_cfg for a in ('embed_dim', 'query_cfg', 'doc_cfg')): | ||
_MODEL_CONFIGS[cf.stem] = model_cfg | ||
|
||
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} | ||
|
@@ -72,7 +72,11 @@ def get_model_config(model_name): | |
|
||
def get_tokenizer(model_name): | ||
config = get_model_config(model_name) | ||
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize | ||
if 'text_cfg' in config.keys(): | ||
key = 'text_cfg' | ||
elif 'query_cfg' in config.keys(): | ||
key = 'query_cfg' | ||
tokenizer = HFTokenizer(config[key]['hf_tokenizer_name']) if 'hf_tokenizer_name' in config[key] else tokenize | ||
return tokenizer | ||
|
||
|
||
|
@@ -109,6 +113,7 @@ def create_model( | |
pretrained_image: bool = False, | ||
pretrained_hf: bool = True, | ||
cache_dir: Optional[str] = None, | ||
text_to_text: Optional[bool] = False, | ||
): | ||
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names | ||
if isinstance(device, str): | ||
|
@@ -148,13 +153,22 @@ def create_model( | |
|
||
cast_dtype = get_cast_dtype(precision) | ||
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or ('hf_model_name' in model_cfg.get('text_cfg', {})) | ||
|
||
if custom_text: | ||
if 'hf_model_name' in model_cfg.get('text_cfg', {}): | ||
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf | ||
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) | ||
|
||
# switch to TextTextCLIP | ||
if text_to_text: | ||
if 'hf_model_name' in model_cfg.get('doc_cfg', {}): | ||
model_cfg['doc_cfg']['hf_model_pretrained'] = pretrained_hf | ||
if 'hf_model_name' in model_cfg.get('query_cfg', {}): | ||
model_cfg['query_cfg']['hf_model_pretrained'] = pretrained_hf | ||
|
||
model = TextTextCLIP(**model_cfg, cast_dtype=cast_dtype) | ||
else: | ||
model = CLIP(**model_cfg, cast_dtype=cast_dtype) | ||
if custom_text: | ||
if 'hf_model_name' in model_cfg.get('text_cfg', {}): | ||
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf | ||
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) | ||
else: | ||
model = CLIP(**model_cfg, cast_dtype=cast_dtype) | ||
|
||
pretrained_cfg = {} | ||
if pretrained: | ||
|
@@ -179,9 +193,10 @@ def create_model( | |
if precision in ("fp16", "bf16"): | ||
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) | ||
|
||
if not text_to_text: | ||
# set image / mean metadata from pretrained_cfg if available, or use default | ||
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN | ||
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD | ||
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN | ||
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD | ||
|
||
if jit: | ||
model = torch.jit.script(model) | ||
|
@@ -203,6 +218,7 @@ def create_model_and_transforms( | |
image_mean: Optional[Tuple[float, ...]] = None, | ||
image_std: Optional[Tuple[float, ...]] = None, | ||
cache_dir: Optional[str] = None, | ||
text_to_text: Optional[bool] = False, | ||
): | ||
model = create_model( | ||
model_name, | ||
|
@@ -216,22 +232,27 @@ def create_model_and_transforms( | |
pretrained_image=pretrained_image, | ||
pretrained_hf=pretrained_hf, | ||
cache_dir=cache_dir, | ||
text_to_text=text_to_text, | ||
) | ||
|
||
image_mean = image_mean or getattr(model.visual, 'image_mean', None) | ||
image_std = image_std or getattr(model.visual, 'image_std', None) | ||
preprocess_train = image_transform( | ||
model.visual.image_size, | ||
is_train=True, | ||
mean=image_mean, | ||
std=image_std | ||
) | ||
preprocess_val = image_transform( | ||
model.visual.image_size, | ||
is_train=False, | ||
mean=image_mean, | ||
std=image_std | ||
) | ||
if not text_to_text: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about checking if model.visual exists instead ? |
||
image_mean = image_mean or getattr(model.visual, 'image_mean', None) | ||
image_std = image_std or getattr(model.visual, 'image_std', None) | ||
preprocess_train = image_transform( | ||
model.visual.image_size, | ||
is_train=True, | ||
mean=image_mean, | ||
std=image_std | ||
) | ||
preprocess_val = image_transform( | ||
model.visual.image_size, | ||
is_train=False, | ||
mean=image_mean, | ||
std=image_std | ||
) | ||
else: | ||
preprocess_val = None | ||
preprocess_train = None | ||
|
||
return model, preprocess_train, preprocess_val | ||
|
||
|
@@ -248,6 +269,7 @@ def create_model_from_pretrained( | |
image_mean: Optional[Tuple[float, ...]] = None, | ||
image_std: Optional[Tuple[float, ...]] = None, | ||
cache_dir: Optional[str] = None, | ||
text_to_text: Optional[bool] = False, | ||
): | ||
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained): | ||
raise RuntimeError( | ||
|
@@ -263,18 +285,20 @@ def create_model_from_pretrained( | |
force_quick_gelu=force_quick_gelu, | ||
force_custom_text=force_custom_text, | ||
cache_dir=cache_dir, | ||
text_to_text=text_to_text, | ||
) | ||
|
||
if not return_transform: | ||
return model | ||
|
||
image_mean = image_mean or getattr(model.visual, 'image_mean', None) | ||
image_std = image_std or getattr(model.visual, 'image_std', None) | ||
preprocess = image_transform( | ||
model.visual.image_size, | ||
is_train=False, | ||
mean=image_mean, | ||
std=image_std | ||
) | ||
if not text_to_text: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about checking if model.visual exists instead ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we merge |
||
image_mean = image_mean or getattr(model.visual, 'image_mean', None) | ||
image_std = image_std or getattr(model.visual, 'image_std', None) | ||
preprocess = image_transform( | ||
model.visual.image_size, | ||
is_train=False, | ||
mean=image_mean, | ||
std=image_std | ||
) | ||
|
||
return model, preprocess |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -248,6 +248,47 @@ def forward(self, image, text): | |
return image_features, text_features, self.logit_scale.exp() | ||
|
||
|
||
|
||
class TextTextCLIP(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is completely duplicated from above class, I wonder if we could reconcile it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could use this one as the general model? It's not specific about modality. We can refer to |
||
def __init__( | ||
self, | ||
embed_dim: int, | ||
query_cfg: CLIPTextCfg, | ||
doc_cfg: CLIPTextCfg, | ||
quick_gelu: bool = False, | ||
cast_dtype: Optional[torch.dtype] = None, | ||
): | ||
super().__init__() | ||
self.doc = _build_text_tower(embed_dim, doc_cfg, quick_gelu, cast_dtype) | ||
self.query = _build_text_tower(embed_dim, query_cfg, quick_gelu, cast_dtype) | ||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | ||
|
||
def lock_query_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): | ||
self.query.lock(unlocked_layers, freeze_layer_norm) | ||
|
||
def lock_doc_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): | ||
self.doc.lock(unlocked_layers, freeze_layer_norm) | ||
|
||
@torch.jit.ignore | ||
def set_grad_checkpointing(self, enable=True): | ||
self.doc.set_grad_checkpointing(enable) | ||
self.query.set_grad_checkpointing(enable) | ||
|
||
def encode_doc(self, text, normalize: bool = False): | ||
features = self.doc(text) | ||
return F.normalize(features, dim=-1) if normalize else features | ||
|
||
def encode_query(self, text, normalize: bool = False): | ||
features = self.query(text) | ||
return F.normalize(features, dim=-1) if normalize else features | ||
|
||
def forward(self, query, doc): | ||
query_features = self.encode_query(query, normalize=True) | ||
doc_features = self.encode_doc(doc, normalize=True) | ||
return query_features, doc_features, self.logit_scale.exp() | ||
|
||
|
||
|
||
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): | ||
"""Convert applicable model parameters to low-precision (bf16 or fp16)""" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
{ | ||
"embed_dim": 512, | ||
"quick_gelu": true, | ||
"query_cfg": { | ||
"hf_model_name": "roberta-base", | ||
"hf_tokenizer_name": "roberta-base", | ||
"proj": "mlp", | ||
"pooler_type": "mean_pooler" | ||
}, | ||
"doc_cfg": { | ||
"hf_model_name": "roberta-base", | ||
"hf_tokenizer_name": "roberta-base", | ||
"proj": "mlp", | ||
"pooler_type": "mean_pooler" | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about checking if model.visual exists instead ?