diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index aacee7818e..58afa65934 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -93,6 +93,9 @@ from keras_hub.src.models.dinov2.dinov2_image_converter import ( DINOV2ImageConverter as DINOV2ImageConverter, ) +from keras_hub.src.models.dinov3.dinov3_image_converter import ( + DINOV3ImageConverter as DINOV3ImageConverter, +) from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( EfficientNetImageConverter as EfficientNetImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index b90dde2cc7..72eebbf64a 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -184,6 +184,9 @@ from keras_hub.src.models.dinov2.dinov2_backbone import ( DINOV2Backbone as DINOV2Backbone, ) +from keras_hub.src.models.dinov3.dinov3_backbone import ( + DINOV3Backbone as DINOV3Backbone, +) from keras_hub.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone as DistilBertBackbone, ) diff --git a/keras_hub/src/models/dinov2/dinov2_layers.py b/keras_hub/src/models/dinov2/dinov2_layers.py index 1124b57a50..ce040ae266 100644 --- a/keras_hub/src/models/dinov2/dinov2_layers.py +++ b/keras_hub/src/models/dinov2/dinov2_layers.py @@ -502,7 +502,9 @@ def call(self, inputs, training=None): def get_config(self): config = super().get_config() - config.update({"hidden_dim": self.hidden_dim}) + config.update( + {"hidden_dim": self.hidden_dim, "init_values": self.init_values} + ) return config def compute_output_shape(self, input_shape): diff --git a/keras_hub/src/models/dinov3/__init__.py b/keras_hub/src/models/dinov3/__init__.py new file mode 100644 index 0000000000..1752b3c838 --- /dev/null +++ b/keras_hub/src/models/dinov3/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone +from keras_hub.src.models.dinov3.dinov3_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, DINOV3Backbone) diff --git a/keras_hub/src/models/dinov3/dinov3_backbone.py b/keras_hub/src/models/dinov3/dinov3_backbone.py new file mode 100644 index 0000000000..5f52d9a509 --- /dev/null +++ b/keras_hub/src/models/dinov3/dinov3_backbone.py @@ -0,0 +1,263 @@ +from keras import layers + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.dinov3.dinov3_layers import DINOV3Embedding +from keras_hub.src.models.dinov3.dinov3_layers import DINOV3Encoder +from keras_hub.src.models.dinov3.dinov3_layers import ( + DINOV3RopePositionEmbedding, +) +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.DINOV3Backbone") +class DINOV3Backbone(FeaturePyramidBackbone): + """DINOV3 core network with hyperparameters. + + Args: + patch_size: int. The size of each square patch in the input image. + num_layers: int. The number of transformer layers. + hidden_dim: int. The size of the transformer hidden state at the end + of each transformer layer. + num_heads: int. The number of attention heads for each transformer. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + layer_scale_init_value: float. The initial value for the layer scale in + the transformer layers. Defaults to `1.0`. + num_register_tokens: int. The number of register tokens to use in the + embedding layer. Defaults to `0`. + use_mask_token: bool. Whether to use a mask token in the embedding + layer. Defaults to `True`. + hidden_activation: str or callable. Activation to use in the MLP. + Defaults to `"gelu"`. + use_gated_mlp: bool. Whether to use Gated MLP layers. Defaults to + `False`. + use_query_bias: bool. Whether to use a bias for the query projection. + Defaults to `True`. + use_key_bias: bool. Whether to use a bias for the key projection. + Defaults to `True`. + use_value_bias: bool. Whether to use a bias for the value projection. + Defaults to `True`. + use_proj_bias: bool. Whether to use a bias for the output projection. + Defaults to `True`. + use_mlp_bias: bool. Whether to use a bias for the dense layers in MLP. + Defaults to `True`. + attention_dropout: float. The dropout rate for the attention + probabilities. Defaults to `0.0`. + drop_path_rate: float. The drop path rate to use. Defaults to `0.0`. + image_shape: tuple. The input shape without the batch size. Defaults to + `(518, 518, 3)`. + rope_theta: float. The base period of the rotary position embeddings. + Defaults to `100.0`. + apply_layernorm: bool. Whether to apply layer normalization to the + outputs of each stage in the feature pyramid. Defaults to `False`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + + Example: + ```python + # Pretrained DINOV3 model. + input_data = { + "images": np.ones(shape=(1, 518, 518, 3), dtype="float32"), + } + model = keras_hub.models.DINOV3Backbone.from_preset( + "dinov3_vit_small_lvd1689m" + ) + model(input_data) + + # Pretrained DINOV3 model with custom image shape. + input_data = { + "images": np.ones(shape=(1, 224, 224, 3), dtype="float32"), + } + model = keras_hub.models.DINOV3Backbone.from_preset( + "dinov3_vit_small_lvd1689m", image_shape=(224, 224, 3) + ) + model(input_data) + + # Randomly initialized DINOV3 model with custom config. + model = keras_hub.models.DINOV3Backbone( + patch_size=14, + num_layers=2, + hidden_dim=32, + num_heads=2, + intermediate_dim=128, + image_shape=(224, 224, 3), + ) + model(input_data) + + # Accessing feature pyramid outputs. + backbone = keras_hub.models.DINOV3Backbone.from_preset( + "dinov3_vit_small_lvd1689m", image_shape=(224, 224, 3) + ) + model = keras.Model( + inputs=backbone.inputs, + outputs=backbone.pyramid_outputs, + ) + features = model(input_data) + ``` + """ + + def __init__( + self, + patch_size, + num_layers, + hidden_dim, + num_heads, + intermediate_dim, + layer_scale_init_value=1.0, + num_register_tokens=4, + use_mask_token=True, + hidden_activation="gelu", + use_gated_mlp=False, + use_query_bias=True, + use_key_bias=True, + use_value_bias=True, + use_proj_bias=True, + use_mlp_bias=True, + attention_dropout=0.0, + drop_path_rate=0.0, + layer_norm_eps=1e-5, + image_shape=(518, 518, 3), + rope_theta=100.0, + apply_layernorm=False, + data_format=None, + dtype=None, + name=None, + **kwargs, + ): + data_format = standardize_data_format(data_format) + + prefix = str(name) + "_" if name is not None else "" + + # === Layers === + self.embeddings = DINOV3Embedding( + hidden_dim=hidden_dim, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + use_mask_token=use_mask_token, + data_format=data_format, + dtype=dtype, + name=f"{prefix}embeddings", + ) + self.rope_embedding = DINOV3RopePositionEmbedding( + hidden_dim=hidden_dim, + num_heads=num_heads, + rope_theta=rope_theta, + patch_size=patch_size, + dtype=dtype, + name=f"{prefix}rope_embedding", + ) + self.encoder = DINOV3Encoder( + num_layers=num_layers, + hidden_dim=hidden_dim, + num_heads=num_heads, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + hidden_activation=hidden_activation, + use_gated_mlp=use_gated_mlp, + use_query_bias=use_query_bias, + use_key_bias=use_key_bias, + use_value_bias=use_value_bias, + use_proj_bias=use_proj_bias, + use_mlp_bias=use_mlp_bias, + attention_dropout=attention_dropout, + drop_path_rate=drop_path_rate, + layer_norm_eps=layer_norm_eps, + dtype=dtype, + name=f"{prefix}encoder", + ) + self.layernorm = layers.LayerNormalization( + epsilon=layer_norm_eps, dtype=dtype, name=f"{prefix}layernorm" + ) + + # === Functional Model === + pyramid_outputs = {} + image_input = layers.Input(shape=image_shape, name="pixel_values") + x = self.embeddings(image_input) + pyramid_outputs["stem"] = x + + position_embeddings = self.rope_embedding(image_input) + num_prefix_tokens = 1 + num_register_tokens + + x, encoder_pyramid_outputs = self.encoder( + x, + position_embeddings=position_embeddings, + num_prefix_tokens=num_prefix_tokens, + ) + pyramid_outputs.update(encoder_pyramid_outputs) + x = self.layernorm(x) + if apply_layernorm: + for key in pyramid_outputs: + pyramid_outputs[key] = self.layernorm(pyramid_outputs[key]) + outputs = x + super().__init__( + inputs={"pixel_values": image_input}, + outputs=outputs, + dtype=dtype, + name=name, + **kwargs, + ) + + # === Config === + self.patch_size = int(patch_size) + self.num_layers = int(num_layers) + self.hidden_dim = int(hidden_dim) + self.num_heads = int(num_heads) + self.intermediate_dim = int(intermediate_dim) + self.layer_scale_init_value = float(layer_scale_init_value) + self.num_register_tokens = int(num_register_tokens) + self.use_mask_token = bool(use_mask_token) + self.hidden_activation = hidden_activation + self.use_gated_mlp = bool(use_gated_mlp) + self.use_query_bias = bool(use_query_bias) + self.use_key_bias = bool(use_key_bias) + self.use_value_bias = bool(use_value_bias) + self.use_proj_bias = bool(use_proj_bias) + self.use_mlp_bias = bool(use_mlp_bias) + self.attention_dropout = float(attention_dropout) + self.drop_path_rate = float(drop_path_rate) + self.layer_norm_eps = float(layer_norm_eps) + self.image_shape = image_shape + self.rope_theta = rope_theta + self.apply_layernorm = apply_layernorm + self.pyramid_outputs = pyramid_outputs + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "num_layers": self.num_layers, + "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "num_register_tokens": self.num_register_tokens, + "use_mask_token": self.use_mask_token, + "layer_scale_init_value": self.layer_scale_init_value, + "hidden_activation": self.hidden_activation, + "use_gated_mlp": self.use_gated_mlp, + "use_query_bias": self.use_query_bias, + "use_key_bias": self.use_key_bias, + "use_value_bias": self.use_value_bias, + "use_proj_bias": self.use_proj_bias, + "use_mlp_bias": self.use_mlp_bias, + "attention_dropout": self.attention_dropout, + "drop_path_rate": self.drop_path_rate, + "layer_norm_eps": self.layer_norm_eps, + "image_shape": self.image_shape, + "rope_theta": self.rope_theta, + "apply_layernorm": self.apply_layernorm, + } + ) + return config diff --git a/keras_hub/src/models/dinov3/dinov3_backbone_test.py b/keras_hub/src/models/dinov3/dinov3_backbone_test.py new file mode 100644 index 0000000000..b8fdd9a0c6 --- /dev/null +++ b/keras_hub/src/models/dinov3/dinov3_backbone_test.py @@ -0,0 +1,101 @@ +import os + +import keras +import pytest +from keras import ops + +from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone +from keras_hub.src.tests.test_case import TestCase + + +class DINOV3BackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "patch_size": 16, + "num_layers": 2, + "hidden_dim": 16, + "num_heads": 2, + "intermediate_dim": 16 * 4, + "layer_scale_init_value": 1.0, + "num_register_tokens": 4, + "use_gated_mlp": False, + "image_shape": (64, 64, 3), + "name": "dinov3_backbone", + } + self.input_data = { + "pixel_values": ops.ones((2, 64, 64, 3)), + } + + def test_backbone_basics(self): + patch_size = self.init_kwargs["patch_size"] + image_size = self.init_kwargs["image_shape"][0] + hidden_dim = self.init_kwargs["hidden_dim"] + num_register_tokens = self.init_kwargs["num_register_tokens"] + sequence_length = ( + (image_size // patch_size) ** 2 + 1 + num_register_tokens + ) + self.run_vision_backbone_test( + cls=DINOV3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, sequence_length, hidden_dim), + expected_pyramid_output_keys=["stem", "stage1", "stage2"], + expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DINOV3Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.large + def test_position_embedding_interpolation(self): + model = DINOV3Backbone(**self.init_kwargs) + model_output = model(self.input_data) + + # Test not using interpolation in `save` and `load_model`. + path = os.path.join(self.get_temp_dir(), "model.keras") + model.save(path) + restored_model = keras.models.load_model(path) + restored_output = restored_model(self.input_data) + self.assertAllClose(model_output, restored_output, atol=1e-5, rtol=1e-5) + + # Test using interpolation in `save_to_preset` and `from_preset` if + # image_shape is different. + path = os.path.join(self.get_temp_dir(), "model") + model.save_to_preset(path) + restored_model = DINOV3Backbone.from_preset( + path, + image_shape=(128, 128, 3), # From 64 to 128. + ) + input_data = { + "pixel_values": ops.ones((2, 128, 128, 3)), + } + restored_output = restored_model(input_data) + self.assertNotEqual(model_output.shape, restored_output.shape) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_smallest_preset(self): + self.skipTest("Presets are not uploaded yet.") + self.run_preset_test( + cls=DINOV3Backbone, + preset="dinov3_vit_small_lvd1689m", + input_data=self.input_data, + expected_output_shape=(2, 1374, 768), + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + self.skipTest("Presets are not uploaded yet.") + for preset in DINOV3Backbone.presets: + self.run_preset_test( + cls=DINOV3Backbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/dinov3/dinov3_image_converter.py b/keras_hub/src/models/dinov3/dinov3_image_converter.py new file mode 100644 index 0000000000..54b08eacf3 --- /dev/null +++ b/keras_hub/src/models/dinov3/dinov3_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone + + +@keras_hub_export("keras_hub.layers.DINOV3ImageConverter") +class DINOV3ImageConverter(ImageConverter): + backbone_cls = DINOV3Backbone diff --git a/keras_hub/src/models/dinov3/dinov3_layers.py b/keras_hub/src/models/dinov3/dinov3_layers.py new file mode 100644 index 0000000000..edc46006a8 --- /dev/null +++ b/keras_hub/src/models/dinov3/dinov3_layers.py @@ -0,0 +1,1013 @@ +import math + +from keras import initializers +from keras import layers +from keras import ops +from keras import random + +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class DINOV3PatchEmbedding(layers.Layer): + """A layer that converts images into patches. + + Args: + hidden_dim: int. The number of units in the hidden layers. + patch_size: int. The size of one side of each patch. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__(self, hidden_dim, patch_size, data_format=None, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.patch_size = int(patch_size) + self.data_format = standardize_data_format(data_format) + + self.projection = layers.Conv2D( + hidden_dim, + kernel_size=patch_size, + strides=patch_size, + data_format=data_format, + kernel_initializer=initializers.TruncatedNormal(stddev=0.02), + dtype=self.dtype_policy, + name="projection", + ) + + def build(self, input_shape): + self.projection.build(input_shape) + + def call(self, inputs, training=None): + batch_size = ops.shape(inputs)[0] + embeddings = self.projection(inputs, training=training) + if self.data_format == "channels_last": + embeddings = ops.reshape( + embeddings, (batch_size, -1, self.hidden_dim) + ) + else: + embeddings = ops.reshape( + embeddings, (batch_size, self.hidden_dim, -1) + ) + embeddings = ops.transpose(embeddings, (0, 2, 1)) + return embeddings + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "patch_size": self.patch_size, + } + ) + return config + + def compute_output_shape(self, input_shape): + output_shape = [input_shape[0], None, self.hidden_dim] + if self.data_format == "channels_last": + if input_shape[1] is not None and input_shape[2] is not None: + patch_num = input_shape[1] // self.patch_size + output_shape[1] = patch_num**2 + else: + if input_shape[2] is not None and input_shape[3] is not None: + patch_num = input_shape[2] // self.patch_size + output_shape[1] = patch_num**2 + return output_shape + + +class DINOV3Embedding(layers.Layer): + """A layer that converts images into patches. + + This layer adds all the necessary tokens to the embeddings, inlcuding + the class token, register tokens and mask token if specified. + + Args: + hidden_dim: int. The number of units in the hidden layers. + patch_size: int. The size of one side of each patch. + num_register_tokens: int. The number of register tokens to add to the + embeddings. Defaults to `0`. + use_mask_token: bool. Whether to use a mask token. Defaults to `True`. + initializer_range: float. The standard deviation of the truncated + normal initializer. Defaults to `0.02`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + hidden_dim, + patch_size, + num_register_tokens=0, + use_mask_token=True, + initializer_range=0.02, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.patch_size = int(patch_size) + self.num_register_tokens = int(num_register_tokens) + self.use_mask_token = bool(use_mask_token) + self.initializer_range = float(initializer_range) + self.data_format = standardize_data_format(data_format) + + self.patch_embeddings = DINOV3PatchEmbedding( + hidden_dim, + patch_size, + data_format=data_format, + dtype=self.dtype_policy, + name="patch_embeddings", + ) + + def build(self, input_shape): + self.cls_token = self.add_weight( + shape=(1, 1, self.hidden_dim), + initializer=initializers.TruncatedNormal( + stddev=self.initializer_range + ), + trainable=True, + name="cls_token", + ) + if self.use_mask_token: + self.mask_token = self.add_weight( + shape=(1, 1, self.hidden_dim), + initializer="zeros", + trainable=True, + name="mask_token", + ) + if self.num_register_tokens > 0: + self.register_tokens = self.add_weight( + shape=(1, self.num_register_tokens, self.hidden_dim), + initializer=initializers.TruncatedNormal( + stddev=self.initializer_range + ), + trainable=True, + name="register_tokens", + ) + self.patch_embeddings.build(input_shape) + + def call(self, inputs, masks=None, training=None): + batch_size = ops.shape(inputs)[0] + embeddings = self.patch_embeddings(inputs, training=training) + + if masks is not None and self.use_mask_token: + mask_token = ops.cast(self.mask_token, embeddings.dtype) + embeddings = ops.where( + ops.expand_dims(masks, axis=-1), + mask_token, + embeddings, + ) + + cls_tokens = ops.tile(self.cls_token, (batch_size, 1, 1)) + embeddings = ops.concatenate((cls_tokens, embeddings), axis=1) + + if self.num_register_tokens > 0: + register_tokens = ops.tile(self.register_tokens, (batch_size, 1, 1)) + embeddings = ops.concatenate( + ( + embeddings[:, :1, ...], + register_tokens, + embeddings[:, 1:, ...], + ), + axis=1, + ) + return embeddings + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "patch_size": self.patch_size, + "num_register_tokens": self.num_register_tokens, + "use_mask_token": self.use_mask_token, + "initializer_range": self.initializer_range, + } + ) + return config + + def compute_output_shape(self, input_shape): + output_shape = [input_shape[0], None, self.hidden_dim] + if self.data_format == "channels_last": + if input_shape[1] is not None and input_shape[2] is not None: + patch_num = input_shape[1] // self.patch_size + output_shape[1] = 1 + self.num_register_tokens + patch_num**2 + else: + if input_shape[2] is not None and input_shape[3] is not None: + patch_num = input_shape[2] // self.patch_size + output_shape[1] = 1 + self.num_register_tokens + patch_num**2 + return output_shape + + +class DINOV3RopePositionEmbedding(layers.Layer): + """A layer that implements Rotary Position Embedding. + + Args: + hidden_dim: int. The number of units in the hidden layers. + num_heads: int. Number of attention heads. + rope_theta: float. The base period of the rotary position embeddings. + patch_size: int. The size of one side of each patch. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + hidden_dim, + num_heads, + rope_theta, + patch_size, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.num_heads = int(num_heads) + self.rope_theta = float(rope_theta) + self.patch_size = int(patch_size) + self.data_format = standardize_data_format(data_format) + self.head_dim = hidden_dim // num_heads + self.inv_freq = 1.0 / ( + rope_theta ** (ops.arange(0, 1, 4 / self.head_dim, dtype="float32")) + ) + + def _get_patches_center_coordinates( + self, num_patches_h, num_patches_w, dtype="float32" + ): + """A helper function to get the center coordinates of the patches.""" + coords_h = ops.arange(0.5, num_patches_h, dtype=dtype) + coords_w = ops.arange(0.5, num_patches_w, dtype=dtype) + + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + + coords_h = ops.expand_dims(coords_h, axis=1) + coords_w = ops.expand_dims(coords_w, axis=0) + + coords_h = ops.repeat(coords_h, num_patches_w, axis=1) + coords_w = ops.repeat(coords_w, num_patches_h, axis=0) + + coords = ops.stack([coords_h, coords_w], axis=-1) + coords = ops.reshape(coords, (-1, 2)) + coords = 2.0 * coords - 1.0 + return coords + + def call(self, inputs): + shape = ops.shape(inputs) + if self.data_format == "channels_last": + height, width = shape[1], shape[2] + else: + height, width = shape[2], shape[3] + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + + patch_coords = self._get_patches_center_coordinates( + num_patches_h, num_patches_w, dtype="float32" + ) + angles = ( + 2 + * math.pi + * ops.expand_dims(patch_coords, axis=-1) + * ops.expand_dims(ops.expand_dims(self.inv_freq, axis=0), axis=0) + ) + angles = ops.reshape(angles, (ops.shape(angles)[0], -1)) + angles = ops.tile(angles, (1, 2)) + + cos = ops.cast(ops.cos(angles), inputs.dtype) + sin = ops.cast(ops.sin(angles), inputs.dtype) + return cos, sin + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, + "rope_theta": self.rope_theta, + "patch_size": self.patch_size, + } + ) + return config + + def compute_output_shape(self, input_shape): + output_shape = input_shape + if self.data_format == "channels_last": + height, width = input_shape[1], input_shape[2] + else: + height, width = input_shape[2], input_shape[3] + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + seq_len = num_patches_h * num_patches_w + output_shape = (seq_len, self.head_dim) + return output_shape, output_shape + + +class DINOV3Attention(layers.Layer): + """A multi-head attention layer with dropout. + + Args: + hidden_dim: int. The number of units in the hidden layers. + num_heads: int. Number of attention heads. + dropout_rate: float. The dropout rate to use. Defaults to `0.0`. + use_query_bias: bool. Whether to use a bias for the query projection. + use_key_bias: bool. Whether to use a bias for the key projection. + use_value_bias: bool. Whether to use a bias for the value projection. + use_proj_bias: bool. Whether to use a bias for the output projection. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + hidden_dim, + num_heads, + dropout_rate=0.0, + use_query_bias=True, + use_key_bias=True, + use_value_bias=True, + use_proj_bias=True, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.num_heads = int(num_heads) + self.dropout_rate = float(dropout_rate) + self.use_query_bias = bool(use_query_bias) + self.use_key_bias = bool(use_key_bias) + self.use_value_bias = bool(use_value_bias) + self.use_proj_bias = bool(use_proj_bias) + self.head_dim = hidden_dim // num_heads + self.scale = self.head_dim**-0.5 + + self.query_dense = layers.Dense( + hidden_dim, + use_bias=use_query_bias, + dtype=self.dtype_policy, + name="q_proj", + ) + self.key_dense = layers.Dense( + hidden_dim, + use_bias=use_key_bias, + dtype=self.dtype_policy, + name="k_proj", + ) + self.value_dense = layers.Dense( + hidden_dim, + use_bias=use_value_bias, + dtype=self.dtype_policy, + name="v_proj", + ) + self.output_dense = layers.Dense( + hidden_dim, + use_bias=use_proj_bias, + dtype=self.dtype_policy, + name="o_proj", + ) + self.dropout = layers.Dropout( + dropout_rate, dtype=self.dtype_policy, name="dropout" + ) + + def build(self, input_shape): + self.query_dense.build(input_shape) + self.key_dense.build(input_shape) + self.value_dense.build(input_shape) + self.output_dense.build(input_shape) + + def _apply_rotary(self, q, k, cos, sin, num_prefix_tokens): + """Apply rotary position embedding to query and key.""" + + def _rotate_half(x): + """A helper function to rotate half of the features.""" + x1 = x[..., : ops.shape(x)[-1] // 2] + x2 = x[..., ops.shape(x)[-1] // 2 :] + return ops.concatenate([-x2, x1], axis=-1) + + q_prefix_tokens = q[:, :num_prefix_tokens, :, :] + q_patches = q[:, num_prefix_tokens:, :, :] + k_prefix_tokens = k[:, :num_prefix_tokens, :, :] + k_patches = k[:, num_prefix_tokens:, :, :] + cos = ops.expand_dims(ops.expand_dims(cos, axis=0), axis=2) + sin = ops.expand_dims(ops.expand_dims(sin, axis=0), axis=2) + + q_patches = (q_patches * cos) + (_rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (_rotate_half(k_patches) * sin) + q = ops.concatenate([q_prefix_tokens, q_patches], axis=-3) + k = ops.concatenate([k_prefix_tokens, k_patches], axis=-3) + return q, k + + def call( + self, + inputs, + attention_mask=None, + position_embeddings=None, + num_prefix_tokens=0, + training=None, + ): + batch_size, seq_len, _ = ops.shape(inputs) + q = self.query_dense(inputs, training=training) + k = self.key_dense(inputs, training=training) + v = self.value_dense(inputs, training=training) + q = ops.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)) + k = ops.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)) + v = ops.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)) + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = self._apply_rotary(q, k, cos, sin, num_prefix_tokens) + + attn_output = ops.nn.dot_product_attention( + q, + k, + v, + mask=attention_mask, + scale=self.scale, + is_causal=False, + ) + attn_output = ops.reshape(attn_output, (batch_size, seq_len, -1)) + attn_output = self.dropout(attn_output, training=training) + return self.output_dense(attn_output, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, + "dropout_rate": self.dropout_rate, + "query_bias": self.use_query_bias, + "key_bias": self.use_key_bias, + "value_bias": self.use_value_bias, + "proj_bias": self.use_proj_bias, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +class DINOV3LayerScale(layers.Layer): + """A layer scale. + + Args: + hidden_dim: int. The number of units in the hidden layers. + init_values: float. The initial value for the scale. Defaults to `1.0`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__(self, hidden_dim, init_values=1.0, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.init_values = float(init_values) + + def build(self, input_shape): + self.lambda1 = self.add_weight( + shape=(self.hidden_dim,), + initializer=initializers.Constant(self.init_values), + trainable=True, + name="lambda1", + ) + + def call(self, inputs, training=None): + return ops.multiply(inputs, self.lambda1) + + def get_config(self): + config = super().get_config() + config.update( + {"hidden_dim": self.hidden_dim, "init_values": self.init_values} + ) + return config + + +class DINOV3DropPath(layers.Layer): + """A drop path layer. + + Args: + rate: float. The drop path rate to use. Defaults to `0.0`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__(self, rate=0.0, **kwargs): + super().__init__(**kwargs) + self.rate = float(rate) + + def build(self, input_shape): + self.noise_shape = (input_shape[0],) + (1,) * (len(input_shape) - 1) + + def call(self, inputs, training=None): + if not training or self.rate == 0.0: + return inputs + + keep_prob = 1.0 - self.rate + random_tensor = random.uniform(self.noise_shape, dtype=inputs.dtype) + random_tensor = ops.add(random_tensor, keep_prob) + return ops.multiply(ops.divide(inputs, keep_prob), random_tensor) + + def get_config(self): + config = super().get_config() + config.update({"rate": self.rate}) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +class DINOV3MLP(layers.Layer): + """A DINOV3 MLP block. + + Args: + hidden_dim: int. The number of units in the output layer. + intermediate_dim: int. The output dimension of the first Dense layer. + activation: str of callable. Activation to use in the intermediate + layer. Defaults to `"gelu"`. + use_bias: bool. Whether to use a bias for the dense layers. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + hidden_dim, + intermediate_dim, + activation="gelu", + use_bias=True, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.intermediate_dim = int(intermediate_dim) + self.activation = activation + self.use_bias = bool(use_bias) + + self.up_proj = layers.Dense( + intermediate_dim, + activation=activation, + use_bias=use_bias, + dtype=self.dtype_policy, + name="up_proj", + ) + self.down_proj = layers.Dense( + hidden_dim, + use_bias=use_bias, + dtype=self.dtype_policy, + name="down_proj", + ) + + def build(self, input_shape): + self.up_proj.build(input_shape) + input_shape = self.up_proj.compute_output_shape(input_shape) + self.down_proj.build(input_shape) + + def call(self, inputs, training=None): + x = self.up_proj(inputs, training=training) + return self.down_proj(x, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "activation": self.activation, + "use_bias": self.use_bias, + } + ) + return config + + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + output_shape[-1] = self.hidden_dim + return output_shape + + +class DINOV3GatedMLP(layers.Layer): + """A DINOV3 Gated MLP block. + + Args: + hidden_dim: int. The number of units in the output layer. + intermediate_dim: int. The output dimension of the first Dense layer. + activation: str of callable. Activation to use in the intermediate + layer. Defaults to `"gelu"`. + use_bias: bool. Whether to use a bias for the dense layers. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + hidden_dim, + intermediate_dim, + activation="gelu", + use_bias=True, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.intermediate_dim = int(intermediate_dim) + self.activation = activation + self.use_bias = bool(use_bias) + + self.gate_proj = layers.Dense( + intermediate_dim, + activation=activation, + use_bias=use_bias, + dtype=self.dtype_policy, + name="gate_proj", + ) + self.up_proj = layers.Dense( + intermediate_dim, + use_bias=use_bias, + dtype=self.dtype_policy, + name="up_proj", + ) + self.down_proj = layers.Dense( + hidden_dim, + use_bias=use_bias, + dtype=self.dtype_policy, + name="down_proj", + ) + + def build(self, input_shape): + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + input_shape = self.up_proj.compute_output_shape(input_shape) + self.down_proj.build(input_shape) + + def call(self, inputs, training=None): + x = ops.multiply( + self.gate_proj(inputs, training=training), + self.up_proj(inputs, training=training), + ) + return self.down_proj(x, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "activation": self.activation, + "use_bias": self.use_bias, + } + ) + return config + + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + output_shape[-1] = self.hidden_dim + return output_shape + + +class DINOV3Layer(layers.Layer): + """A DINOV3 encoder layer. + + Args: + hidden_dim: int. The number of units in the hidden layers. + num_heads: int. Number of attention heads. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + layer_scale_init_value: float. The initial value for the scale. + Defaults to `1.0`. + hidden_activation: str or callable. Activation to use in the MLP. + Defaults to `"gelu"`. + use_gated_mlp: bool. Whether to use Gated MLP layers. Defaults to + `False`. + use_query_bias: bool. Whether to use a bias for the query projection. + use_key_bias: bool. Whether to use a bias for the key projection. + use_value_bias: bool. Whether to use a bias for the value projection. + use_proj_bias: bool. Whether to use a bias for the output projection. + use_mlp_bias: bool. Whether to use a bias for the MLP layers. + attention_dropout: float. The dropout rate for the attention + probabilities. Defaults to `0.0`. + drop_path_rate: float. The drop path rate to use. Defaults to `0.0`. + layer_norm_eps: float. The epsilon for layer normalization. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + hidden_dim, + num_heads, + intermediate_dim, + layer_scale_init_value=1.0, + hidden_activation="gelu", + use_gated_mlp=False, + use_query_bias=True, + use_key_bias=True, + use_value_bias=True, + use_proj_bias=True, + use_mlp_bias=True, + attention_dropout=0.0, + drop_path_rate=0.0, + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.num_heads = int(num_heads) + self.intermediate_dim = int(intermediate_dim) + self.layer_scale_init_value = float(layer_scale_init_value) + self.hidden_activation = hidden_activation + self.use_gated_mlp = bool(use_gated_mlp) + self.use_query_bias = bool(use_query_bias) + self.use_key_bias = bool(use_key_bias) + self.use_value_bias = bool(use_value_bias) + self.use_proj_bias = bool(use_proj_bias) + self.use_mlp_bias = bool(use_mlp_bias) + self.attention_dropout = float(attention_dropout) + self.drop_path_rate = float(drop_path_rate) + self.layer_norm_eps = float(layer_norm_eps) + + self.norm1 = layers.LayerNormalization( + epsilon=layer_norm_eps, dtype=self.dtype_policy, name="norm1" + ) + self.attention = DINOV3Attention( + hidden_dim=hidden_dim, + num_heads=num_heads, + dropout_rate=attention_dropout, + use_query_bias=use_query_bias, + use_key_bias=use_key_bias, + use_value_bias=use_value_bias, + use_proj_bias=use_proj_bias, + dtype=self.dtype_policy, + name="attention", + ) + self.layer_scale1 = DINOV3LayerScale( + hidden_dim, + init_values=layer_scale_init_value, + dtype=self.dtype_policy, + name="layer_scale1", + ) + self.drop_path = ( + DINOV3DropPath(drop_path_rate, dtype=self.dtype_policy) + if drop_path_rate > 0.0 + else layers.Identity(dtype=self.dtype_policy) + ) + self.norm2 = layers.LayerNormalization( + epsilon=layer_norm_eps, dtype=self.dtype_policy, name="norm2" + ) + if use_gated_mlp: + self.mlp = DINOV3GatedMLP( + hidden_dim, + intermediate_dim, + activation=hidden_activation, + use_bias=use_mlp_bias, + dtype=self.dtype_policy, + name="mlp", + ) + else: + self.mlp = DINOV3MLP( + hidden_dim, + intermediate_dim, + activation=hidden_activation, + use_bias=use_mlp_bias, + dtype=self.dtype_policy, + name="mlp", + ) + self.layer_scale2 = DINOV3LayerScale( + hidden_dim, + init_values=layer_scale_init_value, + dtype=self.dtype_policy, + name="layer_scale2", + ) + + def build(self, input_shape): + self.norm1.build(input_shape) + self.attention.build(input_shape) + input_shape = self.attention.compute_output_shape(input_shape) + self.layer_scale1.build(input_shape) + self.drop_path.build(input_shape) + self.norm2.build(input_shape) + self.mlp.build(input_shape) + input_shape = self.mlp.compute_output_shape(input_shape) + self.layer_scale2.build(input_shape) + + def call( + self, + inputs, + attention_mask=None, + position_embeddings=None, + num_prefix_tokens=0, + training=None, + ): + residual = inputs + hidden_states = self.norm1(inputs) + hidden_states = self.attention( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + num_prefix_tokens=num_prefix_tokens, + training=training, + ) + hidden_states = self.layer_scale1(hidden_states, training=training) + hidden_states = ( + self.drop_path(hidden_states, training=training) + residual + ) + + residual = hidden_states + hidden_states = self.norm2(hidden_states, training=training) + hidden_states = self.mlp(hidden_states, training=training) + hidden_states = self.layer_scale2(hidden_states, training=training) + return self.drop_path(hidden_states, training=training) + residual + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "layer_scale_init_value": self.layer_scale_init_value, + "hidden_activation": self.hidden_activation, + "use_gated_mlp": self.use_gated_mlp, + "use_query_bias": self.use_query_bias, + "use_key_bias": self.use_key_bias, + "use_value_bias": self.use_value_bias, + "use_proj_bias": self.use_proj_bias, + "use_mlp_bias": self.use_mlp_bias, + "attention_dropout": self.attention_dropout, + "drop_path_rate": self.drop_path_rate, + "layer_norm_eps": self.layer_norm_eps, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +class DINOV3Encoder(layers.Layer): + """A DINOV3 encoder. + + Args: + num_layers: int. The number of transformer layers. + hidden_dim: int. The number of units in the hidden layers. + num_heads: int. Number of attention heads. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + layer_scale_init_value: float. The initial value for the scale. + Defaults to `1.0`. + hidden_activation: str or callable. Activation to use in the MLP. + Defaults to `"gelu"`. + use_gated_mlp: bool. Whether to use Gated MLP layers. Defaults to + `False`. + use_query_bias: bool. Whether to use a bias for the query projection. + Defaults to `True`. + use_key_bias: bool. Whether to use a bias for the key projection. + Defaults to `True`. + use_value_bias: bool. Whether to use a bias for the value projection. + Defaults to `True`. + use_proj_bias: bool. Whether to use a bias for the output projection. + Defaults to `True`. + use_mlp_bias: bool. Whether to use a bias for the dense layers in MLP. + Defaults to `True`. + attention_dropout: float. The dropout rate for the attention + probabilities. Defaults to `0.0`. + drop_path_rate: float. The drop path rate to use. Defaults to `0.0`. + layer_norm_eps: float. The epsilon for layer normalization. Defaults to + `1e-5`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + + def __init__( + self, + num_layers, + hidden_dim, + num_heads, + intermediate_dim, + layer_scale_init_value=1.0, + hidden_activation="gelu", + use_gated_mlp=False, + use_query_bias=True, + use_key_bias=True, + use_value_bias=True, + use_proj_bias=True, + use_mlp_bias=True, + attention_dropout=0.0, + drop_path_rate=0.0, + layer_norm_eps=1e-5, + **kwargs, + ): + super().__init__(**kwargs) + self.num_layers = int(num_layers) + self.hidden_dim = int(hidden_dim) + self.num_heads = int(num_heads) + self.intermediate_dim = int(intermediate_dim) + self.layer_scale_init_value = float(layer_scale_init_value) + self.hidden_activation = hidden_activation + self.use_gated_mlp = bool(use_gated_mlp) + self.use_query_bias = bool(use_query_bias) + self.use_key_bias = bool(use_key_bias) + self.use_value_bias = bool(use_value_bias) + self.use_proj_bias = bool(use_proj_bias) + self.use_mlp_bias = bool(use_mlp_bias) + self.attention_dropout = float(attention_dropout) + self.drop_path_rate = float(drop_path_rate) + self.layer_norm_eps = float(layer_norm_eps) + + dpr = [x for x in ops.linspace(0.0, drop_path_rate, num_layers)] + self.layers = [ + DINOV3Layer( + hidden_dim=hidden_dim, + num_heads=num_heads, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + hidden_activation=hidden_activation, + use_gated_mlp=use_gated_mlp, + use_query_bias=use_query_bias, + use_key_bias=use_key_bias, + use_value_bias=use_value_bias, + use_proj_bias=use_proj_bias, + use_mlp_bias=use_mlp_bias, + attention_dropout=attention_dropout, + drop_path_rate=dpr[i], + layer_norm_eps=layer_norm_eps, + dtype=self.dtype_policy, + name=f"layers_{i}", + ) + for i in range(num_layers) + ] + + def build(self, input_shape): + for layer in self.layers: + layer.build(input_shape) + input_shape = layer.compute_output_shape(input_shape) + + def call( + self, + inputs, + attention_mask=None, + position_embeddings=None, + num_prefix_tokens=0, + training=None, + ): + pyramid_outputs = {} + x = inputs + for layer_index, layer in enumerate(self.layers, start=1): + x = layer( + x, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + num_prefix_tokens=num_prefix_tokens, + training=training, + ) + pyramid_outputs[f"stage{str(layer_index)}"] = x + return x, pyramid_outputs + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "layer_scale_init_value": self.layer_scale_init_value, + "hidden_activation": self.hidden_activation, + "use_gated_mlp": self.use_gated_mlp, + "use_query_bias": self.use_query_bias, + "use_key_bias": self.use_key_bias, + "use_value_bias": self.use_value_bias, + "use_proj_bias": self.use_proj_bias, + "use_mlp_bias": self.use_mlp_bias, + "attention_dropout": self.attention_dropout, + "drop_path_rate": self.drop_path_rate, + "layer_norm_eps": self.layer_norm_eps, + } + ) + return config + + def compute_output_shape(self, input_shape): + pyramid_outputs = {} + for layer_index in range(1, len(self.layers) + 1): + pyramid_outputs[f"stage{str(layer_index)}"] = input_shape + return input_shape, pyramid_outputs diff --git a/keras_hub/src/models/dinov3/dinov3_presets.py b/keras_hub/src/models/dinov3/dinov3_presets.py new file mode 100644 index 0000000000..077663f11b --- /dev/null +++ b/keras_hub/src/models/dinov3/dinov3_presets.py @@ -0,0 +1,4 @@ +"""DINOV3 model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = {} diff --git a/keras_hub/src/utils/transformers/convert_dinov3.py b/keras_hub/src/utils/transformers/convert_dinov3.py new file mode 100644 index 0000000000..7d51eeb118 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_dinov3.py @@ -0,0 +1,106 @@ +import numpy as np + +from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone + +backbone_cls = DINOV3Backbone + + +def convert_backbone_config(transformers_config): + image_size = transformers_config["image_size"] + return { + "patch_size": transformers_config["patch_size"], + "num_layers": transformers_config["num_hidden_layers"], + "hidden_dim": transformers_config["hidden_size"], + "num_heads": transformers_config["num_attention_heads"], + "intermediate_dim": transformers_config["intermediate_size"], + "layer_scale_init_value": transformers_config["layerscale_value"], + "num_register_tokens": transformers_config["num_register_tokens"], + "use_mask_token": True, + "hidden_activation": transformers_config["hidden_act"], + "use_gated_mlp": transformers_config["use_gated_mlp"], + "use_query_bias": transformers_config["query_bias"], + "use_key_bias": transformers_config["key_bias"], + "use_value_bias": transformers_config["value_bias"], + "use_proj_bias": transformers_config["proj_bias"], + "use_mlp_bias": transformers_config["mlp_bias"], + "attention_dropout": transformers_config["attention_dropout"], + "drop_path_rate": transformers_config["drop_path_rate"], + "layer_norm_eps": transformers_config["layer_norm_eps"], + "image_shape": (image_size, image_size, 3), + "rope_theta": transformers_config["rope_theta"], + "apply_layernorm": False, + } + + +def convert_weights(backbone, loader, transformers_config): + if not isinstance(backbone, DINOV3Backbone): + raise ValueError( + "The provided backbone must be an instance of DINOV3Backbone. " + f"Received: {type(backbone)}" + ) + + def port_ln(keras_variable, weight_key): + loader.port_weight(keras_variable.gamma, f"{weight_key}.weight") + loader.port_weight(keras_variable.beta, f"{weight_key}.bias") + + def port_dense(keras_variable, weight_key): + loader.port_weight( + keras_variable.kernel, + f"{weight_key}.weight", + hook_fn=lambda x, _: x.T, + ) + if keras_variable.bias is not None: + loader.port_weight(keras_variable.bias, f"{weight_key}.bias") + + # Embedding. + loader.port_weight( + keras_variable=backbone.embeddings.cls_token, + hf_weight_key="embeddings.cls_token", + ) + if backbone.use_mask_token: + loader.port_weight( + keras_variable=backbone.embeddings.mask_token, + hf_weight_key="embeddings.mask_token", + ) + if backbone.num_register_tokens > 0: + loader.port_weight( + keras_variable=backbone.embeddings.register_tokens, + hf_weight_key="embeddings.register_tokens", + ) + loader.port_weight( + keras_variable=backbone.embeddings.patch_embeddings.projection.kernel, + hf_weight_key="embeddings.patch_embeddings.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + loader.port_weight( + keras_variable=backbone.embeddings.patch_embeddings.projection.bias, + hf_weight_key="embeddings.patch_embeddings.bias", + ) + + # Encoder. + for i, layer in enumerate(backbone.encoder.layers): + prefix = f"layer.{i}" + port_ln(layer.norm1, f"{prefix}.norm1") + port_dense(layer.attention.query_dense, f"{prefix}.attention.q_proj") + port_dense(layer.attention.key_dense, f"{prefix}.attention.k_proj") + port_dense(layer.attention.value_dense, f"{prefix}.attention.v_proj") + port_dense(layer.attention.output_dense, f"{prefix}.attention.o_proj") + + loader.port_weight( + keras_variable=layer.layer_scale1.lambda1, + hf_weight_key=f"{prefix}.layer_scale1.lambda1", + ) + port_ln(layer.norm2, f"{prefix}.norm2") + if backbone.use_gated_mlp: + port_dense(layer.mlp.gate_proj, f"{prefix}.mlp.gate_proj") + port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj") + port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj") + else: + port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj") + port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj") + loader.port_weight( + keras_variable=layer.layer_scale2.lambda1, + hf_weight_key=f"{prefix}.layer_scale2.lambda1", + ) + + port_ln(backbone.layernorm, "norm") diff --git a/keras_hub/src/utils/transformers/convert_dinov3_test.py b/keras_hub/src/utils/transformers/convert_dinov3_test.py new file mode 100644 index 0000000000..81e3775dc8 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_dinov3_test.py @@ -0,0 +1,35 @@ +import numpy as np +import pytest + +from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone +from keras_hub.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.large + def test_convert_tiny_preset(self): + pytest.skip(reason="TODO: enable after HF token is available in CI") + model = DINOV3Backbone.from_preset( + "hf://facebook/dinov3-vits16-pretrain-lvd1689m", + image_shape=(224, 224, 3), + ) + dummy_input = { + "pixel_values": np.ones((1, 224, 224, 3), dtype="float32") + } + output = model.predict(dummy_input) + self.assertAllClose( + output[0, 0, :10], + [ + -0.2769, + 0.5487, + 0.2501, + -1.2269, + 0.5886, + 0.0762, + 0.6251, + 0.1874, + -0.4259, + -0.4362, + ], + atol=1e-2, + ) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 73f6a27717..f98007b438 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -8,6 +8,7 @@ from keras_hub.src.utils.transformers import convert_bert from keras_hub.src.utils.transformers import convert_deit from keras_hub.src.utils.transformers import convert_dinov2 +from keras_hub.src.utils.transformers import convert_dinov3 from keras_hub.src.utils.transformers import convert_distilbert from keras_hub.src.utils.transformers import convert_esm from keras_hub.src.utils.transformers import convert_gemma @@ -42,6 +43,8 @@ def __init__(self, preset, config): self.converter = convert_distilbert elif model_type in ("dinov2", "dinov2_with_registers"): self.converter = convert_dinov2 + elif model_type == "dinov3_vit": + self.converter = convert_dinov3 elif model_type == "esm": self.converter = convert_esm elif model_type in ("gemma", "gemma2"): diff --git a/tools/checkpoint_conversion/convert_dinov3_checkpoints.py b/tools/checkpoint_conversion/convert_dinov3_checkpoints.py new file mode 100644 index 0000000000..36d06e3fb4 --- /dev/null +++ b/tools/checkpoint_conversion/convert_dinov3_checkpoints.py @@ -0,0 +1,177 @@ +"""Convert DINOV3 checkpoints. + +export KAGGLE_USERNAME=xxx KAGGLE_KEY=xxx + +python tools/checkpoint_conversion/convert_dinov3_checkpoints.py \ + --preset dinov3_vit_small_lvd1689m --upload_uri kaggle://kerashub/dinov3/keras/dinov3_vit_small_lvd1689m +python tools/checkpoint_conversion/convert_dinov3_checkpoints.py \ + --preset dinov3_vit_small_plus_lvd1689m --upload_uri kaggle://kerashub/dinov3/keras/dinov3_vit_small_plus_lvd1689m +python tools/checkpoint_conversion/convert_dinov3_checkpoints.py \ + --preset dinov3_vit_base_lvd1689m --upload_uri kaggle://kerashub/dinov3/keras/dinov3_vit_base_lvd1689m +python tools/checkpoint_conversion/convert_dinov3_checkpoints.py \ + --preset dinov3_vit_large_lvd1689m --upload_uri kaggle://kerashub/dinov3/keras/dinov3_vit_large_lvd1689m +python tools/checkpoint_conversion/convert_dinov3_checkpoints.py \ + --preset dinov3_vit_huge_lvd1689m --upload_uri kaggle://kerashub/dinov3/keras/dinov3_vit_huge_lvd1689m +python tools/checkpoint_conversion/convert_dinov3_checkpoints.py \ + --preset dinov3_vit_huge_plus_lvd1689m --upload_uri kaggle://kerashub/dinov3/keras/dinov3_vit_huge_plus_lvd1689m +python tools/checkpoint_conversion/convert_dinov3_checkpoints.py \ + --preset dinov3_vit_7b_lvd1689m --upload_uri kaggle://kerashub/dinov3/keras/dinov3_vit_7b_lvd1689m + +python tools/checkpoint_conversion/convert_dinov3_checkpoints.py \ + --preset dinov3_vit_large_sat493m --upload_uri kaggle://kerashub/dinov3/keras/dinov3_vit_large_sat493m +python tools/checkpoint_conversion/convert_dinov3_checkpoints.py \ + --preset dinov3_vit_7b_sat493m --upload_uri kaggle://kerashub/dinov3/keras/dinov3_vit_7b_sat493m +""" + +import keras +import numpy as np +import torch +from absl import app +from absl import flags +from PIL import Image +from transformers import AutoImageProcessor +from transformers import AutoModel + +import keras_hub + +PRESET_MAP = { + # ViT lvd1689m variants. + "dinov3_vit_small_lvd1689m": "facebook/dinov3-vits16-pretrain-lvd1689m", + "dinov3_vit_small_plus_lvd1689m": ( + "facebook/dinov3-vits16plus-pretrain-lvd1689m" + ), + "dinov3_vit_base_lvd1689m": "facebook/dinov3-vitb16-pretrain-lvd1689m", + "dinov3_vit_large_lvd1689m": "facebook/dinov3-vitl16-pretrain-lvd1689m", + "dinov3_vit_huge_plus_lvd1689m": ( + "facebook/dinov3-vith16plus-pretrain-lvd1689m" + ), + "dinov3_vit_7b_lvd1689m": "facebook/dinov3-vit7b16-pretrain-lvd1689m", + # ViT sat493m variants. + "dinov3_vit_large_sat493m": "facebook/dinov3-vitl16-pretrain-sat493m", + "dinov3_vit_7b_sat493m": "facebook/dinov3-vit7b16-pretrain-sat493m", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f"Must be one of {','.join(PRESET_MAP.keys())}", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}"', + required=False, +) + + +def convert_image_converter(image_size, hf_image_processor): + config = hf_image_processor.to_dict() + image_size = (image_size, image_size) + std = config["image_std"] + mean = config["image_mean"] + return keras_hub.layers.DINOV3ImageConverter( + image_size=image_size, + scale=[1.0 / 255.0 / s for s in std], + offset=[-m / s for m, s in zip(mean, std)], + crop_to_aspect_ratio=False, + interpolation="bilinear", + antialias=True, + ) + + +def validate_output( + keras_hub_model, + keras_hub_image_converter, + hf_model, + hf_image_processor, +): + file = keras.utils.get_file( + origin=("http://images.cocodataset.org/val2017/000000039769.jpg") + ) + image = Image.open(file) + + # Preprocess with hf. + hf_inputs = hf_image_processor(images=image, return_tensors="pt") + hf_preprocessed = hf_inputs["pixel_values"].detach().cpu().numpy() + + # Preprocess with keras. + images = np.expand_dims(np.array(image).astype("float32"), axis=0) + images = keras_hub_image_converter(images) + keras_preprocessed = keras.ops.convert_to_numpy(images) + + print("🔶 Keras preprocessor output:", keras_preprocessed[0, 0, :10, 0]) + print("🔶 HF preprocessor output:", hf_preprocessed[0, 0, 0, :10]) + + # Call with hf. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + hf_inputs["pixel_values"] = torch.from_numpy( + keras.ops.convert_to_numpy( + keras.ops.transpose(keras_preprocessed, (0, 3, 1, 2)) + ) + ) + hf_outputs = hf_model(**hf_inputs) + hf_outputs = hf_outputs[0].detach().cpu().numpy() + + # Call with keras. + keras_outputs = keras_hub_model.predict({"images": images}, verbose=0) + keras_outputs = keras.ops.convert_to_numpy(keras_outputs) + + print("🔶 Keras output:", keras_outputs[0, 0, :10]) + print("🔶 HF output:", hf_outputs[0, 0, :10]) + modeling_diff = np.mean(np.abs(keras_outputs - hf_outputs)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean( + np.abs(keras_preprocessed - np.transpose(hf_preprocessed, (0, 2, 3, 1))) + ) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # Load the HF model. + hf_model = AutoModel.from_pretrained(hf_preset) + hf_model.eval() + image_size = int(hf_model.config.image_size) + hf_image_processor = AutoImageProcessor.from_pretrained(hf_preset) + + # Load the KerasHub model. + keras_hub_backbone = keras_hub.models.DINOV3Backbone.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_backbone.summary() + keras_hub_image_converter = convert_image_converter( + image_size, hf_image_processor + ) + print("✅ KerasHub model loaded.") + print("✅ Weights converted.") + + validate_output( + keras_hub_backbone, + keras_hub_image_converter, + hf_model, + hf_image_processor, + ) + print("✅ Output validated.") + + keras_hub_backbone.save_to_preset(f"./{preset}") + keras_hub_image_converter.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}.") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main)