How to read multiple dictionaries from a single binary file using flax.serialization #2462
Unanswered
abhiroop513
asked this question in
Q&A
Replies: 1 comment
-
Can you please give a code example? Not sure how you are running into this. That said, you can create a dictionary of dictionaries and serialize all models together e.g: import jax.numpy as jnp
from flax import serialization
import flax.linen as nn
import jax
x = jnp.ones((1, 4))
module_a = nn.Dense(10)
variables_a = module_a.init(jax.random.PRNGKey(1), x)
module_b = nn.Dense(8)
variables_b = module_a.init(jax.random.PRNGKey(2), x)
# save
variables_dict = {'a': variables_a, 'b': variables_b}
variable_bytes = serialization.to_bytes(variables_dict)
...
# load
variables_dict = serialization.from_bytes(variables_dict, variable_bytes)
variables_a = variables_dict['a']
variables_b = variables_dict['b'] |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
When I try to read multiple dictionaries (previously appended using 'flax.serialization.to_bytes' to a single binary file) , I cannot, as it uses msgpack.unpackb, and thus gives the error:
" msgpack.exceptions.ExtraData: unpack(b) received extra data"
Please let me know how to read multiple dictionaries from a single binary file.
Beta Was this translation helpful? Give feedback.
All reactions