Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TextTextCLIP #323

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
3367f27
add texttext-clip
Dec 21, 2022
5089d57
fix loss
Dec 21, 2022
d414da5
change main.py
Dec 26, 2022
949834e
add arguments
Dec 26, 2022
479aa07
add factory.py test
Dec 27, 2022
8fcc3aa
test main.py
Jan 4, 2023
db81bbb
Merge branch 'main' into main
lingjzhu Jan 4, 2023
e68f4f8
fix main.py
Jan 5, 2023
fef6721
Merge branch 'main' of https://github.com/lingjzhu/open_clip
Jan 5, 2023
f6eba4f
rename variables
Jan 5, 2023
5d331d9
rename variables
Jan 5, 2023
9d620f5
Merge branch 'mlfoundations:main' into main
lingjzhu Jan 9, 2023
4a00ea0
add hf datasets
Jan 24, 2023
588e8ba
fix Siamese network
Jan 26, 2023
1e0d6aa
fix some typos
lingjzhu Jan 30, 2023
516176c
fix typos
Jan 30, 2023
2241ce9
Merge branch 'main' into main
lingjzhu Jan 31, 2023
71d46ea
resolve conflicts
Jan 31, 2023
73ab4d7
resolve conflicts
Jan 31, 2023
d210534
resolve conflicts
Jan 31, 2023
26d677b
resolve conflicts
Jan 31, 2023
a3029f4
resolve conflicts
Jan 31, 2023
633f53f
resolve conflicts in loss.py
Jan 31, 2023
634709e
resolve conflicts in loss.py
Jan 31, 2023
a9710b1
add output_dict
Feb 6, 2023
4084147
add webdataset loader
Feb 19, 2023
d430974
Merge branch 'main' into main
lingjzhu Mar 21, 2023
9167976
Update loss.py
lingjzhu Mar 21, 2023
579a591
add sts evaluation code
Mar 23, 2023
646d517
fix dependencies
Mar 23, 2023
9e5ed73
add weighted mean pooling for decoder models
Apr 2, 2023
2ede8fc
fix tokenizers
Apr 2, 2023
5415590
enable freezing all weights but biases
Apr 2, 2023
4f89a44
fixed a typo
Apr 2, 2023
3dddc3f
add contriever training
Apr 30, 2023
ab6c0b5
fix import
May 10, 2023
edef673
fix import
May 10, 2023
a495a8a
fix agumentation script
May 21, 2023
f7c7f19
add MTEB evaluation
May 22, 2023
330d4d9
MTEB benchmark
Jun 20, 2023
0fa9e22
add script example
Jun 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions docs/script_examples/text_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/bin/bash
#SBATCH --partition=g80
#SBATCH --account=laion
#SBATCH --job-name=TextTextCLIP
#SBATCH --nodes=32
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=4
#SBATCH --output=logs/%x_%j.out
#SBATCH --comment=laion
#SBATCH --open-mode=append
#SBATCH --exclusive

module load openmpi
module load cuda/11.7

export PYTHONFAULTHANDLER=1
export CUDA_LAUNCH_BLOCKING=0
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=12802
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`

echo go $COUNT_NODE
echo $HOSTNAMES


cd /admin/home-jianz/open_clip/src
export PYTHONPATH="$PYTHONPATH:/admin/home-jianz/open_clip/src"

EXP_NAME=""

#srun --comment laion --cpu_bind=v --accel-bind=gn torchrun --nproc_per_node 4 --max_restarts=3 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 -m training.main \
srun --comment laion --cpu_bind=v --accel-bind=gn python3 -m training.main \
--save-frequency 1 \
--dataset-type="webdataset" \
--text-a-key="text_a" \
--text-b-key="text_b" \
--report-to wandb \
--wandb-project-name="TextTextCLIP" \
--train-data="/fsx/home-jianz/mycache/huggingface/hub/datasets--lingjzhu--laion-multi-2B/snapshots/7ec0d572ac4d8da1e6997ed32e383ab63967e05d/laion-multi-2B-{000..127}.tar" \
--train-num-samples 135646078 \
--warmup 2000 \
--batch-size=2048 \
--precision amp_bfloat16 \
--lr=0.001 \
--wd=0.2 \
--epochs=97 \
--workers=1 \
--model="Siamese-xlm-roberta-large" \
--seed 0 \
--log-every-n-steps 5 \
--local-loss \
--gather-with-grad \
--ddp-static-graph \
--grad-checkpointing \
--model_type="text_siamese_encoder" \
--debug \
--sts-val-data="lingjzhu/sts17-crosslingual"

Empty file added models_gh_runner.txt
Empty file.
2 changes: 2 additions & 0 deletions requirements-training.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ pandas
braceexpand
huggingface_hub
transformers
datasets
timm
fsspec
scipy
3 changes: 2 additions & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint

from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
from .model import CLIP, CustomTextCLIP, TextTextCLIP, SiameseTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
Expand Down
123 changes: 86 additions & 37 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
from .model import CLIP, CustomTextCLIP, TextTextCLIP, SiameseTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
Expand Down Expand Up @@ -44,10 +44,16 @@ def _rescan_model_configs():
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):

is_clip = all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg'))
is_text_only = all(a in model_cfg for a in ('embed_dim', 'tower_a_cfg', 'tower_b_cfg'))
is_siamese_text_only = all(a in model_cfg for a in ('embed_dim', 'text_cfg'))
is_multimodal = all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg', 'multimodal_cfg'))
if is_clip or is_text_only or is_siamese_text_only or is_multimodal:
_MODEL_CONFIGS[cf.stem] = model_cfg

_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}



_rescan_model_configs() # initial populate of model config registry
Expand All @@ -74,12 +80,24 @@ def get_model_config(model_name):


def get_tokenizer(model_name):

if model_name.startswith(HF_HUB_PREFIX):
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
else:
config = get_model_config(model_name)
tokenizer = HFTokenizer(
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize

if 'text_cfg' in config.keys():
key = 'text_cfg'
elif 'tower_a_cfg' in config.keys():
key = 'tower_a_cfg'
if 'hf_tokenizer_name' in config[key]:
tokenizer = HFTokenizer(config[key]['hf_tokenizer_name'])
if "pythia" in config[key]['hf_tokenizer_name']:
tokenizer.tokenizer.pad_token_id = 1
else:
tokenizer = tokenize


return tokenizer


Expand Down Expand Up @@ -117,8 +135,10 @@ def create_model(
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
model_type: Optional[str] = "CLIP",
output_dict: Optional[bool] = None,
require_pretrained: bool = False,

):
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
Expand Down Expand Up @@ -180,18 +200,35 @@ def create_model(
assert False, 'pretrained image towers currently only supported for timm models'

cast_dtype = get_cast_dtype(precision)

is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model

if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
# switch to TextTextCLIP
if model_type=="text_dual_encoder":
if 'hf_model_name' in model_cfg.get('tower_a_cfg', {}):
model_cfg['tower_a_cfg']['hf_model_pretrained'] = pretrained_hf
if 'hf_model_name' in model_cfg.get('tower_b_cfg', {}):
model_cfg['tower_b_cfg']['hf_model_pretrained'] = pretrained_hf
model = TextTextCLIP(**model_cfg, cast_dtype=cast_dtype)

elif model_type=="text_siamese_encoder":
if 'hf_model_name' in model_cfg.get('tower_a_cfg', {}):
model_cfg['tower_a_cfg']['hf_model_pretrained'] = pretrained_hf
model = SiameseTextCLIP(**model_cfg, cast_dtype=cast_dtype)

elif model_type=="CLIP":

if custom_text:
if is_hf_model:
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
if "coca" in model_name:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
model = CLIP(**model_cfg, cast_dtype=cast_dtype)


pretrained_loaded = False
if pretrained:
Expand Down Expand Up @@ -226,9 +263,10 @@ def create_model(
if precision in ("fp16", "bf16"):
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)

if model_type=="CLIP":
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD

# to always output dict even if it is clip
if output_dict and hasattr(model, "output_dict"):
Expand Down Expand Up @@ -287,6 +325,7 @@ def create_model_and_transforms(
image_std: Optional[Tuple[float, ...]] = None,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
cache_dir: Optional[str] = None,
model_type: Optional[str] = "CLIP",
output_dict: Optional[bool] = None,
):
model = create_model(
Expand All @@ -302,24 +341,29 @@ def create_model_and_transforms(
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
model_type=model_type,
output_dict=output_dict,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
)
if model_type=="CLIP":
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
else:
preprocess_val = None
preprocess_train = None


return model, preprocess_train, preprocess_val

Expand All @@ -337,6 +381,7 @@ def create_model_from_pretrained(
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
model_type: Optional[str] = "CLIP",
):
model = create_model(
model_name,
Expand All @@ -348,19 +393,23 @@ def create_model_from_pretrained(
force_custom_text=force_custom_text,
force_image_size=force_image_size,
cache_dir=cache_dir,
model_type=model_type,
require_pretrained=True,
)

if not return_transform:
return model

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
)

if model_type=="CLIP":
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)


return model, preprocess
14 changes: 14 additions & 0 deletions src/open_clip/hf_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,18 @@
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/main/en/model_doc/gpt_neox#transformers.GPTNeoXConfig
"gpt_neox": {
"config_names": {
# https://github.com/huggingface/transformers/blob/c612628045822f909020f7eb6784c79700813eda/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L410
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layers",
"token_embeddings_attr": "embed_in"
},
"pooler": "weighted_mean_pooler",
},
}
51 changes: 42 additions & 9 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,29 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
return masked_output.max(1).values


@register_pooler
class WeightedMeanPooler(nn.Module):
"""Weighted mean pooling for autoregressive models"""

def forward(self, x: BaseModelOutput, attention_mask: TensorType):

input_mask_expanded = attention_mask.unsqueeze(-1).expand(x.last_hidden_state.size()).float()
weights = (
torch.arange(start=1, end=x.last_hidden_state.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(x.last_hidden_state.size())
.float().to(x.last_hidden_state.device)
)
assert weights.shape == x.last_hidden_state.shape == input_mask_expanded.shape
input_mask_expanded = input_mask_expanded * weights

sum_embeddings = torch.sum(x.last_hidden_state * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
return sum_embeddings / sum_mask


@register_pooler
class ClsPooler(nn.Module):
"""CLS token pooling"""
Expand All @@ -78,6 +101,8 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
return x.pooler_output

return x.last_hidden_state[:, self.cls_token_position, :]




class HFTextEncoder(nn.Module):
Expand Down Expand Up @@ -111,6 +136,9 @@ def __init__(
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
elif 'gpt' in self.config.model_type:
self.transformer = create_func(model_args)
self.config.pad_token_id = 1 # this is for GPT-NeoX. It might need to be changed if other models are needed.
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
Expand Down Expand Up @@ -139,23 +167,28 @@ def forward(self, x: TensorType):
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
projected = self.proj(pooled_out)

seq_len = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)

if self.output_tokens:
seq_len = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)
return projected, tokens

return projected

def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlocked_biases: bool = False):
if not unlocked_layers: # full freezing
for n, p in self.transformer.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
p.requires_grad = False
if "LayerNorm" in n.split("."):
p.requires_grad = (not freeze_layer_norm)
if 'bias' in n.split("."):
p.requires_grad = unlocked_biases
return


encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
Expand Down
Loading