-
Notifications
You must be signed in to change notification settings - Fork 70
Extends penzai to support gemma3 models. #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
a05150e
eb8512b
f15251e
cbeb885
f2421fe
97c9e95
3dfc906
0be6eec
0e65e85
15b1163
d025344
4659231
6b192f8
41f43df
b3e31ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -217,9 +217,9 @@ You can read more about Penzai's conventions for layers in ["How to Think in Pen | |
|
|
||
| ## Loading Pretrained Models | ||
|
|
||
| ### Loading Gemma or Gemma 2 | ||
| ### Loading Gemma or Gemma 2 or Gemma 3 | ||
|
|
||
| Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2)) into the correct form. You can load it using: | ||
| Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2), [Gemma 3](https://www.kaggle.com/models/google/gemma-3)) into the correct form. You can load it using: | ||
|
|
||
| ```python | ||
| import kagglehub | ||
|
|
@@ -236,13 +236,20 @@ flax_params_dict = checkpointer.restore(ckpt_path) | |
| model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict) | ||
| ``` | ||
|
|
||
| To load Gemma 2, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use: | ||
| To load Gemma 2/3, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use: | ||
|
|
||
| ```python | ||
| weights_dir = kagglehub.model_download('google/gemma-2/flax/gemma2-9b') | ||
| ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt') | ||
| ``` | ||
|
|
||
| For instance, to load the Gemma 3 4B model, you can use: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can you make this just
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have fixed it. |
||
|
|
||
| ```python | ||
| weights_dir = kagglehub.model_download('google/gemma-3/flax/gemma3-4b') | ||
| ckpt_path = os.path.join(weights_dir, 'gemma3_4b_pt') | ||
| ``` | ||
|
|
||
| See the "Model Variations" section on the Kaggle model pages for details about the names and paths for each checkpoint. (You may also need to create a Kaggle account and request access to each model before you can download the checkpoints.) | ||
|
|
||
| If you are using multiple accelerator devices (e.g. for a TPU v2 Colab kernel), you may want to shard the parameters over the devices while loading them. To do so, you can pass a sharding specification to `orbax.checkpoint`. For instance, to shard over the last axis of every parameter, you can use | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,13 +14,14 @@ | |
|
|
||
| """The Gemma architecture transformer variant. | ||
|
|
||
| Supports both the Gemma 1 and Gemma 2 architectures. Based on the Flax | ||
| reference implementation at https://github.com/google-deepmind/gemma. | ||
| Supports all the Gemma 1, Gemma 2 and Gemma 3 architectures. Based on the | ||
| Flax reference implementation at https://github.com/google-deepmind/gemma. | ||
|
|
||
| See the Gemma technical reports for more information: | ||
|
|
||
| * Gemma 1: https://arxiv.org/abs/2403.08295 | ||
| * Gemma 2: https://arxiv.org/abs/2408.00118 | ||
| * Gemma 3: https://arxiv.org/abs/2503.19786 | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
@@ -33,6 +34,20 @@ | |
| from penzai.models.transformer.variants import llamalike_common | ||
|
|
||
|
|
||
| def make_attention_layers_types( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can you add an underscore at the beginning to make this private (
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have fixed it. |
||
| pattern: tuple[llamalike_common.AttentionType, ...], | ||
| *, | ||
| num_layers: int, | ||
| ) -> tuple[llamalike_common.AttentionType, ...]: | ||
| """Returns the list of attention types for every layers.""" | ||
|
|
||
| pattern_size = len(pattern) | ||
| out = pattern * (num_layers // pattern_size) | ||
| if num_layers % pattern_size != 0: | ||
| out += pattern[: num_layers % pattern_size] | ||
| return tuple(out) | ||
|
|
||
|
|
||
| _GEMMA_PRESETS = { | ||
| "gemma_2b": dict( | ||
| num_decoder_blocks=18, | ||
|
|
@@ -105,13 +120,101 @@ | |
| final_logit_softcap=30.0, | ||
| attn_logits_soft_cap=50.0, | ||
| ), | ||
| "gemma3_1b": dict( | ||
| num_decoder_blocks=26, | ||
| vocab_size=262_144, | ||
| num_kv_heads=1, | ||
| query_head_multiplier=4, | ||
| embedding_dim=1152, | ||
| projection_dim=256, | ||
| mlp_hidden_dim=6 * 1152, | ||
| attention_type=make_attention_layers_types( | ||
| pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(512),) | ||
| * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), | ||
| num_layers=26, | ||
| ), | ||
| use_qk_norm=True, | ||
| use_post_attn_norm=True, | ||
| use_post_ffw_norm=True, | ||
| local_rope_wavelength=10_000, | ||
| global_rope_wavelength=1_000_000, | ||
| ), | ||
| "gemma3_4b": dict( | ||
| num_decoder_blocks=34, | ||
| vocab_size=262_144, | ||
| num_kv_heads=4, | ||
| query_head_multiplier=2, | ||
| embedding_dim=2560, | ||
| projection_dim=256, | ||
| mlp_hidden_dim=2560 * 8 // 2, | ||
| attention_type=make_attention_layers_types( | ||
| pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) | ||
| * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), | ||
| num_layers=34, | ||
| ), | ||
| use_qk_norm=True, | ||
| use_post_attn_norm=True, | ||
| use_post_ffw_norm=True, | ||
| local_scale_factor=1.0, | ||
| global_scale_factor=8.0, | ||
| local_rope_wavelength=10_000, | ||
| global_rope_wavelength=1_000_000, | ||
| ), | ||
| "gemma3_12b": dict( | ||
| num_decoder_blocks=48, | ||
| vocab_size=262_144, | ||
| num_kv_heads=8, | ||
| query_head_multiplier=2, | ||
| embedding_dim=30 * 128, | ||
| projection_dim=256, | ||
| mlp_hidden_dim=8 * 30 * 128 // 2, | ||
| attention_type=make_attention_layers_types( | ||
| pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) | ||
| * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), | ||
| num_layers=48, | ||
| ), | ||
| use_qk_norm=True, | ||
| use_post_attn_norm=True, | ||
| use_post_ffw_norm=True, | ||
| local_scale_factor=1.0, | ||
| global_scale_factor=8.0, | ||
| local_rope_wavelength=10_000, | ||
| global_rope_wavelength=1_000_000, | ||
| ), | ||
| "gemma3_27b": dict( | ||
| num_decoder_blocks=62, | ||
| vocab_size=262_144, | ||
| num_kv_heads=16, | ||
| query_head_multiplier=2, | ||
| embedding_dim=5376, | ||
| projection_dim=128, | ||
| mlp_hidden_dim=5376 * 8 // 2, | ||
| # query scaling factor: 1/sqrt(embedding_dim / num_query_heads) | ||
| query_scaling_factor=(5376 // 32) ** -0.5, | ||
| attention_type=make_attention_layers_types( | ||
| pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) | ||
| * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), | ||
| num_layers=34, | ||
| ), | ||
| use_qk_norm=True, | ||
| use_post_attn_norm=True, | ||
| use_post_ffw_norm=True, | ||
| local_scale_factor=1.0, | ||
| global_scale_factor=8.0, | ||
| local_rope_wavelength=10_000, | ||
| global_rope_wavelength=1_000_000, | ||
| ), | ||
| } | ||
| _NEEDS_GATING_TRANSPOSE = { | ||
| "gemma_2b": False, | ||
| "gemma_7b": False, | ||
| "gemma2_2b": False, | ||
| "gemma2_9b": True, | ||
| "gemma2_27b": True, | ||
| "gemma3_1b": True, | ||
| "gemma3_4b": True, | ||
| "gemma3_12b": True, | ||
| "gemma3_27b": True, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -120,7 +223,8 @@ def gemma_from_pretrained_checkpoint( | |
| upcast_activations_to_float32: bool = False, | ||
| use_layer_stack: bool = False, | ||
| preset_name: Literal[ | ||
| "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto" | ||
| "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", | ||
| "gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b", "auto" | ||
| ] = "auto", | ||
| ) -> model_parts.TransformerLM: | ||
| """Builds a Gemma model from a pretrained checkpoint. | ||
|
|
@@ -144,7 +248,8 @@ def gemma_from_pretrained_checkpoint( | |
| without consuming additional memory for parameters. | ||
| use_layer_stack: Whether to use a layer stack for the decoder blocks. | ||
| preset_name: Preset name, used to determine model config. If "auto", uses | ||
| the number of layers in the checkpoint to determine the configuration. | ||
| the number of layers and whether the model needs qk norm in the checkpoint | ||
| to determine the configuration. | ||
|
|
||
| Returns: | ||
| A Transformer model containing the loaded parameters. | ||
|
|
@@ -155,15 +260,30 @@ def gemma_from_pretrained_checkpoint( | |
| num_layers = 0 | ||
| while f"layer_{num_layers}/mlp/linear" in params: | ||
| num_layers += 1 | ||
| preset_by_num_layers = { | ||
| kwargs["num_decoder_blocks"]: preset_name | ||
| for preset_name, kwargs in _GEMMA_PRESETS.items() | ||
| } | ||
| if num_layers not in preset_by_num_layers: | ||
| if ( | ||
| "layer_0/attn/_query_norm" in params | ||
| and "layer_0/attn/_key_norm" in params | ||
| ): | ||
| qk_norm = True | ||
| else: | ||
| qk_norm = False | ||
| is_match = False | ||
| for gemma_preset_name, kwargs in _GEMMA_PRESETS.items(): | ||
| if kwargs["num_decoder_blocks"] == num_layers: | ||
| if qk_norm and "use_qk_norm" in kwargs: | ||
| if kwargs["use_qk_norm"]: | ||
| is_match = True | ||
| preset_name = gemma_preset_name | ||
| break | ||
| if (not qk_norm) and ("use_qk_norm" not in kwargs): | ||
| is_match = True | ||
| preset_name = gemma_preset_name | ||
| break | ||
| if not is_match: | ||
| raise ValueError( | ||
| f"Could not determine preset for model with {num_layers} layers." | ||
| f"Could not determine preset for model with {num_layers} layers and" | ||
| f" qk norm {qk_norm}." | ||
| ) | ||
| preset_name = preset_by_num_layers[num_layers] | ||
|
|
||
| preset_kwargs = _GEMMA_PRESETS[preset_name] | ||
| preset_needs_gating_transpose = _NEEDS_GATING_TRANSPOSE[preset_name] | ||
|
|
@@ -207,6 +327,19 @@ def gemma_from_pretrained_checkpoint( | |
| 1 + params[f"layer_{i}/pre_attention_norm"]["scale"] | ||
| ).tag("embedding") | ||
| ) | ||
| # Add qk norm if needed | ||
| if config.use_qk_norm: | ||
| cur_block_params["attention/query_norm/scale.weights"] = ( | ||
| pz.nx.NamedArray.wrap( | ||
| 1 + params[f"layer_{i}/attn/_query_norm"]["scale"] | ||
| ).tag("projection") | ||
| ) | ||
| cur_block_params["attention/key_norm/scale.weights"] = ( | ||
| pz.nx.NamedArray.wrap( | ||
| 1 + params[f"layer_{i}/attn/_key_norm"]["scale"] | ||
| ).tag("projection") | ||
| ) | ||
|
|
||
| if config.use_post_attn_norm: | ||
| cur_block_params["post_attention_norm/scale.weights"] = ( | ||
| pz.nx.NamedArray.wrap( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you make this "Loading Gemma (1, 2, or 3)"