diff --git a/flax/serialization.py b/flax/serialization.py index 99da0862..d25462cd 100644 --- a/flax/serialization.py +++ b/flax/serialization.py @@ -117,11 +117,14 @@ def _restore_list(xs, state_dict: Dict[str, Any]) -> List[Any]: def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]: - return {key: to_state_dict(value) for key, value in xs.items()} + str_keys = set(str(k) for k in xs.keys()) + if len(str_keys) != len(xs): + raise ValueError(f'Dict keys do not have a unique string representation: {str_keys}') + return {str(key): to_state_dict(value) for key, value in xs.items()} def _restore_dict(xs, states: Dict[str, Any]) -> Dict[str, Any]: - return {key: from_state_dict(value, states[key]) + return {key: from_state_dict(value, states[str(key)]) for key, value in xs.items()} diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 8f5cc738..c1117a47 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -159,6 +159,22 @@ def __call__(self, x): {'lyrs1_a': {'kernel': (10, 3)}, 'lyrs1_b': {'kernel': (3, 3)}}) + def test_setup_dict_nonstring_keys(self): + class Foo(nn.Module): + def setup(self): + self.a = {(1, 2): nn.Dense(2)} # here the dict using tuple as key + + @nn.compact + def __call__(self, x): + return self.a[(1, 2)](x) + + foo = Foo() + x = jnp.ones(shape=(1, 3)) + params = foo.init(random.PRNGKey(0), x)['params'] + param_shape = jax.tree_map(jnp.shape, params) + self.assertEqual(param_shape, + {'a_(1, 2)': {'kernel': (3, 2), 'bias': (2,)}}) + def test_setup_cloning(self): class MLP(nn.Module): def setup(self):