Skip to content

How to use nn.BatchNorm? #2282

Answered by cgarciae
fmscole asked this question in Q&A
Jul 7, 2022 · 4 comments · 3 replies
Discussion options

You must be logged in to vote

Update: for a more indepth explanation please take a look at our Batch normalization guide.


Hey @fmscole, to use BatchNorm you usually follow a pattern like this:

class CNN(nn.Module):
  """A simple CNN model."""
  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x 

Replies: 4 comments 3 replies

Comment options

You must be logged in to vote
3 replies
@fmscole
Comment options

@zaccharieramzi
Comment options

@zaccharieramzi
Comment options

Answer selected by marcvanzee
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants
Converted from issue

This discussion was converted from issue #2279 on July 08, 2022 14:23.