Skip to content

How to pass submodule (as pure function) to another JAX library, which might use jax function transform and lax control flow internally? #2256

Answered by jheek
YouJiacheng asked this question in Q&A
Discussion options

You must be logged in to vote

You can use flax.linen.scan for a scan within a Module

Replies: 2 comments 5 replies

Comment options

You must be logged in to vote
3 replies
@YouJiacheng
Comment options

@jheek
Comment options

jheek Jul 1, 2022
Maintainer

@YouJiacheng
Comment options

Answer selected by YouJiacheng
Comment options

You must be logged in to vote
2 replies
@cgarciae
Comment options

@YouJiacheng
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants