Skip to content

Commit

Permalink
Merge pull request #4272 from zinccat:rnn_nnx
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686670119
  • Loading branch information
Flax Authors committed Oct 16, 2024
2 parents ae5c662 + 8c86f8e commit 3bf732c
Show file tree
Hide file tree
Showing 3 changed files with 1,473 additions and 0 deletions.
7 changes: 7 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@
from .nn.attention import dot_product_attention as dot_product_attention
from .nn.attention import make_attention_mask as make_attention_mask
from .nn.attention import make_causal_mask as make_causal_mask
from .nn.recurrent import RNNCellBase as RNNCellBase
from .nn.recurrent import LSTMCell as LSTMCell
from .nn.recurrent import GRUCell as GRUCell
from .nn.recurrent import OptimizedLSTMCell as OptimizedLSTMCell
from .nn.recurrent import SimpleCell as SimpleCell
from .nn.recurrent import RNN as RNN
from .nn.recurrent import Bidirectional as Bidirectional
from .nn.linear import Conv as Conv
from .nn.linear import ConvTranspose as ConvTranspose
from .nn.linear import Embed as Embed
Expand Down
Loading

0 comments on commit 3bf732c

Please sign in to comment.