-
My code is:
But ,get the errors:
|
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 3 replies
-
Update: for a more indepth explanation please take a look at our Batch normalization guide. Hey @fmscole, to use class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x, training: bool):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
x = jnp.ones((16, 28, 28, 1))
module = CNN()
# initialize weights
key = jax.random.PRNGKey(0)
variables = module.init(key, x, training=False)
# forward during training
y, updates = module.apply(variables, x, training=True, mutable=['batch_stats'])
variables = variables.copy(updates)
# foward during testing/inference
y = module.apply(variables, x, training=False) |
Beta Was this translation helpful? Give feedback.
-
When I use the
So I abandon the
Is this normal? |
Beta Was this translation helpful? Give feedback.
-
I already know how to use batchnormal:
|
Beta Was this translation helpful? Give feedback.
-
or like this:
|
Beta Was this translation helpful? Give feedback.
Update: for a more indepth explanation please take a look at our Batch normalization guide.
Hey @fmscole, to use
BatchNorm
you usually follow a pattern like this: