How to use nn.DenseGeneral #3033
-
I have the following network
But I just don't understand how The images I want to pass in are |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hey @Peter-Vincent, to fix this just don't specify the import jax
import jax.numpy as jnp
import flax.linen as nn
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.DenseGeneral(features = x.shape[1:],axis=(1,2,3))(x)
x = nn.Conv(features=64, kernel_size=(4,4))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
x = nn.Conv(features=128, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
x = x.reshape((x.shape[0],-1))
x = nn.Dense(features = 2000)(x)
x = nn.relu(x)
x = nn.Dense(features = 502)(x)
x = nn.relu(x)
return jnp.squeeze(x)
x = jax.random.normal(jax.random.PRNGKey(0), (1, 51, 51, 1))
module = Foo()
print(module.tabulate(jax.random.PRNGKey(1), x))
variables = module.init(jax.random.PRNGKey(1), x)
x = jax.random.normal(jax.random.PRNGKey(0), (32, 51, 51, 1))
y = module.apply(variables, x)
print(jax.tree_map(lambda x: x.shape, variables))
print(y.shape) |
Beta Was this translation helpful? Give feedback.
Hey @Peter-Vincent, to fix this just don't specify the
batch_dims
argument. Here is an example: