How to quantize the pre trained JAX model bsed on flax.linen #2239
-
I have a pre-trained
Going for the second option, there's this function Looking at this official example, the code block
it seems to be using the predict:
and the params
My question is that how can I get these two required |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hi @deshwalmahesh, jax2tf provides an example of converting a Flax model to TFLite, which seems to be closer to your use-case: https://github.com/google/jax/tree/main/jax/experimental/jax2tf/examples/tflite/mnist. I would recommend converting using this path, since we provide more support than the This path first converts the JAX graph to a TF graph using If you encounter any non-Flax related problems with this conversion path please file an issue against the JAX repo starting with "[jax2tf] ....", we will try resolving your issue there then. |
Beta Was this translation helpful? Give feedback.
-
Hi @deshwalmahesh , did you succeed in converting MAXIM to ONNX? Could you please share your conversion script or converted model? |
Beta Was this translation helpful? Give feedback.
Hi @deshwalmahesh , did you succeed in converting MAXIM to ONNX? Could you please share your conversion script or converted model?