How to pass submodule (as pure function) to another JAX library, which might use jax function transform and lax control flow internally? #2256
Answered
by
jheek
YouJiacheng
asked this question in
Q&A
-
Simple example: def foo(fun, ...):
return jax.lax.scan(fun, ...)
class Model(nn.Module)
@nn.compact
def __call__(self, ...):
fun = submodule()
return foo(fun, ...) SolutionAssume submodule only have params variable class Model(nn.Module)
@nn.compact
def __call__(self, x):
m = submodule()
m_params = self.param('m', m.init, x)
return foo(lambda x: m.apply(m_params, x), ...) |
Beta Was this translation helpful? Give feedback.
Answered by
jheek
Jul 1, 2022
Replies: 2 comments 5 replies
-
You can use |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
YouJiacheng
-
@YouJiacheng The pattern is: class Model(nn.Module)
@nn.compact
def __call__(self, x):
module = submodule()
def scan_fn(module, carry, x):
return module(carry, x)
scan = nn.scan(scan_fn, ....)
carry = get_carry()
carry, output = scan(module, carry, x)
return output You can simplify it to: class Model(nn.Module)
@nn.compact
def __call__(self, x):
module = submodule()
scan = nn.scan(type(module).__call__, ....)
carry = get_carry()
carry, output = scan(module, carry, x)
return output |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can use
flax.linen.scan
for a scan within a Module