-
Notifications
You must be signed in to change notification settings - Fork 70
Extends penzai to support gemma3 models. #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a05150e
eb8512b
f15251e
cbeb885
f2421fe
97c9e95
3dfc906
0be6eec
0e65e85
15b1163
d025344
4659231
6b192f8
41f43df
b3e31ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,7 @@ | |
| import dataclasses | ||
| import functools | ||
| from typing import Any, Literal | ||
|
|
||
| from absl import logging | ||
| import jax | ||
| import jax.numpy as jnp | ||
| from penzai import pz | ||
|
|
@@ -102,6 +102,12 @@ class LlamalikeTransformerConfig: | |
| parameter_dtype: Floating dtype to use for all parameters. | ||
| activation_dtype: Floating dtype to use for activations and KV cache tables. | ||
| use_layer_stack: Whether to stack the blocks together using a LayerStack. | ||
| # NOTE: Gemma3 specific parameters | ||
| use_qk_norm: Whether to use QK normalization. | ||
| local_scale_factor: Scale factor for the localRoPE layers. | ||
| global_scale_factor: Scale factor for the gloabl RoPE layers. | ||
| local_rope_wavelength: Wavelength for the local RoPE layers. | ||
| global_rope_wavelength: Wavelength for the globalRoPE layers. | ||
| """ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor, but can we make it so that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I don't think I understand what you mean. Are you saying there's some constraint on what works here? Actually, though, I think the simplest thing to do would be to say that We could annotate it as
|
||
|
|
||
| num_kv_heads: int | ||
|
|
@@ -126,6 +132,12 @@ class LlamalikeTransformerConfig: | |
| parameter_dtype: jax.typing.DTypeLike = jnp.float32 | ||
| activation_dtype: jax.typing.DTypeLike = jnp.float32 | ||
| use_layer_stack: bool = False | ||
| # NOTE: Gemma3 specific parameters | ||
| use_qk_norm: bool = False | ||
| local_scale_factor: float | None = None | ||
| global_scale_factor: float | None = None | ||
| local_rope_wavelength: float | None = None | ||
| global_rope_wavelength: float | None = None | ||
|
|
||
|
|
||
| def build_llamalike_feedforward( | ||
|
|
@@ -261,10 +273,30 @@ def build_llamalike_attention( | |
| sliding_window_size=attention_type.window_size, | ||
| masked_out_value=masked_out_value, | ||
| ) | ||
| # Decide which wavelength to use for local RoPE. | ||
| if config.local_rope_wavelength is not None: | ||
| wavelength = config.local_rope_wavelength | ||
| else: | ||
| wavelength = config.rope_wavelength | ||
| # Decide which scale factor to use for local RoPE. | ||
| if config.local_scale_factor is not None: | ||
| scale_factor = config.local_scale_factor | ||
| else: | ||
| scale_factor = 1.0 | ||
| elif isinstance(attention_type, AttentionTypeGlobalCausal): | ||
| attn_masker = pz.nn.ApplyCausalAttentionMask( | ||
| masked_out_value=masked_out_value, | ||
| ) | ||
| # Decide which wavelength to use for global RoPE. | ||
| if config.global_rope_wavelength is not None: | ||
| wavelength = config.global_rope_wavelength | ||
| else: | ||
| wavelength = config.rope_wavelength | ||
| # Decide which scale factor to use for global RoPE. | ||
| if config.global_scale_factor is not None: | ||
| scale_factor = config.global_scale_factor | ||
| else: | ||
| scale_factor = 1.0 | ||
| else: | ||
| raise ValueError(f"Unsupported attention type {attention_type}") | ||
|
|
||
|
|
@@ -290,42 +322,74 @@ def build_llamalike_attention( | |
| pz.nn.Softmax("kv_seq"), | ||
| ]) | ||
|
|
||
| # add qk norm if needed in the module of input_to_query sublayers | ||
| input_to_query_sublayers = [ | ||
| pz.nn.Linear.from_config( | ||
| name=f"{name}/query", | ||
| init_base_rng=init_base_rng, | ||
| input_axes={"embedding": embedding_dim}, | ||
| output_axes={ | ||
| **common_head_axes, | ||
| **query_only_head_axes, | ||
| "projection": projection_dim, | ||
| }, | ||
| dtype=config.parameter_dtype, | ||
| ), | ||
| ] | ||
| if config.use_qk_norm: | ||
| input_to_query_sublayers.append( | ||
| pz.nn.RMSLayerNorm.from_config( | ||
| name=f"{name}/_query_norm", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove the leading underscore? I'm not sure why the original parameters have an underscore here, but it seems nicer if the Penzai version doesn't have one. The parameter names are already not exactly the same as the Flax version. (Same comment for _key_norm)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have fixed it. |
||
| init_base_rng=init_base_rng, | ||
| across_axes={"projection": config.projection_dim}, | ||
| dtype=config.parameter_dtype, | ||
| epsilon=config.rms_norm_eps, | ||
| ), | ||
| ) | ||
| input_to_query_sublayers.extend([ | ||
| pz.nn.ApplyRoPE( | ||
| positions_input_name="token_positions", | ||
| embedding_axis="projection", | ||
| max_wavelength=wavelength, | ||
| scale_factor=scale_factor, | ||
| ), | ||
| pz.nn.ConstantRescale( | ||
| by=jnp.array(query_scaling_factor, dtype=config.activation_dtype) | ||
| ), | ||
| ]) | ||
|
|
||
| # add qk norm if needed in the module of input_to_key sublayers | ||
| input_to_key_sublayers = [ | ||
| pz.nn.Linear.from_config( | ||
| name=f"{name}/key", | ||
| init_base_rng=init_base_rng, | ||
| input_axes={"embedding": embedding_dim}, | ||
| output_axes={**common_head_axes, "projection": projection_dim}, | ||
| dtype=config.parameter_dtype, | ||
| ), | ||
| ] | ||
| if config.use_qk_norm: | ||
| input_to_key_sublayers.append( | ||
| pz.nn.RMSLayerNorm.from_config( | ||
| name=f"{name}/_key_norm", | ||
| init_base_rng=init_base_rng, | ||
| across_axes={"projection": config.projection_dim}, | ||
| dtype=config.parameter_dtype, | ||
| epsilon=config.rms_norm_eps, | ||
| ), | ||
| ) | ||
| input_to_key_sublayers.append( | ||
| pz.nn.ApplyRoPE( | ||
| positions_input_name="token_positions", | ||
| embedding_axis="projection", | ||
| max_wavelength=wavelength, | ||
| scale_factor=scale_factor, | ||
| ), | ||
| ) | ||
|
|
||
| return pz.nn.Attention( | ||
| input_to_query=pz.nn.Sequential([ | ||
| pz.nn.Linear.from_config( | ||
| name=f"{name}/query", | ||
| init_base_rng=init_base_rng, | ||
| input_axes={"embedding": embedding_dim}, | ||
| output_axes={ | ||
| **common_head_axes, | ||
| **query_only_head_axes, | ||
| "projection": projection_dim, | ||
| }, | ||
| dtype=config.parameter_dtype, | ||
| ), | ||
| pz.nn.ApplyRoPE( | ||
| positions_input_name="token_positions", | ||
| embedding_axis="projection", | ||
| max_wavelength=config.rope_wavelength, | ||
| ), | ||
| pz.nn.ConstantRescale( | ||
| by=jnp.array(query_scaling_factor, dtype=config.activation_dtype) | ||
| ), | ||
| ]), | ||
| input_to_key=pz.nn.Sequential([ | ||
| pz.nn.Linear.from_config( | ||
| name=f"{name}/key", | ||
| init_base_rng=init_base_rng, | ||
| input_axes={"embedding": embedding_dim}, | ||
| output_axes={**common_head_axes, "projection": projection_dim}, | ||
| dtype=config.parameter_dtype, | ||
| ), | ||
| pz.nn.ApplyRoPE( | ||
| positions_input_name="token_positions", | ||
| embedding_axis="projection", | ||
| max_wavelength=config.rope_wavelength, | ||
| ), | ||
| ]), | ||
| input_to_query=pz.nn.Sequential(input_to_query_sublayers), | ||
| input_to_key=pz.nn.Sequential(input_to_key_sublayers), | ||
| input_to_value=pz.nn.Sequential([ | ||
| pz.nn.Linear.from_config( | ||
| name=f"{name}/value", | ||
|
|
@@ -483,9 +547,10 @@ def build_llamalike_transformer( | |
| else: | ||
| if not isinstance(config.attention_type, AttentionType): | ||
| if config.num_decoder_blocks % len(config.attention_type) != 0: | ||
| raise ValueError( | ||
| "Per-layer attention types must have a length that divides the" | ||
| " number of blocks." | ||
| logging.warning( | ||
| "Please ensure that you are using Gemma3 models." | ||
| "For other models, per-layer attention types must have a length " | ||
| "that divides the number of blocks." | ||
| ) | ||
|
Comment on lines
549
to
543
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, this seems less safe and also pretty confusing for users. I don't think we should bypass this check. Instead, can you do the adjustment in the (Motivation here is that we don't want someone to accidentally mess up their config and end up with a different pattern of attention layers than they expected. It's pretty obvious what should happen when attention types divides number of blocks, but allowing e.g. off-by-one errors seems like it could be a footgun.)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your suggestions. I have remained the original check. Instead, I follow |
||
| for block_index in range(config.num_decoder_blocks): | ||
| sublayers.append( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is too bad that this is a breaking change in the function signature, since this means existing code will no longer work. Is there some way to do this in a backwards compatible way?
I think it's OK if "auto" does not allow loading gemma 3 models, but it would be nice if it was still possible for us to load gemma 1 and gemma 2 in "auto" mode. Maybe there are differences in the parameter names that we can use, like _query_norm?
Ideal solution would be something like:
(Probably long term it makes sense to just require the preset to be specified directly, but I'd prefer not to make breaking changes too often if possible.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion, now I write code to "auto" load gemma 3 models by checking whether the model has qk norm.