Skip to content

Commit

Permalink
fix type
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 10, 2024
1 parent b19772e commit 8d8073d
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions flax/nnx/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import (
Any,
TypeVar,
TypeVar
)
from collections.abc import Callable
from functools import partial
Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(
self.hg = dense_h()
self.ho = dense_h()

def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Array, Array], Array]:
def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Array, Array], Array]: # type: ignore[override]
r"""A long short-term memory (LSTM) cell.
Args:
Expand All @@ -191,9 +191,7 @@ def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Arr
new_h = o * self.activation_fn(new_c)
return (new_c, new_h), new_h

def initialize_carry(
self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None
) -> tuple[Array, Array]:
def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> tuple[Array, Array]: # type: ignore[override]
"""Initialize the RNN cell carry.
Args:
Expand Down Expand Up @@ -300,7 +298,7 @@ def __init__(
rngs=rngs,
)

def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Array, Array], Array]:
def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Array, Array], Array]: # type: ignore[override]
r"""An optimized long short-term memory (LSTM) cell.
Args:
Expand Down Expand Up @@ -331,9 +329,7 @@ def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Arr
new_h = o * self.activation_fn(new_c)
return (new_c, new_h), new_h

def initialize_carry(
self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None
) -> tuple[Array, Array]:
def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> tuple[Array, Array]: # type: ignore[override]
"""Initialize the RNN cell carry.
Args:
Expand Down Expand Up @@ -427,14 +423,14 @@ def __init__(
rngs=rngs,
)

def __call__(self, carry, inputs):
def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override]
new_carry = self.dense_i(inputs) + self.dense_h(carry)
if self.residual:
new_carry += carry
new_carry = self.activation_fn(new_carry)
return new_carry, new_carry

def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> Array:
def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> Array: # type: ignore[override]
"""Initialize the RNN cell carry.
Args:
Expand Down Expand Up @@ -535,7 +531,7 @@ def __init__(
rngs=rngs,
)

def __call__(self, carry, inputs):
def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override]
"""Gated recurrent unit (GRU) cell.
Args:
Expand Down Expand Up @@ -568,9 +564,7 @@ def __call__(self, carry, inputs):
new_h = (1.0 - z) * n + z * h
return new_h, new_h

def initialize_carry(
self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None
) -> Array:
def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> Array: # type: ignore[override]
"""Initialize the RNN cell carry.
Args:
Expand Down Expand Up @@ -676,12 +670,12 @@ def __call__(

slice_carry = seq_lengths is not None and return_carry

def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array]:
def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]:
carry, y = cell(carry, x)
if slice_carry:
return carry, (carry, y)
return carry, y
state_axes = nnx.StateAxes({...: Carry})
state_axes = nnx.StateAxes({...: Carry}) # type: ignore[arg-type]
scan = nnx.scan(
scan_fn,
in_axes=(state_axes, Carry, time_axis),
Expand Down Expand Up @@ -867,7 +861,7 @@ def __call__(
self,
inputs: Array,
*,
initial_carry: Carry | None = None,
initial_carry: tuple[Carry, Carry] | None = None,
rngs: rnglib.Rngs | None = None,
seq_lengths: Array | None = None,
return_carry: bool | None = None,
Expand All @@ -884,7 +878,8 @@ def __call__(
if initial_carry is not None:
initial_carry_forward, initial_carry_backward = initial_carry
else:
initial_carry_forward = initial_carry_backward = None
initial_carry_forward = None
initial_carry_backward = None
# Throw a warning in case the user accidentally re-uses the forward RNN
# for the backward pass and does not intend for them to share parameters.
if self.forward_rnn is self.backward_rnn:
Expand Down

0 comments on commit 8d8073d

Please sign in to comment.