Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/transformers/models/siglip/modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def _init_weights(self, module):
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
elif "MultiheadAttentionPoolingHead" in module.__class__.__name__:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems unrelated, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you would believe that, right 👀
it wasn't AFAIK but it's related to when I shuffled around PreTrainedModel or not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dang wtf, no biggie but interesting

nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
Expand Down Expand Up @@ -678,9 +678,14 @@ def forward(
)


class SiglipVisionTransformer(nn.Module):
class SiglipVisionTransformer(SiglipPreTrainedModel):
_can_record_outputs = {
"hidden_states": SiglipEncoderLayer,
"attentions": SiglipAttention,
}

def __init__(self, config: SiglipVisionConfig):
super().__init__()
super().__init__(config)
self.config = config
embed_dim = config.hidden_size

Expand All @@ -691,6 +696,7 @@ def __init__(self, config: SiglipVisionConfig):
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(config)

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
Expand Down
194 changes: 100 additions & 94 deletions src/transformers/models/siglip2/modeling_siglip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,99 +349,6 @@ def forward(
return hidden_states


class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].

Args:
config: Siglip2Config
"""

def __init__(self, config: Siglip2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

# Ignore copy
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
**kwargs,
)

return BaseModelOutput(last_hidden_state=hidden_states)


class Siglip2VisionTransformer(nn.Module):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size

self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = Siglip2MultiheadAttentionPoolingHead(config)

@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

hidden_states = self.embeddings(pixel_values, spatial_shapes)

if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
else:
encoder_attention_mask = attention_mask

encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)

pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None

return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
Expand Down Expand Up @@ -585,7 +492,7 @@ def _init_weights(self, module):
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, Siglip2MultiheadAttentionPoolingHead):
elif "MultiheadAttentionPoolingHead" in module.__class__.__name__:
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
Expand All @@ -607,6 +514,105 @@ def _init_weights(self, module):
module.weight.data.fill_(1.0)


class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].

Args:
config: Siglip2Config
"""

def __init__(self, config: Siglip2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

# Ignore copy
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
**kwargs,
)

return BaseModelOutput(last_hidden_state=hidden_states)


class Siglip2VisionTransformer(Siglip2PreTrainedModel):
_can_record_outputs = {
"hidden_states": Siglip2EncoderLayer,
"attentions": Siglip2Attention,
}

def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size

self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = Siglip2MultiheadAttentionPoolingHead(config)

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width) of the input images.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

hidden_states = self.embeddings(pixel_values, spatial_shapes)

if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
else:
encoder_attention_mask = attention_mask

encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)

pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None

return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


class Siglip2TextEmbeddings(nn.Module):
def __init__(self, config: Siglip2TextConfig):
super().__init__()
Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/siglip2/modular_siglip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...utils import auto_docstring, filter_out_non_signature_kwargs
from ...utils.generic import check_model_inputs


class Siglip2TextConfig(SiglipTextConfig):
Expand Down Expand Up @@ -230,6 +231,10 @@ def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTen
return embeddings


class Siglip2PreTrainedModel(SiglipPreTrainedModel):
pass


class Siglip2VisionTransformer(SiglipVisionTransformer):
def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
Expand Down Expand Up @@ -280,10 +285,6 @@ def forward(
)


class Siglip2PreTrainedModel(SiglipPreTrainedModel):
pass


class Siglip2TextModel(SiglipTextModel):
pass

Expand Down Expand Up @@ -314,6 +315,8 @@ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Ten

class Siglip2VisionModel(SiglipVisionModel):
# Update: add `spatial_shapes` and `pixel_attention_mask`
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
Expand Down
4 changes: 4 additions & 0 deletions tests/models/siglip/test_modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def test_sdpa_can_dispatch_composite_models(self):
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")

@unittest.skip(reason="This test is broken on A10 multi runners for now")
def test_multi_gpu_data_parallel_forward(self):
pass


class SiglipVisionModelTester:
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions tests/models/siglip2/test_modeling_siglip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
def test_sdpa_can_dispatch_on_flash(self):
pass

@unittest.skip(reason="This test is broken on A10 multi runners for now")
def test_multi_gpu_data_parallel_forward(self):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't skip these, will be hard to revert because everyone will forget imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it's entirely broken, so need to write it down somewhere 😅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough but the issue is bigger than this PR and model. Imo we should isolate this in a different PR and directly in the common tests. Wdyt?

I've encountered these issues with all recent models I've interacted with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted for now, let's remember that slow tests do not pass but that's ok :P



class Siglip2VisionModelTester:
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
"Kosmos2_5TextForCausalLM",
"Kosmos2_5VisionModel",
"SmolVLMVisionTransformer",
"SiglipVisionTransformer",
"Siglip2VisionTransformer",
"AriaTextForCausalLM",
"AriaTextModel",
"Phi4MultimodalAudioModel",
Expand Down Expand Up @@ -358,7 +360,9 @@
"SegGptForImageSegmentation",
"SiglipVisionModel",
"SiglipTextModel",
"SiglipVisionTransformer",
"Siglip2VisionModel",
"Siglip2VisionTransformer",
"Siglip2TextModel",
"ChameleonVQVAE", # no autoclass for VQ-VAE models
"VitPoseForPoseEstimation",
Expand Down