From 940ff5d9557a9f8826f10b499b273735a60353fa Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Mon, 31 Jul 2023 23:06:26 +0200 Subject: [PATCH 1/4] make flax.core.copy `add_or_replace` optional ignore .envrc (direnv files) format add test fix --- .gitignore | 5 ++++- flax/core/frozen_dict.py | 8 ++++++-- tests/core/core_frozen_dict_test.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 74ac6d8558..5648751009 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,7 @@ build/ .pytype .vscode/* /.devcontainer -docs/**/tmp \ No newline at end of file +docs/**/tmp + +# used by direnv +.envrc \ No newline at end of file diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 48707fa046..3fcdba6d05 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -15,7 +15,7 @@ """Frozen Dictionary.""" import collections -from typing import Any, TypeVar, Mapping, Dict, Tuple, Union, Hashable +from typing import Any, Dict, Hashable, Optional, Mapping, Tuple, TypeVar, Union from flax import serialization import jax @@ -111,7 +111,11 @@ def __hash__(self): self._hash = h return self._hash - def copy(self, add_or_replace: Mapping[K, V]) -> 'FrozenDict[K, V]': + def copy( + self, add_or_replace: Optional[Mapping[K, V]] = None + ) -> 'FrozenDict[K, V]': + if add_or_replace is None: + add_or_replace = {} """Create a new FrozenDict with additional or replaced entries.""" return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type] diff --git a/tests/core/core_frozen_dict_test.py b/tests/core/core_frozen_dict_test.py index a44405ffb9..70f3554318 100644 --- a/tests/core/core_frozen_dict_test.py +++ b/tests/core/core_frozen_dict_test.py @@ -121,6 +121,18 @@ def test_utility_copy(self, x, add_or_replace, actual_new_x): new_x == actual_new_x and isinstance(new_x, type(actual_new_x)) ) + @parameterized.parameters( + { + 'x': {'a': 1, 'b': {'c': 2}}, + }, + { + 'x': FrozenDict({'a': 1, 'b': {'c': 2}}), + }, + ) + def test_utility_copy_singlearg(self, x): + new_x = copy(x) + self.assertTrue(new_x == x and isinstance(new_x, type(x))) + @parameterized.parameters( { 'x': {'a': 1, 'b': {'c': 2}}, From 62efe1271aab98b3a6f71258054436f3a9d41135 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Mon, 31 Jul 2023 23:17:06 +0200 Subject: [PATCH 2/4] fix --- flax/core/frozen_dict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 3fcdba6d05..3b34c6d80f 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -227,7 +227,7 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: def copy( x: Union[FrozenDict, Dict[str, Any]], - add_or_replace: Union[FrozenDict, Dict[str, Any]], + add_or_replace: Optional[Union[FrozenDict, Dict[str, Any]]] = None, ) -> Union[FrozenDict, Dict[str, Any]]: """Create a new dict with additional and/or replaced entries. This is a utility function that can act on either a FrozenDict or regular dict and mimics the @@ -248,7 +248,8 @@ def copy( return x.copy(add_or_replace) elif isinstance(x, dict): new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x - new_dict.update(add_or_replace) + if add_or_replace is not None: + new_dict.update(add_or_replace) return new_dict raise TypeError(f'Expected FrozenDict or dict, got {type(x)}') From a30f09f15aa0f7e6bc383e0d78a969a4d01e7aea Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Tue, 1 Aug 2023 10:15:30 +0200 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Cristian Garcia --- flax/core/frozen_dict.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 3b34c6d80f..006efbc8ed 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -112,10 +112,8 @@ def __hash__(self): return self._hash def copy( - self, add_or_replace: Optional[Mapping[K, V]] = None + self, add_or_replace: Mapping[K, V] = MappingProxyType({}) ) -> 'FrozenDict[K, V]': - if add_or_replace is None: - add_or_replace = {} """Create a new FrozenDict with additional or replaced entries.""" return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type] @@ -227,7 +225,7 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: def copy( x: Union[FrozenDict, Dict[str, Any]], - add_or_replace: Optional[Union[FrozenDict, Dict[str, Any]]] = None, + add_or_replace: Union[FrozenDict[str, Any], Dict[str, Any]] = FrozenDict({}), ) -> Union[FrozenDict, Dict[str, Any]]: """Create a new dict with additional and/or replaced entries. This is a utility function that can act on either a FrozenDict or regular dict and mimics the From 81c274af780f52dcdecbfcb06e436caf3e2ce8ea Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Tue, 1 Aug 2023 10:18:50 +0200 Subject: [PATCH 4/4] fixes pyink --- flax/core/frozen_dict.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 006efbc8ed..8c5c5646f1 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -16,6 +16,7 @@ import collections from typing import Any, Dict, Hashable, Optional, Mapping, Tuple, TypeVar, Union +from types import MappingProxyType from flax import serialization import jax @@ -225,7 +226,9 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: def copy( x: Union[FrozenDict, Dict[str, Any]], - add_or_replace: Union[FrozenDict[str, Any], Dict[str, Any]] = FrozenDict({}), + add_or_replace: Union[FrozenDict[str, Any], Dict[str, Any]] = FrozenDict( + {} + ), ) -> Union[FrozenDict, Dict[str, Any]]: """Create a new dict with additional and/or replaced entries. This is a utility function that can act on either a FrozenDict or regular dict and mimics the @@ -246,8 +249,7 @@ def copy( return x.copy(add_or_replace) elif isinstance(x, dict): new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x - if add_or_replace is not None: - new_dict.update(add_or_replace) + new_dict.update(add_or_replace) return new_dict raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')