-
Notifications
You must be signed in to change notification settings - Fork 306
Model Export to liteRT #2405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Model Export to liteRT #2405
Changes from 48 commits
087b9b2
de830b1
62d2484
3b71125
3d453ff
92b1254
e46241d
6e970e2
15ad9f3
02ca0d9
901c233
442fdd3
5446e2a
5c31d88
3290d42
8b1024f
8df5a75
759d223
0737c93
c733e18
5ab911f
c1e26dd
6fa8379
81c6ed5
d6a8dfd
663c190
e0d02ee
6c98400
b9e3789
9f63b2a
b4ce293
4ebc701
298967e
bc0a8b7
ab99186
1c06c46
22587f1
765d55c
70f712a
e47545d
0a266b4
21f6b2c
efa25ae
ec37ac4
911eb96
51b99b1
7ef9348
2295181
4adeadf
00f49ca
5fa0498
052669d
a273e42
519c3b6
2dcbf23
0136c34
f8bd6fa
14cffe0
9267b51
c622d8d
ca6056b
d43de36
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| """DO NOT EDIT. | ||
| This file was autogenerated. Do not edit it by hand, | ||
| since your modifications would be overwritten. | ||
| """ | ||
|
|
||
| from keras_hub.src.export.configs import ( | ||
| CausalLMExporterConfig as CausalLMExporterConfig, | ||
| ) | ||
| from keras_hub.src.export.configs import ( | ||
| ImageClassifierExporterConfig as ImageClassifierExporterConfig, | ||
| ) | ||
| from keras_hub.src.export.configs import ( | ||
| ImageSegmenterExporterConfig as ImageSegmenterExporterConfig, | ||
| ) | ||
| from keras_hub.src.export.configs import ( | ||
| ObjectDetectorExporterConfig as ObjectDetectorExporterConfig, | ||
| ) | ||
| from keras_hub.src.export.configs import ( | ||
| Seq2SeqLMExporterConfig as Seq2SeqLMExporterConfig, | ||
| ) | ||
| from keras_hub.src.export.configs import ( | ||
| TextClassifierExporterConfig as TextClassifierExporterConfig, | ||
| ) | ||
| from keras_hub.src.export.litert import LiteRTExporter as LiteRTExporter | ||
pctablet505 marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+7
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Here's a cleaner way to write these imports: from keras_hub.src.export.configs import (
CausalLMExporterConfig,
ImageClassifierExporterConfig,
ImageSegmenterExporterConfig,
ObjectDetectorExporterConfig,
Seq2SeqLMExporterConfig,
TextClassifierExporterConfig,
)
from keras_hub.src.export.litert import LiteRTExporter |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # Import registry to trigger initialization and export method extension | ||
| from keras_hub.src.export import registry # noqa: F401 | ||
| from keras_hub.src.export.base import ExporterRegistry | ||
| from keras_hub.src.export.base import KerasHubExporter | ||
| from keras_hub.src.export.base import KerasHubExporterConfig | ||
| from keras_hub.src.export.configs import CausalLMExporterConfig | ||
| from keras_hub.src.export.configs import Seq2SeqLMExporterConfig | ||
| from keras_hub.src.export.configs import TextClassifierExporterConfig | ||
| from keras_hub.src.export.litert import LiteRTExporter | ||
| from keras_hub.src.export.litert import export_litert | ||
| from keras_hub.src.export.registry import export_model |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| """Base classes for Keras-Hub model exporters. | ||
|
|
||
| This module provides the foundation for exporting Keras-Hub models to various | ||
| formats. It follows the Optimum pattern of having different exporters for | ||
| different model types and formats. | ||
| """ | ||
|
|
||
| from abc import ABC | ||
| from abc import abstractmethod | ||
|
|
||
| # Import model classes for registry | ||
|
|
||
|
|
||
| class KerasHubExporterConfig(ABC): | ||
| """Base configuration class for Keras-Hub model exporters. | ||
|
|
||
| This class defines the interface for exporter configurations that specify | ||
| how different types of Keras-Hub models should be exported. | ||
| """ | ||
|
|
||
| # Model type this exporter handles (e.g., "causal_lm", "text_classifier") | ||
| MODEL_TYPE = None | ||
|
|
||
| # Expected input structure for this model type | ||
| EXPECTED_INPUTS = [] | ||
|
|
||
| def __init__(self, model, **kwargs): | ||
| """Initialize the exporter configuration. | ||
|
|
||
| Args: | ||
| model: `keras.Model`. The Keras-Hub model to export. | ||
| **kwargs: Additional configuration parameters. | ||
| """ | ||
| self.model = model | ||
| self.config_kwargs = kwargs | ||
| self._validate_model() | ||
|
|
||
| def _validate_model(self): | ||
| """Validate that the model is compatible with this exporter.""" | ||
| if not self._is_model_compatible(): | ||
| raise ValueError( | ||
| f"Model {self.model.__class__.__name__} is not compatible " | ||
| f"with {self.__class__.__name__}" | ||
| ) | ||
|
|
||
| @abstractmethod | ||
| def _is_model_compatible(self): | ||
| """Check if the model is compatible with this exporter. | ||
|
|
||
| Returns: | ||
| `bool`. True if compatible, False otherwise | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def get_input_signature(self, sequence_length=None): | ||
| """Get the input signature for this model type. | ||
|
|
||
| Args: | ||
| sequence_length: `int` or `None`. Optional sequence length for | ||
| input tensors. | ||
|
|
||
| Returns: | ||
| `dict`. Dictionary mapping input names to tensor specifications. | ||
| """ | ||
| pass | ||
|
|
||
|
|
||
| class KerasHubExporter(ABC): | ||
| """Base class for Keras-Hub model exporters. | ||
|
|
||
| This class provides the common interface for exporting Keras-Hub models | ||
| to different formats (LiteRT, ONNX, etc.). | ||
| """ | ||
|
|
||
| def __init__(self, config, **kwargs): | ||
| """Initialize the exporter. | ||
|
|
||
| Args: | ||
| config: `KerasHubExporterConfig`. Exporter configuration specifying | ||
| model type and parameters. | ||
| **kwargs: Additional exporter-specific parameters. | ||
| """ | ||
| self.config = config | ||
| self.model = config.model | ||
| self.export_kwargs = kwargs | ||
|
|
||
| @abstractmethod | ||
| def export(self, filepath): | ||
| """Export the model to the specified filepath. | ||
|
|
||
| Args: | ||
| filepath: `str`. Path where to save the exported model. | ||
| """ | ||
| pass | ||
|
|
||
| def _ensure_model_built(self, param=None): | ||
| """Ensure the model is properly built with correct input structure. | ||
|
|
||
| This method builds the model using model.build() with input shapes. | ||
| This creates the necessary variables and initializes the model structure | ||
| for export without needing dummy data. | ||
|
|
||
| Args: | ||
| param: `int` or `None`. Optional parameter for input signature | ||
| (e.g., sequence_length for text models, image_size for vision | ||
| models). | ||
| """ | ||
| # Get input signature (returns dict of InputSpec objects) | ||
| input_signature = self.config.get_input_signature(param) | ||
|
|
||
| # Extract shapes from InputSpec objects | ||
| input_shapes = {} | ||
| for name, spec in input_signature.items(): | ||
| if hasattr(spec, "shape"): | ||
| input_shapes[name] = spec.shape | ||
| else: | ||
| # Fallback for unexpected formats | ||
| input_shapes[name] = spec | ||
|
|
||
| # Build the model using shapes only (no actual data allocation) | ||
| # This creates variables and initializes the model structure | ||
| self.model.build(input_shape=input_shapes) | ||
|
|
||
|
|
||
| class ExporterRegistry: | ||
| """Registry for mapping model types to their appropriate exporters.""" | ||
|
|
||
| _configs = {} | ||
| _exporters = {} | ||
|
|
||
| @classmethod | ||
| def register_config(cls, model_class, config_class): | ||
| """Register a configuration class for a model type. | ||
|
|
||
| Args: | ||
| model_class: `type`. The model class (e.g., CausalLM) | ||
| config_class: `type`. The configuration class | ||
| """ | ||
| cls._configs[model_class] = config_class | ||
|
|
||
| @classmethod | ||
| def register_exporter(cls, format_name, exporter_class): | ||
| """Register an exporter class for a format. | ||
|
|
||
| Args: | ||
| format_name: `str`. The export format (e.g., "litert") | ||
| exporter_class: `type`. The exporter class | ||
| """ | ||
| cls._exporters[format_name] = exporter_class | ||
|
|
||
| @classmethod | ||
| def get_config_for_model(cls, model): | ||
| """Get the appropriate configuration for a model. | ||
|
|
||
| Args: | ||
| model: `keras.Model`. The Keras-Hub model | ||
|
|
||
| Returns: | ||
| `KerasHubExporterConfig`. An appropriate exporter configuration | ||
| instance | ||
|
|
||
| Raises: | ||
| ValueError: If no configuration is found for the model type | ||
| """ | ||
| # Iterate through registered configs to find a match | ||
| # This approach is more maintainable and extensible than a | ||
| # hardcoded list | ||
| for model_class, config_class in cls._configs.items(): | ||
| if isinstance(model, model_class): | ||
| return config_class(model) | ||
|
|
||
| # If we get here, model type is not recognized | ||
| raise ValueError( | ||
| f"Could not detect model type for {model.__class__.__name__}. " | ||
| "Supported types: CausalLM, TextClassifier, Seq2SeqLM, " | ||
| "ImageClassifier, ObjectDetector, ImageSegmenter" | ||
| ) | ||
|
|
||
| @classmethod | ||
| def get_exporter(cls, format_name, config, **kwargs): | ||
| """Get an exporter for the specified format. | ||
|
|
||
| Args: | ||
| format_name: `str`. The export format | ||
| config: `KerasHubExporterConfig`. The exporter configuration | ||
| **kwargs: `dict`. Additional parameters for the exporter | ||
|
|
||
| Returns: | ||
| `KerasHubExporter`. An appropriate exporter instance | ||
|
|
||
| Raises: | ||
| ValueError: If no exporter is found for the format | ||
| """ | ||
| if format_name not in cls._exporters: | ||
| raise ValueError(f"No exporter found for format: {format_name}") | ||
|
|
||
| exporter_class = cls._exporters[format_name] | ||
| return exporter_class(config, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alias
as exportis redundant here. For better readability and to follow common Python conventions, you can simplify this import. While this pattern exists in the file, it's a good opportunity to correct it for the new addition.