Skip to content

Commit

Permalink
whether to just average pool the output embeddings post transformer i…
Browse files Browse the repository at this point in the history
…n vision transformer, cls token is unnecessary based on a lot of follow up vision transformer works as well as from Beyer himself
  • Loading branch information
lucidrains authored and rwightman committed Dec 8, 2022
1 parent 7fe5b87 commit db8a924
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class CLIPVisionCfg:
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
Expand Down Expand Up @@ -105,6 +106,7 @@ def _build_vision_tower(
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
global_average_pool=vision_cfg.global_average_pool,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
Expand Down
9 changes: 8 additions & 1 deletion src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
global_average_pool: bool = False,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
Expand All @@ -266,6 +267,7 @@ def __init__(
norm_layer=norm_layer,
)

self.global_average_pool = global_average_pool
self.ln_post = norm_layer(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

Expand Down Expand Up @@ -342,7 +344,12 @@ def forward(self, x: torch.Tensor):
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

x = self.ln_post(x[:, 0, :])
if self.global_average_pool:
x = x.mean(dim=1)
else:
x = x[:, 0]

x = self.ln_post(x)

if self.proj is not None:
x = x @ self.proj
Expand Down

0 comments on commit db8a924

Please sign in to comment.