-
Hi,
|
Beta Was this translation helpful? Give feedback.
Answered by
andsteing
Feb 2, 2022
Replies: 1 comment 1 reply
-
Inside your class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
conv = nn.Conv(
features=1,
kernel_size=(1, 1),
strides=(1, 1),
padding="SAME",
name="need_kernel_here",
)
x = conv(x)
w = conv.variables['params']['kernel']
b = conv.variables['params']['bias']
my_calculation = (
w.mean() + b.mean()
) # just a simple example, in my real project the calculation is more complex but the calculation only need read the value of w and b, not change them
return x * my_calculation
foo = Foo()
x = jnp.zeros([1, 2, 2, 3])
vs = foo.init(jax.random.PRNGKey(0), x)
foo.apply(vs, x) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
wztdream
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Inside your
__call__()
function the module is in a "bound" state and the variables of submodules can simply be accessed viasub_module.variables
: