-
I've been experimenting with JAX via Flax Linen and Objax. I have a Flax Linen EfficientNet model def working now and next step was run basic validation on the model. I got stuck figuring out what a reasonable approach for loading the weights. There are examples for init the model, training, etc. I feel I have a handle on general concepts related to the params, pytrees, etc. But I'm not seeing a quick and easy path to load pretrained weights into a variables param/batch_stats for just running inference. I've got model weights in a numpy dict, it's flat like a PyTorch dict ('module.module.conv.weight'). I can separate batch_stats and param easily enough. Are there any helpers that would map an existing variables struct, in an assignable form ( |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
I think the easiest way to do what you describe is by using
Do you think that works? We don't have a utility for traversing keys in creation order. I'm also not quite sure if that would be flexible enough? What if our internal layers would lets say create a bias before a kernel or the other way around then this would stop working, I guess? |
Beta Was this translation helpful? Give feedback.
-
@jheek thanks, the flatten/unflatten utils were helpful, I wrote my own that essentially did the same thing but no sense in keeping that... they do actually traverse in the correct order My current hack is:
|
Beta Was this translation helpful? Give feedback.
I think the easiest way to do what you describe is by using
flax.traverse_util.unflatten_dict
You pass it the flattened dictionary structure so you just write a function like this:Do you think that works?
We don't have a utility for traversing keys in creation order. I'm also not quite sure if that would be flexible enough? What if our internal layers would lets say create a bias before a kernel or the other way around then this would stop working, I guess?