You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I am trying to reduce compile time for a custom model.
This model requires iterating 100-1000 times over all layers, which with Python for loops leads to very slow compilation (up to tens of minutes).
My first attempt at a solution uses jax.lax.fori_loop, which reduces compile time to a few seconds with standard Dense layers but fails with the following error when I use custom layers (inhereting from linen.nn.Module):
"Exception has occurred: JaxTransformError
Jax transforms and Flax models cannot be mixed."
I reduced my problem to the following minimal form (In practice infer_hidden would be more complicated, calling custom modules methods), but it is unclear to me how/if this problem be rephrased in a way that is compatible with flax.linen.scan. Any help is greatly appreciated.
import jax
from jax import random
state_init_keys = [random.key(42), random.key(54), random.key(81), random.key(11)]
W_init_keys = [random.key(13), random.key(92), random.key(91)]
batchsize = 64
N = [20, 10, 5, 2] # input dim, hidden1 dim, hidden2 dim, output dim
#Some random initial neuron states and weights
s = [random.normal(key_i, (N_i, batchsize)) for (key_i, N_i) in zip(state_init_keys, N)]
W = [0.1*random.normal(W_init_keys[i], (N[i+1], N[i])) for i in range(0, len(N)-1)]
def update_hidden(x, y, W1, W2):
si = W1 @ x + W2.T @ y
return si
def update_output(sL, WL):
s_out = WL @ sL
return s_out
@jax.jit
def infer_states(t, inputs):
s, W = inputs
for i in range(1, len(s)):
if i == len(N)-1:
s[i] = update_output(s[i-1], W[i-1])
else:
s[i] = update_hidden(s[i-1], s[i+1], W[i-1], W[i])
return s, W
inputs = (s,W)
T = 100 # Time steps
s, W = jax.lax.fori_loop(0, T, infer_states, inputs)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi,
I am trying to reduce compile time for a custom model.
This model requires iterating 100-1000 times over all layers, which with Python for loops leads to very slow compilation (up to tens of minutes).
My first attempt at a solution uses
jax.lax.fori_loop
, which reduces compile time to a few seconds with standard Dense layers but fails with the following error when I use custom layers (inhereting fromlinen.nn.Module
):I read that the solution is to use "lifted transformations": https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#. However, there is no lifted variant of fori_loop, so I think I need to instead use
jax.lax.scan
. Is this correct?I reduced my problem to the following minimal form (In practice
infer_hidden
would be more complicated, calling custom modules methods), but it is unclear to me how/if this problem be rephrased in a way that is compatible withflax.linen.scan
. Any help is greatly appreciated.Beta Was this translation helpful? Give feedback.
All reactions