Skip to content

Loading pretrained weights (Flax Linen) #544

Answered by jheek
rwightman asked this question in Q&A
Discussion options

You must be logged in to vote

I think the easiest way to do what you describe is by using flax.traverse_util.unflatten_dict You pass it the flattened dictionary structure so you just write a function like this:

flat_params = {}
for key, val in mynpweights.items()
  segments = key.split('.')
  # some logic on the segments
  flat_params[segments] = val
params = flax.traverse_util.unflatten_dict(flat_params)

Do you think that works?

We don't have a utility for traversing keys in creation order. I'm also not quite sure if that would be flexible enough? What if our internal layers would lets say create a bias before a kernel or the other way around then this would stop working, I guess?

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Answer selected by rwightman
Comment options

You must be logged in to vote
1 reply
@rwightman
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants