-
I'm fairly new to Flax and have mostly played with Jax. I see that Flax provides a fairly "platform agnostic" serialization mechanism via MessagePack. I was wondering if there are existing tools to convert them to "standard" formats for deploying trained models like ONNX or tflite. Understandably, Flax doesn't want to take on any of these "heavyweight" libraries as a dependency. But I'm considering using Flax in my next production model and am looking for workflow recommendations. I can certainly take the trained weights and plumb the weights into protos or whatnot, but if tools already exist, I'd rather use those. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi jiawen thanks for your interest in Flax. For serialization we make a distinction between two different strategies: checkpointing & exporting. Currently, Jax and by extension Flax don't have a goto solution for exporting a model. However, there is an experimental api in Jax to convert a model to TensorFlow https://github.com/google/jax/tree/master/jax/experimental/jax2tf which can then be saved as a Please, let me know if you have any more questions |
Beta Was this translation helpful? Give feedback.
Hi jiawen thanks for your interest in Flax.
For serialization we make a distinction between two different strategies: checkpointing & exporting.
The former is what the serialization api is for. ONNX and tflite is what we consider exporting formats. This means they don't just store the state but also the actual model such that it can be run standalone.
Currently, Jax and by extension Flax don't have a goto solution for exporting a model. However, there is an experimental api in Jax to convert a model to TensorFlow https://github.com/google/jax/tree/master/jax/experimental/jax2tf which can then be saved as a
tf.SavedModel
. A SavedModel is what goes into the tflite converter. I don't think t…