Skip to content

PyTorch nn.Conv2D to Flax linen.Conv #1680

Answered by marcvanzee
kamalkraj asked this question in Q&A
Discussion options

You must be logged in to vote

Have your checked the documentation? I think it is actually quite clear:

  • From the Pytorch documentation: "padding controls the amount of padding applied to the input. It can be either a string {‘valid’, ‘same’} or a tuple of ints giving the amount of implicit padding applied on both sides."
  • From the Flax documentation: "either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’` (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension."

So the torch Conv2d you provided above can be translated to Flax as follows:

 nn.Conv(features=32, kernel_size=(41, 11), strides=(2, 2), padding=((20, 20), (

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@andsteing
Comment options

Answer selected by kamalkraj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants