Skip to content

Commit

Permalink
allow for using an external vision transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 25, 2021
1 parent 699a491 commit d263446
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 11 deletions.
47 changes: 46 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ clip = CLIP(
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
num_visual_tokens = 512,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
Expand Down Expand Up @@ -60,6 +59,52 @@ loss = clip(
loss.backward()
```

You can also pass in an external visual transformer / residual net. You simply have to make sure your image encoder returns a set of embeddings in the shape of `batch x seq x dim`, and make sure `dim_image` is properly specified as the dimension of the returned embeddings. Below is an example using vision transformer from `vit_pytorch`

```bash
$ pip install vit_pytorch>=0.25.6
```

```python
import torch
from x_clip import CLIP

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

base_vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 512,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

vit = Extractor(base_vit, return_embeddings_only = True)

clip = CLIP(
image_encoder = vit,
dim_image = 512, # must be set as the same dimensions as the vision transformer above
dim_text = 512,
dim_latent = 512,
num_text_tokens = 10000,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()
```

## Citations

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-clip',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'X-CLIP',
author = 'Phil Wang',
Expand Down
26 changes: 17 additions & 9 deletions x_clip/x_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,14 @@ class CLIP(nn.Module):
def __init__(
self,
*,
image_encoder = None,
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 10000,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
num_visual_tokens = 512,
visual_enc_depth = 6,
visual_heads = 8,
visual_image_size = 256,
Expand All @@ -237,6 +237,9 @@ def __init__(
image_ssl_loss_weight = 0.05
):
super().__init__()

# instantiate text transformer

self.text_transformer = TextTransformer(
dim = dim_text,
num_tokens = num_text_tokens + (1 if use_mlm else 0),
Expand All @@ -245,14 +248,19 @@ def __init__(
heads = text_heads
)

self.visual_transformer = VisionTransformer(
dim = dim_image,
image_size = visual_image_size,
patch_size = visual_patch_size,
channels = channels,
depth = visual_enc_depth,
heads = visual_heads
)
# instantiate image transformer

if exists(image_encoder):
self.visual_transformer = image_encoder
else:
self.visual_transformer = VisionTransformer(
dim = dim_image,
image_size = visual_image_size,
patch_size = visual_patch_size,
channels = channels,
depth = visual_enc_depth,
heads = visual_heads
)

# text ssl

Expand Down

0 comments on commit d263446

Please sign in to comment.