Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] Save/load of model parameters #4422

Open
PhilipVinc opened this issue Dec 8, 2024 · 2 comments
Open

[FR] Save/load of model parameters #4422

PhilipVinc opened this issue Dec 8, 2024 · 2 comments

Comments

@PhilipVinc
Copy link
Contributor

Keras offers a very handy way to serialise and deserialise model 'hyper-parameters' to json.
This allows to reconstruct a model, equivalent to a flax Module, with no code, just by loading them from a file.

This powers its functionality of loading a model from a file without having to store an accompanying script with the model weights.
See for example this link

One thing that other the years I've consistently desired is for flax to offer something similar.
While flax offers a good way to serialise parameters (through flax.serialise or orbax) serialising the structure is not easy:

  • if using linen, we can call dataclasses.asdict(LinenModule), but this results in several non-json compatible types. In principle we could pickle it, which is not ideal but better than nothing, however the large use of lambdas in flax initialisers break it
In [1]: import flax; import dataclasses; import pickle
In [2]: nn = flax.linen.Dense(3)
In [3]: dataclasses.asdict(nn)
Out[3]:
{'features': 3,
 'use_bias': True,
 'dtype': None,
 'param_dtype': jax.numpy.float32,
 'precision': None,
 'kernel_init': <function jax._src.nn.initializers.variance_scaling.<locals>.init(key: 'Array', shape: 'core.Shape', dtype: 'DTypeLikeInexact' = <class 'jax.numpy.float64'>) -> 'Array'>,
 'bias_init': <function jax.nn.initializers.zeros(key: 'Array', shape: 'core.Shape', dtype: 'DTypeLikeInexact' = <class 'jax.numpy.float64'>) -> 'Array'>,
....
}
In [4]: pickle.dumps(dataclasses.asdict(nn))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 pickle.dumps(dataclasses.asdict(nn))

AttributeError: Can't pickle local object 'variance_scaling.<locals>.init'
  • it is possible to use cloud pickle to serialise this dictionary, but cloud pickle is incompatible among different python versions, so it is not a good solution for storing this kind of metadata.
  • If using the new nnx, this approach does not work, and I'm unsure of what alternative one could use.

It would be a huge addition if some similar feature was supported in flax.
Already a minor improvement would be if the initialisers in flax where partial(init_fun, kwargs) instead of lambdas, as to make them pickle able.
But in general, a mechanism like that of keras to make all Modules serialisable, and a method that also supports nnx, would be an important addition.

This, in my opinion, should be supported by flax itself and not an addition by an external package, because it should be something standardised, so that when users define custom modules they can optionally define the methods necessary to make them play well with serialisation.

@cgarciae
Copy link
Collaborator

cgarciae commented Dec 9, 2024

Hey @PhilipVinc, try using cloudpickle which has better support for lambdas. Getting base pickle to work would require a lot of work.

@PhilipVinc
Copy link
Contributor Author

I know, but as I mentioned above, you cannot use cloudpickle to serialise among different python versions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants