diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 367fe7e092..f6ef81a9ad 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -86,6 +86,13 @@ from .nn.attention import dot_product_attention as dot_product_attention from .nn.attention import make_attention_mask as make_attention_mask from .nn.attention import make_causal_mask as make_causal_mask +from .nn.recurrent import RNNCellBase as RNNCellBase +from .nn.recurrent import LSTMCell as LSTMCell +from .nn.recurrent import GRUCell as GRUCell +from .nn.recurrent import OptimizedLSTMCell as OptimizedLSTMCell +from .nn.recurrent import SimpleCell as SimpleCell +from .nn.recurrent import RNN as RNN +from .nn.recurrent import Bidirectional as Bidirectional from .nn.linear import Conv as Conv from .nn.linear import ConvTranspose as ConvTranspose from .nn.linear import Embed as Embed diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py new file mode 100644 index 0000000000..e659144afb --- /dev/null +++ b/flax/nnx/nn/recurrent.py @@ -0,0 +1,923 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RNN modules for Flax.""" + +from typing import ( + Any, + TypeVar +) +from collections.abc import Callable +from functools import partial +from typing_extensions import Protocol +from absl import logging + +import jax +import jax.numpy as jnp + +from flax import nnx +from flax.nnx import rnglib +from flax.nnx.module import Module +from flax.nnx.nn import initializers +from flax.nnx.nn.linear import Linear +from flax.nnx.nn.activations import sigmoid +from flax.nnx.nn.activations import tanh +from flax.nnx.transforms.iteration import Carry +from flax.typing import ( + Dtype, + Initializer, + Shape +) + +default_kernel_init = initializers.lecun_normal() +default_bias_init = initializers.zeros_init() + +A = TypeVar("A") +Array = jax.Array +Output = Any + + +class RNNCellBase(Module): + """RNN cell base class.""" + + def initialize_carry( + self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None + ) -> Carry: + """Initialize the RNN cell carry. + + Args: + rng: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + + Returns: + An initialized carry for the given RNN cell. + """ + raise NotImplementedError + + def __call__( + self, + carry: Carry, + inputs: Array + ) -> tuple[Carry, Array]: + """Run the RNN cell. + + Args: + carry: the hidden state of the RNN cell. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + raise NotImplementedError + + @property + def num_feature_axes(self) -> int: + """Returns the number of feature axes of the RNN cell.""" + raise NotImplementedError + +def modified_orthogonal(key: Array, shape: Shape, dtype: Dtype = jnp.float32) -> Array: + """Modified orthogonal initializer for compatibility with half precision.""" + initializer = initializers.orthogonal() + return initializer(key, shape).astype(dtype) + +class LSTMCell(RNNCellBase): + r"""LSTM cell. + + The mathematical definition of the cell is as follows + + .. math:: + \begin{array}{ll} + i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ + f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ + g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ + o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ + c' = f * c + i * g \\ + h' = o * \tanh(c') \\ + \end{array} + + where x is the input, h is the output of the previous time step, and c is + the memory. + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + *, + gate_fn: Callable[..., Any] = sigmoid, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = default_kernel_init, + recurrent_kernel_init: Initializer = modified_orthogonal, + bias_init: Initializer = initializers.zeros_init(), + dtype: Dtype | None = None, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.gate_fn = gate_fn + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.rngs = rngs + + # input and recurrent layers are summed so only one needs a bias. + dense_i = partial( + Linear, + in_features=in_features, + out_features=hidden_features, + use_bias=False, + kernel_init=self.kernel_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + dense_h = partial( + Linear, + in_features=hidden_features, + out_features=hidden_features, + use_bias=True, + kernel_init=self.recurrent_kernel_init, + bias_init=self.bias_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + self.ii = dense_i() + self.if_ = dense_i() + self.ig = dense_i() + self.io = dense_i() + self.hi = dense_h() + self.hf = dense_h() + self.hg = dense_h() + self.ho = dense_h() + + def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] + r"""A long short-term memory (LSTM) cell. + + Args: + carry: the hidden state of the LSTM cell, + initialized using ``LSTMCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + c, h = carry + i = self.gate_fn(self.ii(inputs) + self.hi(h)) + f = self.gate_fn(self.if_(inputs) + self.hf(h)) + g = self.activation_fn(self.ig(inputs) + self.hg(h)) + o = self.gate_fn(self.io(inputs) + self.ho(h)) + new_c = f * c + i * g + new_h = o * self.activation_fn(new_c) + return (new_c, new_h), new_h + + def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> tuple[Array, Array]: # type: ignore[override] + """Initialize the RNN cell carry. + + Args: + rng: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + Returns: + An initialized carry for the given RNN cell. + """ + batch_dims = input_shape[:-1] + if rngs is None: + rngs = self.rngs + mem_shape = batch_dims + (self.hidden_features,) + c = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + return (c, h) + + @property + def num_feature_axes(self) -> int: + return 1 + + +class OptimizedLSTMCell(RNNCellBase): + r"""More efficient LSTM Cell that concatenates state components before matmul. + + The parameters are compatible with ``LSTMCell``. Note that this cell is often + faster than ``LSTMCell`` as long as the hidden size is roughly <= 2048 units. + + The mathematical definition of the cell is the same as ``LSTMCell`` and as + follows: + + .. math:: + + \begin{array}{ll} + i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ + f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ + g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ + o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ + c' = f * c + i * g \\ + h' = o * \tanh(c') \\ + \end{array} + + where x is the input, h is the output of the previous time step, and c is + the memory. + + Attributes: + gate_fn: activation function used for gates (default: sigmoid). + activation_fn: activation function used for output and memory update + (default: tanh). + kernel_init: initializer function for the kernels that transform + the input (default: lecun_normal). + recurrent_kernel_init: initializer function for the kernels that transform + the hidden state (default: initializers.orthogonal()). + bias_init: initializer for the bias parameters (default: initializers.zeros_init()). + dtype: the dtype of the computation (default: infer from inputs and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + *, + gate_fn: Callable[..., Any] = sigmoid, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = default_kernel_init, + recurrent_kernel_init: Initializer = initializers.orthogonal(), + bias_init: Initializer = initializers.zeros_init(), + dtype: Dtype | None = None, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.gate_fn = gate_fn + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.rngs = rngs + + # input and recurrent layers are summed so only one needs a bias. + self.dense_i = Linear( + in_features=in_features, + out_features=4 * hidden_features, + use_bias=False, + kernel_init=self.kernel_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + self.dense_h = Linear( + in_features=hidden_features, + out_features=4 * hidden_features, + use_bias=True, + kernel_init=self.recurrent_kernel_init, + bias_init=self.bias_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] + r"""An optimized long short-term memory (LSTM) cell. + + Args: + carry: the hidden state of the LSTM cell, initialized using + ``LSTMCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + c, h = carry + + # Compute combined transformations for inputs and hidden state + y = self.dense_i(inputs) + self.dense_h(h) + + # Split the combined transformations into individual gates + i, f, g, o = jnp.split(y, indices_or_sections=4, axis=-1) + + # Apply gate activations + i = self.gate_fn(i) + f = self.gate_fn(f) + g = self.activation_fn(g) + o = self.gate_fn(o) + + # Update cell state and hidden state + new_c = f * c + i * g + new_h = o * self.activation_fn(new_c) + return (new_c, new_h), new_h + + def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> tuple[Array, Array]: # type: ignore[override] + """Initialize the RNN cell carry. + + Args: + rngs: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + + Returns: + An initialized carry for the given RNN cell. + """ + batch_dims = input_shape[:-1] + if rngs is None: + rngs = self.rngs + mem_shape = batch_dims + (self.hidden_features,) + c = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + return (c, h) + + @property + def num_feature_axes(self) -> int: + return 1 + + +class SimpleCell(RNNCellBase): + r"""Simple cell. + + The mathematical definition of the cell is as follows + + .. math:: + + \begin{array}{ll} + h' = \tanh(W_i x + b_i + W_h h) + \end{array} + + where x is the input and h is the output of the previous time step. + + If `residual` is `True`, + + .. math:: + + \begin{array}{ll} + h' = \tanh(W_i x + b_i + W_h h + h) + \end{array} + """ + + def __init__( + self, + in_features: int, + hidden_features: int, # not inferred from carry for now + *, + dtype: Dtype = jnp.float32, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + residual: bool = False, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = initializers.lecun_normal(), + recurrent_kernel_init: Initializer = initializers.orthogonal(), + bias_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.residual = residual + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.rngs = rngs + + # self.hidden_features = carry.shape[-1] + # input and recurrent layers are summed so only one needs a bias. + self.dense_h = Linear( + in_features=self.hidden_features, + out_features=self.hidden_features, + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=self.recurrent_kernel_init, + rngs=rngs, + ) + self.dense_i = Linear( + in_features=self.in_features, + out_features=self.hidden_features, + use_bias=True, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + rngs=rngs, + ) + + def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] + new_carry = self.dense_i(inputs) + self.dense_h(carry) + if self.residual: + new_carry += carry + new_carry = self.activation_fn(new_carry) + return new_carry, new_carry + + def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> Array: # type: ignore[override] + """Initialize the RNN cell carry. + + Args: + rng: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + + Returns: + An initialized carry for the given RNN cell. + """ + if rngs is None: + rngs = self.rngs + batch_dims = input_shape[:-1] + mem_shape = batch_dims + (self.hidden_features,) + return self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + + @property + def num_feature_axes(self) -> int: + return 1 + + +class GRUCell(RNNCellBase): + r"""GRU cell. + + The mathematical definition of the cell is as follows + + .. math:: + + \begin{array}{ll} + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ + n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ + h' = (1 - z) * n + z * h \\ + \end{array} + + where x is the input and h is the output of the previous time step. + + Attributes: + in_features: number of input features. + hidden_features: number of output features. + gate_fn: activation function used for gates (default: sigmoid). + activation_fn: activation function used for output and memory update + (default: tanh). + kernel_init: initializer function for the kernels that transform + the input (default: lecun_normal). + recurrent_kernel_init: initializer function for the kernels that transform + the hidden state (default: initializers.orthogonal()). + bias_init: initializer for the bias parameters (default: initializers.zeros_init()). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: float32). + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + *, + gate_fn: Callable[..., Any] = sigmoid, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = default_kernel_init, + recurrent_kernel_init: Initializer = initializers.orthogonal(), + bias_init: Initializer = initializers.zeros_init(), + dtype: Dtype | None = None, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.gate_fn = gate_fn + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.rngs = rngs + + # Combine input transformations into a single linear layer + self.dense_i = Linear( + in_features=in_features, + out_features=3 * hidden_features, # r, z, n + use_bias=True, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + self.dense_h = Linear( + in_features=hidden_features, + out_features=3 * hidden_features, # r, z, n + use_bias=False, + kernel_init=self.recurrent_kernel_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] + """Gated recurrent unit (GRU) cell. + + Args: + carry: the hidden state of the GRU cell, + initialized using ``GRUCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + h = carry + + # Compute combined transformations for inputs and hidden state + x_transformed = self.dense_i(inputs) + h_transformed = self.dense_h(h) + + # Split the combined transformations into individual components + xi_r, xi_z, xi_n = jnp.split(x_transformed, 3, axis=-1) + hh_r, hh_z, hh_n = jnp.split(h_transformed, 3, axis=-1) + + # Compute gates + r = self.gate_fn(xi_r + hh_r) + z = self.gate_fn(xi_z + hh_z) + + # Compute n with an additional linear transformation on h + n = self.activation_fn(xi_n + r * hh_n) + + # Update hidden state + new_h = (1.0 - z) * n + z * h + return new_h, new_h + + def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> Array: # type: ignore[override] + """Initialize the RNN cell carry. + + Args: + rngs: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + + Returns: + An initialized carry for the given RNN cell. + """ + batch_dims = input_shape[:-1] + if rngs is None: + rngs = self.rngs + mem_shape = batch_dims + (self.hidden_features,) + h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + return h + + @property + def num_feature_axes(self) -> int: + return 1 + + +class RNN(Module): + """The ``RNN`` module takes any :class:`RNNCellBase` instance and applies it over a sequence + + using :func:`flax.linen.scan`. + """ + + def __init__( + self, + cell: RNNCellBase, + time_major: bool = False, + return_carry: bool = False, + reverse: bool = False, + keep_order: bool = False, + unroll: int = 1, + rngs: rnglib.Rngs | None = None, + ): + self.cell = cell + self.time_major = time_major + self.return_carry = return_carry + self.reverse = reverse + self.keep_order = keep_order + self.unroll = unroll + if rngs is None: + rngs = rnglib.Rngs(0) + self.rngs = rngs + + def __call__( + self, + inputs: Array, + *, + initial_carry: Carry | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, + keep_order: bool | None = None, + rngs: rnglib.Rngs | None = None, + ): + if return_carry is None: + return_carry = self.return_carry + if time_major is None: + time_major = self.time_major + if reverse is None: + reverse = self.reverse + if keep_order is None: + keep_order = self.keep_order + + # Infer the number of batch dimensions from the input shape. + # Cells like ConvLSTM have additional spatial dimensions. + time_axis = 0 if time_major else inputs.ndim - (self.cell.num_feature_axes + 1) + + # make time_axis positive + if time_axis < 0: + time_axis += inputs.ndim + + if time_major: + # we add +1 because we moved the time axis to the front + batch_dims = inputs.shape[1 : -self.cell.num_feature_axes] + else: + batch_dims = inputs.shape[:time_axis] + + # maybe reverse the sequence + if reverse: + inputs = jax.tree_util.tree_map( + lambda x: flip_sequences( + x, + seq_lengths, + num_batch_dims=len(batch_dims), + time_major=time_major, # type: ignore + ), + inputs, + ) + if rngs is None: + rngs = self.rngs + carry: Carry = ( + self.cell.initialize_carry( + inputs.shape[:time_axis] + inputs.shape[time_axis + 1 :], rngs + ) + if initial_carry is None + else initial_carry + ) + + slice_carry = seq_lengths is not None and return_carry + + def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]: + carry, y = cell(carry, x) + if slice_carry: + return carry, (carry, y) + return carry, y + state_axes = nnx.StateAxes({...: Carry}) # type: ignore[arg-type] + scan = nnx.scan( + scan_fn, + in_axes=(state_axes, Carry, time_axis), + out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis), + unroll=self.unroll, + ) + scan_output = scan(self.cell, carry, inputs) + + # Next we select the final carry. If a segmentation mask was provided and + # return_carry is True we slice the carry history and select the last valid + # carry for each sequence. Otherwise we just use the last carry. + if slice_carry: + assert seq_lengths is not None + _, (carries, outputs) = scan_output + # seq_lengths[None] expands the shape of the mask to match the + # number of dimensions of the carry. + carry = _select_last_carry(carries, seq_lengths) + else: + carry, outputs = scan_output + + if reverse and keep_order: + outputs = jax.tree_util.tree_map( + lambda x: flip_sequences( + x, + seq_lengths, + num_batch_dims=len(batch_dims), + time_major=time_major, # type: ignore + ), + outputs, + ) + + if return_carry: + return carry, outputs + else: + return outputs + + +def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A: + last_idx = seq_lengths - 1 + + def _slice_array(x: jnp.ndarray): + return x[last_idx, jnp.arange(x.shape[1])] + + return jax.tree_util.tree_map(_slice_array, sequence) + + +def _expand_dims_like(x, target): + """Expands the shape of `x` to match `target`'s shape by adding singleton dimensions.""" + return x.reshape(list(x.shape) + [1] * (target.ndim - x.ndim)) + + +def flip_sequences( + inputs: Array, + seq_lengths: Array | None, + num_batch_dims: int, + time_major: bool, +) -> Array: + """Flips a sequence of inputs along the time axis. + + This function can be used to prepare inputs for the reverse direction of a + bidirectional LSTM. It solves the issue that, when naively flipping multiple + padded sequences stored in a matrix, the first elements would be padding + values for those sequences that were padded. This function keeps the padding + at the end, while flipping the rest of the elements. + + Example: + ```python + inputs = [[1, 0, 0], + [2, 3, 0] + [4, 5, 6]] + lengths = [1, 2, 3] + flip_sequences(inputs, lengths) = [[1, 0, 0], + [3, 2, 0], + [6, 5, 4]] + ``` + + Args: + inputs: An array of input IDs [batch_size, seq_length]. + lengths: The length of each sequence [batch_size]. + + Returns: + An ndarray with the flipped inputs. + """ + # Compute the indices to put the inputs in flipped order as per above example. + time_axis = 0 if time_major else num_batch_dims + max_steps = inputs.shape[time_axis] + + if seq_lengths is None: + # reverse inputs and return + inputs = jnp.flip(inputs, axis=time_axis) + return inputs + + seq_lengths = jnp.expand_dims(seq_lengths, axis=time_axis) + + # create indexes + idxs = jnp.arange(max_steps - 1, -1, -1) # [max_steps] + if time_major: + idxs = jnp.reshape(idxs, [max_steps] + [1] * num_batch_dims) + else: + idxs = jnp.reshape( + idxs, [1] * num_batch_dims + [max_steps] + ) # [1, ..., max_steps] + idxs = (idxs + seq_lengths) % max_steps # [*batch, max_steps] + idxs = _expand_dims_like(idxs, target=inputs) # [*batch, max_steps, *features] + # Select the inputs in flipped order. + outputs = jnp.take_along_axis(inputs, idxs, axis=time_axis) + + return outputs + + +def _concatenate(a: Array, b: Array) -> Array: + """Concatenates two arrays along the last dimension.""" + return jnp.concatenate([a, b], axis=-1) + + +class RNNBase(Protocol): + def __call__( + self, + inputs: Array, + *, + initial_carry: Carry | None = None, + rngs: rnglib.Rngs | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, + keep_order: bool | None = None, + ) -> Output | tuple[Carry, Output]: ... + + +class Bidirectional(Module): + """Processes the input in both directions and merges the results. + + Example usage: + + ```python + import nnx + import jax + import jax.numpy as jnp + + # Define forward and backward RNNs + forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) + backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) + + # Create Bidirectional layer + layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn) + + # Input data + x = jnp.ones((2, 3, 3)) + + # Apply the layer + out = layer(x) + print(out.shape) + ``` + """ + + forward_rnn: RNNBase + backward_rnn: RNNBase + merge_fn: Callable[[Array, Array], Array] = _concatenate + time_major: bool = False + return_carry: bool = False + + def __init__( + self, + forward_rnn: RNNBase, + backward_rnn: RNNBase, + *, + merge_fn: Callable[[Array, Array], Array] = _concatenate, + time_major: bool = False, + return_carry: bool = False, + rngs: rnglib.Rngs | None = None, + ): + self.forward_rnn = forward_rnn + self.backward_rnn = backward_rnn + self.merge_fn = merge_fn + self.time_major = time_major + self.return_carry = return_carry + if rngs is None: + rngs = rnglib.Rngs(0) + self.rngs = rngs + + def __call__( + self, + inputs: Array, + *, + initial_carry: tuple[Carry, Carry] | None = None, + rngs: rnglib.Rngs | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, # unused + keep_order: bool | None = None, # unused + ) -> Output | tuple[tuple[Carry, Carry], Output]: + if time_major is None: + time_major = self.time_major + if return_carry is None: + return_carry = self.return_carry + if rngs is None: + rngs = self.rngs + if initial_carry is not None: + initial_carry_forward, initial_carry_backward = initial_carry + else: + initial_carry_forward = None + initial_carry_backward = None + # Throw a warning in case the user accidentally re-uses the forward RNN + # for the backward pass and does not intend for them to share parameters. + if self.forward_rnn is self.backward_rnn: + logging.warning( + "forward_rnn and backward_rnn is the same object, so " + "they will share parameters." + ) + + # Encode in the forward direction. + carry_forward, outputs_forward = self.forward_rnn( + inputs, + initial_carry=initial_carry_forward, + rngs=rngs, + seq_lengths=seq_lengths, + return_carry=True, + time_major=time_major, + reverse=False, + ) + + # Encode in the backward direction. + carry_backward, outputs_backward = self.backward_rnn( + inputs, + initial_carry=initial_carry_backward, + rngs=rngs, + seq_lengths=seq_lengths, + return_carry=True, + time_major=time_major, + reverse=True, + keep_order=True, + ) + + carry = (carry_forward, carry_backward) if return_carry else None + outputs = jax.tree_util.tree_map( + self.merge_fn, outputs_forward, outputs_backward + ) + + if return_carry: + return carry, outputs + else: + return outputs \ No newline at end of file diff --git a/tests/nnx/nn/recurrent_test.py b/tests/nnx/nn/recurrent_test.py new file mode 100644 index 0000000000..b724b69d7b --- /dev/null +++ b/tests/nnx/nn/recurrent_test.py @@ -0,0 +1,543 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax, jax.numpy as jnp +from jax import random + +from flax import linen +from flax import nnx +from flax.nnx.nn import initializers + +import numpy as np + +from absl.testing import absltest + +class TestLSTMCell(absltest.TestCase): + def test_basic(self): + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0), + ) + x = jnp.ones((2, 3)) + carry = module.initialize_carry(x.shape, module.rngs) + new_carry, y = module(carry, x) + self.assertEqual(y.shape, (2, 4)) + + def test_lstm_sequence(self): + """Test LSTMCell over a sequence of inputs.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0), + ) + x = random.normal(random.PRNGKey(1), (5, 2, 3)) # seq_len, batch, feature + carry = module.initialize_carry(x.shape[1:], module.rngs) + outputs = [] + for t in range(x.shape[0]): + carry, y = module(carry, x[t]) + outputs.append(y) + outputs = jnp.stack(outputs) + self.assertEqual(outputs.shape, (5, 2, 4)) + + def test_lstm_with_different_dtypes(self): + """Test LSTMCell with different data types.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + dtype=jnp.bfloat16, + param_dtype=jnp.bfloat16, + rngs=nnx.Rngs(0), + ) + x = jnp.ones((2, 3), dtype=jnp.bfloat16) + carry = module.initialize_carry(x.shape, module.rngs) + new_carry, y = module(carry, x) + self.assertEqual(y.dtype, jnp.bfloat16) + self.assertEqual(y.shape, (2, 4)) + + def test_lstm_with_custom_activations(self): + """Test LSTMCell with custom activation functions.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + gate_fn=jax.nn.relu, + activation_fn=jax.nn.elu, + rngs=nnx.Rngs(0), + ) + x = jnp.ones((1, 3)) + carry = module.initialize_carry(x.shape, module.rngs) + new_carry, y = module(carry, x) + self.assertEqual(y.shape, (1, 4)) + + def test_lstm_initialize_carry(self): + """Test the initialize_carry method.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + carry_init=initializers.ones, + rngs=nnx.Rngs(0), + ) + x_shape = (1, 3) + carry = module.initialize_carry(x_shape, module.rngs) + c, h = carry + self.assertTrue(jnp.all(c == 1.0)) + self.assertTrue(jnp.all(h == 1.0)) + self.assertEqual(c.shape, (1, 4)) + self.assertEqual(h.shape, (1, 4)) + + def test_lstm_with_variable_sequence_length(self): + """Test LSTMCell with variable sequence lengths.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0) + ) + + # Simulate a batch with variable sequence lengths + x = jnp.array([ + [[1, 2, 3], [4, 5, 6], [0, 0, 0]], # Sequence length 2 + [[7, 8, 9], [10, 11, 12], [13, 14, 15]], # Sequence length 3 + ]) # Shape: (batch_size=2, max_seq_length=3, features=3) + + seq_lengths = jnp.array([2, 3]) # Actual lengths for each sequence + batch_size = x.shape[0] + max_seq_length = x.shape[1] + carry = module.initialize_carry((batch_size, 3), module.rngs) + outputs = [] + for t in range(max_seq_length): + input_t = x[:, t, :] + carry, y = module(carry, input_t) + outputs.append(y) + outputs = jnp.stack(outputs, axis=1) # Shape: (batch_size, max_seq_length, hidden_features) + + # Zero out outputs beyond the actual sequence lengths + mask = (jnp.arange(max_seq_length)[None, :] < seq_lengths[:, None]) + outputs = outputs * mask[:, :, None] + self.assertEqual(outputs.shape, (2, 3, 4)) + + def test_lstm_stateful(self): + """Test that LSTMCell maintains state across calls.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0), + ) + x1 = jnp.ones((1, 3)) + x2 = jnp.ones((1, 3)) * 2 + carry = module.initialize_carry(x1.shape) + carry, y1 = module(carry, x1) + carry, y2 = module(carry, x2) + self.assertEqual(y1.shape, (1, 4)) + self.assertEqual(y2.shape, (1, 4)) + + def test_lstm_equivalence_with_flax_linen(self): + """Test that nnx.LSTMCell produces the same outputs as flax.linen.LSTMCell.""" + in_features = 3 + hidden_features = 4 + key = random.PRNGKey(42) + x = random.normal(key, (1, in_features)) + + # Initialize nnx.LSTMCell + rngs_nnx = nnx.Rngs(0) + module_nnx = nnx.LSTMCell( + in_features=in_features, + hidden_features=hidden_features, + rngs=rngs_nnx, + ) + carry_nnx = module_nnx.initialize_carry(x.shape, rngs_nnx) + # Initialize flax.linen.LSTMCell + module_linen = linen.LSTMCell( + features=hidden_features, + ) + carry_linen = module_linen.initialize_carry(random.PRNGKey(0), x.shape) + variables_linen = module_linen.init(random.PRNGKey(1), carry_linen, x) + + # Copy parameters from flax.linen.LSTMCell to nnx.LSTMCell + params_linen = variables_linen['params'] + # Map the parameters from linen to nnx + # Assuming the parameter names and shapes are compatible + # For a precise mapping, you might need to adjust parameter names + # Get the parameters from nnx module + nnx_params = module_nnx.__dict__ + + # Map parameters from linen to nnx + for gate in ['i', 'f', 'g', 'o']: + # Input kernels (input to gate) + if gate == 'f': + nnx_layer = getattr(module_nnx, f'if_') + else: + nnx_layer = getattr(module_nnx, f'i{gate}') + linen_params = params_linen[f'i{gate}'] + nnx_layer.kernel.value = linen_params['kernel'] + if nnx_layer.use_bias: + nnx_layer.bias.value = linen_params['bias'] + # Hidden kernels (hidden state to gate) + nnx_layer = getattr(module_nnx, f'h{gate}') + linen_params = params_linen[f'h{gate}'] + nnx_layer.kernel.value = linen_params['kernel'] + if nnx_layer.use_bias: + nnx_layer.bias.value = linen_params['bias'] + + # Run both modules + new_carry_nnx, y_nnx = module_nnx(carry_nnx, x) + new_carry_linen, y_linen = module_linen.apply(variables_linen, carry_linen, x) + + # Compare outputs + np.testing.assert_allclose(y_nnx, y_linen, atol=1e-5) + # Compare carries + for c_nnx, c_linen in zip(new_carry_nnx, new_carry_linen): + np.testing.assert_allclose(c_nnx, c_linen, atol=1e-5) + +class TestRNN(absltest.TestCase): + + def test_rnn_with_lstm_cell(self): + """Test RNN module using LSTMCell.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0), + ) + + # Initialize the RNN module with the LSTMCell + rnn = nnx.RNN(cell) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.ones((2, 5, 3)) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (2, 5, 4)) # Output features should match hidden_features + + def test_rnn_with_gru_cell(self): + """Test RNN module using GRUCell.""" + # Initialize the GRUCell + cell = nnx.GRUCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(1), + ) + + # Initialize the RNN module with the GRUCell + rnn = nnx.RNN(cell) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.ones((2, 5, 3)) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (2, 5, 4)) # Output features should match hidden_features + + def test_rnn_time_major(self): + """Test RNN module with time_major=True.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(2), + ) + + # Initialize the RNN module with time_major=True + rnn = nnx.RNN(cell, time_major=True) + + # Create input data (seq_length=5, batch_size=2, features=3) + x = jnp.ones((5, 2, 3)) + + # Initialize the carry + carry = cell.initialize_carry(x.shape[1:2] + x.shape[2:], cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (5, 2, 4)) # Output features should match hidden_features + + def test_rnn_reverse(self): + """Test RNN module with reverse=True.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(3), + ) + + # Initialize the RNN module with reverse=True + rnn = nnx.RNN(cell, reverse=True) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.tile(jnp.arange(5), (2, 1)).reshape(2, 5, 1) # Distinct values to check reversal + x = jnp.concatenate([x, x, x], axis=-1) # Shape: (2, 5, 3) + + # Run the RNN module + outputs = rnn(x) + + # Check if the outputs are in reverse order + outputs_reversed = outputs[:, ::-1, :] + # Since we used distinct input values, we can compare outputs to check reversal + # For simplicity, just check the shapes here + self.assertEqual(outputs.shape, (2, 5, 4)) + self.assertEqual(outputs_reversed.shape, (2, 5, 4)) + + def test_rnn_with_seq_lengths(self): + """Test RNN module with variable sequence lengths.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(4), + ) + + # Initialize the RNN module + rnn = nnx.RNN(cell, return_carry=True) + + # Create input data with padding (batch_size=2, seq_length=5, features=3) + x = jnp.array([ + [[1, 1, 1], [2, 2, 2], [3, 3, 3], [0, 0, 0], [0, 0, 0]], # Sequence length 3 + [[4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7], [8, 8, 8]], # Sequence length 5 + ]) # Shape: (2, 5, 3) + + seq_lengths = jnp.array([3, 5]) # Actual lengths for each sequence + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + final_carry, outputs = rnn(x, initial_carry=carry, seq_lengths=seq_lengths) + + self.assertEqual(outputs.shape, (2, 5, 4)) + + self.assertEqual(final_carry[0].shape, (2, 4)) # c: (batch_size, hidden_features) + self.assertEqual(final_carry[1].shape, (2, 4)) # h: (batch_size, hidden_features) + + # Todo: a better test by matching the outputs with the expected values + + def test_rnn_with_keep_order(self): + """Test RNN module with reverse=True and keep_order=True.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(5), + ) + + # Initialize the RNN module with reverse=True and keep_order=True + rnn = nnx.RNN(cell, reverse=True, keep_order=True) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.tile(jnp.arange(5), (2, 1)).reshape(2, 5, 1) # Distinct values to check reversal + x = jnp.concatenate([x, x, x], axis=-1) # Shape: (2, 5, 3) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + # Check if the outputs are in the original order despite processing in reverse + self.assertEqual(outputs.shape, (2, 5, 4)) + + def test_rnn_equivalence_with_flax_linen(self): + """Test that nnx.RNN produces the same outputs as flax.linen.RNN.""" + in_features = 3 + hidden_features = 4 + seq_length = 5 + batch_size = 2 + key = random.PRNGKey(42) + + # Create input data + x = random.normal(key, (batch_size, seq_length, in_features)) + + # Initialize nnx.LSTMCell and RNN + rngs_nnx = nnx.Rngs(0) + cell_nnx = nnx.LSTMCell( + in_features=in_features, + hidden_features=hidden_features, + rngs=rngs_nnx, + ) + rnn_nnx = nnx.RNN(cell_nnx) + + # Initialize flax.linen.LSTMCell and RNN + cell_linen = linen.LSTMCell(features=hidden_features) + rnn_linen = linen.RNN(cell_linen) + carry_linen = cell_linen.initialize_carry(random.PRNGKey(0), x[:, 0].shape) + variables_linen = rnn_linen.init(random.PRNGKey(1), x) + + # Copy parameters from flax.linen to nnx + params_linen = variables_linen['params']['cell'] + # Copy cell parameters + for gate in ['i', 'f', 'g', 'o']: + # Input kernels + if gate == 'f': + nnx_layer = getattr(cell_nnx, f'if_') + else: + nnx_layer = getattr(cell_nnx, f'i{gate}') + linen_params = params_linen[f'i{gate}'] + nnx_layer.kernel.value = linen_params['kernel'] + if nnx_layer.use_bias: + nnx_layer.bias.value = linen_params['bias'] + # Hidden kernels + nnx_layer = getattr(cell_nnx, f'h{gate}') + linen_params = params_linen[f'h{gate}'] + nnx_layer.kernel.value = linen_params['kernel'] + if nnx_layer.use_bias: + nnx_layer.bias.value = linen_params['bias'] + + # Initialize carries + carry_nnx = cell_nnx.initialize_carry((batch_size, in_features), rngs_nnx) + + # Run nnx.RNN + outputs_nnx = rnn_nnx(x, initial_carry=carry_nnx) + + # Run flax.linen.RNN + outputs_linen = rnn_linen.apply(variables_linen, x, initial_carry=carry_linen) + + # Compare outputs + np.testing.assert_allclose(outputs_nnx, outputs_linen, atol=1e-5) + + def test_rnn_with_unroll(self): + """Test RNN module with unroll parameter.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(6) + ) + + # Initialize the RNN module with unroll=2 + rnn = nnx.RNN(cell, unroll=2) + + # Create input data (batch_size=2, seq_length=6, features=3) + x = jnp.ones((2, 6, 3)) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (2, 6, 4)) # Output features should match hidden_features + + def test_rnn_with_custom_cell(self): + """Test RNN module with a custom RNN cell.""" + class CustomRNNCell(nnx.Module): + """A simple custom RNN cell.""" + + in_features: int + hidden_features: int + + def __init__(self, in_features, hidden_features, rngs): + self.in_features = in_features + self.hidden_features = hidden_features + self.rngs = rngs + self.dense = nnx.Linear( + in_features=in_features + hidden_features, + out_features=hidden_features, + rngs=rngs, + ) + + def __call__(self, carry, inputs): + h = carry + x = jnp.concatenate([inputs, h], axis=-1) + new_h = jax.nn.tanh(self.dense(x)) + return new_h, new_h + + def initialize_carry(self, input_shape, rngs): + batch_size = input_shape[0] + h = jnp.zeros((batch_size, self.hidden_features)) + return h + + @property + def num_feature_axes(self) -> int: + return 1 + + # Initialize the custom RNN cell + cell = CustomRNNCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(7) + ) + + # Initialize the RNN module + rnn = nnx.RNN(cell) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.ones((2, 5, 3)) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (2, 5, 4)) # Output features should match hidden_features + + def test_rnn_with_different_dtypes(self): + """Test RNN module with different data types.""" + # Initialize the LSTMCell with float16 + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + dtype=jnp.float16, + param_dtype=jnp.float16, + rngs=nnx.Rngs(8), + ) + + # Initialize the RNN module + rnn = nnx.RNN(cell) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.ones((2, 5, 3), dtype=jnp.float16) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.dtype, jnp.float16) + self.assertEqual(outputs.shape, (2, 5, 4)) + + def test_rnn_with_variable_batch_size(self): + """Test RNN module with variable batch sizes.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(9), + ) + + # Initialize the RNN module + rnn = nnx.RNN(cell) + + for batch_size in [1, 2, 5]: + # Create input data (batch_size, seq_length=5, features=3) + x = jnp.ones((batch_size, 5, 3)) + + # Initialize the carry + carry = cell.initialize_carry((batch_size, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (batch_size, 5, 4)) + +if __name__ == '__main__': + absltest.main()