Skip to content

Commit

Permalink
Merge pull request #2007 from jheek:fix-non-string-keys-in-statedict
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 436496587
  • Loading branch information
Flax Authors committed Mar 22, 2022
2 parents 08f4c53 + 6226a9c commit 049096b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
7 changes: 5 additions & 2 deletions flax/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,14 @@ def _restore_list(xs, state_dict: Dict[str, Any]) -> List[Any]:


def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]:
return {key: to_state_dict(value) for key, value in xs.items()}
str_keys = set(str(k) for k in xs.keys())
if len(str_keys) != len(xs):
raise ValueError(f'Dict keys do not have a unique string representation: {str_keys}')
return {str(key): to_state_dict(value) for key, value in xs.items()}


def _restore_dict(xs, states: Dict[str, Any]) -> Dict[str, Any]:
return {key: from_state_dict(value, states[key])
return {key: from_state_dict(value, states[str(key)])
for key, value in xs.items()}


Expand Down
16 changes: 16 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,22 @@ def __call__(self, x):
{'lyrs1_a': {'kernel': (10, 3)},
'lyrs1_b': {'kernel': (3, 3)}})

def test_setup_dict_nonstring_keys(self):
class Foo(nn.Module):
def setup(self):
self.a = {(1, 2): nn.Dense(2)} # here the dict using tuple as key

@nn.compact
def __call__(self, x):
return self.a[(1, 2)](x)

foo = Foo()
x = jnp.ones(shape=(1, 3))
params = foo.init(random.PRNGKey(0), x)['params']
param_shape = jax.tree_map(jnp.shape, params)
self.assertEqual(param_shape,
{'a_(1, 2)': {'kernel': (3, 2), 'bias': (2,)}})

def test_setup_cloning(self):
class MLP(nn.Module):
def setup(self):
Expand Down

0 comments on commit 049096b

Please sign in to comment.