Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion clip_benchmark/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import torch
from .open_clip import load_open_clip
from .japanese_clip import load_japanese_clip
from .transformers_clip import load_transformers_clip

# loading function must return (model, transform, tokenizer)
TYPE2FUNC = {
"open_clip": load_open_clip,
"ja_clip": load_japanese_clip
"ja_clip": load_japanese_clip,
"transformers": load_transformers_clip,
}
MODEL_TYPES = list(TYPE2FUNC.keys())

Expand Down
29 changes: 29 additions & 0 deletions clip_benchmark/models/transformers_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from torch import nn
from transformers import AutoModel, AutoProcessor
from functools import partial

class TransformerWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def encode_text(self, text):
return self.model.get_text_features(**text)

def encode_image(self, image):
# we get an extended dimension possibly due to the collation in dataloader
image = {key: value.squeeze(1) for key, value in image.items()}
return self.model.get_image_features(**image)

def load_transformers_clip(model_name, pretrained, cache_dir, device):
ckpt = f"{model_name}/{pretrained}"
Comment on lines +18 to +19
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ckpt = f"{model_name}/{pretrained}" may confusing, it's better to provide model_name as checkpoint on the hub, and hardcode pretrained as True IMO. Otherwise it's going to be like

model_name = "openai"
pretrained = "clip-..."

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to choose this option for better verbosity.

print(f"Running '{task}' on '{dataset_name}' with the model '{args.pretrained}' on language '{args.language}'")

model = AutoModel.from_pretrained(ckpt, cache_dir=cache_dir, device_map=device)
model = TransformerWrapper(model)

processor = AutoProcessor.from_pretrained(ckpt)
transforms = partial(processor.image_processor.preprocess, return_tensors="pt")
tokenizer = partial(
processor.tokenizer, return_tensors="pt", padding="max_length", max_length=64
)
return model, transforms, tokenizer