Skip to content

Commit

Permalink
Merge pull request #2020 from jheek:lifted-cond
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 438232016
  • Loading branch information
Flax Authors committed Mar 30, 2022
2 parents 1cb6935 + d244887 commit cd5c4d7
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ vNext
-
-
-
-
- Added lifted conditional `nn.cond`
-
-
-
Expand Down
1 change: 1 addition & 0 deletions docs/flax.linen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Transformations
vjp
custom_vjp
while_loop
cond


Linear modules
Expand Down
70 changes: 67 additions & 3 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,7 @@ def body_fn(scope, c):
rng_groups, rng_splits = _unzip2(split_rngs.items())

def inner(scope_fn, repack_fn,
variable_groups, rng_groups, carry_init):
del carry_init # unused.
variable_groups, rng_groups):
carry_variables, broadcast_variables = variable_groups

def make_loop_rngs(i):
Expand Down Expand Up @@ -865,7 +864,72 @@ def body_wrapper(c):
(carry_variables, broadcast_variables),
(carry_variables,),
rng_groups,
name='while_loop')(scope, init)
name='while_loop')(scope)


def cond(pred: Any,
true_fun: Callable[..., C], false_fun: Callable[..., C],
scope: Scope, *operands,
variables: CollectionFilter = True,
rngs: PRNGSequenceFilter = True) -> C:
"""Lifted version of ``jax.lax.cond``.
The returned values from ``true_fun`` and ``false_fun``
must have the same Pytree structure, shapes, and dtypes.
The variables created or updated inside the
branches must also have the same structure.
Note that this constraint is violated when
creating variables or submodules in only one branch.
Because initializing variables in just one branch
causes the paramater structure to be different.
Example::
def cond_example(scope, x, pred):
scope.variable('state', 'true_count', lambda: 0)
scope.variable('state', 'false_count', lambda: 0)
def true_fn(scope, x):
scope.variable('state', 'true_count').value += 1
return scope.child(nn.dense)(x, 2)
def false_fn(scope, x):
scope.variable('state', 'false_count').value += 1
return -scope.child(nn.dense)(x, 2)
return lift.cond(pred, true_fn, false_fn, scope, x)
Args:
pred: determines if true_fun or false_fun is evaluated.
true_fun: The function evalauted when ``pred`` is `True`.
The signature is (Scope, *operands) -> T.
false_fun: The function evalauted when ``pred`` is `False`.
The signature is (Scope, *operands) -> T.
scope: A Scope or Pytree of scopes to pass
*operands: The arguments passed to ``true_fun`` and ``false_fun``
variables: The variable collections passed to the conditional
branches (default: all)
rngs: The PRNG sequences passed to the conditionals (default: all)
Returns:
The result of the evaluated branch (``true_fun`` or ``false_fun``).
"""
branches = [true_fun, false_fun]
def inner(scope_fn, repack_fn,
variable_groups, rng_groups):
def branch_wrapper(branch_fn, *operands):
scope = scope_fn(variable_groups, rng_groups)
y = branch_fn(scope, *operands)
return y, repack_fn(scope)
pure_branches = [
functools.partial(branch_wrapper, branch_fn)
for branch_fn in branches]
return jax.lax.cond(
pred, pure_branches[0], pure_branches[1], *operands)

return pack(
inner,
(variables,),
(variables,),
(rngs,),
name='cond')(scope)


def custom_vjp(fn: Callable[..., Any],
Expand Down
8 changes: 5 additions & 3 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,23 +672,25 @@ def put_variable(self, col: str, name: str, value: Any):
variables = self._mutable_collection(col)
variables[name] = value

def variable(self, col: str, name: str, init_fn: Callable[..., T],
def variable(self, col: str, name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args) -> Variable[T]:
"""Creates a variable if it doesn't exist yet in this scope and returns it.
Args:
col: the collection of the variable.
name: the name of the variable.
init_fn: a function taking a PRNGKey plus any other number of positional
arguments.
arguments. If None, the variable must already be initialized otherwise
an error is raised.
*init_args: the arguments to evaluate init_fn on lazily.
Returns:
The variable.
"""
self.reserve(name)
if not self.has_variable(col, name):
if not self.is_mutable_collection(col):
if not self.is_mutable_collection(col) or init_fn is None:
if self.is_collection_empty(col):
raise errors.ScopeCollectionNotFound(col, name, self.path_text)
raise errors.ScopeVariableNotFoundError(name, col, self.path_text)
Expand Down
2 changes: 1 addition & 1 deletion flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@
from .stochastic import Dropout
from .transforms import (checkpoint, custom_vjp, jit, jvp, map_variables,
named_call, remat, remat_scan, scan, vjp, vmap,
while_loop)
while_loop, cond)

# pylint: enable=g-multiple-import
7 changes: 5 additions & 2 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,9 @@ def clone(self, *,
attrs.update(parent=parent, **updates)
return self.__class__(**attrs)

def variable(self, col: str, name: str, init_fn, *init_args) -> Variable:
def variable(self, col: str, name: str,
init_fn: Optional[Callable[..., Any]] = None,
*init_args) -> Variable:
"""Declares and returns a variable in this Module.
See :mod:`flax.core.variables` for more information. See also :meth:`param`
Expand All @@ -928,7 +930,8 @@ def variable(self, col: str, name: str, init_fn, *init_args) -> Variable:
name: The variable name.
init_fn: The function that will be called to compute the initial value
of this variable. This function will only be called the first time
this variable is used in this module.
this variable is used in this module. If None, the variable must
already be initialized otherwise an error is raised.
*init_args: The arguments to pass to init_fn.
Returns:
Expand Down
57 changes: 57 additions & 0 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,63 @@ def body_fn(mdl, c):
split_rngs)


def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs):
return lift.cond(pred, t_fn, f_fn, scope, *ops, variables=variables, rngs=rngs)


def cond(
pred: Any,
true_fun: Callable[..., C], false_fun: Callable[..., C],
mdl: Module, *operands,
variables: lift.CollectionFilter = True,
rngs: lift.PRNGSequenceFilter = True) -> C:
"""Lifted version of ``jax.lax.cond``.
The returned values from ``true_fun`` and ``false_fun``
must have the same Pytree structure, shapes, and dtypes.
The variables created or updated inside the
branches must also have the same structure.
Note that this constraint is violated when
creating variables or submodules in only one branch.
Because initializing variables in just one branch
causes the paramater structure to be different.
Example::
class CondExample(nn.Module):
@nn.compact
def __call__(self, x, pred):
self.variable('state', 'true_count', lambda: 0)
self.variable('state', 'false_count', lambda: 0)
def true_fn(mdl, x):
mdl.variable('state', 'true_count').value += 1
return nn.Dense(2, name='dense')(x)
def false_fn(mdl, x):
mdl.variable('state', 'false_count').value += 1
return -nn.Dense(2, name='dense')(x)
return nn.cond(pred, true_fn, false_fn, self, x)
Args:
pred: determines if true_fun or false_fun is evaluated.
true_fun: The function evalauted when ``pred`` is `True`.
The signature is (Scope, *operands) -> T.
false_fun: The function evalauted when ``pred`` is `False`.
The signature is (Scope, *operands) -> T.
scope: A Scope or Pytree of scopes to pass
*operands: The arguments passed to ``true_fun`` and ``false_fun``
variables: The variable collections passed to the conditional
branches (default: all)
rngs: The PRNG sequences passed to the conditionals (default: all)
Returns:
The result of the evaluated branch (``true_fun`` or ``false_fun``).
"""
return lift_direct_transform(
_cond_wrapper, (true_fun, false_fun), mdl,
pred, *operands,
variables=variables, rngs=rngs)


# a version of lift.custom_vjp with a single scope function
# this avoids having to lift multiple functions in
# lift_transform.
Expand Down
23 changes: 23 additions & 0 deletions tests/core/core_lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,29 @@ def body_fn(scope, c):
self.assertEqual(c, 2 * x)
np.testing.assert_array_equal(vars['state']['rng_params'][0], vars['state']['rng_params'][1])
np.testing.assert_array_compare(operator.__ne__, vars['state']['rng_loop'][0], vars['state']['rng_loop'][1])

def test_cond(self):
def f(scope, x, pred):
scope.variable('state', 'true_count', lambda: 0)
scope.variable('state', 'false_count', lambda: 0)
def true_fn(scope, x):
scope.variable('state', 'true_count').value += 1
return scope.child(nn.dense)(x, 2)

def false_fn(scope, x):
scope.variable('state', 'false_count').value += 1
return -scope.child(nn.dense)(x, 2)

return lift.cond(pred, true_fn, false_fn, scope, x)

x = jnp.ones((1, 3))
y1, vars = init(f)(random.PRNGKey(0), x, True)
self.assertEqual(vars['state'].unfreeze(), {'true_count': 1, 'false_count': 0})
y2, vars = apply(f, mutable="state")(vars, x, False)
self.assertEqual(vars['state'].unfreeze(), {'true_count': 1, 'false_count': 1})
np.testing.assert_allclose(y1, -y2)



if __name__ == '__main__':
absltest.main()
10 changes: 10 additions & 0 deletions tests/core/core_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,16 @@ def test_empty_col_error(self):
root = Scope({'state': {}})
with self.assertRaises(errors.ScopeCollectionNotFound):
root.variable('state', 'test', jnp.zeros, ())

def test_variable_no_init(self):
root = Scope({}, mutable='state')
with self.assertRaises(errors.ScopeCollectionNotFound):
root.variable('state', 'test')
root = Scope({'state': {'abc': 1}}, mutable='state')
abc = root.variable('state', 'abc')
self.assertEqual(abc.value, 1)
with self.assertRaises(errors.ScopeVariableNotFoundError):
root.variable('state', 'test')


if __name__ == '__main__':
Expand Down
24 changes: 24 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,30 @@ def body_fn(mdl, c):
np.testing.assert_array_equal(vars['state']['rng_params'][0], vars['state']['rng_params'][1])
np.testing.assert_array_compare(operator.__ne__, vars['state']['rng_loop'][0], vars['state']['rng_loop'][1])

def test_cond(self):
class Foo(nn.Module):
@nn.compact
def __call__(self, x, pred):
self.variable('state', 'true_count', lambda: 0)
self.variable('state', 'false_count', lambda: 0)
def true_fn(mdl, x):
mdl.variable('state', 'true_count').value += 1
return nn.Dense(2, name='dense')(x)

def false_fn(mdl, x):
mdl.variable('state', 'false_count').value += 1
return -nn.Dense(2, name='dense')(x)

return nn.cond(pred, true_fn, false_fn, self, x)

x = jnp.ones((1, 3))
foo = Foo()
y1, vars = foo.init_with_output(random.PRNGKey(0), x, True)
self.assertEqual(vars['state'].unfreeze(), {'true_count': 1, 'false_count': 0})
y2, vars = foo.apply(vars, x, False, mutable="state")
self.assertEqual(vars['state'].unfreeze(), {'true_count': 1, 'false_count': 1})
np.testing.assert_allclose(y1, -y2)

def test_lift_instance_error(self):
class Foo(nn.Module):
@nn.compact
Expand Down

0 comments on commit cd5c4d7

Please sign in to comment.