Skip to content
Open
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
087b9b2
Update backbone.py
pctablet505 Sep 1, 2025
de830b1
Update backbone.py
pctablet505 Sep 1, 2025
62d2484
Update task.py
pctablet505 Sep 1, 2025
3b71125
Revert "Update task.py"
pctablet505 Sep 2, 2025
3d453ff
Revert "Update backbone.py"
pctablet505 Sep 2, 2025
92b1254
export
pctablet505 Sep 9, 2025
e46241d
refactoring
pctablet505 Sep 10, 2025
6e970e2
refactor
pctablet505 Sep 10, 2025
15ad9f3
Update registry.py
pctablet505 Sep 10, 2025
02ca0d9
Refactor export logic and improve error handling
pctablet505 Sep 15, 2025
901c233
Merge branch 'keras-team:master' into export
pctablet505 Sep 17, 2025
442fdd3
reformat
pctablet505 Sep 22, 2025
5446e2a
Add export submodule to keras_hub API
pctablet505 Sep 22, 2025
5c31d88
reformat
pctablet505 Sep 22, 2025
3290d42
now supporting export for objectDetectors
pctablet505 Sep 23, 2025
8b1024f
Add and refine image model exporter configs
pctablet505 Sep 23, 2025
8df5a75
Refactor: move keras import to module level
pctablet505 Sep 24, 2025
759d223
Remove debug_object_detection.py script
pctablet505 Sep 24, 2025
0737c93
Rename LiteRT to Litert and update exporter configs
pctablet505 Oct 3, 2025
c733e18
Refactor InputSpec formatting and fix import path
pctablet505 Oct 6, 2025
5ab911f
Refactor exporter configs and model building logic
pctablet505 Oct 9, 2025
c1e26dd
Refactor export initialization and improve warnings
pctablet505 Oct 9, 2025
6fa8379
Improve dtype handling and verbose output in exporters
pctablet505 Oct 9, 2025
81c6ed5
Remove get_dummy_inputs methods from exporter configs
pctablet505 Oct 13, 2025
d6a8dfd
Rename LitertExporter to LiteRTExporter
pctablet505 Oct 21, 2025
663c190
Update registry.py
pctablet505 Oct 24, 2025
e0d02ee
Refactor exporter registry to use model classes
pctablet505 Oct 24, 2025
6c98400
Remove conditional import for keras
pctablet505 Oct 24, 2025
b9e3789
Add comprehensive export test suites for Keras Hub
pctablet505 Oct 24, 2025
9f63b2a
Refactor LiteRT exporter model adapters
pctablet505 Oct 24, 2025
b4ce293
Merge branch 'keras-team:master' into export
pctablet505 Oct 24, 2025
4ebc701
Clarify type annotations in docstrings for export modules
pctablet505 Oct 24, 2025
298967e
testing refactor
pctablet505 Oct 25, 2025
bc0a8b7
refactor test
pctablet505 Oct 25, 2025
ab99186
Fix LiteRT export filepath and mask argument usage
pctablet505 Oct 25, 2025
1c06c46
Refactor LiteRTExporter model adapter calls
pctablet505 Oct 25, 2025
22587f1
Add warning for private TensorFlow API usage
pctablet505 Oct 27, 2025
765d55c
Merge branch 'export' of https://github.com/pctablet505/keras-hub int…
pctablet505 Oct 27, 2025
70f712a
Refactor exporter configs and remove TextModelExporterConfig
pctablet505 Oct 27, 2025
e47545d
Refactor trackable children filtering logic
pctablet505 Oct 27, 2025
0a266b4
Refactor ExporterRegistry model config lookup
pctablet505 Oct 27, 2025
21f6b2c
Update litert.py
pctablet505 Oct 27, 2025
efa25ae
Refactor tests to remove try/except and improve clarity
pctablet505 Oct 27, 2025
ec37ac4
Fix docstring in TextClassifierExporterConfig
pctablet505 Oct 27, 2025
911eb96
Update base.py
pctablet505 Oct 27, 2025
51b99b1
Create litert_export_design.md
pctablet505 Oct 27, 2025
7ef9348
Refactor LiteRT export tests for consistency and efficiency
pctablet505 Oct 28, 2025
2295181
Refactor LiteRT export tests to support per-output thresholds
pctablet505 Oct 28, 2025
4adeadf
Update litert_models_test.py
pctablet505 Oct 28, 2025
00f49ca
Delete litert_export_design.md
pctablet505 Oct 28, 2025
5fa0498
Update litert_models_test.py
pctablet505 Oct 28, 2025
052669d
Refactor LiteRT export tests to use pytest parametrization
pctablet505 Oct 28, 2025
a273e42
Refactor export registry and add direct export to Task
pctablet505 Oct 28, 2025
519c3b6
Update litert.py
pctablet505 Oct 28, 2025
2dcbf23
Merge branch 'keras-team:master' into export
pctablet505 Oct 28, 2025
0136c34
Update task.py
pctablet505 Oct 29, 2025
f8bd6fa
Merge branch 'export' of https://github.com/pctablet505/keras-hub int…
pctablet505 Oct 29, 2025
14cffe0
Update test_case.py
pctablet505 Oct 29, 2025
9267b51
Enable dynamic input shapes for LiteRT export
pctablet505 Oct 31, 2025
c622d8d
Improve SignatureDef handling in LiteRT export tests
pctablet505 Nov 4, 2025
ca6056b
Refactor LiteRT test utilities for clarity and robustness
pctablet505 Nov 4, 2025
d43de36
Refactor TFLite inference to use signature runner
pctablet505 Nov 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_hub/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
since your modifications would be overwritten.
"""

from keras_hub import export as export
from keras_hub import layers as layers
from keras_hub import metrics as metrics
from keras_hub import models as models
Expand Down
25 changes: 25 additions & 0 deletions keras_hub/api/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""DO NOT EDIT.
This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras_hub.src.export.configs import (
CausalLMExporterConfig as CausalLMExporterConfig,
)
from keras_hub.src.export.configs import (
ImageClassifierExporterConfig as ImageClassifierExporterConfig,
)
from keras_hub.src.export.configs import (
ImageSegmenterExporterConfig as ImageSegmenterExporterConfig,
)
from keras_hub.src.export.configs import (
ObjectDetectorExporterConfig as ObjectDetectorExporterConfig,
)
from keras_hub.src.export.configs import (
Seq2SeqLMExporterConfig as Seq2SeqLMExporterConfig,
)
from keras_hub.src.export.configs import (
TextClassifierExporterConfig as TextClassifierExporterConfig,
)
from keras_hub.src.export.litert import LiteRTExporter as LiteRTExporter
9 changes: 9 additions & 0 deletions keras_hub/src/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Export base classes and configurations for advanced usage
from keras_hub.src.export.base import KerasHubExporter
from keras_hub.src.export.base import KerasHubExporterConfig
from keras_hub.src.export.configs import CausalLMExporterConfig
from keras_hub.src.export.configs import Seq2SeqLMExporterConfig
from keras_hub.src.export.configs import TextClassifierExporterConfig
from keras_hub.src.export.configs import get_exporter_config
from keras_hub.src.export.litert import LiteRTExporter
from keras_hub.src.export.litert import export_litert
120 changes: 120 additions & 0 deletions keras_hub/src/export/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Base classes for Keras-Hub model exporters.

This module provides the foundation for exporting Keras-Hub models to various
formats. It defines the abstract base classes that all exporters must implement.
"""

from abc import ABC
from abc import abstractmethod


class KerasHubExporterConfig(ABC):
"""Base configuration class for Keras-Hub model exporters.

This class defines the interface for exporter configurations that specify
how different types of Keras-Hub models should be exported.
"""

# Model type this exporter handles (e.g., "causal_lm", "text_classifier")
MODEL_TYPE = None

# Expected input structure for this model type
EXPECTED_INPUTS = []

def __init__(self, model, **kwargs):
"""Initialize the exporter configuration.

Args:
model: `keras.Model`. The Keras-Hub model to export.
**kwargs: Additional configuration parameters.
"""
self.model = model
self.config_kwargs = kwargs
self._validate_model()

def _validate_model(self):
"""Validate that the model is compatible with this exporter."""
if not self._is_model_compatible():
raise ValueError(
f"Model {self.model.__class__.__name__} is not compatible "
f"with {self.__class__.__name__}"
)

@abstractmethod
def _is_model_compatible(self):
"""Check if the model is compatible with this exporter.

Returns:
`bool`. True if compatible, False otherwise
"""
pass

@abstractmethod
def get_input_signature(self, sequence_length=None):
"""Get the input signature for this model type.

Args:
sequence_length: `int` or `None`. Optional sequence length for
input tensors.

Returns:
`dict`. Dictionary mapping input names to tensor specifications.
"""
pass


class KerasHubExporter(ABC):
"""Base class for Keras-Hub model exporters.

This class provides the common interface for exporting Keras-Hub models
to different formats (LiteRT, ONNX, etc.).
"""

def __init__(self, config, **kwargs):
"""Initialize the exporter.

Args:
config: `KerasHubExporterConfig`. Exporter configuration specifying
model type and parameters.
**kwargs: Additional exporter-specific parameters.
"""
self.config = config
self.model = config.model
self.export_kwargs = kwargs

@abstractmethod
def export(self, filepath):
"""Export the model to the specified filepath.

Args:
filepath: `str`. Path where to save the exported model.
"""
pass

def _ensure_model_built(self, param=None):
"""Ensure the model is properly built with correct input structure.

This method builds the model using model.build() with input shapes.
This creates the necessary variables and initializes the model structure
for export without needing dummy data.

Args:
param: `int` or `None`. Optional parameter for input signature
(e.g., sequence_length for text models, image_size for vision
models).
"""
# Get input signature (returns dict of InputSpec objects)
input_signature = self.config.get_input_signature(param)

# Extract shapes from InputSpec objects
input_shapes = {}
for name, spec in input_signature.items():
if hasattr(spec, "shape"):
input_shapes[name] = spec.shape
else:
# Fallback for unexpected formats
input_shapes[name] = spec

# Build the model using shapes only (no actual data allocation)
# This creates variables and initializes the model structure
self.model.build(input_shape=input_shapes)
197 changes: 197 additions & 0 deletions keras_hub/src/export/base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""Tests for base export classes."""

import keras

from keras_hub.src.export.base import KerasHubExporter
from keras_hub.src.export.base import KerasHubExporterConfig
from keras_hub.src.tests.test_case import TestCase


class DummyExporterConfig(KerasHubExporterConfig):
"""Dummy configuration for testing."""

MODEL_TYPE = "test_model"
EXPECTED_INPUTS = ["input_ids", "attention_mask"]
DEFAULT_SEQUENCE_LENGTH = 128

def __init__(self, model, compatible=True, **kwargs):
self.is_compatible = compatible
super().__init__(model, **kwargs)

def _is_model_compatible(self):
return self.is_compatible

def get_input_signature(self, sequence_length=None):
seq_len = sequence_length or self.DEFAULT_SEQUENCE_LENGTH
return {
"input_ids": keras.layers.InputSpec(
shape=(None, seq_len), dtype="int32"
),
"attention_mask": keras.layers.InputSpec(
shape=(None, seq_len), dtype="int32"
),
}


class DummyExporter(KerasHubExporter):
"""Dummy exporter for testing."""

def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.exported = False
self.export_path = None

def export(self, filepath):
self.exported = True
self.export_path = filepath
return filepath


class KerasHubExporterConfigTest(TestCase):
"""Tests for KerasHubExporterConfig base class."""

def test_init_with_compatible_model(self):
"""Test initialization with a compatible model."""
model = keras.Sequential([keras.layers.Dense(10)])
config = DummyExporterConfig(model, compatible=True)

self.assertEqual(config.model, model)
self.assertEqual(config.MODEL_TYPE, "test_model")
self.assertEqual(
config.EXPECTED_INPUTS, ["input_ids", "attention_mask"]
)

def test_init_with_incompatible_model_raises_error(self):
"""Test that incompatible model raises ValueError."""
model = keras.Sequential([keras.layers.Dense(10)])

with self.assertRaisesRegex(ValueError, "not compatible"):
DummyExporterConfig(model, compatible=False)

def test_get_input_signature_default_sequence_length(self):
"""Test get_input_signature with default sequence length."""
model = keras.Sequential([keras.layers.Dense(10)])
config = DummyExporterConfig(model)

signature = config.get_input_signature()

self.assertIn("input_ids", signature)
self.assertIn("attention_mask", signature)
self.assertEqual(signature["input_ids"].shape, (None, 128))
self.assertEqual(signature["attention_mask"].shape, (None, 128))

def test_get_input_signature_custom_sequence_length(self):
"""Test get_input_signature with custom sequence length."""
model = keras.Sequential([keras.layers.Dense(10)])
config = DummyExporterConfig(model)

signature = config.get_input_signature(sequence_length=256)

self.assertEqual(signature["input_ids"].shape, (None, 256))
self.assertEqual(signature["attention_mask"].shape, (None, 256))

def test_config_kwargs_stored(self):
"""Test that additional kwargs are stored."""
model = keras.Sequential([keras.layers.Dense(10)])
config = DummyExporterConfig(
model, custom_param="value", another_param=42
)

self.assertEqual(config.config_kwargs["custom_param"], "value")
self.assertEqual(config.config_kwargs["another_param"], 42)


class KerasHubExporterTest(TestCase):
"""Tests for KerasHubExporter base class."""

def test_init_stores_config_and_model(self):
"""Test that initialization stores config and model."""
model = keras.Sequential([keras.layers.Dense(10)])
config = DummyExporterConfig(model)
exporter = DummyExporter(config, verbose=True, custom_param="test")

self.assertEqual(exporter.config, config)
self.assertEqual(exporter.model, model)
self.assertEqual(exporter.export_kwargs["verbose"], True)
self.assertEqual(exporter.export_kwargs["custom_param"], "test")

def test_export_method_called(self):
"""Test that export method can be called."""
model = keras.Sequential([keras.layers.Dense(10)])
config = DummyExporterConfig(model)
exporter = DummyExporter(config)

result = exporter.export("/tmp/test_model")

self.assertTrue(exporter.exported)
self.assertEqual(exporter.export_path, "/tmp/test_model")
self.assertEqual(result, "/tmp/test_model")

def test_ensure_model_built(self):
"""Test _ensure_model_built method."""

class TestModel(keras.Model):
def __init__(self):
super().__init__()
self.dense = keras.layers.Dense(10)

def call(self, inputs):
return self.dense(inputs["input_ids"])

model = TestModel()
config = DummyExporterConfig(model)
exporter = DummyExporter(config)

# Model should not be built initially
self.assertFalse(model.built)

# Call _ensure_model_built
exporter._ensure_model_built()

# Model should now be built
self.assertTrue(model.built)

def test_ensure_model_built_with_custom_param(self):
"""Test _ensure_model_built with custom sequence length."""

class TestModel(keras.Model):
def __init__(self):
super().__init__()
self.dense = keras.layers.Dense(10)

def call(self, inputs):
return self.dense(inputs["input_ids"])

model = TestModel()
config = DummyExporterConfig(model)
exporter = DummyExporter(config)

# Call with custom sequence length
exporter._ensure_model_built(param=512)

# Verify model is built
self.assertTrue(model.built)

def test_ensure_model_built_already_built_model(self):
"""Test _ensure_model_built with already built model."""

class TestModel(keras.Model):
def __init__(self):
super().__init__()
self.dense = keras.layers.Dense(10)

def call(self, inputs):
return self.dense(inputs["input_ids"])

model = TestModel()
# Pre-build the model
model.build(input_shape={"input_ids": (None, 128)})

config = DummyExporterConfig(model)
exporter = DummyExporter(config)

# Should not raise an error for already built model
exporter._ensure_model_built()

# Model should still be built
self.assertTrue(model.built)
Loading
Loading