Skip to content

Commit 0fde3f7

Browse files
committed
Resolves Gemini comments.
1 parent 9550d79 commit 0fde3f7

File tree

4 files changed

+91
-12
lines changed

4 files changed

+91
-12
lines changed

keras_hub/src/models/dinov3/dinov3_backbone.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,48 @@ class DINOV3Backbone(FeaturePyramidBackbone):
6464
for the models computations and weights. Note that some
6565
computations, such as softmax and layer normalization will always
6666
be done a float32 precision regardless of dtype.
67+
68+
Example:
69+
```python
70+
# Pretrained DINOV3 model.
71+
input_data = {
72+
"images": np.ones(shape=(1, 518, 518, 3), dtype="float32"),
73+
}
74+
model = keras_hub.models.DINOV3Backbone.from_preset(
75+
"dinov3_vit_small_lvd1689m"
76+
)
77+
model(input_data)
78+
79+
# Pretrained DINOV3 model with custom image shape.
80+
input_data = {
81+
"images": np.ones(shape=(1, 224, 224, 3), dtype="float32"),
82+
}
83+
model = keras_hub.models.DINOV3Backbone.from_preset(
84+
"dinov3_vit_small_lvd1689m", image_shape=(224, 224, 3)
85+
)
86+
model(input_data)
87+
88+
# Randomly initialized DINOV3 model with custom config.
89+
model = keras_hub.models.DINOV3Backbone(
90+
patch_size=14,
91+
num_layers=2,
92+
hidden_dim=32,
93+
num_heads=2,
94+
intermediate_dim=128,
95+
image_shape=(224, 224, 3),
96+
)
97+
model(input_data)
98+
99+
# Accessing feature pyramid outputs.
100+
backbone = keras_hub.models.DINOV3Backbone.from_preset(
101+
"dinov3_vit_small_lvd1689m", image_shape=(224, 224, 3)
102+
)
103+
model = keras.Model(
104+
inputs=backbone.inputs,
105+
outputs=backbone.pyramid_outputs,
106+
)
107+
features = model(input_data)
108+
```
67109
"""
68110

69111
def __init__(
@@ -141,7 +183,7 @@ def __init__(
141183

142184
# === Functional Model ===
143185
pyramid_outputs = {}
144-
image_input = layers.Input(shape=image_shape, name="images")
186+
image_input = layers.Input(shape=image_shape, name="pixel_values")
145187
x = self.embeddings(image_input)
146188
pyramid_outputs["stem"] = x
147189

@@ -160,7 +202,7 @@ def __init__(
160202
pyramid_outputs[key] = self.layernorm(pyramid_outputs[key])
161203
outputs = x
162204
super().__init__(
163-
inputs={"images": image_input},
205+
inputs={"pixel_values": image_input},
164206
outputs=outputs,
165207
dtype=dtype,
166208
name=name,

keras_hub/src/models/dinov3/dinov3_backbone_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def setUp(self):
2323
"name": "dinov3_backbone",
2424
}
2525
self.input_data = {
26-
"images": ops.ones((2, 64, 64, 3)),
26+
"pixel_values": ops.ones((2, 64, 64, 3)),
2727
}
2828

2929
def test_backbone_basics(self):
@@ -73,7 +73,7 @@ def test_position_embedding_interpolation(self):
7373
image_shape=(128, 128, 3), # From 64 to 128.
7474
)
7575
input_data = {
76-
"images": ops.ones((2, 128, 128, 3)),
76+
"pixel_values": ops.ones((2, 128, 128, 3)),
7777
}
7878
restored_output = restored_model(input_data)
7979
self.assertNotEqual(model_output.shape, restored_output.shape)

keras_hub/src/models/dinov3/dinov3_layers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ def call(
449449
is_causal=False,
450450
)
451451
attn_output = ops.reshape(attn_output, (batch_size, seq_len, -1))
452+
attn_output = self.dropout(attn_output, training=training)
452453
return self.output_dense(attn_output, training=training)
453454

454455
def get_config(self):
@@ -815,6 +816,7 @@ def call(
815816
attention_mask=None,
816817
position_embeddings=None,
817818
num_prefix_tokens=0,
819+
training=None,
818820
):
819821
residual = inputs
820822
hidden_states = self.norm1(inputs)
@@ -823,17 +825,18 @@ def call(
823825
attention_mask=attention_mask,
824826
position_embeddings=position_embeddings,
825827
num_prefix_tokens=num_prefix_tokens,
828+
training=training,
829+
)
830+
hidden_states = self.layer_scale1(hidden_states, training=training)
831+
hidden_states = (
832+
self.drop_path(hidden_states, training=training) + residual
826833
)
827-
hidden_states = self.layer_scale1(hidden_states)
828-
hidden_states = self.drop_path(hidden_states) + residual
829834

830835
residual = hidden_states
831-
hidden_states = self.norm2(hidden_states)
832-
hidden_states = self.mlp(hidden_states)
833-
hidden_states = self.layer_scale2(hidden_states)
834-
hidden_states = self.drop_path(hidden_states) + residual
835-
836-
return hidden_states
836+
hidden_states = self.norm2(hidden_states, training=training)
837+
hidden_states = self.mlp(hidden_states, training=training)
838+
hidden_states = self.layer_scale2(hidden_states, training=training)
839+
return self.drop_path(hidden_states, training=training) + residual
837840

838841
def get_config(self):
839842
config = super().get_config()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import numpy as np
2+
import pytest
3+
4+
from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone
5+
from keras_hub.src.tests.test_case import TestCase
6+
7+
8+
class TestTask(TestCase):
9+
@pytest.mark.large
10+
def test_convert_tiny_preset(self):
11+
model = DINOV3Backbone.from_preset(
12+
"hf://facebook/dinov3-vits16-pretrain-lvd1689m",
13+
image_shape=(224, 224, 3),
14+
)
15+
dummy_input = {
16+
"pixel_values": np.ones((1, 224, 224, 3), dtype="float32")
17+
}
18+
output = model.predict(dummy_input)
19+
self.assertAllClose(
20+
output[0, 0, :10],
21+
[
22+
-0.2769,
23+
0.5487,
24+
0.2501,
25+
-1.2269,
26+
0.5886,
27+
0.0762,
28+
0.6251,
29+
0.1874,
30+
-0.4259,
31+
-0.4362,
32+
],
33+
atol=1e-2,
34+
)

0 commit comments

Comments
 (0)