Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions keras_hub/src/utils/transformers/export/gemma3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import keras.ops as ops


def get_gemma3_config(backbone):
"""Convert Keras Gemma3 config to Hugging Face config dictionary."""
token_embedding_layer = backbone.get_layer("token_embedding")
hf_config = {
"architectures": ["Gemma3ForCausalLM"],
"model_type": "gemma3_text",
"vocab_size": backbone.vocabulary_size,
"num_hidden_layers": backbone.num_layers,
"num_attention_heads": backbone.num_query_heads,
"num_key_value_heads": backbone.num_key_value_heads,
"hidden_size": backbone.hidden_dim,
"intermediate_size": backbone.intermediate_dim,
"head_dim": backbone.head_dim,
"max_position_embeddings": 32768,
"tie_word_embeddings": token_embedding_layer.tie_weights,
"rms_norm_eps": 1e-6,
"rope_theta": 10000.0,
"attention_bias": False,
"attention_dropout": 0.0,
"hidden_activation": "gelu_pytorch_tanh",
}
return hf_config


def get_gemma3_weights_map(backbone, include_lm_head=False):
"""Convert a Keras Gemma3 model to Hugging Face format.

include_lm_head: If True, exports for CausalLM (with "model." prefix).
If False, exports for backbone only (without prefix).
"""

weights_dict = {}

# For CausalLM export, use "model." prefix
# For backbone export, use no prefix
prefix = "model." if include_lm_head else ""

# Token embeddings - use .weights[0] to get backend tensor
token_embedding_layer = backbone.get_layer("token_embedding")
token_embedding = token_embedding_layer.weights[0]
weights_dict[f"{prefix}embed_tokens.weight"] = token_embedding

for i in range(backbone.num_layers):
block = backbone.get_layer(f"decoder_block_{i}")

# Attention query projection
q_kernel = block.attention.query_dense.weights[0]
q_kernel = ops.transpose(q_kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
q_kernel = ops.reshape(q_kernel, (backbone.hidden_dim, -1))
q_kernel = ops.transpose(q_kernel) # .T
weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel

# Attention key projection
k_kernel = block.attention.key_dense.weights[0]
k_kernel = ops.transpose(k_kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
k_kernel = ops.reshape(k_kernel, (backbone.hidden_dim, -1))
k_kernel = ops.transpose(k_kernel) # .T
weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel

# Attention value projection
v_kernel = block.attention.value_dense.weights[0]
v_kernel = ops.transpose(v_kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
v_kernel = ops.reshape(v_kernel, (backbone.hidden_dim, -1))
v_kernel = ops.transpose(v_kernel) # .T
weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel

# Attention output projection
o_kernel = block.attention.output_dense.weights[0]
o_kernel = ops.transpose(o_kernel, axes=(2, 0, 1)) # permute(2, 0, 1)
o_kernel = ops.reshape(o_kernel, (backbone.hidden_dim, -1))
weights_dict[f"{prefix}layers.{i}.self_attn.o_proj.weight"] = o_kernel

# Query and key normalization
q_norm = block.attention.query_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.self_attn.q_norm.weight"] = q_norm

k_norm = block.attention.key_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.self_attn.k_norm.weight"] = k_norm

# MLP gate projection
gate_kernel = block.gating_ffw.weights[0]
gate_kernel = ops.transpose(gate_kernel) # .T
weights_dict[f"{prefix}layers.{i}.mlp.gate_proj.weight"] = gate_kernel

# MLP up projection
up_kernel = block.gating_ffw_2.weights[0]
up_kernel = ops.transpose(up_kernel) # .T
weights_dict[f"{prefix}layers.{i}.mlp.up_proj.weight"] = up_kernel

# MLP down projection
down_kernel = block.ffw_linear.weights[0]
down_kernel = ops.transpose(down_kernel) # .T
weights_dict[f"{prefix}layers.{i}.mlp.down_proj.weight"] = down_kernel

# Pre-attention normalization
input_layer_norm = block.pre_attention_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.input_layernorm.weight"] = (
input_layer_norm
)

# Post-attention normalization
if hasattr(block, "post_attention_norm"):
post_attn_norm = block.post_attention_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_attention_norm doesn't exist
post_attn_norm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = (
post_attn_norm
)

# Pre-feedforward normalization
pre_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = (
pre_feedforward_layernorm
)

# Post-feedforward normalization
if hasattr(block, "post_ffw_norm"):
post_feedforward_layernorm = block.post_ffw_norm.weights[0]
else:
# Fallback to pre_ffw_norm if post_ffw_norm doesn't exist
post_feedforward_layernorm = block.pre_ffw_norm.weights[0]
weights_dict[
f"{prefix}layers.{i}.post_feedforward_layernorm.weight"
] = post_feedforward_layernorm

# Final normalization
final_norm = backbone.get_layer("final_normalization").weights[0]
weights_dict[f"{prefix}norm.weight"] = final_norm

if include_lm_head and not token_embedding_layer.tie_weights:
weights_dict["lm_head.weight"] = ops.transpose(
token_embedding_layer.reverse_embeddings
)

return weights_dict


def get_gemma3_tokenizer_config(tokenizer):
tokenizer_config = {
"tokenizer_class": "GemmaTokenizer",
"clean_up_tokenization_spaces": False,
"bos_token": "<bos>",
"eos_token": "<eos>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"add_bos_token": True,
"add_eos_token": False,
"model_max_length": 32768,
}
# Add added_tokens_decoder
added_tokens_decoder = {}
special_tokens = [
"<pad>",
"<bos>",
"<eos>",
"<unk>",
"<start_of_image>",
"<end_of_image>",
"<img>",
]
for token in special_tokens:
token_id = tokenizer.token_to_id(token)
if token_id is not None:
added_tokens_decoder[str(token_id)] = {
"content": token,
"special": True,
"single_word": False,
"lstrip": False,
"rstrip": False,
"normalized": False,
}
tokenizer_config["added_tokens_decoder"] = added_tokens_decoder
return tokenizer_config
161 changes: 161 additions & 0 deletions keras_hub/src/utils/transformers/export/gemma3_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import os

import numpy as np
from transformers import AutoModel
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM
from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import (
Gemma3CausalLMPreprocessor,
)
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
from keras_hub.src.tests.test_case import TestCase


class TestGemma3Export(TestCase):
def test_export_to_hf(self):
proto = os.path.join(self.get_test_data_dir(), "gemma3_test_vocab.spm")
tokenizer = Gemma3Tokenizer(proto=proto)

# Create a small backbone (text-only, no vision encoder)
backbone = Gemma3Backbone(
vocabulary_size=tokenizer.vocabulary_size(),
image_size=896, # Default value even for text-only
num_layers=2,
num_query_heads=4,
num_key_value_heads=1,
hidden_dim=512,
intermediate_dim=1028,
head_dim=128,
query_head_dim_normalize=True,
use_query_key_norm=True,
use_post_ffw_norm=True, # Real Gemma3 models have these
use_post_attention_norm=True, # Real Gemma3 models have these
attention_logit_soft_cap=None,
final_logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
vision_encoder=None, # Text-only model for testing
layer_norm_epsilon=1e-6,
dropout=0,
)

# Create preprocessor
preprocessor = Gemma3CausalLMPreprocessor(tokenizer=tokenizer)

# Create the causal LM model
keras_model = Gemma3CausalLM(
backbone=backbone, preprocessor=preprocessor
)

# Set all weights to random values
rng = np.random.default_rng(42)
weights = keras_model.get_weights()
for i in range(len(weights)):
weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype)
keras_model.set_weights(weights)

# Export to Hugging Face format using the new methods
export_path_backbone = os.path.join(
self.get_temp_dir(), "export_backbone"
)
backbone.export_to_transformers(export_path_backbone)

export_path_tokenizer = os.path.join(
self.get_temp_dir(), "export_tokenizer"
)
preprocessor.tokenizer.export_to_transformers(export_path_tokenizer)

export_path_task = os.path.join(self.get_temp_dir(), "export_task")
keras_model.export_to_transformers(export_path_task)

# Load Hugging Face models and tokenizer
# Note: We only test the slow tokenizer because the test vocab file
# may not be compatible with fast tokenizer conversion
hf_backbone = AutoModel.from_pretrained(export_path_backbone)
hf_tokenizer_slow = AutoTokenizer.from_pretrained(
export_path_tokenizer, use_fast=False
)
hf_full_model = AutoModelForCausalLM.from_pretrained(export_path_task)

# Verify configuration
hf_config = hf_backbone.config
self.assertEqual(
hf_config.vocab_size,
backbone.vocabulary_size,
"Vocabulary sizes do not match",
)
self.assertEqual(
hf_config.num_hidden_layers,
backbone.num_layers,
"Number of layers do not match",
)
self.assertEqual(
hf_config.num_attention_heads,
backbone.num_query_heads,
"Number of query heads do not match",
)
self.assertEqual(
hf_config.num_key_value_heads,
backbone.num_key_value_heads,
"Number of key value heads do not match",
)
self.assertEqual(
hf_config.hidden_size,
backbone.hidden_dim,
"Hidden dimensions do not match",
)
self.assertEqual(
hf_config.intermediate_size,
backbone.intermediate_dim,
"Intermediate sizes do not match",
)
self.assertEqual(
hf_config.head_dim,
backbone.head_dim,
"Head dimensions do not match",
)
self.assertEqual(
hf_config.max_position_embeddings,
32768,
"Max position embeddings do not match",
)
self.assertEqual(
hf_config.tie_word_embeddings,
backbone.token_embedding.tie_weights,
"Tie word embeddings do not match",
)

# Verify tokenizer compatibility (using slow tokenizer)
self.assertEqual(
hf_tokenizer_slow.vocab_size,
tokenizer.vocabulary_size(),
"Tokenizer vocabulary sizes do not match",
)

# Compare generated outputs using full model
prompt = "the quick"

# Generate with Keras model
keras_output = keras_model.generate(prompt, max_length=20)

# Generate with HuggingFace model using slow tokenizer
input_ids_slow = hf_tokenizer_slow.encode(prompt, return_tensors="pt")
output_ids_slow = hf_full_model.generate(
input_ids_slow, max_length=20, do_sample=False
)
hf_slow_output = hf_tokenizer_slow.decode(
output_ids_slow[0], skip_special_tokens=True
)

# Debug print to see the actual outputs
print(f"Keras output: '{keras_output}'")
print(f"HF slow output: '{hf_slow_output}'")

self.assertEqual(
keras_output,
hf_slow_output,
"Generated outputs do not match",
)
10 changes: 10 additions & 0 deletions keras_hub/src/utils/transformers/export/hf_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,29 @@
get_gemma_tokenizer_config,
)
from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map
from keras_hub.src.utils.transformers.export.gemma3 import get_gemma3_config
from keras_hub.src.utils.transformers.export.gemma3 import (
get_gemma3_tokenizer_config,
)
from keras_hub.src.utils.transformers.export.gemma3 import (
get_gemma3_weights_map,
)

MODEL_CONFIGS = {
"GemmaBackbone": get_gemma_config,
"Gemma3Backbone": get_gemma3_config,
# Add for future models, e.g., "MistralBackbone": get_mistral_config
}

MODEL_EXPORTERS = {
"GemmaBackbone": get_gemma_weights_map,
"Gemma3Backbone": get_gemma3_weights_map,
# Add for future models, e.g., "MistralBackbone": get_mistral_weights_map
}

MODEL_TOKENIZER_CONFIGS = {
"GemmaTokenizer": get_gemma_tokenizer_config,
"Gemma3Tokenizer": get_gemma3_tokenizer_config,
# Add for future models, e.g., "MistralTokenizer":
# get_mistral_tokenizer_config
}
Expand Down
Loading
Loading