diff --git a/.gitignore b/.gitignore index 74ac6d8558..5648751009 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,7 @@ build/ .pytype .vscode/* /.devcontainer -docs/**/tmp \ No newline at end of file +docs/**/tmp + +# used by direnv +.envrc \ No newline at end of file diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 48707fa046..8c5c5646f1 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -15,7 +15,8 @@ """Frozen Dictionary.""" import collections -from typing import Any, TypeVar, Mapping, Dict, Tuple, Union, Hashable +from typing import Any, Dict, Hashable, Optional, Mapping, Tuple, TypeVar, Union +from types import MappingProxyType from flax import serialization import jax @@ -111,7 +112,9 @@ def __hash__(self): self._hash = h return self._hash - def copy(self, add_or_replace: Mapping[K, V]) -> 'FrozenDict[K, V]': + def copy( + self, add_or_replace: Mapping[K, V] = MappingProxyType({}) + ) -> 'FrozenDict[K, V]': """Create a new FrozenDict with additional or replaced entries.""" return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type] @@ -223,7 +226,9 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: def copy( x: Union[FrozenDict, Dict[str, Any]], - add_or_replace: Union[FrozenDict, Dict[str, Any]], + add_or_replace: Union[FrozenDict[str, Any], Dict[str, Any]] = FrozenDict( + {} + ), ) -> Union[FrozenDict, Dict[str, Any]]: """Create a new dict with additional and/or replaced entries. This is a utility function that can act on either a FrozenDict or regular dict and mimics the diff --git a/tests/core/core_frozen_dict_test.py b/tests/core/core_frozen_dict_test.py index a44405ffb9..70f3554318 100644 --- a/tests/core/core_frozen_dict_test.py +++ b/tests/core/core_frozen_dict_test.py @@ -121,6 +121,18 @@ def test_utility_copy(self, x, add_or_replace, actual_new_x): new_x == actual_new_x and isinstance(new_x, type(actual_new_x)) ) + @parameterized.parameters( + { + 'x': {'a': 1, 'b': {'c': 2}}, + }, + { + 'x': FrozenDict({'a': 1, 'b': {'c': 2}}), + }, + ) + def test_utility_copy_singlearg(self, x): + new_x = copy(x) + self.assertTrue(new_x == x and isinstance(new_x, type(x))) + @parameterized.parameters( { 'x': {'a': 1, 'b': {'c': 2}},