Skip to content

Commit 329bc36

Browse files
committed
add siglip2 and llama3.1-8b
1 parent 71839f1 commit 329bc36

13 files changed

+1721
-2
lines changed

llm2clip/eva_clip/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
99
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
1010
from .tokenizer import SimpleTokenizer, tokenize
11-
from .transform import image_transform
11+
from .transform import image_transform
12+
from .llm_model import LLM2VecTextTransformer

llm2clip/eva_clip/llm_model.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from torch import nn
3+
from llm2vec import LLM2Vec
4+
5+
class LLM2VecTextTransformer(nn.Module):
6+
def __init__(self, text_proj=None):
7+
super().__init__()
8+
enable_bidirectional = True
9+
base_model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"
10+
extra_model_name_or_path = None
11+
peft_path = "checkpoints/LLM2CLIP-Llama-3.1-8B"
12+
self.text = LLM2Vec.from_pretrained(
13+
base_model_name_or_path,
14+
peft_path,
15+
merge_peft = True,
16+
extra_model_name_or_path=extra_model_name_or_path,
17+
enable_bidirectional=enable_bidirectional,
18+
attn_implementation = "flash_attention_2",
19+
torch_dtype=torch.bfloat16
20+
)
21+
self.text_proj = text_proj
22+
23+
def lock(self, **kwargs):
24+
for param in self.text.parameters():
25+
param.requires_grad = False
26+
27+
def forward(self, text, batch_size=32):
28+
with torch.autocast("cuda"):
29+
x = self.text.encode(text,batch_size=batch_size).to(torch.float16)
30+
if self.text_proj is not None:
31+
x = self.text_proj(x, l2_norm=False)
32+
return x
33+
34+
def set_grad_checkpointing(self, enable=True):
35+
#Not implemented
36+
pass

llm2clip/eva_clip/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .timm_model import TimmModel
2121
from .eva_vit_model import EVAVisionTransformer
2222
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer, LayerNormFp32
23-
23+
from .llm_model import LLM2VecTextTransformer
2424
try:
2525
from apex.normalization import FusedLayerNorm
2626
except:
@@ -191,6 +191,9 @@ def _build_text_tower(
191191
)
192192
elif text_cfg.use_embedding:
193193
text = TextProj(embedding_dim=text_cfg.llm_embedding_dim, output_dim=embed_dim)
194+
elif not text_cfg.use_embedding and text_cfg.llm_embedding_dim:
195+
text_proj = TextProj(embedding_dim=text_cfg.llm_embedding_dim, output_dim=embed_dim)
196+
text = LLM2VecTextTransformer(text_proj)
194197
else:
195198
act_layer = QuickGELU if quick_gelu else nn.GELU
196199
norm_layer = LayerNorm
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"embed_dim": 1152,
3+
"vision_cfg": {
4+
"timm_model_name": "vit_so400m_patch14_siglip_224.v2_webli",
5+
"timm_model_pretrained": true,
6+
"timm_pool": "map",
7+
"timm_proj": "none",
8+
"image_size": 224,
9+
"layers": 27,
10+
"width": 1152,
11+
"patch_size": 14
12+
},
13+
"text_cfg": {
14+
"use_embedding": false,
15+
"llm_embedding_dim": 4096,
16+
"context_length": 77,
17+
"vocab_size": 49408,
18+
"width": 768,
19+
"heads": 12,
20+
"layers": 12
21+
}
22+
}

llm2clip/llm2vec/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .llm2vec import LLM2Vec

0 commit comments

Comments
 (0)