Skip to content
142 changes: 121 additions & 21 deletions penzai/models/transformer/variants/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

"""The Gemma architecture transformer variant.

Supports both the Gemma 1 and Gemma 2 architectures. Based on the Flax
reference implementation at https://github.com/google-deepmind/gemma.
Supports all the Gemma 1, Gemma 2 and Gemma 3 architectures. Based on the
Flax reference implementation at https://github.com/google-deepmind/gemma.

See the Gemma technical reports for more information:

* Gemma 1: https://arxiv.org/abs/2403.08295
* Gemma 2: https://arxiv.org/abs/2408.00118
* Gemma 3: https://arxiv.org/abs/2503.19786
"""

from __future__ import annotations
Expand Down Expand Up @@ -105,23 +106,124 @@
final_logit_softcap=30.0,
attn_logits_soft_cap=50.0,
),
"gemma3_1b": dict(
num_decoder_blocks=26,
vocab_size=262_144,
num_kv_heads=1,
query_head_multiplier=4,
embedding_dim=1152,
projection_dim=256,
mlp_hidden_dim=6*1152,
attention_type=(
llamalike_common.AttentionTypeSlidingWindowCausal(512),
llamalike_common.AttentionTypeSlidingWindowCausal(512),
llamalike_common.AttentionTypeSlidingWindowCausal(512),
llamalike_common.AttentionTypeSlidingWindowCausal(512),
llamalike_common.AttentionTypeSlidingWindowCausal(512),
llamalike_common.AttentionTypeGlobalCausal(),
),
use_qk_norm=True,
use_post_attn_norm=True,
use_post_ffw_norm=True,
local_rope_wavelength=10_000,
global_rope_wavelength=1_000_000,
),
"gemma3_4b": dict(
num_decoder_blocks=34,
vocab_size=262_144,
num_kv_heads=4,
query_head_multiplier=2,
embedding_dim=2560,
projection_dim=256,
mlp_hidden_dim=2560 * 8 // 2,
attention_type=(
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeGlobalCausal(),
),
use_qk_norm=True,
use_post_attn_norm=True,
use_post_ffw_norm=True,
local_scale_factor=1.0,
global_scale_factor=8.0,
local_rope_wavelength=10_000,
global_rope_wavelength=1_000_000,
),
"gemma3_12b": dict(
num_decoder_blocks=48,
vocab_size=262_144,
num_kv_heads=8,
query_head_multiplier=2,
embedding_dim=30 * 128,
projection_dim=256,
mlp_hidden_dim=8 * 30 * 128 // 2,
attention_type=(
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeGlobalCausal(),
),
use_qk_norm=True,
use_post_attn_norm=True,
use_post_ffw_norm=True,
local_scale_factor=1.0,
global_scale_factor=8.0,
local_rope_wavelength=10_000,
global_rope_wavelength=1_000_000,
),
"gemma3_27b": dict(
num_decoder_blocks=62,
vocab_size=262_144,
num_kv_heads=16,
query_head_multiplier=2,
embedding_dim=5376,
projection_dim=128,
mlp_hidden_dim=5376 * 8 // 2,
# query scaling factor: 1/sqrt(embedding_dim / num_query_heads)
query_scaling_factor=(5376 // 32) ** -0.5,
attention_type=(
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeSlidingWindowCausal(1024),
llamalike_common.AttentionTypeGlobalCausal(),
),
use_qk_norm=True,
use_post_attn_norm=True,
use_post_ffw_norm=True,
local_scale_factor=1.0,
global_scale_factor=8.0,
local_rope_wavelength=10_000,
global_rope_wavelength=1_000_000,
),
}
_NEEDS_GATING_TRANSPOSE = {
"gemma_2b": False,
"gemma_7b": False,
"gemma2_2b": False,
"gemma2_9b": True,
"gemma2_27b": True,
"gemma3_1b": True,
"gemma3_4b": True,
"gemma3_12b": True,
"gemma3_27b": True,
}


def gemma_from_pretrained_checkpoint(
ckpt_params: dict[str, Any],
preset_name: Literal[
"gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b",
"gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b",
],
upcast_activations_to_float32: bool = False,
use_layer_stack: bool = False,
preset_name: Literal[
"gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto"
] = "auto",
) -> model_parts.TransformerLM:
"""Builds a Gemma model from a pretrained checkpoint.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is too bad that this is a breaking change in the function signature, since this means existing code will no longer work. Is there some way to do this in a backwards compatible way?

I think it's OK if "auto" does not allow loading gemma 3 models, but it would be nice if it was still possible for us to load gemma 1 and gemma 2 in "auto" mode. Maybe there are differences in the parameter names that we can use, like _query_norm?

Ideal solution would be something like:

  • keep preset name where it is with "auto" as the default argument
  • check if this is gemma 3 by looking at something about the params
  • if it is gemma 3, raise a ValueError and say that you need to specify preset_name
  • if it is gemma 1 or 2, emit a warning saying you should specify preset name, but then infer it like it is being inferred now

(Probably long term it makes sense to just require the preset to be specified directly, but I'd prefer not to make breaking changes too often if possible.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion, now I write code to "auto" load gemma 3 models by checking whether the model has qk norm.


Expand All @@ -139,32 +241,17 @@ def gemma_from_pretrained_checkpoint(

Args:
ckpt_params: Nested dictionary of weights from the Gemma checkpoint.
preset_name: The name of the Gemma preset to use.
upcast_activations_to_float32: Whether to cast activations to float32 when
the model runs. This allows analyzing activations at higher precision
without consuming additional memory for parameters.
use_layer_stack: Whether to use a layer stack for the decoder blocks.
preset_name: Preset name, used to determine model config. If "auto", uses
the number of layers in the checkpoint to determine the configuration.

Returns:
A Transformer model containing the loaded parameters.
"""
params = {k.removeprefix("transformer/"): v for k, v in ckpt_params.items()}

if preset_name == "auto":
num_layers = 0
while f"layer_{num_layers}/mlp/linear" in params:
num_layers += 1
preset_by_num_layers = {
kwargs["num_decoder_blocks"]: preset_name
for preset_name, kwargs in _GEMMA_PRESETS.items()
}
if num_layers not in preset_by_num_layers:
raise ValueError(
f"Could not determine preset for model with {num_layers} layers."
)
preset_name = preset_by_num_layers[num_layers]

preset_kwargs = _GEMMA_PRESETS[preset_name]
preset_needs_gating_transpose = _NEEDS_GATING_TRANSPOSE[preset_name]

Expand Down Expand Up @@ -207,6 +294,19 @@ def gemma_from_pretrained_checkpoint(
1 + params[f"layer_{i}/pre_attention_norm"]["scale"]
).tag("embedding")
)
# Add qk norm if needed
if config.use_qk_norm:
cur_block_params["attention/_query_norm/scale.weights"] = (
pz.nx.NamedArray.wrap(
1 + params[f"layer_{i}/attn/_query_norm"]["scale"]
).tag("projection")
)
cur_block_params["attention/_key_norm/scale.weights"] = (
pz.nx.NamedArray.wrap(
1 + params[f"layer_{i}/attn/_key_norm"]["scale"]
).tag("projection")
)

if config.use_post_attn_norm:
cur_block_params["post_attention_norm/scale.weights"] = (
pz.nx.NamedArray.wrap(
Expand Down
143 changes: 104 additions & 39 deletions penzai/models/transformer/variants/llamalike_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import dataclasses
import functools
from typing import Any, Literal

from absl import logging
import jax
import jax.numpy as jnp
from penzai import pz
Expand Down Expand Up @@ -102,6 +102,12 @@ class LlamalikeTransformerConfig:
parameter_dtype: Floating dtype to use for all parameters.
activation_dtype: Floating dtype to use for activations and KV cache tables.
use_layer_stack: Whether to stack the blocks together using a LayerStack.
# NOTE: Gemma3 specific parameters
use_qk_norm: Whether to use QK normalization.
local_scale_factor: Scale factor for the localRoPE layers.
global_scale_factor: Scale factor for the gloabl RoPE layers.
local_rope_wavelength: Wavelength for the local RoPE layers.
global_rope_wavelength: Wavelength for the globalRoPE layers.
"""
Copy link
Copy Markdown
Collaborator

@danieldjohnson danieldjohnson Jun 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but can we make it so that rope_wavelength can be None, and build_llamalike_attention checks to make sure either rope_wavelength is set OR both local_rope_wavelength and global_rope_wavelength are set, but not both?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because LlamalikeTransformerConfig is used to transfer the parameters to build_llama_like_attention, we need to first define an object with the dictionary from Gemma 3, at that time, we may need LlamalikeTransformerConfig already set both local_rope_wavelength and global_rope_wavelength. I really appreciate the idea to make it simpler.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't think I understand what you mean. Are you saying there's some constraint on what works here?

Actually, though, I think the simplest thing to do would be to say that rope_wavelength always means the global RoPE wavelength, and just add local_rope_wavelength: float | None = None. Then, for local RoPE, if config.local_rope_wavelength is not None we use config.local_rope_wavelength and otherwise we use config.rope_wavelength. For global RoPE, we always use config.rope_wavelength.

We could annotate it as

rope_wavelength: Wavelength for global RoPE layers (and for local RoPE layers if local_rope_wavelength is not set).
...
local_rope_wavelength: Wavelength for the local RoPE layers. If None, local RoPE layers will use the same wavelength as global RoPE layers (config.rope_wavelength)


num_kv_heads: int
Expand All @@ -126,6 +132,12 @@ class LlamalikeTransformerConfig:
parameter_dtype: jax.typing.DTypeLike = jnp.float32
activation_dtype: jax.typing.DTypeLike = jnp.float32
use_layer_stack: bool = False
# NOTE: Gemma3 specific parameters
use_qk_norm: bool = False
local_scale_factor: float | None = None
global_scale_factor: float | None = None
local_rope_wavelength: float | None = None
global_rope_wavelength: float | None = None


def build_llamalike_feedforward(
Expand Down Expand Up @@ -261,10 +273,30 @@ def build_llamalike_attention(
sliding_window_size=attention_type.window_size,
masked_out_value=masked_out_value,
)
# Decide which wavelength to use for local RoPE.
if config.local_rope_wavelength is not None:
wavelength = config.local_rope_wavelength
else:
wavelength = config.rope_wavelength
# Decide which scale factor to use for local RoPE.
if config.local_scale_factor is not None:
scale_factor = config.local_scale_factor
else:
scale_factor = 1.0
elif isinstance(attention_type, AttentionTypeGlobalCausal):
attn_masker = pz.nn.ApplyCausalAttentionMask(
masked_out_value=masked_out_value,
)
# Decide which wavelength to use for global RoPE.
if config.global_rope_wavelength is not None:
wavelength = config.global_rope_wavelength
else:
wavelength = config.rope_wavelength
# Decide which scale factor to use for global RoPE.
if config.global_scale_factor is not None:
scale_factor = config.global_scale_factor
else:
scale_factor = 1.0
else:
raise ValueError(f"Unsupported attention type {attention_type}")

Expand All @@ -290,42 +322,74 @@ def build_llamalike_attention(
pz.nn.Softmax("kv_seq"),
])

# add qk norm if needed in the module of input_to_query sublayers
input_to_query_sublayers = [
pz.nn.Linear.from_config(
name=f"{name}/query",
init_base_rng=init_base_rng,
input_axes={"embedding": embedding_dim},
output_axes={
**common_head_axes,
**query_only_head_axes,
"projection": projection_dim,
},
dtype=config.parameter_dtype,
),
]
if config.use_qk_norm:
input_to_query_sublayers.append(
pz.nn.RMSLayerNorm.from_config(
name=f"{name}/_query_norm",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the leading underscore? I'm not sure why the original parameters have an underscore here, but it seems nicer if the Penzai version doesn't have one. The parameter names are already not exactly the same as the Flax version. (Same comment for _key_norm)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed it.

init_base_rng=init_base_rng,
across_axes={"projection": config.projection_dim},
dtype=config.parameter_dtype,
epsilon=config.rms_norm_eps,
),
)
input_to_query_sublayers.extend([
pz.nn.ApplyRoPE(
positions_input_name="token_positions",
embedding_axis="projection",
max_wavelength=wavelength,
scale_factor=scale_factor,
),
pz.nn.ConstantRescale(
by=jnp.array(query_scaling_factor, dtype=config.activation_dtype)
),
])

# add qk norm if needed in the module of input_to_key sublayers
input_to_key_sublayers = [
pz.nn.Linear.from_config(
name=f"{name}/key",
init_base_rng=init_base_rng,
input_axes={"embedding": embedding_dim},
output_axes={**common_head_axes, "projection": projection_dim},
dtype=config.parameter_dtype,
),
]
if config.use_qk_norm:
input_to_key_sublayers.append(
pz.nn.RMSLayerNorm.from_config(
name=f"{name}/_key_norm",
init_base_rng=init_base_rng,
across_axes={"projection": config.projection_dim},
dtype=config.parameter_dtype,
epsilon=config.rms_norm_eps,
),
)
input_to_key_sublayers.append(
pz.nn.ApplyRoPE(
positions_input_name="token_positions",
embedding_axis="projection",
max_wavelength=wavelength,
scale_factor=scale_factor,
),
)

return pz.nn.Attention(
input_to_query=pz.nn.Sequential([
pz.nn.Linear.from_config(
name=f"{name}/query",
init_base_rng=init_base_rng,
input_axes={"embedding": embedding_dim},
output_axes={
**common_head_axes,
**query_only_head_axes,
"projection": projection_dim,
},
dtype=config.parameter_dtype,
),
pz.nn.ApplyRoPE(
positions_input_name="token_positions",
embedding_axis="projection",
max_wavelength=config.rope_wavelength,
),
pz.nn.ConstantRescale(
by=jnp.array(query_scaling_factor, dtype=config.activation_dtype)
),
]),
input_to_key=pz.nn.Sequential([
pz.nn.Linear.from_config(
name=f"{name}/key",
init_base_rng=init_base_rng,
input_axes={"embedding": embedding_dim},
output_axes={**common_head_axes, "projection": projection_dim},
dtype=config.parameter_dtype,
),
pz.nn.ApplyRoPE(
positions_input_name="token_positions",
embedding_axis="projection",
max_wavelength=config.rope_wavelength,
),
]),
input_to_query=pz.nn.Sequential(input_to_query_sublayers),
input_to_key=pz.nn.Sequential(input_to_key_sublayers),
input_to_value=pz.nn.Sequential([
pz.nn.Linear.from_config(
name=f"{name}/value",
Expand Down Expand Up @@ -483,9 +547,10 @@ def build_llamalike_transformer(
else:
if not isinstance(config.attention_type, AttentionType):
if config.num_decoder_blocks % len(config.attention_type) != 0:
raise ValueError(
"Per-layer attention types must have a length that divides the"
" number of blocks."
logging.warning(
"Please ensure that you are using Gemma3 models."
"For other models, per-layer attention types must have a length "
"that divides the number of blocks."
)
Comment on lines 549 to 543
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this seems less safe and also pretty confusing for users. I don't think we should bypass this check.

Instead, can you do the adjustment in the _GEMMA_PRESETS constant? So, e.g., for "gemma3_1b", the "attention_type" field should be a tuple of length 26. You can do something like ((...,) * 5 + (...,)) to avoid typing it all out.

(Motivation here is that we don't want someone to accidentally mess up their config and end up with a different pattern of attention layers than they expected. It's pretty obvious what should happen when attention types divides number of blocks, but allowing e.g. off-by-one errors seems like it could be a footgun.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestions. I have remained the original check. Instead, I follow gemma package to have a function of make_attention_layers_types in gemma.py, and then simplify the argument for attention_type.

for block_index in range(config.num_decoder_blocks):
sublayers.append(
Expand Down
Loading
Loading