Skip to content

How to reduce compiling cost of a deep NN? #2003

Answered by jheek
YouJiacheng asked this question in Q&A
Discussion options

You must be logged in to vote

You can create a loop over similar layers by using flax.linen.scan(Body, length=num_layers, variable_axes={"params": 0}, split_rngs={"params": True}). Usually Body has a few different types of layers for example Body = MLP() -> Attention() which are then repeated a number of times.

If you would have a non-regular sequence of layers you would indeed need something like a switch. But we don't have a linen version of that yet that can switch between layers

Replies: 1 comment 7 replies

Comment options

You must be logged in to vote
7 replies
@YouJiacheng
Comment options

@YouJiacheng
Comment options

@YouJiacheng
Comment options

@jheek
Comment options

jheek Apr 4, 2022
Maintainer

@jheek
Comment options

jheek Apr 4, 2022
Maintainer

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants