Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/transformers/models/siglip/modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,14 @@ def forward(
)


class SiglipVisionTransformer(nn.Module):
class SiglipVisionTransformer(SiglipPreTrainedModel):
_can_record_outputs = {
"hidden_states": SiglipEncoderLayer,
"attentions": SiglipAttention,
}

def __init__(self, config: SiglipVisionConfig):
super().__init__()
super().__init__(config)
self.config = config
embed_dim = config.hidden_size

Expand All @@ -691,6 +696,7 @@ def __init__(self, config: SiglipVisionConfig):
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(config)

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
Expand Down
192 changes: 99 additions & 93 deletions src/transformers/models/siglip2/modeling_siglip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,99 +349,6 @@ def forward(
return hidden_states


class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].

Args:
config: Siglip2Config
"""

def __init__(self, config: Siglip2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

# Ignore copy
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
**kwargs,
)

return BaseModelOutput(last_hidden_state=hidden_states)


class Siglip2VisionTransformer(nn.Module):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size

self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = Siglip2MultiheadAttentionPoolingHead(config)

@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

hidden_states = self.embeddings(pixel_values, spatial_shapes)

if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
else:
encoder_attention_mask = attention_mask

encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)

pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None

return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
Expand Down Expand Up @@ -607,6 +514,105 @@ def _init_weights(self, module):
module.weight.data.fill_(1.0)


class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].

Args:
config: Siglip2Config
"""

def __init__(self, config: Siglip2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

# Ignore copy
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
**kwargs,
)

return BaseModelOutput(last_hidden_state=hidden_states)


class Siglip2VisionTransformer(Siglip2PreTrainedModel):
_can_record_outputs = {
"hidden_states": Siglip2EncoderLayer,
"attentions": Siglip2Attention,
}

def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size

self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = Siglip2MultiheadAttentionPoolingHead(config)

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

hidden_states = self.embeddings(pixel_values, spatial_shapes)

if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
else:
encoder_attention_mask = attention_mask

encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)

pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None

return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


class Siglip2TextEmbeddings(nn.Module):
def __init__(self, config: Siglip2TextConfig):
super().__init__()
Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/siglip2/modular_siglip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...utils import auto_docstring, filter_out_non_signature_kwargs
from ...utils.generic import check_model_inputs


class Siglip2TextConfig(SiglipTextConfig):
Expand Down Expand Up @@ -230,6 +231,10 @@ def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTen
return embeddings


class Siglip2PreTrainedModel(SiglipPreTrainedModel):
pass


class Siglip2VisionTransformer(SiglipVisionTransformer):
def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
Expand Down Expand Up @@ -280,10 +285,6 @@ def forward(
)


class Siglip2PreTrainedModel(SiglipPreTrainedModel):
pass


class Siglip2TextModel(SiglipTextModel):
pass

Expand Down Expand Up @@ -314,6 +315,8 @@ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Ten

class Siglip2VisionModel(SiglipVisionModel):
# Update: add `spatial_shapes` and `pixel_attention_mask`
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
Expand Down
4 changes: 4 additions & 0 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
"Kosmos2_5TextForCausalLM",
"Kosmos2_5VisionModel",
"SmolVLMVisionTransformer",
"SiglipVisionTransformer",
"Siglip2VisionTransformer",
"AriaTextForCausalLM",
"AriaTextModel",
"Phi4MultimodalAudioModel",
Expand Down Expand Up @@ -358,7 +360,9 @@
"SegGptForImageSegmentation",
"SiglipVisionModel",
"SiglipTextModel",
"SiglipVisionTransformer",
"Siglip2VisionModel",
"Siglip2VisionTransformer",
"Siglip2TextModel",
"ChameleonVQVAE", # no autoclass for VQ-VAE models
"VitPoseForPoseEstimation",
Expand Down