Skip to content

How to get access of the parameters within a model? #1846

Answered by andsteing
wztdream asked this question in Q&A
Discussion options

You must be logged in to vote

Inside your __call__() function the module is in a "bound" state and the variables of submodules can simply be accessed via sub_module.variables:

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
        conv = nn.Conv(
            features=1,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="SAME",
            name="need_kernel_here",
        )
        x = conv(x)
        w = conv.variables['params']['kernel']
        b = conv.variables['params']['bias']
        my_calculation = (
            w.mean() + b.mean()
        )  # just a simple example, in my re…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@wztdream
Comment options

Answer selected by wztdream
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants