Skip to content

Error during initialization when using flax.linen.vmap on a module with multiple called methods #3055

Answered by cgarciae
lucaslingle asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @lucaslingle, first thing you can do is use the methods argument for nn.vmap to specify you also want call_transpose to be lifted (by default only __call__ is lifted):

self.smalls = nn.vmap(
    SmallModule,
    in_axes=0,
    out_axes=-2,
    variable_axes={"params": 0},
    split_rngs={"params": True},
    methods=["__call__", "call_transpose"],
)(input_dim=self.input_dim, output_dim=self.hidden_dim // self.num_small)

This fixes the original error but now you have some logical errors. Hope this helps!

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@lucaslingle
Comment options

Answer selected by lucaslingle
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #3049 on April 25, 2023 13:06.