From db8a9246ec052a0ee378a808052ddffc3e03aecf Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 4 Dec 2022 07:55:04 -0800 Subject: [PATCH] whether to just average pool the output embeddings post transformer in vision transformer, cls token is unnecessary based on a lot of follow up vision transformer works as well as from Beyer himself --- src/open_clip/model.py | 2 ++ src/open_clip/transformer.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index d5f56e06c..682d384a1 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -29,6 +29,7 @@ class CLIPVisionCfg: patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) timm_model_name: str = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') @@ -105,6 +106,7 @@ def _build_vision_tower( heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, + global_average_pool=vision_cfg.global_average_pool, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index a36fa5f5d..ebd826c5a 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -241,6 +241,7 @@ def __init__( heads: int, mlp_ratio: float, ls_init_value: float = None, + global_average_pool: bool = False, output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, @@ -266,6 +267,7 @@ def __init__( norm_layer=norm_layer, ) + self.global_average_pool = global_average_pool self.ln_post = norm_layer(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) @@ -342,7 +344,12 @@ def forward(self, x: torch.Tensor): x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_post(x[:, 0, :]) + if self.global_average_pool: + x = x.mean(dim=1) + else: + x = x[:, 0] + + x = self.ln_post(x) if self.proj is not None: x = x @ self.proj