Wrappers/Decorators for flax modules #2065
-
Hi there, I've been trying to implement a decorator for import functools
import flax.linen as nn
def input_separate(func):
"""
Decorator to have functions ignore dict keys other than 'maps'.
:param func: callable function to decorate
returns a function f that will only act on x['maps'] when called on x.
i.e f({'a': a, 'b': b, 'maps': maps},
*args, **kwargs) = {'a': a, 'b': b, 'maps': f(maps, *agrs, **kwargs)}
"""
@functools.wraps(func)
def non_hp_decorator(inputs, *args, **kwargs):
nside, indices, x = inputs['nside'], inputs['indices'], inputs['maps']
x = func(x, *args, **kwargs)
output = {'nside':nside, 'indices':indices, 'maps':x}
return output
return non_hp_decorator
nhprelu = input_separate(nn.relu) works as expected but I cannot quite figure out how to construct a similar decorator for modules directly. I naively tried def input_separate_mod(module):
"""
Decorator to have linen modules ignore inputs['nside'] and inputs['indices']
:param module: a linen module with a __call__() method.
returns a module that will only act on x['maps'] when called on x.
i.e module(*args, **kwargs)({'a': a, 'b': b, 'maps': maps})
= {'a': a, 'b': b, 'maps': module(*agrs, **kwargs)(maps)}
"""
@functools.wraps(module)
def non_hp_decorator(inputs, *args, **kwargs):
x = inputs['maps']
mod = module(*args, **kwargs)
variables = mod.init(random.PRNGKey(0), x) ###<--- I know this is not the way to go.
x = mod.apply(variables, x)
output = {'nside': inputs['nside'], 'indices': inputs['indices'], 'maps': x}
return output
return non_hp_decorator
class test(nn.Module):
@nn.compact
def __call__(self, inputs):
nonhpDense = input_separate_mod(nn.Dense)
nonhpBN = input_separate_mod(nn.BatchNorm)
x = nonhpDense(inputs, 5)
x = nonhpBN(x, use_running_average = True)
x = nonhpDense(x, 10)
return x
inputs = {'nside': 2, 'indices': 3, 'maps': 2*jnp.ones((2,4,8))}
mod = test()
variables = mod.init(random.PRNGKey(0), inputs)
print(variables)
###returns FrozenDict({}) but it fails to initialize and act properly. Is there some built in method/transformation in I would appreciate it a lot if anybody could point me in the right direction. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
The problem is that you are explicitly using def input_separate_mod(module):
"""
Decorator to have linen modules ignore inputs['nside'] and inputs['indices']
:param module: a linen module with a __call__() method.
returns a module that will only act on x['maps'] when called on x.
i.e module(*args, **kwargs)({'a': a, 'b': b, 'maps': maps})
= {'a': a, 'b': b, 'maps': module(*agrs, **kwargs)(maps)}
"""
@functools.wraps(module)
def non_hp_decorator(inputs, *args, **kwargs):
x = inputs['maps']
mod = module(*args, **kwargs)(x)
output = {'nside': inputs['nside'], 'indices': inputs['indices'], 'maps': x}
return output
return non_hp_decorator
class test(nn.Module):
@nn.compact
def __call__(self, inputs):
nonhpDense = input_separate_mod(nn.Dense)
nonhpBN = input_separate_mod(nn.BatchNorm)
x = nonhpDense(inputs, 5)
x = nonhpBN(x, use_running_average = True)
x = nonhpDense(x, 10)
return x
inputs = {'nside': 2, 'indices': 3, 'maps': 2*jnp.ones((2,4,8))}
mod = test()
variables = mod.init(random.PRNGKey(0), inputs)
print(variables) |
Beta Was this translation helpful? Give feedback.
The problem is that you are explicitly using
init
andapply
ininput_separate_mod
, while that is being called inside a compact module. So you can simply call the module as follows: