From 963fa4aa177f9875a6391f5981452fb7ceb56469 Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Wed, 20 Jul 2022 22:50:06 -0700 Subject: [PATCH] Replace use of id() with global counter-based id. Historically we key'd on id() to record sharing relationships during lifting and outer module adoption. This was dumb, and after recently fixing one bad bug arising from id reuse, we should use something sound instead. --- flax/core/scope.py | 5 ++-- flax/ids.py | 57 +++++++++++++++++++++++++++++++++++++++ flax/linen/module.py | 14 +++++----- flax/linen/transforms.py | 31 ++++++++++----------- tests/linen/linen_test.py | 16 +++++++++++ 5 files changed, 99 insertions(+), 24 deletions(-) create mode 100644 flax/ids.py diff --git a/flax/core/scope.py b/flax/core/scope.py index 4d2b1194c..57599ac1a 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -22,7 +22,7 @@ from typing import (Any, Callable, Dict, Generic, Iterable, Mapping, Optional, Sequence, Set, Tuple, TypeVar, Union) -from . import tracers +from flax.ids import uuid from flax import config from flax import errors from flax import struct @@ -30,6 +30,7 @@ from .frozen_dict import freeze from .frozen_dict import FrozenDict from .frozen_dict import unfreeze +from . import tracers import jax from jax import config as jax_config from jax import numpy as jnp @@ -51,7 +52,6 @@ # When conditioning on filters we require explicit boolean comparisons. # pylint: disable=g-bool-id-comparison - @dataclasses.dataclass(frozen=True, eq=True) class DenyList: """DenyList represents an opt-out based mutability filter. @@ -343,6 +343,7 @@ def __init__(self, scope: 'Scope', collection: str, name: str): collection: The collection of the variable (e.g., "params"). name: The name of the variable (e.g., "dense"). """ + self._id = uuid() self.scope = scope self.collection = collection self.name = name diff --git a/flax/ids.py b/flax/ids.py new file mode 100644 index 000000000..973a09b55 --- /dev/null +++ b/flax/ids.py @@ -0,0 +1,57 @@ +# Copyright 2022 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""UUIDs for Flax internals.""" + +import threading + + +class UUIDManager: + """Globally unique counter-based id manager. + + We need globally unique key ids for Module and Variable object instances + to preserve and recreate sharing-by-reference relationship when lifting + transforms and adopting outside Modules. + - Use of id() is unacceptable because these identifiers are literally + pointers which can be recycled, so we rely on a globally unique counter id + instead. + - We need to handle copy/deepcopy uniqueness via a wrapped type. + """ + def __init__(self): + self._lock = threading.Lock() + self._id = 0 + + def __call__(self): + with self._lock: + self._id += 1 + return FlaxId(self._id) + +uuid = UUIDManager() + + +class FlaxId: + """Hashable wrapper for ids that handles uniqueness of copies.""" + def __init__(self, rawid): + self.id = rawid + def __eq__(self, other): + return isinstance(other, FlaxId) and other.id == self.id + def __hash__(self): + return hash(self.id) + def __repr__(self): + return f"FlaxId({self.id})" + def __deepcopy__(self, memo): + del memo + return uuid() + def __copy__(self): + return uuid() diff --git a/flax/linen/module.py b/flax/linen/module.py index 3f2561e7b..ca55c16b7 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -36,6 +36,7 @@ from flax.core.scope import ( # pylint: disable=g-multiple-import CollectionFilter, DenyList, FrozenVariableDict, Variable, VariableDict, union_filters) +from flax.ids import uuid from flax.linen import summary @@ -730,6 +731,7 @@ def __post_init__(self) -> None: # initialization, attach this Module as a submodule of a parent, or bind # this Module at the top-level to variables and rngs. + object.__setattr__(self, '_id', uuid()) object.__setattr__(self, '_state', _ModuleInternalState()) # Typically we set the parent based on the dynamic module context. @@ -827,14 +829,12 @@ def adopt_attr_modules(cache, queue, suffix, subvalue): # Module was passed from outside. It needs to be cloned. # Outside modules are named by attachment, not an outer name. object.__setattr__(subvalue, 'name', None) - key = id(subvalue) + # Preserve sharing-by-reference relationships during adoption + # via cache keyed on unique instance ids. + key = subvalue._id if key not in cache: - # since we use id() as key, we need to keep a reference to original - # subvalue to ensure it's lifetime is long enough for the entire - # model setup and the id() is not recycled. - # TODO(levskaya): consider switching to per-module UUIDs - cache[key] = (subvalue.clone(), subvalue) - subvalue = cache[key][0] + cache[key] = subvalue.clone() + subvalue = cache[key] if subvalue.name is None: object.__setattr__(subvalue, 'parent', self) object.__setattr__(subvalue, 'name', f'{name}{suffix}') diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 434187b4c..8dd9dd4b4 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -43,6 +43,7 @@ traceback_util.register_exclusion(__file__) +# pylint: disable=protected-access # Utils # ----------------------------------------------------------------------------- @@ -79,12 +80,12 @@ def wrapped_fn(x): nonlocal refs if isinstance(x, (VariablePlaceholder, InstancePlaceholder)): x_id = x.id + elif isinstance(x, (Variable, Module)): + x_id = x._id else: - x_id = id(x) + return fn(x) if x_id not in refs: refs[x_id] = fn(x) - else: - pass return refs[x_id] return wrapped_fn @@ -124,9 +125,9 @@ def get_arg_scope(x): nonlocal scopes if isinstance(x, Variable) and isinstance(x.scope, Scope): scopes.append(x.scope) - return VariablePlaceholder(x.collection, x.name, id(x)) + return VariablePlaceholder(x.collection, x.name, x._id) elif isinstance(x, Module) and isinstance(x.scope, Scope): - x._try_setup(shallow=True) # pylint: disable=protected-access + x._try_setup(shallow=True) scopes.append(x.scope) attrs = { f.name: getattr(x, f.name) @@ -134,7 +135,7 @@ def get_arg_scope(x): if f.name != 'parent' and f.init } attrs = jax.tree_util.tree_map(get_arg_scope, attrs) - return InstancePlaceholder(x.__class__, attrs, id(x)) + return InstancePlaceholder(x.__class__, attrs, x._id) return x new_args, new_kwargs = jax.tree_util.tree_map(get_arg_scope, (args, kwargs)) @@ -142,7 +143,7 @@ def get_arg_scope(x): @functools.partial(_memoize_by_id, refs=refs) def get_scopes(module): nonlocal scopes - module._try_setup(shallow=True) # pylint: disable=protected-access + module._try_setup(shallow=True) def get_scopes_inner(x): nonlocal scopes if isinstance(x, Module) and isinstance(x.scope, Scope): @@ -303,9 +304,9 @@ def core_fn(scopes, *args, **kwargs): # we reference module_class, not self.__class__ to avoid infinite loop cloned = module_class(parent=None, **attrs) cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes) - object.__setattr__(cloned, '_state', state.export()) # pylint: disable=protected-access + object.__setattr__(cloned, '_state', state.export()) res = fn(cloned, *args, **kwargs) - self._state.reimport(cloned._state) # pylint: disable=protected-access + self._state.reimport(cloned._state) _test_transformed_return_values(res, fn_name) return res # here we apply the given lifting transform to the scope-ingesting fn @@ -351,9 +352,9 @@ def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs): if not multi_scope: scopes = [scopes] cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes) - object.__setattr__(cloned, '_state', state.export()) # pylint: disable=protected-access + object.__setattr__(cloned, '_state', state.export()) res = prewrapped_fn(cloned, *args, **kwargs) - self._state.reimport(cloned._state) # pylint: disable=protected-access + self._state.reimport(cloned._state) _test_transformed_return_values(res, getattr(class_fn, '__name__', None)) return res core_fns = [functools.partial(core_fn, prewrapped_fn, class_fn) @@ -1325,8 +1326,8 @@ def wrapped_fn(self, *args, **kwargs): prewrapped_fn = wrap_method_once(class_fn) @functools.wraps(prewrapped_fn) def wrapped_fn(self, *args, **kwargs): - if ((not force and not linen_module._use_named_call) # pylint: disable=protected-access - or self._state.in_setup): # pylint: disable=protected-access + if ((not force and not linen_module._use_named_call) + or self._state.in_setup): return prewrapped_fn(self, *args, **kwargs) fn_name = class_fn.__name__ method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' @@ -1335,9 +1336,9 @@ def wrapped_fn(self, *args, **kwargs): # make a scope-function to transform def core_fn(scopes, *args, **kwargs): cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes) - object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access + object.__setattr__(cloned, '_state', self._state.export()) res = prewrapped_fn(cloned, *args, **kwargs) - self._state.reimport(cloned._state) # pylint: disable=protected-access + self._state.reimport(cloned._state) _test_transformed_return_values(res, fn_name) return res # here we apply the given lifting transform to the scope-ingesting fn diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 1db346cce..f740b4b2f 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -14,8 +14,10 @@ """Tests for flax.linen.""" +import copy from absl.testing import absltest, parameterized +from flax import ids from flax import linen as nn import jax @@ -361,5 +363,19 @@ def test_optimized_lstm_cell_matches_regular(self): jtu.check_eq(lstm_params, lstm_opt_params) +class IdsTest(absltest.TestCase): + + def test_hashable(self): + id1 = ids.uuid() + id2 = ids.uuid() + self.assertEqual(id1, id1) + self.assertNotEqual(id1, id2) + self.assertNotEqual(hash(id1), hash(id2)) + id1c = copy.copy(id1) + id1dc = copy.deepcopy(id1) + self.assertNotEqual(hash(id1), hash(id1c)) + self.assertNotEqual(hash(id1), hash(id1dc)) + + if __name__ == '__main__': absltest.main()