pass user-define module into a parent module with nn.scan #2266
Answered
by
luweizheng
luweizheng
asked this question in
Q&A
-
Hi, I am working on a small framework on flax. Users of this framework can define the class Step(nn.Module):
layers: Sequence
@nn.compact
def __call__(self, x):
for i in range(layers):
x = nn.Dense(layers[i])(x)
... Maybe another user would define the class Step(nn.Module):
@nn.compact
def __call__(self, x):
for i in range(layers):
x = nn.Dense(20)(x)
x = nn.Dense(30)(x)
... Here is my framework's code, it would wrap the above user-defined
What is the best way to pass the initialization parameters of Maybe I should use a function as the first argument of |
Beta Was this translation helpful? Give feedback.
Answered by
luweizheng
Jul 5, 2022
Replies: 1 comment
-
I just found a solution. User can define their own fields. On framework side, use a class ScanStep(nn.Module):
step_mdl: nn.Module
sde_step_kwargs: Dict
unroll: int = 1
@nn.compact
def __call__(self, x0, dW, dt, random_type):
carry = (0, x0)
sdes = nn.scan(self.step_mdl,
variable_broadcast="params",
split_rngs={"params": False},
in_axes=(0),
out_axes=(0),
unroll=self.unroll)
(carry, xs) = sdes(name="dynamic_sde", **self.sde_step_kwargs)(carry, dW)
sde_step_kwargs = {'layers': ...}
s_net = ScanStep(Step, sde_step_kwargs) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
luweizheng
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I just found a solution.
User can define their own fields. On framework side, use a
sde_step_kwargs
field to store keyword arguments.