Replies: 1 comment
-
Hey @OhadRubin, lets say you have this MLP that defines some import flax.linen as nn
import jax.numpy as jnp
import jax
class MLPBlock(nn.Module):
features: int
def setup(self):
self.dense = nn.Dense(self.features)
def __call__(self, x):
return nn.relu(self.dense(x))
class MLP(nn.Module):
n_layers: int
features: int
def setup(self):
self.layers = [MLPBlock(self.features) for _ in range(self.n_layers)]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
x = jnp.ones((3, 5))
module = MLP(10, 5)
y, variables = module.init_with_output(jax.random.PRNGKey(0), x)
print("Regular MLP")
print(jax.tree_map(jnp.shape, variables))
print(y.shape)
print() You can refactor to use class MLPScan(nn.Module):
n_layers: int
features: int
def setup(self):
Layers = nn.remat_scan(
MLPBlock, variable_axes={'params': 0},
split_rngs={'params': True}, lengths=(self.n_layers,))
self.layers = Layers(self.features)
def __call__(self, x):
return self.layers(x)
x = jnp.ones((3, 5))
module = MLPScan(10, 5)
y, variables = module.init_with_output(jax.random.PRNGKey(0), x)
print("MLPScan")
print(jax.tree_map(jnp.shape, variables))
print(y.shape)
print() Now |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hey,
My code is using the self.setup to define parameters, and compilation is really slow, but it doesn't seem very clear on how to refactor it to use remat_scan since all the examples are using nn.compact.
I would also like to emphasize that I would like to keep the self.setup declaration so I would be able to run models i've already trained.
Beta Was this translation helpful? Give feedback.
All reactions