Skip to content

How to use nn.DenseGeneral #3033

Answered by cgarciae
Peter-Vincent asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @Peter-Vincent, to fix this just don't specify the batch_dims argument. Here is an example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.DenseGeneral(features = x.shape[1:],axis=(1,2,3))(x)
        x = nn.Conv(features=64, kernel_size=(4,4))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = nn.Conv(features=128, kernel_size=(3,3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = x.reshape((x.shape[0],-1))
        x = nn.Dense(features = 2000)(x)
        x = nn.relu(x)
        x = nn.Dense(fea…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Peter-Vincent
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants