Skip to content

Commit 165c023

Browse files
authored
[SmolLM3] Add Backbone, CausalLM + Converter for HuggingFace Weights (#2327)
* add first few utils * add eager attention forward * Add SmolLM3Attention * Add SmolLM3MLP * Add SmolLM3DecoderLayer * remove unnecessary comments * Add SmolLM3RotaryEmbedding * add most of smollm3backbone * Fix calls within causal model * Move causal mask computation to forward call * Fix rope and caching indexing * Remove unnecessary trimming of cache padding * Remove type hints, expad docstrings * Add basic tests * Run linter
1 parent dd7dc05 commit 165c023

File tree

12 files changed

+1928
-0
lines changed

12 files changed

+1928
-0
lines changed

keras_hub/api/models/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,30 @@
649649
from keras_hub.src.models.siglip.siglip_vision_encoder import (
650650
SigLIPVisionEncoder as SigLIPVisionEncoder,
651651
)
652+
from keras_hub.src.models.smollm3.smollm3_backbone import (
653+
SmolLM3Backbone as SmolLM3Backbone,
654+
)
655+
from keras_hub.src.models.smollm3.smollm3_backbone import (
656+
SmolLM3Backbone as SmolLMBackbone,
657+
)
658+
from keras_hub.src.models.smollm3.smollm3_causal_lm import (
659+
SmolLM3CausalLM as SmolLM3CausalLM,
660+
)
661+
from keras_hub.src.models.smollm3.smollm3_causal_lm import (
662+
SmolLM3CausalLM as SmolLMCausalLM,
663+
)
664+
from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import (
665+
SmolLM3CausalLMPreprocessor as SmolLM3CausalLMPreprocessor,
666+
)
667+
from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import (
668+
SmolLM3CausalLMPreprocessor as SmolLMCausalLMPreprocessor,
669+
)
670+
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
671+
SmolLM3Tokenizer as SmolLM3Tokenizer,
672+
)
673+
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
674+
SmolLM3Tokenizer as SmolLMTokenizer,
675+
)
652676
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
653677
StableDiffusion3Backbone as StableDiffusion3Backbone,
654678
)

keras_hub/api/tokenizers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@
9393
from keras_hub.src.models.siglip.siglip_tokenizer import (
9494
SigLIPTokenizer as SigLIPTokenizer,
9595
)
96+
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
97+
SmolLM3Tokenizer as SmolLM3Tokenizer,
98+
)
99+
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
100+
SmolLM3Tokenizer as SmolLMTokenizer,
101+
)
96102
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
97103
from keras_hub.src.models.t5gemma.t5gemma_tokenizer import (
98104
T5GemmaTokenizer as T5GemmaTokenizer,
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.layers.modeling.reversible_embedding import (
5+
ReversibleEmbedding,
6+
)
7+
from keras_hub.src.models.backbone import Backbone
8+
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
9+
10+
11+
@keras_hub_export(
12+
[
13+
"keras_hub.models.SmolLM3Backbone",
14+
"keras_hub.models.SmolLMBackbone",
15+
]
16+
)
17+
class SmolLM3Backbone(Backbone):
18+
"""SmolLM3 core network with hyperparameters.
19+
20+
This network implements a Transformer-based decoder network,
21+
SmolLM3, as described in the SmolLM3 model architecture.
22+
It includes the embedding lookups and transformer layers.
23+
24+
The default constructor gives a fully customizable, randomly initialized
25+
SmolLM3 model with any number of layers, heads, and embedding
26+
dimensions. To load preset architectures and weights, use the `from_preset`
27+
constructor.
28+
29+
Args:
30+
vocabulary_size: int. The size of the token vocabulary.
31+
hidden_dim: int. The size of the transformer hidden state at the end
32+
of each transformer layer.
33+
intermediate_dim: int. The output dimension of the first Dense layer in
34+
the MLP network of each transformer layer.
35+
num_layers: int. The number of transformer layers.
36+
num_attention_heads: int. The number of attention heads for each
37+
transformer layer.
38+
num_key_value_heads: int. The number of key-value heads for grouped
39+
query attention in each transformer layer.
40+
attention_bias: bool. Whether to use bias in the query, key, value, and
41+
output projection layers in the attention blocks.
42+
attention_dropout: float. Dropout probability for the attention layers.
43+
rope_layer_enabled_list: list of bool. List indicating whether RoPE
44+
(Rotary Position Embedding) is enabled for each layer. Typically,
45+
some layers may disable RoPE for architectural variations.
46+
layer_types: list of str. List of layer types for each transformer
47+
layer (e.g., "attention" or other custom types).
48+
mlp_bias: bool. Whether to use bias in the MLP (feedforward) layers.
49+
layer_norm_epsilon: float. Epsilon value for layer normalization layers
50+
to prevent division by zero.
51+
max_position_embeddings: int. The maximum sequence length that this
52+
model might ever be used with.
53+
rope_theta: float. The base period of the RoPE embeddings.
54+
partial_rotary_factor: float. The percentage of hidden dimensions to
55+
rotate in RoPE. A value of 1.0 rotates all dimensions, while values
56+
less than 1.0 only rotate a subset.
57+
58+
Examples:
59+
60+
```python
61+
input_data = {
62+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
63+
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
64+
}
65+
66+
# Pretrained SmolLM3 decoder.
67+
model = keras_hub.models.SmolLM3Backbone.from_preset(
68+
"hf://HuggingFaceTB/SmolLM3-3B"
69+
)
70+
model(input_data)
71+
72+
# Randomly initialized SmolLM3 decoder with custom config.
73+
model = keras_hub.models.SmolLM3Backbone(
74+
vocabulary_size=49152,
75+
hidden_dim=576,
76+
intermediate_dim=1536,
77+
num_layers=30,
78+
num_attention_heads=9,
79+
num_key_value_heads=3,
80+
attention_bias=False,
81+
attention_dropout=0.0,
82+
rope_layer_enabled_list=[True] * 30,
83+
layer_types=["attention"] * 30,
84+
mlp_bias=False,
85+
layer_norm_epsilon=1e-5,
86+
max_position_embeddings=2048,
87+
rope_theta=10000.0,
88+
partial_rotary_factor=1.0,
89+
)
90+
model(input_data)
91+
```
92+
"""
93+
94+
def __init__(
95+
self,
96+
vocabulary_size,
97+
hidden_dim,
98+
intermediate_dim,
99+
num_layers,
100+
num_attention_heads,
101+
num_key_value_heads,
102+
attention_bias,
103+
attention_dropout,
104+
rope_layer_enabled_list,
105+
layer_types,
106+
mlp_bias,
107+
layer_norm_epsilon,
108+
max_position_embeddings,
109+
rope_theta,
110+
partial_rotary_factor,
111+
**kwargs,
112+
):
113+
# === Layers ===
114+
self.token_embedding = ReversibleEmbedding(
115+
input_dim=vocabulary_size,
116+
output_dim=hidden_dim,
117+
name="token_embedding",
118+
)
119+
self.transformer_layers = []
120+
for i in range(num_layers):
121+
layer = SmolLM3DecoderLayer(
122+
hidden_size=hidden_dim,
123+
num_attention_heads=num_attention_heads,
124+
num_key_value_heads=num_key_value_heads,
125+
attention_bias=attention_bias,
126+
attention_dropout=attention_dropout,
127+
rope_layer_enabled_list=rope_layer_enabled_list,
128+
layer_types=layer_types,
129+
layer_idx=i,
130+
intermediate_size=intermediate_dim,
131+
mlp_bias=mlp_bias,
132+
layer_norm_epsilon=layer_norm_epsilon,
133+
max_position_embeddings=max_position_embeddings,
134+
rope_theta=rope_theta,
135+
partial_rotary_factor=partial_rotary_factor,
136+
name=f"transformer_layer_{i}",
137+
)
138+
self.transformer_layers.append(layer)
139+
140+
self.norm = keras.layers.RMSNormalization(
141+
epsilon=layer_norm_epsilon,
142+
name="sequence_output_layernorm",
143+
)
144+
145+
# === Functional Model ===
146+
token_id_input = keras.Input(
147+
shape=(None,), dtype="int32", name="token_ids"
148+
)
149+
150+
padding_mask_input = keras.Input(
151+
shape=(None,), dtype="int32", name="padding_mask"
152+
)
153+
154+
x = self.token_embedding(token_id_input)
155+
156+
for decoder_layer in self.transformer_layers:
157+
x = decoder_layer(
158+
x,
159+
decoder_padding_mask=padding_mask_input,
160+
**kwargs,
161+
)
162+
163+
sequence_output = self.norm(x)
164+
super().__init__(
165+
inputs={
166+
"token_ids": token_id_input,
167+
"padding_mask": padding_mask_input,
168+
},
169+
outputs=sequence_output,
170+
**kwargs,
171+
)
172+
173+
# === Config ===
174+
self.vocabulary_size = vocabulary_size
175+
self.hidden_dim = hidden_dim
176+
self.intermediate_dim = intermediate_dim
177+
self.num_layers = num_layers
178+
self.num_attention_heads = num_attention_heads
179+
self.num_key_value_heads = num_key_value_heads
180+
self.attention_bias = attention_bias
181+
self.attention_dropout = attention_dropout
182+
self.rope_layer_enabled_list = rope_layer_enabled_list
183+
self.layer_types = layer_types
184+
self.mlp_bias = mlp_bias
185+
self.layer_norm_epsilon = layer_norm_epsilon
186+
self.max_position_embeddings = max_position_embeddings
187+
self.rope_theta = rope_theta
188+
self.partial_rotary_factor = partial_rotary_factor
189+
190+
def get_config(self):
191+
config = super().get_config()
192+
config.update(
193+
{
194+
"vocabulary_size": self.vocabulary_size,
195+
"hidden_dim": self.hidden_dim,
196+
"intermediate_dim": self.intermediate_dim,
197+
"num_layers": self.num_layers,
198+
"num_attention_heads": self.num_attention_heads,
199+
"num_key_value_heads": self.num_key_value_heads,
200+
"attention_bias": self.attention_bias,
201+
"attention_dropout": self.attention_dropout,
202+
"rope_layer_enabled_list": self.rope_layer_enabled_list,
203+
"layer_types": self.layer_types,
204+
"mlp_bias": self.mlp_bias,
205+
"layer_norm_epsilon": self.layer_norm_epsilon,
206+
"max_position_embeddings": self.max_position_embeddings,
207+
"rope_theta": self.rope_theta,
208+
"partial_rotary_factor": self.partial_rotary_factor,
209+
}
210+
)
211+
return config
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
from keras import ops
3+
4+
from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone
5+
from keras_hub.src.tests.test_case import TestCase
6+
7+
8+
class SmolLM3BackboneTest(TestCase):
9+
def setUp(self):
10+
self.init_kwargs = {
11+
"vocabulary_size": 100,
12+
"hidden_dim": 64,
13+
"intermediate_dim": 128,
14+
"num_layers": 2,
15+
"num_attention_heads": 4,
16+
"num_key_value_heads": 2,
17+
"attention_bias": False,
18+
"attention_dropout": 0.0,
19+
"rope_layer_enabled_list": [True, True],
20+
"layer_types": ["attention", "attention"],
21+
"mlp_bias": False,
22+
"layer_norm_epsilon": 1e-5,
23+
"max_position_embeddings": 128,
24+
"rope_theta": 10000.0,
25+
"partial_rotary_factor": 1.0,
26+
}
27+
self.input_data = {
28+
"token_ids": ops.ones((2, 5), dtype="int32"),
29+
"padding_mask": ops.ones((2, 5), dtype="int32"),
30+
}
31+
32+
def test_backbone_basics(self):
33+
self.run_backbone_test(
34+
cls=SmolLM3Backbone,
35+
init_kwargs=self.init_kwargs,
36+
input_data=self.input_data,
37+
expected_output_shape=(2, 5, 64),
38+
run_mixed_precision_check=False,
39+
run_quantization_check=False,
40+
)
41+
42+
@pytest.mark.large
43+
def test_saved_model(self):
44+
self.run_model_saving_test(
45+
cls=SmolLM3Backbone,
46+
init_kwargs=self.init_kwargs,
47+
input_data=self.input_data,
48+
)
49+
50+
def test_num_parameters(self):
51+
model = SmolLM3Backbone(**self.init_kwargs)
52+
# Reference value calculated from the model architecture
53+
self.assertEqual(model.count_params(), 80464)

0 commit comments

Comments
 (0)