Skip to content

Commit

Permalink
Fix type errors in PixelCNN example
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 346288055
  • Loading branch information
jheek authored and Flax Authors committed Dec 8, 2020
1 parent afe0df1 commit 840a99d
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions linen_examples/pixelcnn/pixelcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# pytype: disable=wrong-arg-count

from functools import partial
from typing import Any, Callable, Tuple
from typing import Any, Callable, Iterable, Tuple, Optional, Union

import flax.linen as nn
import jax
Expand Down Expand Up @@ -186,8 +186,8 @@ class ConvWeightNorm(nn.Module):
"""2D convolution Modules with weightnorm."""
features: int
kernel_size: Tuple[int, int]
strides: Tuple[int, int] = None
padding: str = 'VALID'
strides: Optional[Tuple[int, int]] = None
padding: Union[str, Iterable[Iterable[int]]] = 'VALID'
transpose: bool = False
init_scale: float = 1.
dtype: Any = jnp.float32
Expand Down Expand Up @@ -232,7 +232,7 @@ class ConvDown(nn.Module):
"""Convolution with padding so that information cannot flow upwards."""
features: int
kernel_size: Tuple[int, int] = (2, 3)
strides: Tuple[int, int] = None
strides: Optional[Tuple[int, int]] = None
init_scale: float = 1.

@nn.compact
Expand All @@ -251,7 +251,7 @@ class ConvDownRight(nn.Module):
"""Convolution with padding so that information cannot flow left/upwards."""
features: Any
kernel_size: Tuple[int, int] = (2, 2)
strides: Tuple[int, int] = None
strides: Optional[Tuple[int, int]] = None
init_scale: float = 1.0

@nn.compact
Expand All @@ -272,7 +272,7 @@ class ConvTransposeDown(nn.Module):
"""
features: Any
kernel_size: Tuple[int, int] = (2, 3)
strides: Tuple[int, int] = (2, 2)
strides: Optional[Tuple[int, int]] = (2, 2)

@nn.compact
def __call__(self, inputs):
Expand All @@ -289,7 +289,7 @@ class ConvTransposeDownRight(nn.Module):
"""
features: Any
kernel_size: Tuple[int, int] = (2, 2)
strides: Tuple[int, int] = (2, 2)
strides: Optional[Tuple[int, int]] = (2, 2)

@nn.compact
def __call__(self, inputs):
Expand Down

0 comments on commit 840a99d

Please sign in to comment.