Skip to content

pass user-define module into a parent module with nn.scan #2266

Answered by luweizheng
luweizheng asked this question in Q&A
Discussion options

You must be logged in to vote

I just found a solution.

User can define their own fields. On framework side, use a sde_step_kwargs field to store keyword arguments.

class ScanStep(nn.Module):
    step_mdl: nn.Module
    sde_step_kwargs: Dict
    unroll: int = 1
    
    @nn.compact
    def __call__(self, x0, dW, dt, random_type):
        carry = (0, x0)

        sdes = nn.scan(self.step_mdl,
            variable_broadcast="params",
            split_rngs={"params": False},
            in_axes=(0),
            out_axes=(0),
            unroll=self.unroll)
        (carry, xs) = sdes(name="dynamic_sde", **self.sde_step_kwargs)(carry, dW)

sde_step_kwargs = {'layers': ...}
s_net = ScanStep(Step, sde_step_kwargs)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by luweizheng
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant