diff --git a/flax/linen/combinators.py b/flax/linen/combinators.py index b97a8560fc..ca34d90e3d 100644 --- a/flax/linen/combinators.py +++ b/flax/linen/combinators.py @@ -38,9 +38,10 @@ class Foo(nn.Module): @nn.compact def __call__(self, x): - return nn.Sequential([nn.Dense(layer_size, name=f'layers_{idx}') - for idx, layer_size - in enumerate(self.feature_sizes)])(x) + return nn.Sequential([nn.Dense(4), + nn.relu, + nn.Dense(2), + nn.log_softmax])(x) """ layers: Sequence[Callable[..., Any]]