Skip to content

Clarify Gemma4 vision bidirectional mask behavior for full vs sliding attention #46711

Description

@wnma3mz

Description

I noticed a potential mismatch between HuggingFace Transformers' Gemma4 vision bidirectional mask behavior and the official google-deepmind/gemma implementation.

In the official google/gemma implementation, use_bidirectional_attention="vision" appears to apply the bidirectional image-token mask only to local sliding attention layers. Global/full attention layers remain causal. Also, for local sliding layers, the final mask is intersected with the sliding window, so image-token bidirectional visibility does not bypass the sliding-window boundary.

In Transformers, block_sequence_ids is passed into create_masks_for_generate(...), and then applied to both full_attention and sliding_attention masks through blockwise_overlay(...).

This seems to result in:

Transformers:
  full_attention = causal OR vision_blockwise_overlay
  sliding_attention = sliding_causal OR vision_blockwise_overlay

Official google/gemma:
  full/global = causal
  local_sliding = (causal OR vision_bidirectional) AND sliding_window

This creates two observable differences:

  1. In Transformers, full attention layers allow bidirectional visibility inside the image token block, while official google/gemma keeps full/global layers causal.
  2. In Transformers, image tokens inside the same block can attend bidirectionally across the sliding-window boundary, because the blockwise overlay is OR'ed after the sliding causal mask.

Relevant Transformers code

In Transformers, Gemma4 constructs block_sequence_ids for use_bidirectional_attention="vision":

if self.config.get_text_config().use_bidirectional_attention == "vision":
    block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device)
    if mm_token_type_ids is not None:
        block_sequence_ids = get_block_sequence_ids_for_mask(
            mm_token_type_ids, device=inputs_embeds.device
        )

    mask_kwargs["block_sequence_ids"] = block_sequence_ids

causal_mask_mapping = create_masks_for_generate(**mask_kwargs)

Then create_masks_for_generate(...) applies the same block_sequence_ids to all layer patterns:

for layer_pattern in layer_patterns:
    causal_masks[layer_pattern] = LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING[layer_pattern](**mask_kwargs)

Both create_causal_mask(...) and create_sliding_window_causal_mask(...) apply:

mask_factory_function = or_masks(mask_factory_function, blockwise_overlay(block_sequence_ids))

Reference implementation: google-deepmind/gemma

From google-deepmind/gemma, checked locally at commit ab5a660.

The Gemma4 config explicitly documents the intended split between sliding and global layers:

# Bidirectional attention for image tokens in the text backbone.
# None: purely causal for all layers (used by E2B, E4B).
# 'vision': bidirectional for image tokens in sliding layers only,
#   causal for global layers (used by 26B_A4B, 31B).
use_bidirectional_attention: str | None = None

Source: gemma/gm/nn/gemma4/_config.py in google-deepmind/gemma.

The transformer creates the normal causal attention_mask, and creates a separate sliding_attention_mask only when use_bidirectional_attention == 'vision':

if attention_mask is None:
  attention_mask = (
      _attention_mask.make_causal_bidirectional_attention_mask(
          inputs_mask,
          bidirectional_mask=None,
      )
  )

# For models with use_bidirectional_attention='vision' (26B_A4B/31B),
# create a separate sliding attention mask with bidirectional attention
# for image tokens within the same image block.
sliding_attention_mask = None
if self.config.use_bidirectional_attention == 'vision':
  bidirectional_mask = tokens == _token_utils.SOFT_TOKEN_PLACEHOLDER
  sliding_attention_mask = (
      _attention_mask.make_causal_bidirectional_attention_mask(
          inputs_mask,
          bidirectional_mask=bidirectional_mask,
      )
  )

Source: gemma/gm/nn/gemma4/_transformer.py in google-deepmind/gemma.

Then the model uses that separate mask only for local sliding blocks:

attn_mask = inputs.attention_mask
if (
    inputs.sliding_attention_mask is not None
    and block.attn_type == _modules.AttentionType.LOCAL_SLIDING
):
  attn_mask = inputs.sliding_attention_mask

Source: gemma/gm/nn/gemma4/_transformer.py in google-deepmind/gemma.

Finally, local sliding attention intersects the selected attention mask with the sliding window:

if self.attn_type == AttentionType.LOCAL_SLIDING:
  if self.sliding_window_size is None:
    raise ValueError(
        'Sliding_window_size must be set if Local Sliding attention type'
    )
  sliding_mask = _create_sliding_mask(
      segment_pos,
      cache_positions=cache_positions,
      sliding_window_size=self.sliding_window_size,
  )
  # [batch_size, seq_len, cache_size]
  attn_mask *= sliding_mask

Source: gemma/gm/nn/gemma4/_modules.py in google-deepmind/gemma.

So this reference implementation seems to implement:

global/full: causal
local_sliding: (causal OR vision_bidirectional) AND sliding_window

Reference implementation: vLLM

I also checked vllm-project/vllm, locally at commit 88d34c6409 on releases/v0.20.0.

vLLM appears to encode the same full-vs-sliding distinction explicitly. It precomputes non-sliding/full layer indices when use_bidirectional_attention == "vision":

# --- Precompute full-attention layer indices for bidi clearing ---
self._full_attn_layer_idxs: frozenset[int] = frozenset()
text_config = config.text_config
if getattr(text_config, "use_bidirectional_attention", None) == "vision":
    layer_types = getattr(text_config, "layer_types", None)
    if layer_types:
        self._full_attn_layer_idxs = frozenset(
            i for i, lt in enumerate(layer_types) if lt != "sliding_attention"
        )

Source: vllm/model_executor/models/gemma4_mm.py in vllm-project/vllm.

Before running decoder layers, vLLM clears the multimodal prefix range for full-attention layers:

# Gemma4 bidi: clear mm_prefix_range for full_attention layers.
# Must run here (outside @support_torch_compile boundary) because
# _run_decoder_layers is inside a compiled graph where Python
# side effects are eliminated.
self._clear_mm_prefix_for_full_attn_layers()

And the helper documents the intended behavior:

def _clear_mm_prefix_for_full_attn_layers(self) -> None:
    """Clear mm_prefix_range for non-sliding layers.

    Gemma4 with use_bidirectional_attention='vision' applies
    bidirectional attention only to sliding_attention layers.
    Full attention layers use plain causal masking.

    Uses _full_attn_layer_idxs (precomputed in __init__) for O(1)
    lookup instead of per-call regex parsing.
    """

Source: vllm/model_executor/models/gemma4_mm.py in vllm-project/vllm.

vLLM also filters oversized image ranges before installing the multimodal prefix/bidirectional range:

# Gemma4 bidi: skip ranges that exceed the sliding
# window. When image tokens > sliding_window, bidi causes
# early image tokens to attend to the entire image
# (e.g. 6 → 1092 targets), degrading spatial precision.
# Per-range filtering keeps bidi for small images/video
# frames while skipping oversized images.
hf_text_config = self.model_config.hf_text_config
_bidi_sw = getattr(hf_text_config, "sliding_window", None)

for req_id in self.input_batch.req_ids:
    image_doc_ranges = []
    req_state = self.requests[req_id]
    for mm_feature in req_state.mm_features:
        pos_info = mm_feature.mm_position
        img_doc_range = pos_info.extract_embeds_range()
        for r in img_doc_range:
            if _bidi_sw is not None and (r[1] - r[0] + 1) > _bidi_sw:
                continue
            image_doc_ranges.append(r)

Source: vllm/v1/worker/gpu_model_runner.py in vllm-project/vllm.

This makes vLLM behavior closer to:

full_attention:
  causal

sliding_attention:
  (causal AND sliding_window) OR vision_bidirectional,
  but only for image ranges whose length is <= sliding_window

Minimal example

Assume:

sliding_window = 1024
image tokens = positions 0..1099

For an image query token at position 0 and image key token at position 1099:

  • In Transformers sliding_attention, they are visible to each other because they are in the same block_sequence_ids group.
  • In official google/gemma, they should not be visible in a local sliding layer because the final mask is intersected with the sliding window.
  • In vLLM, this image range has length 1100 > 1024, so the multimodal bidirectional range is skipped and the pair is not visible through vision-bidi.

For full/global layers:

  • In Transformers, image tokens inside the same block are also bidirectionally visible.
  • In official google/gemma, full/global layers appear to remain causal.
  • In vLLM, full attention layers explicitly clear mm_prefix_range and remain causal.

For normal output text tokens, I believe all implementations still obey the causal sliding window in local sliding layers. The question here is specifically about the image-token blockwise bidirectional overlay and whether it should apply to full layers or bypass the sliding-window boundary.

Expected behavior / question

Could you confirm the intended Gemma4 behavior for use_bidirectional_attention="vision"?

Should Transformers match the official google/gemma behavior:

full_attention:
  causal

sliding_attention:
  (causal OR vision_blockwise_overlay) AND sliding_window

or is the current Transformers behavior intentional?

Environment

  • Transformers version: current main branch
  • Models: Gemma4 / Gemma4Unified
  • Reference implementations checked:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions