You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Keras offers a very handy way to serialise and deserialise model 'hyper-parameters' to json.
This allows to reconstruct a model, equivalent to a flax Module, with no code, just by loading them from a file.
This powers its functionality of loading a model from a file without having to store an accompanying script with the model weights.
See for example this link
One thing that other the years I've consistently desired is for flax to offer something similar.
While flax offers a good way to serialise parameters (through flax.serialise or orbax) serialising the structure is not easy:
if using linen, we can call dataclasses.asdict(LinenModule), but this results in several non-json compatible types. In principle we could pickle it, which is not ideal but better than nothing, however the large use of lambdas in flax initialisers break it
it is possible to use cloud pickle to serialise this dictionary, but cloud pickle is incompatible among different python versions, so it is not a good solution for storing this kind of metadata.
If using the new nnx, this approach does not work, and I'm unsure of what alternative one could use.
It would be a huge addition if some similar feature was supported in flax.
Already a minor improvement would be if the initialisers in flax where partial(init_fun, kwargs) instead of lambdas, as to make them pickle able.
But in general, a mechanism like that of keras to make all Modules serialisable, and a method that also supports nnx, would be an important addition.
This, in my opinion, should be supported by flax itself and not an addition by an external package, because it should be something standardised, so that when users define custom modules they can optionally define the methods necessary to make them play well with serialisation.
The text was updated successfully, but these errors were encountered:
Keras offers a very handy way to serialise and deserialise model 'hyper-parameters' to json.
This allows to reconstruct a model, equivalent to a flax Module, with no code, just by loading them from a file.
This powers its functionality of loading a model from a file without having to store an accompanying script with the model weights.
See for example this link
One thing that other the years I've consistently desired is for flax to offer something similar.
While flax offers a good way to serialise parameters (through flax.serialise or orbax) serialising the structure is not easy:
dataclasses.asdict(LinenModule)
, but this results in several non-json compatible types. In principle we could pickle it, which is not ideal but better than nothing, however the large use of lambdas in flax initialisers break itIt would be a huge addition if some similar feature was supported in flax.
Already a minor improvement would be if the initialisers in flax where
partial(init_fun, kwargs)
instead of lambdas, as to make them pickle able.But in general, a mechanism like that of keras to make all Modules serialisable, and a method that also supports nnx, would be an important addition.
This, in my opinion, should be supported by flax itself and not an addition by an external package, because it should be something standardised, so that when users define custom modules they can optionally define the methods necessary to make them play well with serialisation.
The text was updated successfully, but these errors were encountered: