Skip to content

Commit

Permalink
Merge pull request #491 from mbz:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 334645160
  • Loading branch information
Flax Authors committed Sep 30, 2020
2 parents 9015cc2 + 315fb99 commit e58dea2
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 2 deletions.
2 changes: 1 addition & 1 deletion flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .module import Module, compact, enable_named_call, disable_named_call
from .normalization import BatchNorm, GroupNorm, LayerNorm
from .pooling import avg_pool, max_pool
from .recurrent import GRUCell, LSTMCell
from .recurrent import GRUCell, LSTMCell, ConvLSTM
from .stochastic import Dropout
from .transforms import jit, named_call, remat, scan, vmap

Expand Down
105 changes: 104 additions & 1 deletion flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@

import abc
from functools import partial
from typing import (Any, Callable, Tuple)
from typing import (Any, Callable, Sequence, Optional, Tuple, Union)

from .module import Module, compact
from . import activation
from . import initializers
from . import linear

from jax import numpy as jnp
from jax import random


Expand Down Expand Up @@ -220,3 +221,105 @@ def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros):
"""
mem_shape = batch_dims + (size,)
return init_fn(rng, mem_shape)


class ConvLSTM(RNNCellBase):
r"""A convolutional LSTM cell.
The implementation is based on xingjian2015convolutional.
Given x_t and the previous state (h_{t-1}, c_{t-1})
the core computes
.. math::
\begin{array}{ll}
i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\
f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\
g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\
o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\
c_t = f_t c_{t-1} + i_t g_t \\
h_t = o_t \tanh(c_t)
\end{array}
where * denotes the convolution operator;
i_t, f_t, o_t are input, forget and output gate activations,
and g_t is a vector of cell updates.
Notes:
Forget gate initialization:
Following jozefowicz2015empirical we add 1.0 to b_f
after initialization in order to reduce the scale of forgetting in
the beginning of the training.
Args:
features: number of convolution filters.
kernel_size: shape of the convolutional kernel.
strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
of `n` `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
"""

features: int
kernel_size: Sequence[int]
strides: Optional[Sequence[int]] = None
padding: Union[str, Sequence[Tuple[int, int]]] = 'SAME'
use_bias: bool = True
dtype: Dtype = jnp.float32

@compact
def __call__(self, carry, inputs):
"""Constructs a convolutional LSTM.
Args:
carry: the hidden state of the Conv2DLSTM cell,
initialized using `Conv2DLSTM.initialize_carry`.
inputs: input data with dimensions (batch, spatial_dims..., features).
Returns:
A tuple with the new carry and the output.
"""
c, h = carry
input_to_hidden = partial(linear.Conv,
features=4*self.features,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
use_bias=self.use_bias,
dtype=self.dtype,
name='ih')

hidden_to_hidden = partial(linear.Conv,
features=4*self.features,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
use_bias=self.use_bias,
dtype=self.dtype,
name='hh')

gates = input_to_hidden()(inputs) + hidden_to_hidden()(h)
i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1)

f = activation.sigmoid(f + 1)
new_c = f * c + activation.sigmoid(i) * jnp.tanh(g)
new_h = activation.sigmoid(o) * jnp.tanh(new_c)
return (new_c, new_h), new_h

@staticmethod
def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros):
"""initialize the RNN cell carry.
Args:
rng: random number generator passed to the init_fn.
batch_dims: a tuple providing the shape of the batch dimensions.
size: the input_shape + (features,).
init_fn: initializer function for the carry.
Returns:
An initialized carry for the given RNN cell.
"""
key1, key2 = random.split(rng)
mem_shape = batch_dims + size
return init_fn(key1, mem_shape), init_fn(key2, mem_shape)
18 changes: 18 additions & 0 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,21 @@ def test_gru(self):
'hz': {'kernel': (4, 4)},
'hn': {'kernel': (4, 4), 'bias': (4,)},
})

def test_convlstm(self):
rng = random.PRNGKey(0)
key1, key2 = random.split(rng)
x = random.normal(key1, (2, 4, 4, 3))
c0, h0 = nn.ConvLSTM.initialize_carry(rng, (2,), (4, 4, 6))
self.assertEqual(c0.shape, (2, 4, 4, 6))
self.assertEqual(h0.shape, (2, 4, 4, 6))
lstm = nn.ConvLSTM(features=6, kernel_size=(3, 3))
(carry, y), initial_params = lstm.init_with_output(key2, (c0, h0), x)
self.assertEqual(carry[0].shape, (2, 4, 4, 6))
self.assertEqual(carry[1].shape, (2, 4, 4, 6))
np.testing.assert_allclose(y, carry[1])
param_shapes = jax.tree_map(np.shape, initial_params['params'])
self.assertEqual(param_shapes, {
'hh': {'bias': (6*4,), 'kernel': (3, 3, 6, 6*4)},
'ih': {'bias': (6*4,), 'kernel': (3, 3, 3, 6*4)},
})

0 comments on commit e58dea2

Please sign in to comment.