How to reduce compiling cost of a deep NN? #2003
-
It seems that only python for-loop can be used to scan a sequence of flax modules. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
You can create a loop over similar layers by using If you would have a non-regular sequence of layers you would indeed need something like a switch. But we don't have a linen version of that yet that can switch between layers |
Beta Was this translation helpful? Give feedback.
You can create a loop over similar layers by using
flax.linen.scan(Body, length=num_layers, variable_axes={"params": 0}, split_rngs={"params": True})
. Usually Body has a few different types of layers for example Body = MLP() -> Attention() which are then repeated a number of times.If you would have a non-regular sequence of layers you would indeed need something like a switch. But we don't have a linen version of that yet that can switch between layers