diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index b5ddcd87..1ec2bd80 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -50,7 +50,11 @@ Carry = Any CarryHistory = Any Output = Any -NEVER = object() + +class _Never: + pass + +NEVER = _Never() LEGACY_UPDATE_MESSAGE = ( "The RNNCellBase API has changed, " @@ -649,8 +653,6 @@ class RNN(Module): Attributes: cell: an instance of :class:`RNNCellBase`. - cell_size: the size of the cell as requested by :meth:`RNNCellBase.initialize_carry`, - it can be an integer or a tuple of integers. time_major: if ``time_major=False`` (default) it will expect inputs with shape ``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``. return_carry: if ``return_carry=False`` (default) only the output sequence is returned,