-
Notifications
You must be signed in to change notification settings - Fork 306
Gemma3 text keras hf checkpoint conversion #2433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kharshith-k
wants to merge
11
commits into
keras-team:master
Choose a base branch
from
kharshith-k:gemma3-text-keras-hf-checkpoint-conversion
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+578
−2
Open
Changes from 9 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
938b52b
Added Safetensors export feature for Text models in Gemma3 Checkpoint…
kharshith-k 1f06acb
Modified convert_gemma3_checkpoints.py for passing pre-commit
kharshith-k 24c9573
tested safetensors conversion for all the Gemma3 text models
kharshith-k 71bb3af
Added export and test scripts for gemma3 and updated hf_exporter acco…
kharshith-k 85f9498
Fixed pre commit error
kharshith-k 69a7137
updated config to support torch backend
kharshith-k 525da45
Merge branch 'keras-team:master' into gemma3-text-keras-hf-checkpoint…
kharshith-k 06ed2ad
Updated doctsrings
kharshith-k ab1bde1
Updated doctsring
kharshith-k 1ec7222
Updated gemma3 files according to code-assist suggestions
kharshith-k 10b1439
Updated gemma3 files according to code-assist suggestions
kharshith-k File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| import keras.ops as ops | ||
|
|
||
|
|
||
| def get_gemma3_config(backbone): | ||
| """Convert Keras Gemma3 config to Hugging Face config dictionary.""" | ||
| token_embedding_layer = backbone.get_layer("token_embedding") | ||
| hf_config = { | ||
| "architectures": ["Gemma3ForCausalLM"], | ||
| "model_type": "gemma3_text", | ||
| "vocab_size": backbone.vocabulary_size, | ||
| "num_hidden_layers": backbone.num_layers, | ||
| "num_attention_heads": backbone.num_query_heads, | ||
| "num_key_value_heads": backbone.num_key_value_heads, | ||
| "hidden_size": backbone.hidden_dim, | ||
| "intermediate_size": backbone.intermediate_dim, | ||
| "head_dim": backbone.head_dim, | ||
| "max_position_embeddings": 32768, | ||
| "tie_word_embeddings": token_embedding_layer.tie_weights, | ||
| "rms_norm_eps": 1e-6, | ||
| "rope_theta": 10000.0, | ||
| "attention_bias": False, | ||
| "attention_dropout": 0.0, | ||
| "hidden_activation": "gelu_pytorch_tanh", | ||
| } | ||
| return hf_config | ||
|
|
||
|
|
||
| def get_gemma3_weights_map(backbone, include_lm_head=False): | ||
| """Convert a Keras Gemma3 model to Hugging Face format. | ||
|
|
||
| include_lm_head: If True, exports for CausalLM (with "model." prefix). | ||
| If False, exports for backbone only (without prefix). | ||
| """ | ||
|
|
||
| weights_dict = {} | ||
|
|
||
| # For CausalLM export, use "model." prefix | ||
| # For backbone export, use no prefix | ||
| prefix = "model." if include_lm_head else "" | ||
|
|
||
| # Token embeddings - use .weights[0] to get backend tensor | ||
| token_embedding_layer = backbone.get_layer("token_embedding") | ||
| token_embedding = token_embedding_layer.weights[0] | ||
| weights_dict[f"{prefix}embed_tokens.weight"] = token_embedding | ||
|
|
||
| for i in range(backbone.num_layers): | ||
| block = backbone.get_layer(f"decoder_block_{i}") | ||
|
|
||
| # Attention query projection | ||
| q_kernel = block.attention.query_dense.weights[0] | ||
| q_kernel = ops.transpose(q_kernel, axes=(1, 0, 2)) # permute(1, 0, 2) | ||
| q_kernel = ops.reshape(q_kernel, (backbone.hidden_dim, -1)) | ||
| q_kernel = ops.transpose(q_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel | ||
|
|
||
| # Attention key projection | ||
| k_kernel = block.attention.key_dense.weights[0] | ||
| k_kernel = ops.transpose(k_kernel, axes=(1, 0, 2)) # permute(1, 0, 2) | ||
| k_kernel = ops.reshape(k_kernel, (backbone.hidden_dim, -1)) | ||
| k_kernel = ops.transpose(k_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel | ||
|
|
||
| # Attention value projection | ||
| v_kernel = block.attention.value_dense.weights[0] | ||
| v_kernel = ops.transpose(v_kernel, axes=(1, 0, 2)) # permute(1, 0, 2) | ||
| v_kernel = ops.reshape(v_kernel, (backbone.hidden_dim, -1)) | ||
| v_kernel = ops.transpose(v_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel | ||
|
|
||
| # Attention output projection | ||
| o_kernel = block.attention.output_dense.weights[0] | ||
| o_kernel = ops.transpose(o_kernel, axes=(2, 0, 1)) # permute(2, 0, 1) | ||
| o_kernel = ops.reshape(o_kernel, (backbone.hidden_dim, -1)) | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.o_proj.weight"] = o_kernel | ||
|
|
||
| # Query and key normalization | ||
| q_norm = block.attention.query_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.q_norm.weight"] = q_norm | ||
|
|
||
| k_norm = block.attention.key_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.k_norm.weight"] = k_norm | ||
|
|
||
| # MLP gate projection | ||
| gate_kernel = block.gating_ffw.weights[0] | ||
| gate_kernel = ops.transpose(gate_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.mlp.gate_proj.weight"] = gate_kernel | ||
|
|
||
| # MLP up projection | ||
| up_kernel = block.gating_ffw_2.weights[0] | ||
| up_kernel = ops.transpose(up_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.mlp.up_proj.weight"] = up_kernel | ||
|
|
||
| # MLP down projection | ||
| down_kernel = block.ffw_linear.weights[0] | ||
| down_kernel = ops.transpose(down_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.mlp.down_proj.weight"] = down_kernel | ||
|
|
||
| # Pre-attention normalization | ||
| input_layer_norm = block.pre_attention_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.input_layernorm.weight"] = ( | ||
| input_layer_norm | ||
| ) | ||
|
|
||
| # Post-attention normalization | ||
| if hasattr(block, "post_attention_norm"): | ||
| post_attn_norm = block.post_attention_norm.weights[0] | ||
| else: | ||
| # Fallback to pre_ffw_norm if post_attention_norm doesn't exist | ||
| post_attn_norm = block.pre_ffw_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = ( | ||
| post_attn_norm | ||
| ) | ||
|
|
||
| # Pre-feedforward normalization | ||
| pre_feedforward_layernorm = block.pre_ffw_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = ( | ||
| pre_feedforward_layernorm | ||
| ) | ||
|
|
||
| # Post-feedforward normalization | ||
| if hasattr(block, "post_ffw_norm"): | ||
| post_feedforward_layernorm = block.post_ffw_norm.weights[0] | ||
| else: | ||
| # Fallback to pre_ffw_norm if post_ffw_norm doesn't exist | ||
| post_feedforward_layernorm = block.pre_ffw_norm.weights[0] | ||
| weights_dict[ | ||
| f"{prefix}layers.{i}.post_feedforward_layernorm.weight" | ||
| ] = post_feedforward_layernorm | ||
kharshith-k marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Final normalization | ||
| final_norm = backbone.get_layer("final_normalization").weights[0] | ||
| weights_dict[f"{prefix}norm.weight"] = final_norm | ||
|
|
||
| if include_lm_head and not token_embedding_layer.tie_weights: | ||
| weights_dict["lm_head.weight"] = ops.transpose( | ||
| token_embedding_layer.reverse_embeddings | ||
| ) | ||
|
|
||
| return weights_dict | ||
|
|
||
|
|
||
| def get_gemma3_tokenizer_config(tokenizer): | ||
| tokenizer_config = { | ||
| "tokenizer_class": "GemmaTokenizer", | ||
| "clean_up_tokenization_spaces": False, | ||
| "bos_token": "<bos>", | ||
| "eos_token": "<eos>", | ||
| "pad_token": "<pad>", | ||
| "unk_token": "<unk>", | ||
| "add_bos_token": True, | ||
| "add_eos_token": False, | ||
| "model_max_length": 32768, | ||
| } | ||
| # Add added_tokens_decoder | ||
| added_tokens_decoder = {} | ||
| special_tokens = [ | ||
| "<pad>", | ||
| "<bos>", | ||
| "<eos>", | ||
| "<unk>", | ||
| "<start_of_image>", | ||
| "<end_of_image>", | ||
| "<img>", | ||
| ] | ||
| for token in special_tokens: | ||
| token_id = tokenizer.token_to_id(token) | ||
| if token_id is not None: | ||
| added_tokens_decoder[str(token_id)] = { | ||
| "content": token, | ||
| "special": True, | ||
| "single_word": False, | ||
| "lstrip": False, | ||
| "rstrip": False, | ||
| "normalized": False, | ||
| } | ||
| tokenizer_config["added_tokens_decoder"] = added_tokens_decoder | ||
| return tokenizer_config | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| import os | ||
|
|
||
| import numpy as np | ||
| from transformers import AutoModel | ||
| from transformers import AutoModelForCausalLM | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone | ||
| from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM | ||
| from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( | ||
| Gemma3CausalLMPreprocessor, | ||
| ) | ||
| from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer | ||
| from keras_hub.src.tests.test_case import TestCase | ||
|
|
||
|
|
||
| class TestGemma3Export(TestCase): | ||
| def test_export_to_hf(self): | ||
| proto = os.path.join(self.get_test_data_dir(), "gemma3_test_vocab.spm") | ||
| tokenizer = Gemma3Tokenizer(proto=proto) | ||
|
|
||
| # Create a small backbone (text-only, no vision encoder) | ||
| backbone = Gemma3Backbone( | ||
| vocabulary_size=tokenizer.vocabulary_size(), | ||
| image_size=896, # Default value even for text-only | ||
| num_layers=2, | ||
| num_query_heads=4, | ||
| num_key_value_heads=1, | ||
| hidden_dim=512, | ||
| intermediate_dim=1028, | ||
| head_dim=128, | ||
| query_head_dim_normalize=True, | ||
| use_query_key_norm=True, | ||
| use_post_ffw_norm=True, # Real Gemma3 models have these | ||
| use_post_attention_norm=True, # Real Gemma3 models have these | ||
| attention_logit_soft_cap=None, | ||
| final_logit_soft_cap=None, | ||
| use_sliding_window_attention=False, | ||
| sliding_window_size=4096, | ||
| vision_encoder=None, # Text-only model for testing | ||
| layer_norm_epsilon=1e-6, | ||
| dropout=0, | ||
| ) | ||
|
|
||
| # Create preprocessor | ||
| preprocessor = Gemma3CausalLMPreprocessor(tokenizer=tokenizer) | ||
|
|
||
| # Create the causal LM model | ||
| keras_model = Gemma3CausalLM( | ||
| backbone=backbone, preprocessor=preprocessor | ||
| ) | ||
|
|
||
| # Set all weights to random values | ||
| rng = np.random.default_rng(42) | ||
| weights = keras_model.get_weights() | ||
| for i in range(len(weights)): | ||
| weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype) | ||
| keras_model.set_weights(weights) | ||
|
|
||
| # Export to Hugging Face format using the new methods | ||
| export_path_backbone = os.path.join( | ||
| self.get_temp_dir(), "export_backbone" | ||
| ) | ||
| backbone.export_to_transformers(export_path_backbone) | ||
|
|
||
| export_path_tokenizer = os.path.join( | ||
| self.get_temp_dir(), "export_tokenizer" | ||
| ) | ||
| preprocessor.tokenizer.export_to_transformers(export_path_tokenizer) | ||
|
|
||
| export_path_task = os.path.join(self.get_temp_dir(), "export_task") | ||
| keras_model.export_to_transformers(export_path_task) | ||
|
|
||
| # Load Hugging Face models and tokenizer | ||
| # Note: We only test the slow tokenizer because the test vocab file | ||
| # may not be compatible with fast tokenizer conversion | ||
| hf_backbone = AutoModel.from_pretrained(export_path_backbone) | ||
| hf_tokenizer_slow = AutoTokenizer.from_pretrained( | ||
| export_path_tokenizer, use_fast=False | ||
| ) | ||
| hf_full_model = AutoModelForCausalLM.from_pretrained(export_path_task) | ||
|
|
||
| # Verify configuration | ||
| hf_config = hf_backbone.config | ||
| self.assertEqual( | ||
| hf_config.vocab_size, | ||
| backbone.vocabulary_size, | ||
| "Vocabulary sizes do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.num_hidden_layers, | ||
| backbone.num_layers, | ||
| "Number of layers do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.num_attention_heads, | ||
| backbone.num_query_heads, | ||
| "Number of query heads do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.num_key_value_heads, | ||
| backbone.num_key_value_heads, | ||
| "Number of key value heads do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.hidden_size, | ||
| backbone.hidden_dim, | ||
| "Hidden dimensions do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.intermediate_size, | ||
| backbone.intermediate_dim, | ||
| "Intermediate sizes do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.head_dim, | ||
| backbone.head_dim, | ||
| "Head dimensions do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.max_position_embeddings, | ||
| 32768, | ||
| "Max position embeddings do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.tie_word_embeddings, | ||
| backbone.token_embedding.tie_weights, | ||
| "Tie word embeddings do not match", | ||
| ) | ||
|
|
||
| # Verify tokenizer compatibility (using slow tokenizer) | ||
| self.assertEqual( | ||
| hf_tokenizer_slow.vocab_size, | ||
| tokenizer.vocabulary_size(), | ||
| "Tokenizer vocabulary sizes do not match", | ||
| ) | ||
|
|
||
| # Compare generated outputs using full model | ||
| prompt = "the quick" | ||
|
|
||
| # Generate with Keras model | ||
| keras_output = keras_model.generate(prompt, max_length=20) | ||
|
|
||
| # Generate with HuggingFace model using slow tokenizer | ||
| input_ids_slow = hf_tokenizer_slow.encode(prompt, return_tensors="pt") | ||
| output_ids_slow = hf_full_model.generate( | ||
| input_ids_slow, max_length=20, do_sample=False | ||
| ) | ||
| hf_slow_output = hf_tokenizer_slow.decode( | ||
| output_ids_slow[0], skip_special_tokens=True | ||
| ) | ||
|
|
||
| # Debug print to see the actual outputs | ||
| print(f"Keras output: '{keras_output}'") | ||
| print(f"HF slow output: '{hf_slow_output}'") | ||
|
|
||
| self.assertEqual( | ||
| keras_output, | ||
| hf_slow_output, | ||
| "Generated outputs do not match", | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.