Error during initialization when using flax.linen.vmap on a module with multiple called methods #3055
-
System information
Problem you have encountered:Suppose SmallModule is a Linen module with multiple methods defined, including If LargeModule defines its The error says that certain parameters defined in SmallModule's What you expected to happen:The vmapped version of SmallModule should consistently behave as if it has the defined parameters. Logs, error messages, etc:
Steps to reproduce:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hey @lucaslingle, first thing you can do is use the self.smalls = nn.vmap(
SmallModule,
in_axes=0,
out_axes=-2,
variable_axes={"params": 0},
split_rngs={"params": True},
methods=["__call__", "call_transpose"],
)(input_dim=self.input_dim, output_dim=self.hidden_dim // self.num_small) This fixes the original error but now you have some logical errors. Hope this helps! |
Beta Was this translation helpful? Give feedback.
Hey @lucaslingle, first thing you can do is use the
methods
argument fornn.vmap
to specify you also wantcall_transpose
to be lifted (by default only__call__
is lifted):This fixes the original error but now you have some logical errors. Hope this helps!