Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions keras_hub/src/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization


def _gemma_embedding_initializer():
from keras_hub.src.utils import dist_initializer

return dist_initializer.DistributedVarianceScaling(
scale=1.0,
mode="fan_in",
distribution="untruncated_normal",
)


@keras_hub_export("keras_hub.models.GemmaBackbone")
class GemmaBackbone(Backbone):
"""Gemma core network with hyperparameters.
Expand Down Expand Up @@ -110,11 +120,7 @@ def __init__(
input_dim=vocabulary_size,
output_dim=hidden_dim,
tie_weights=True,
embeddings_initializer=keras.initializers.VarianceScaling(
scale=1.0,
mode="fan_in",
distribution="untruncated_normal",
),
embeddings_initializer=_gemma_embedding_initializer(),
dtype=dtype,
logit_soft_cap=final_logit_soft_cap,
name="token_embedding",
Expand Down
30 changes: 25 additions & 5 deletions keras_hub/src/models/gemma3/gemma3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@
)


# Note: For LoRA or quantization, apply after model loading using
# model.backbone.enable_lora() or model.quantize() respectively.
# Distributed initialization is designed for initial text only model loading.
def _gemma3_embedding_initializer(text_only_model):
if text_only_model:
from keras_hub.src.utils import dist_initializer

return dist_initializer.DistributedVarianceScaling(
scale=1.0,
mode="fan_in",
distribution="untruncated_normal",
)
else:
return keras.initializers.VarianceScaling(
scale=1.0,
mode="fan_in",
distribution="untruncated_normal",
)


@keras_hub_export("keras_hub.models.Gemma3Backbone")
class Gemma3Backbone(Backbone):
"""Gemma3 core network with hyperparameters.
Expand Down Expand Up @@ -202,22 +222,22 @@ def __init__(
**kwargs,
):
# === Layers ===
text_only_model = True if vision_encoder is None else False

self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
tie_weights=True,
embeddings_initializer=keras.initializers.VarianceScaling(
scale=1.0,
mode="fan_in",
distribution="untruncated_normal",
embeddings_initializer=_gemma3_embedding_initializer(
text_only_model
),
dtype=dtype,
logit_soft_cap=final_logit_soft_cap,
name="token_embedding",
)

self.vision_encoder = vision_encoder
text_only_model = True if vision_encoder is None else False

if not text_only_model:
self.interleave_embeddings = Gemma3InterleaveEmbeddings(
num_vision_tokens_per_image=self.vision_encoder.num_vision_tokens_per_image,
Expand Down
8 changes: 7 additions & 1 deletion keras_hub/src/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def _llama_kernel_initializer(stddev=0.02):
return keras.initializers.RandomNormal(stddev=stddev)


def _llama_embedding_initializer(stddev=0.01):
from keras_hub.src.utils import dist_initializer

return dist_initializer.DistributedRandomNormal(stddev=0.01)


@keras_hub_export("keras_hub.models.LlamaBackbone")
class LlamaBackbone(Backbone):
"""
Expand Down Expand Up @@ -111,7 +117,7 @@ def __init__(
input_dim=vocabulary_size,
output_dim=hidden_dim,
tie_weights=tie_word_embeddings,
embeddings_initializer=_llama_kernel_initializer(stddev=0.01),
embeddings_initializer=_llama_embedding_initializer(stddev=0.01),
dtype=dtype,
name="token_embedding",
)
Expand Down
8 changes: 7 additions & 1 deletion keras_hub/src/models/qwen/qwen_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def _qwen_kernel_initializer(stddev=0.02):
return keras.initializers.RandomNormal(stddev=stddev)


def _qwen_embedding_initializer(stddev=0.01):
from keras_hub.src.utils import dist_initializer

return dist_initializer.DistributedRandomNormal(stddev=0.01)


@keras_hub_export(
[
"keras_hub.models.QwenBackbone",
Expand Down Expand Up @@ -114,7 +120,7 @@ def __init__(
input_dim=vocabulary_size,
output_dim=hidden_dim,
tie_weights=tie_word_embeddings,
embeddings_initializer=_qwen_kernel_initializer(stddev=0.01),
embeddings_initializer=_qwen_embedding_initializer(stddev=0.01),
dtype=dtype,
name="token_embedding",
)
Expand Down
8 changes: 7 additions & 1 deletion keras_hub/src/models/qwen3/qwen3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def _qwen3_kernel_initializer(stddev=0.02):
return keras.initializers.RandomNormal(stddev=stddev)


def _qwen3_embedding_initializer(stddev=0.01):
from keras_hub.src.utils import dist_initializer

return dist_initializer.DistributedRandomNormal(stddev=0.01)


@keras_hub_export("keras_hub.models.Qwen3Backbone")
class Qwen3Backbone(Backbone):
"""The Qwen3 Transformer core architecture with hyperparameters.
Expand Down Expand Up @@ -105,7 +111,7 @@ def __init__(
input_dim=vocabulary_size,
output_dim=hidden_dim,
tie_weights=tie_word_embeddings,
embeddings_initializer=_qwen3_kernel_initializer(stddev=0.01),
embeddings_initializer=_qwen3_embedding_initializer(stddev=0.01),
dtype=dtype,
name="token_embedding",
)
Expand Down
Loading
Loading