Skip to content

Commit

Permalink
allow text transformer to be externally initialized
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 25, 2021
1 parent d263446 commit 89ed952
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 10 deletions.
55 changes: 54 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ base_vit = ViT(
emb_dropout = 0.1
)

vit = Extractor(base_vit, return_embeddings_only = True)
vit = Extractor(
base_vit,
return_embeddings_only = True
)

clip = CLIP(
image_encoder = vit,
Expand All @@ -105,6 +108,56 @@ loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()
```

Finally, one can also have the text transformer be externally defined. It will need to return the embeddings including the CLS token, for now.

```python
import torch
from x_clip import CLIP, TextTransformer

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

base_vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 512,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

image_encoder = Extractor(
base_vit,
return_embeddings_only = True
)

text_encoder = TextTransformer(
dim = 512,
num_tokens = 10000,
max_seq_len = 256 + 1,
depth = 6,
heads = 8
)

clip = CLIP(
image_encoder = image_encoder,
text_encoder = text_encoder,
dim_image = 512,
dim_text = 512,
dim_latent = 512
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()
```

## Citations

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-clip',
packages = find_packages(exclude=[]),
version = '0.0.10',
version = '0.0.11',
license='MIT',
description = 'X-CLIP',
author = 'Phil Wang',
Expand Down
2 changes: 1 addition & 1 deletion x_clip/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from x_clip.x_clip import CLIP
from x_clip.x_clip import CLIP, TextTransformer
18 changes: 11 additions & 7 deletions x_clip/x_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def __init__(
self,
*,
image_encoder = None,
text_encoder = None,
dim_text = 512,
dim_image = 512,
dim_latent = 512,
Expand Down Expand Up @@ -240,13 +241,16 @@ def __init__(

# instantiate text transformer

self.text_transformer = TextTransformer(
dim = dim_text,
num_tokens = num_text_tokens + (1 if use_mlm else 0),
max_seq_len = text_seq_len,
depth = text_enc_depth,
heads = text_heads
)
if exists(text_encoder):
self.text_transformer = text_encoder
else:
self.text_transformer = TextTransformer(
dim = dim_text,
num_tokens = num_text_tokens + (1 if use_mlm else 0),
max_seq_len = text_seq_len + 1,
depth = text_enc_depth,
heads = text_heads
)

# instantiate image transformer

Expand Down

0 comments on commit 89ed952

Please sign in to comment.