diff --git a/keras/src/export/litert_test.py b/keras/src/export/litert_test.py index c0c06d70a167..bf3e6dfe4b0d 100644 --- a/keras/src/export/litert_test.py +++ b/keras/src/export/litert_test.py @@ -470,7 +470,7 @@ def test_export_with_optimizations_default(self): model.export( temp_filepath, format="litert", - optimizations=[tensorflow.lite.Optimize.DEFAULT], + litert_kwargs={"optimizations": [tensorflow.lite.Optimize.DEFAULT]}, ) self.assertTrue(os.path.exists(temp_filepath)) @@ -501,7 +501,11 @@ def test_export_with_optimizations_sparsity(self): model.export( temp_filepath, format="litert", - optimizations=[tensorflow.lite.Optimize.EXPERIMENTAL_SPARSITY], + litert_kwargs={ + "optimizations": [ + tensorflow.lite.Optimize.EXPERIMENTAL_SPARSITY + ] + }, ) self.assertTrue(os.path.exists(temp_filepath)) @@ -532,7 +536,9 @@ def test_export_with_optimizations_size(self): model.export( temp_filepath, format="litert", - optimizations=[tensorflow.lite.Optimize.OPTIMIZE_FOR_SIZE], + litert_kwargs={ + "optimizations": [tensorflow.lite.Optimize.OPTIMIZE_FOR_SIZE] + }, ) self.assertTrue(os.path.exists(temp_filepath)) @@ -562,7 +568,9 @@ def test_export_with_optimizations_latency(self): model.export( temp_filepath, format="litert", - optimizations=[tensorflow.lite.Optimize.OPTIMIZE_FOR_LATENCY], + litert_kwargs={ + "optimizations": [tensorflow.lite.Optimize.OPTIMIZE_FOR_LATENCY] + }, ) self.assertTrue(os.path.exists(temp_filepath)) @@ -592,10 +600,12 @@ def test_export_with_multiple_optimizations(self): model.export( temp_filepath, format="litert", - optimizations=[ - tensorflow.lite.Optimize.DEFAULT, - tensorflow.lite.Optimize.EXPERIMENTAL_SPARSITY, - ], + litert_kwargs={ + "optimizations": [ + tensorflow.lite.Optimize.DEFAULT, + tensorflow.lite.Optimize.EXPERIMENTAL_SPARSITY, + ] + }, ) self.assertTrue(os.path.exists(temp_filepath)) @@ -627,8 +637,10 @@ def representative_dataset(): model.export( temp_filepath, format="litert", - optimizations=[tensorflow.lite.Optimize.DEFAULT], - representative_dataset=representative_dataset, + litert_kwargs={ + "optimizations": [tensorflow.lite.Optimize.DEFAULT], + "representative_dataset": representative_dataset, + }, ) self.assertTrue(os.path.exists(temp_filepath)) @@ -671,9 +683,11 @@ def representative_dataset(): model.export( temp_filepath, format="litert", - optimizations=[tensorflow.lite.Optimize.DEFAULT], - representative_dataset=representative_dataset, - experimental_new_quantizer=True, + litert_kwargs={ + "optimizations": [tensorflow.lite.Optimize.DEFAULT], + "representative_dataset": representative_dataset, + "experimental_new_quantizer": True, + }, ) self.assertTrue(os.path.exists(temp_filepath)) @@ -709,7 +723,7 @@ def test_export_optimization_file_size_comparison(self): model.export( filepath_with_opt, format="litert", - optimizations=[tensorflow.lite.Optimize.DEFAULT], + litert_kwargs={"optimizations": [tensorflow.lite.Optimize.DEFAULT]}, ) # Optimized model should be smaller diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 37f4b3bef7ef..f358316a4c30 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -622,7 +622,10 @@ def export( format="tf_saved_model", verbose=None, input_signature=None, - **kwargs, + saved_model_kwargs=None, + onnx_kwargs=None, + litert_kwargs=None, + openvino_kwargs=None, ): """Export the model as an artifact for inference. @@ -640,27 +643,29 @@ def export( `tf.TensorSpec`, `backend.KerasTensor`, or backend tensor. If not provided, it will be automatically computed. Defaults to `None`. - **kwargs: Additional keyword arguments. - - `is_static`: Optional `bool`. Specific to the JAX backend and - `format="tf_saved_model"`. Indicates whether `fn` is static. - Set to `False` if `fn` involves state updates (e.g., RNG - seeds and counters). - - `jax2tf_kwargs`: Optional `dict`. Specific to the JAX backend - and `format="tf_saved_model"`. Arguments for - `jax2tf.convert`. See the documentation for + saved_model_kwargs: Optional `dict`. Keyword arguments specific to + `format="tf_saved_model"`. Supported options: + - `is_static`: Optional `bool`. Specific to the JAX backend. + Indicates whether `fn` is static. Set to `False` if `fn` + involves state updates (e.g., RNG seeds and counters). + - `jax2tf_kwargs`: Optional `dict`. Specific to the JAX backend. + Arguments for `jax2tf.convert`. See the documentation for [`jax2tf.convert`]( https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). If `native_serialization` and `polymorphic_shapes` are not provided, they will be automatically computed. - - `opset_version`: Optional `int`. Specific to `format="onnx"`. - An integer value that specifies the ONNX opset version. - - LiteRT-specific options: Optional keyword arguments specific - to `format="litert"`. These are passed directly to the - TensorFlow Lite converter and include options like - `optimizations`, `representative_dataset`, - `experimental_new_quantizer`, `allow_custom_ops`, - `enable_select_tf_ops`, etc. See TensorFlow Lite - documentation for all available options. + onnx_kwargs: Optional `dict`. Keyword arguments specific to + `format="onnx"`. Supported options: + - `opset_version`: Optional `int`. An integer value that + specifies the ONNX opset version. + litert_kwargs: Optional `dict`. Keyword arguments specific to + `format="litert"`. These are passed directly to the TensorFlow + Lite converter and include options like `optimizations`, + `representative_dataset`, `experimental_new_quantizer`, + `allow_custom_ops`, `enable_select_tf_ops`, etc. See + TensorFlow Lite documentation for all available options. + openvino_kwargs: Optional `dict`. Keyword arguments specific to + `format="openvino"`. **Note:** This feature is currently supported only with TensorFlow, JAX and Torch backends. @@ -682,12 +687,34 @@ def export( predictions = reloaded_artifact.serve(input_data) ``` + With JAX backend, you can pass additional options via + `saved_model_kwargs`: + + ```python + # Export with JAX-specific options + model.export( + "path/to/location", + format="tf_saved_model", + saved_model_kwargs={ + "is_static": True, + "jax2tf_kwargs": {"enable_xla": True} + } + ) + ``` + Here's how to export an ONNX for inference. ```python # Export the model as a ONNX artifact model.export("path/to/location", format="onnx") + # Export with specific ONNX opset version + model.export( + "path/to/location", + format="onnx", + onnx_kwargs={"opset_version": 18} + ) + # Load the artifact in a different process/environment ort_session = onnxruntime.InferenceSession("path/to/location") ort_inputs = { @@ -702,6 +729,20 @@ def export( # Export the model as a LiteRT artifact model.export("path/to/location", format="litert") + # Export with quantization options + def representative_dataset(): + for _ in range(100): + yield [sample_input_data] + + model.export( + "path/to/location", + format="litert", + litert_kwargs={ + "optimizations": [tf.lite.Optimize.DEFAULT], + "representative_dataset": representative_dataset + } + ) + # Load the artifact in a different process/environment interpreter = tf.lite.Interpreter(model_path="path/to/location") interpreter.allocate_tensors() @@ -736,7 +777,7 @@ def export( filepath, verbose, input_signature=input_signature, - **kwargs, + **(saved_model_kwargs or {}), ) elif format == "onnx": export_onnx( @@ -744,7 +785,7 @@ def export( filepath, verbose, input_signature=input_signature, - **kwargs, + **(onnx_kwargs or {}), ) elif format == "openvino": export_openvino( @@ -752,14 +793,14 @@ def export( filepath, verbose, input_signature=input_signature, - **kwargs, + **(openvino_kwargs or {}), ) elif format == "litert": export_litert( self, filepath, input_signature=input_signature, - **kwargs, + **(litert_kwargs or {}), ) @classmethod