PyTorch nn.Conv2D to Flax linen.Conv #1680
-
Hi, I am trying to move a model from PyTorch to Flax. Is this correct translation for the 2D convolution layer from PyTorch to Flax PyTorch import torch.nn as nn
conv = nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)) Flax from flax import linen as nn
conv = nn.Conv(features=32, kernel_size=(41, 11), strides=(2, 2), padding=((0, 20), (0, 5))) I am confused about the padding parameter? |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Nov 24, 2021
Replies: 1 comment 1 reply
-
Have your checked the documentation? I think it is actually quite clear:
So the nn.Conv(features=32, kernel_size=(41, 11), strides=(2, 2), padding=((20, 20), (5, 5))) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
kamalkraj
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Have your checked the documentation? I think it is actually quite clear:
padding
controls the amount of padding applied to the input. It can be either a string {‘valid’, ‘same’} or a tuple of ints giving the amount of implicit padding applied on both sides."So the
torch
Conv2d you provided above can be translated to Flax as follows: