Replies: 1 comment 2 replies
-
Hey @minqi, you can set the name for any Module as it would appear in the variables structure using the import jax
import jax.numpy as jnp
import flax.linen as nn
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(10, name='bar')(x)
foo = Foo()
variables = foo.init(jax.random.PRNGKey(0), jnp.ones((1, 10)))
print(jax.tree_map(jnp.shape, variables))
|
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
When using
@nn.compact
with__call__
, is it possible to have the initialized parameters be nested under a provided name key inside theself.variables
dictionary?I noticed that when initializing a module that is bound to
self
, the resulting parameters are nested underself.variables['params'][<variable name of module>]
, but this is not the case for a module that is first defined inline inside__call__
.Beta Was this translation helpful? Give feedback.
All reactions