Skip to content

Wrappers/Decorators for flax modules #2065

Answered by marcvanzee
ozencgungor asked this question in Q&A
Discussion options

You must be logged in to vote

The problem is that you are explicitly using init and apply in input_separate_mod, while that is being called inside a compact module. So you can simply call the module as follows:

def input_separate_mod(module):
    """
    Decorator to have linen modules ignore inputs['nside'] and inputs['indices']
    :param module: a linen module with a __call__() method.
    returns a module that will only act on x['maps'] when called on x.
    i.e module(*args, **kwargs)({'a': a, 'b': b, 'maps': maps})
           = {'a': a, 'b': b, 'maps': module(*agrs, **kwargs)(maps)} 
    """
    @functools.wraps(module)
    def non_hp_decorator(inputs, *args, **kwargs):
        x = inputs['maps']
        mod = m…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@ozencgungor
Comment options

@ozencgungor
Comment options

Answer selected by ozencgungor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants