From 840a99da96839d792a7cb2de6af96513b1a83a56 Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Tue, 8 Dec 2020 03:26:19 -0800 Subject: [PATCH] Fix type errors in PixelCNN example PiperOrigin-RevId: 346288055 --- linen_examples/pixelcnn/pixelcnn.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/linen_examples/pixelcnn/pixelcnn.py b/linen_examples/pixelcnn/pixelcnn.py index b36b48a5..bc93da7f 100644 --- a/linen_examples/pixelcnn/pixelcnn.py +++ b/linen_examples/pixelcnn/pixelcnn.py @@ -39,7 +39,7 @@ # pytype: disable=wrong-arg-count from functools import partial -from typing import Any, Callable, Tuple +from typing import Any, Callable, Iterable, Tuple, Optional, Union import flax.linen as nn import jax @@ -186,8 +186,8 @@ class ConvWeightNorm(nn.Module): """2D convolution Modules with weightnorm.""" features: int kernel_size: Tuple[int, int] - strides: Tuple[int, int] = None - padding: str = 'VALID' + strides: Optional[Tuple[int, int]] = None + padding: Union[str, Iterable[Iterable[int]]] = 'VALID' transpose: bool = False init_scale: float = 1. dtype: Any = jnp.float32 @@ -232,7 +232,7 @@ class ConvDown(nn.Module): """Convolution with padding so that information cannot flow upwards.""" features: int kernel_size: Tuple[int, int] = (2, 3) - strides: Tuple[int, int] = None + strides: Optional[Tuple[int, int]] = None init_scale: float = 1. @nn.compact @@ -251,7 +251,7 @@ class ConvDownRight(nn.Module): """Convolution with padding so that information cannot flow left/upwards.""" features: Any kernel_size: Tuple[int, int] = (2, 2) - strides: Tuple[int, int] = None + strides: Optional[Tuple[int, int]] = None init_scale: float = 1.0 @nn.compact @@ -272,7 +272,7 @@ class ConvTransposeDown(nn.Module): """ features: Any kernel_size: Tuple[int, int] = (2, 3) - strides: Tuple[int, int] = (2, 2) + strides: Optional[Tuple[int, int]] = (2, 2) @nn.compact def __call__(self, inputs): @@ -289,7 +289,7 @@ class ConvTransposeDownRight(nn.Module): """ features: Any kernel_size: Tuple[int, int] = (2, 2) - strides: Tuple[int, int] = (2, 2) + strides: Optional[Tuple[int, int]] = (2, 2) @nn.compact def __call__(self, inputs):