diff --git a/README.md b/README.md index ad9bb9e..26c7e32 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,6 @@ clip = CLIP( text_enc_depth = 6, text_seq_len = 256, text_heads = 8, - num_visual_tokens = 512, visual_enc_depth = 6, visual_image_size = 256, visual_patch_size = 32, @@ -60,6 +59,52 @@ loss = clip( loss.backward() ``` +You can also pass in an external visual transformer / residual net. You simply have to make sure your image encoder returns a set of embeddings in the shape of `batch x seq x dim`, and make sure `dim_image` is properly specified as the dimension of the returned embeddings. Below is an example using vision transformer from `vit_pytorch` + +```bash +$ pip install vit_pytorch>=0.25.6 +``` + +```python +import torch +from x_clip import CLIP + +from vit_pytorch import ViT +from vit_pytorch.extractor import Extractor + +base_vit = ViT( + image_size = 256, + patch_size = 32, + num_classes = 1000, + dim = 512, + depth = 6, + heads = 16, + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1 +) + +vit = Extractor(base_vit, return_embeddings_only = True) + +clip = CLIP( + image_encoder = vit, + dim_image = 512, # must be set as the same dimensions as the vision transformer above + dim_text = 512, + dim_latent = 512, + num_text_tokens = 10000, + text_enc_depth = 6, + text_seq_len = 256, + text_heads = 8 +) + +text = torch.randint(0, 10000, (4, 256)) +images = torch.randn(4, 3, 256, 256) +mask = torch.ones_like(text).bool() + +loss = clip(text, images, text_mask = mask, return_loss = True) +loss.backward() +``` + ## Citations ```bibtex diff --git a/setup.py b/setup.py index 150469f..524d6be 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-clip', packages = find_packages(exclude=[]), - version = '0.0.9', + version = '0.0.10', license='MIT', description = 'X-CLIP', author = 'Phil Wang', diff --git a/x_clip/x_clip.py b/x_clip/x_clip.py index c3774b8..4fd3592 100644 --- a/x_clip/x_clip.py +++ b/x_clip/x_clip.py @@ -211,6 +211,7 @@ class CLIP(nn.Module): def __init__( self, *, + image_encoder = None, dim_text = 512, dim_image = 512, dim_latent = 512, @@ -218,7 +219,6 @@ def __init__( text_enc_depth = 6, text_seq_len = 256, text_heads = 8, - num_visual_tokens = 512, visual_enc_depth = 6, visual_heads = 8, visual_image_size = 256, @@ -237,6 +237,9 @@ def __init__( image_ssl_loss_weight = 0.05 ): super().__init__() + + # instantiate text transformer + self.text_transformer = TextTransformer( dim = dim_text, num_tokens = num_text_tokens + (1 if use_mlm else 0), @@ -245,14 +248,19 @@ def __init__( heads = text_heads ) - self.visual_transformer = VisionTransformer( - dim = dim_image, - image_size = visual_image_size, - patch_size = visual_patch_size, - channels = channels, - depth = visual_enc_depth, - heads = visual_heads - ) + # instantiate image transformer + + if exists(image_encoder): + self.visual_transformer = image_encoder + else: + self.visual_transformer = VisionTransformer( + dim = dim_image, + image_size = visual_image_size, + patch_size = visual_patch_size, + channels = channels, + depth = visual_enc_depth, + heads = visual_heads + ) # text ssl