From 12a1159c2be73443c86c1b76636025af13380de2 Mon Sep 17 00:00:00 2001 From: Marc van Zee Date: Thu, 27 Jan 2022 11:40:37 +0100 Subject: [PATCH 1/3] Removes deprecated API from docs --- docs/design_notes/arguments.md | 8 +- docs/flax.nn.rst | 115 --------------- docs/index.rst | 6 - flax/core/flax_functional_engine.ipynb | 191 ++++++++++++++----------- flax/linen/README.md | 4 +- 5 files changed, 113 insertions(+), 211 deletions(-) delete mode 100644 docs/flax.nn.rst diff --git a/docs/design_notes/arguments.md b/docs/design_notes/arguments.md index 25eb8f5c4e..40212e9aca 100644 --- a/docs/design_notes/arguments.md +++ b/docs/design_notes/arguments.md @@ -85,9 +85,9 @@ It also avoids a default value which would probably cause either the train step -## Functional Core and flax.nn +## Functional Core -The old NN api and functional core define functions rather than classes. -Therefore, there is no clear distinction between hyper parameters and call time arguments. -The only way to pre-determine the hyper parameters is by using `partial`. +Functional core defines functions rather than classes. +Therefore, there is no clear distinction between hyperparameters and call-time arguments. +The only way to pre-determine the hyperparameters is by using `partial`. On the upside, there are no ambiguous cases where method arguments could also be attributes. diff --git a/docs/flax.nn.rst b/docs/flax.nn.rst deleted file mode 100644 index d853712d96..0000000000 --- a/docs/flax.nn.rst +++ /dev/null @@ -1,115 +0,0 @@ - -.. warning:: - **This package is deprecated**. See :mod:`flax.linen` for our new module API. - -flax.nn package (deprecated) -================= - -.. currentmodule:: flax.nn - - -Core: Module abstraction ------------------------- - -.. autoclass:: Module - :members: init, init_by_shape, partial, shared, apply, param, get_param, state, is_stateful, is_initializing - -Core: Additional ------------------------- - -.. autosummary:: - :toctree: _autosummary - - module - Model - Collection - capture_module_outputs - stateful - get_state - module_method - - -Linear modules ------------------------- - -.. autosummary:: - :toctree: _autosummary - - Dense - DenseGeneral - Conv - Embed - - -Normalization ------------------------- - -.. autosummary:: - :toctree: _autosummary - - BatchNorm - LayerNorm - GroupNorm - - -Pooling ------------------------- - -.. autosummary:: - :toctree: _autosummary - - max_pool - avg_pool - - -Activation functions ------------------------- - -.. autosummary:: - :toctree: _autosummary - - celu - elu - gelu - glu - log_sigmoid - log_softmax - relu - sigmoid - soft_sign - softmax - softplus - swish - - -Stochastic functions ------------------------- - -.. autosummary:: - :toctree: _autosummary - - make_rng - stochastic - is_stochastic - dropout - - -Attention primitives ------------------------- - -.. autosummary:: - :toctree: _autosummary - - dot_product_attention - SelfAttention - - -RNN primitives ------------------------- - -.. autosummary:: - :toctree: _autosummary - - LSTMCell - OptimizedLSTMCell - GRUCell diff --git a/docs/index.rst b/docs/index.rst index fbb51bca08..e932af5cfa 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -74,9 +74,3 @@ For a quick introduction and short example snippets, see our `README flax.training flax.config flax.errors - -.. toctree:: - :maxdepth: 1 - :caption: (deprecated) - - flax.nn (deprecated) diff --git a/flax/core/flax_functional_engine.ipynb b/flax/core/flax_functional_engine.ipynb index b35bd6b534..08164cc837 100644 --- a/flax/core/flax_functional_engine.ipynb +++ b/flax/core/flax_functional_engine.ipynb @@ -2,68 +2,93 @@ "cells": [ { "cell_type": "code", + "execution_count": 1, "metadata": { - "id": "x0SPwYS9dtYA", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "x0SPwYS9dtYA" }, + "outputs": [], "source": [ "import functools\n", "import jax\n", "from jax import numpy as jnp, random, lax\n", "import numpy as np\n" - ], - "execution_count": 1, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 2, "metadata": { - "id": "7n9cxyCzluvI", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "7n9cxyCzluvI" }, + "outputs": [], "source": [ - "from flax import nn, struct" - ], - "execution_count": 2, - "outputs": [] + "from flax import linen as nn, struct" + ] }, { "cell_type": "code", + "execution_count": 3, "metadata": { - "id": "0L7YCrobkfzU", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "0L7YCrobkfzU" }, + "outputs": [], "source": [ "from flax.core import Scope, init, apply, Array, lift, unfreeze" - ], - "execution_count": 3, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 4, "metadata": { - "id": "aDLGb3iGkjoL", + "colab": { + "height": 136 + }, "colab_type": "code", - "outputId": "2558605e-e485-407e-b062-74d31cc49f1e", "executionInfo": { + "elapsed": 1116, "status": "ok", "timestamp": 1590673431275, - "user_tz": -120, - "elapsed": 1116, "user": { "displayName": "Jonathan Heek", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhqRcoo1w0woYaM99jSyWQaD-qfmHmeDpXHzHZd=s64", "userId": "00491914421152177709" - } - }, - "colab": { - "height": 136 + }, + "user_tz": -120 }, + "id": "aDLGb3iGkjoL", + "outputId": "2558605e-e485-407e-b062-74d31cc49f1e", "tags": [] }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FrozenDict({'params': {'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n", + " [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}})\n" + ] + }, + { + "data": { + "text/plain": [ + "(DeviceArray([[0.17045607]], dtype=float32),\n", + " FrozenDict({'params': {'hidden': {'bias': DeviceArray([0., 0., 0.], dtype=float32), 'kernel': DeviceArray([[-0.22119394, 0.22075175, -0.0925657 ],\n", + " [ 0.40571952, 0.27750877, 1.0542233 ]], dtype=float32)}, 'out': {'kernel': DeviceArray([[ 0.21448377],\n", + " [-0.01530595],\n", + " [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)}}}))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "def dense(scope: Scope, inputs: Array, features: int, bias: bool = True,\n", " kernel_init=nn.linear.default_kernel_init,\n", @@ -86,46 +111,43 @@ " return dense(scope.push('out'), hidden, 1)\n", "\n", "init(mlp)(random.PRNGKey(0), x, features=3)" - ], - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": "FrozenDict({'params': {'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}})\n" - }, - { - "output_type": "execute_result", - "data": { - "text/plain": "(DeviceArray([[0.17045607]], dtype=float32),\n FrozenDict({'params': {'hidden': {'bias': DeviceArray([0., 0., 0.], dtype=float32), 'kernel': DeviceArray([[-0.22119394, 0.22075175, -0.0925657 ],\n [ 0.40571952, 0.27750877, 1.0542233 ]], dtype=float32)}, 'out': {'kernel': DeviceArray([[ 0.21448377],\n [-0.01530595],\n [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)}}}))" - }, - "metadata": {}, - "execution_count": 4 - } ] }, { "cell_type": "code", + "execution_count": 5, "metadata": { - "id": "LTFjZbRmlqZh", + "colab": { + "height": 85 + }, "colab_type": "code", - "outputId": "5790b763-df4f-47c8-9f4e-53fd1e1eb1fd", "executionInfo": { + "elapsed": 526, "status": "ok", "timestamp": 1590672865722, - "user_tz": -120, - "elapsed": 526, "user": { "displayName": "Jonathan Heek", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhqRcoo1w0woYaM99jSyWQaD-qfmHmeDpXHzHZd=s64", "userId": "00491914421152177709" - } - }, - "colab": { - "height": 85 + }, + "user_tz": -120 }, + "id": "LTFjZbRmlqZh", + "outputId": "5790b763-df4f-47c8-9f4e-53fd1e1eb1fd", "tags": [] }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0.11575121 -0.51936364 -1.113899 ]\n", + " [ 0.45569834 -0.5300623 -0.5873911 ]]\n", + "[ 0.45569834 -0.5300623 -0.5873911 ]\n", + "[[-1.5175114 -0.6617551]]\n" + ] + } + ], "source": [ "@struct.dataclass\n", "class Embedding:\n", @@ -147,37 +169,43 @@ "print(embedding.table)\n", "print(embedding.lookup(1))\n", "print(embedding.attend(jnp.ones((1, 3,))))" - ], - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": "[[ 0.11575121 -0.51936364 -1.113899 ]\n [ 0.45569834 -0.5300623 -0.5873911 ]]\n[ 0.45569834 -0.5300623 -0.5873911 ]\n[[-1.5175114 -0.6617551]]\n" - } ] }, { "cell_type": "code", + "execution_count": 6, "metadata": { - "id": "TMlae0hem0u5", + "colab": { + "height": 71 + }, "colab_type": "code", - "outputId": "dd9c5079-10e7-4944-e09a-e9f65573a733", "executionInfo": { + "elapsed": 342, "status": "ok", "timestamp": 1590673618925, - "user_tz": -120, - "elapsed": 342, "user": { "displayName": "Jonathan Heek", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhqRcoo1w0woYaM99jSyWQaD-qfmHmeDpXHzHZd=s64", "userId": "00491914421152177709" - } + }, + "user_tz": -120 }, - "colab": { - "height": 71 - } + "id": "TMlae0hem0u5", + "outputId": "dd9c5079-10e7-4944-e09a-e9f65573a733" }, + "outputs": [ + { + "data": { + "text/plain": [ + "((((1, 3), (1, 3)), (1, 3)),\n", + " FrozenDict({'params': {'hf': {'bias': (3,), 'kernel': (3, 3)}, 'hg': {'bias': (3,), 'kernel': (3, 3)}, 'hi': {'bias': (3,), 'kernel': (3, 3)}, 'ho': {'bias': (3,), 'kernel': (3, 3)}, 'if': {'kernel': (2, 3)}, 'ig': {'kernel': (2, 3)}, 'ii': {'kernel': (2, 3)}, 'io': {'kernel': (2, 3)}}}))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "def lstm(scope, carry, inputs,\n", " gate_fn=nn.activation.sigmoid, activation_fn=nn.activation.tanh,\n", @@ -240,17 +268,6 @@ "carry = lstm_init_carry((1,), 3)\n", "y, variables = init(lstm)(random.PRNGKey(0), carry, x)\n", "jax.tree_map(np.shape, (y, variables))" - ], - "execution_count": 6, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": "((((1, 3), (1, 3)), (1, 3)),\n FrozenDict({'params': {'hf': {'bias': (3,), 'kernel': (3, 3)}, 'hg': {'bias': (3,), 'kernel': (3, 3)}, 'hi': {'bias': (3,), 'kernel': (3, 3)}, 'ho': {'bias': (3,), 'kernel': (3, 3)}, 'if': {'kernel': (2, 3)}, 'ig': {'kernel': (2, 3)}, 'ii': {'kernel': (2, 3)}, 'io': {'kernel': (2, 3)}}}))" - }, - "metadata": {}, - "execution_count": 6 - } ] }, { @@ -261,9 +278,12 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", - "text": "initialized parameter shapes:\n {'params': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}\n" + "output_type": "stream", + "text": [ + "initialized parameter shapes:\n", + " {'params': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}\n" + ] } ], "source": [ @@ -296,9 +316,12 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", - "text": "output:\n (DeviceArray([[-0.35626447, 0.25178757]], dtype=float32), DeviceArray([[-0.17885922, 0.13063088]], dtype=float32))\n" + "output_type": "stream", + "text": [ + "output:\n", + " (DeviceArray([[-0.35626447, 0.25178757]], dtype=float32), DeviceArray([[-0.17885922, 0.13063088]], dtype=float32))\n" + ] } ], "source": [ @@ -316,17 +339,17 @@ ], "metadata": { "colab": { - "name": "flax functional engine.ipynb", - "provenance": [], "collapsed_sections": [], "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" - } + }, + "name": "flax functional engine.ipynb", + "provenance": [] }, "kernelspec": { - "name": "python3", - "display_name": "Python 3" + "display_name": "Python 3", + "name": "python3" } }, "nbformat": 4, diff --git a/flax/linen/README.md b/flax/linen/README.md index 8a064f501b..54c402711e 100644 --- a/flax/linen/README.md +++ b/flax/linen/README.md @@ -1,7 +1,7 @@ # Linen: A comfortable evolution of Flax -Linen is a rewrite of Flax Modules based on learning from our users and the broader JAX community. Linen improves on much of the former `flax.nn` API, such as submodule sharing and better support for non-trainable variables. -Moreover, Linen builds on a new "functional core", enabling direct usage of JAX transformations such as `vmap`, `remat` or `scan` inside your modules. +Linen is an neural network API based on learning from our users and the broader JAX community. Linen improves on much of the former APIs, such as submodule sharing and better support for non-trainable variables. +Moreover, Linen builds on a "functional core", enabling direct usage of JAX transformations such as `vmap`, `remat` or `scan` inside your modules. In Linen, Modules behave much closer to vanilla Python objects, while still letting you opt-in to the concise single-method pattern many of our users love. From 1c4fe91dae3a3df18eb72bc37b542fb2383079ad Mon Sep 17 00:00:00 2001 From: Marc van Zee Date: Thu, 27 Jan 2022 11:44:03 +0100 Subject: [PATCH 2/3] Updates core notebook to Linen --- flax/core/flax_functional_engine.ipynb | 189 +++++++++++-------------- 1 file changed, 83 insertions(+), 106 deletions(-) diff --git a/flax/core/flax_functional_engine.ipynb b/flax/core/flax_functional_engine.ipynb index 08164cc837..96732c79fe 100644 --- a/flax/core/flax_functional_engine.ipynb +++ b/flax/core/flax_functional_engine.ipynb @@ -2,93 +2,68 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, "metadata": { - "colab": {}, + "id": "x0SPwYS9dtYA", "colab_type": "code", - "id": "x0SPwYS9dtYA" + "colab": {} }, - "outputs": [], "source": [ "import functools\n", "import jax\n", "from jax import numpy as jnp, random, lax\n", "import numpy as np\n" - ] + ], + "execution_count": 1, + "outputs": [] }, { "cell_type": "code", - "execution_count": 2, "metadata": { - "colab": {}, + "id": "7n9cxyCzluvI", "colab_type": "code", - "id": "7n9cxyCzluvI" + "colab": {} }, - "outputs": [], "source": [ "from flax import linen as nn, struct" - ] + ], + "execution_count": 2, + "outputs": [] }, { "cell_type": "code", - "execution_count": 3, "metadata": { - "colab": {}, + "id": "0L7YCrobkfzU", "colab_type": "code", - "id": "0L7YCrobkfzU" + "colab": {} }, - "outputs": [], "source": [ "from flax.core import Scope, init, apply, Array, lift, unfreeze" - ] + ], + "execution_count": 3, + "outputs": [] }, { "cell_type": "code", - "execution_count": 4, "metadata": { - "colab": { - "height": 136 - }, + "id": "aDLGb3iGkjoL", "colab_type": "code", + "outputId": "2558605e-e485-407e-b062-74d31cc49f1e", "executionInfo": { - "elapsed": 1116, "status": "ok", "timestamp": 1590673431275, + "user_tz": -120, + "elapsed": 1116, "user": { "displayName": "Jonathan Heek", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhqRcoo1w0woYaM99jSyWQaD-qfmHmeDpXHzHZd=s64", "userId": "00491914421152177709" - }, - "user_tz": -120 + } + }, + "colab": { + "height": 136 }, - "id": "aDLGb3iGkjoL", - "outputId": "2558605e-e485-407e-b062-74d31cc49f1e", "tags": [] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "FrozenDict({'params': {'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n", - " [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}})\n" - ] - }, - { - "data": { - "text/plain": [ - "(DeviceArray([[0.17045607]], dtype=float32),\n", - " FrozenDict({'params': {'hidden': {'bias': DeviceArray([0., 0., 0.], dtype=float32), 'kernel': DeviceArray([[-0.22119394, 0.22075175, -0.0925657 ],\n", - " [ 0.40571952, 0.27750877, 1.0542233 ]], dtype=float32)}, 'out': {'kernel': DeviceArray([[ 0.21448377],\n", - " [-0.01530595],\n", - " [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)}}}))" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ "def dense(scope: Scope, inputs: Array, features: int, bias: bool = True,\n", " kernel_init=nn.linear.default_kernel_init,\n", @@ -111,43 +86,46 @@ " return dense(scope.push('out'), hidden, 1)\n", "\n", "init(mlp)(random.PRNGKey(0), x, features=3)" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "FrozenDict({'params': {'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}})\n" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": "(DeviceArray([[0.17045607]], dtype=float32),\n FrozenDict({'params': {'hidden': {'bias': DeviceArray([0., 0., 0.], dtype=float32), 'kernel': DeviceArray([[-0.22119394, 0.22075175, -0.0925657 ],\n [ 0.40571952, 0.27750877, 1.0542233 ]], dtype=float32)}, 'out': {'kernel': DeviceArray([[ 0.21448377],\n [-0.01530595],\n [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)}}}))" + }, + "metadata": {}, + "execution_count": 4 + } ] }, { "cell_type": "code", - "execution_count": 5, "metadata": { - "colab": { - "height": 85 - }, + "id": "LTFjZbRmlqZh", "colab_type": "code", + "outputId": "5790b763-df4f-47c8-9f4e-53fd1e1eb1fd", "executionInfo": { - "elapsed": 526, "status": "ok", "timestamp": 1590672865722, + "user_tz": -120, + "elapsed": 526, "user": { "displayName": "Jonathan Heek", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhqRcoo1w0woYaM99jSyWQaD-qfmHmeDpXHzHZd=s64", "userId": "00491914421152177709" - }, - "user_tz": -120 + } + }, + "colab": { + "height": 85 }, - "id": "LTFjZbRmlqZh", - "outputId": "5790b763-df4f-47c8-9f4e-53fd1e1eb1fd", "tags": [] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0.11575121 -0.51936364 -1.113899 ]\n", - " [ 0.45569834 -0.5300623 -0.5873911 ]]\n", - "[ 0.45569834 -0.5300623 -0.5873911 ]\n", - "[[-1.5175114 -0.6617551]]\n" - ] - } - ], "source": [ "@struct.dataclass\n", "class Embedding:\n", @@ -169,43 +147,37 @@ "print(embedding.table)\n", "print(embedding.lookup(1))\n", "print(embedding.attend(jnp.ones((1, 3,))))" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "[[ 0.11575121 -0.51936364 -1.113899 ]\n [ 0.45569834 -0.5300623 -0.5873911 ]]\n[ 0.45569834 -0.5300623 -0.5873911 ]\n[[-1.5175114 -0.6617551]]\n" + } ] }, { "cell_type": "code", - "execution_count": 6, "metadata": { - "colab": { - "height": 71 - }, + "id": "TMlae0hem0u5", "colab_type": "code", + "outputId": "dd9c5079-10e7-4944-e09a-e9f65573a733", "executionInfo": { - "elapsed": 342, "status": "ok", "timestamp": 1590673618925, + "user_tz": -120, + "elapsed": 342, "user": { "displayName": "Jonathan Heek", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhqRcoo1w0woYaM99jSyWQaD-qfmHmeDpXHzHZd=s64", "userId": "00491914421152177709" - }, - "user_tz": -120 + } }, - "id": "TMlae0hem0u5", - "outputId": "dd9c5079-10e7-4944-e09a-e9f65573a733" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "((((1, 3), (1, 3)), (1, 3)),\n", - " FrozenDict({'params': {'hf': {'bias': (3,), 'kernel': (3, 3)}, 'hg': {'bias': (3,), 'kernel': (3, 3)}, 'hi': {'bias': (3,), 'kernel': (3, 3)}, 'ho': {'bias': (3,), 'kernel': (3, 3)}, 'if': {'kernel': (2, 3)}, 'ig': {'kernel': (2, 3)}, 'ii': {'kernel': (2, 3)}, 'io': {'kernel': (2, 3)}}}))" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" + "colab": { + "height": 71 } - ], + }, "source": [ "def lstm(scope, carry, inputs,\n", " gate_fn=nn.activation.sigmoid, activation_fn=nn.activation.tanh,\n", @@ -268,6 +240,17 @@ "carry = lstm_init_carry((1,), 3)\n", "y, variables = init(lstm)(random.PRNGKey(0), carry, x)\n", "jax.tree_map(np.shape, (y, variables))" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "((((1, 3), (1, 3)), (1, 3)),\n FrozenDict({'params': {'hf': {'bias': (3,), 'kernel': (3, 3)}, 'hg': {'bias': (3,), 'kernel': (3, 3)}, 'hi': {'bias': (3,), 'kernel': (3, 3)}, 'ho': {'bias': (3,), 'kernel': (3, 3)}, 'if': {'kernel': (2, 3)}, 'ig': {'kernel': (2, 3)}, 'ii': {'kernel': (2, 3)}, 'io': {'kernel': (2, 3)}}}))" + }, + "metadata": {}, + "execution_count": 6 + } ] }, { @@ -278,12 +261,9 @@ }, "outputs": [ { - "name": "stdout", "output_type": "stream", - "text": [ - "initialized parameter shapes:\n", - " {'params': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}\n" - ] + "name": "stdout", + "text": "initialized parameter shapes:\n {'params': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}\n" } ], "source": [ @@ -316,12 +296,9 @@ }, "outputs": [ { - "name": "stdout", "output_type": "stream", - "text": [ - "output:\n", - " (DeviceArray([[-0.35626447, 0.25178757]], dtype=float32), DeviceArray([[-0.17885922, 0.13063088]], dtype=float32))\n" - ] + "name": "stdout", + "text": "output:\n (DeviceArray([[-0.35626447, 0.25178757]], dtype=float32), DeviceArray([[-0.17885922, 0.13063088]], dtype=float32))\n" } ], "source": [ @@ -339,17 +316,17 @@ ], "metadata": { "colab": { + "name": "flax functional engine.ipynb", + "provenance": [], "collapsed_sections": [], "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" - }, - "name": "flax functional engine.ipynb", - "provenance": [] + } }, "kernelspec": { - "display_name": "Python 3", - "name": "python3" + "name": "python3", + "display_name": "Python 3" } }, "nbformat": 4, From a7cf071de4783c923fea629ba99fd50d360898de Mon Sep 17 00:00:00 2001 From: Marc van Zee Date: Thu, 27 Jan 2022 11:48:12 +0100 Subject: [PATCH 3/3] Updates Linen README --- flax/linen/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flax/linen/README.md b/flax/linen/README.md index 54c402711e..eacf3d4dbe 100644 --- a/flax/linen/README.md +++ b/flax/linen/README.md @@ -1,6 +1,6 @@ # Linen: A comfortable evolution of Flax -Linen is an neural network API based on learning from our users and the broader JAX community. Linen improves on much of the former APIs, such as submodule sharing and better support for non-trainable variables. +Linen is a neural network API developed based on learning from our users and the broader JAX community. Linen improves on much of the former `flax.nn` API (removed since v0.4.0), such as submodule sharing and better support for non-trainable variables. Moreover, Linen builds on a "functional core", enabling direct usage of JAX transformations such as `vmap`, `remat` or `scan` inside your modules. In Linen, Modules behave much closer to vanilla Python objects, while still letting you opt-in to the concise single-method pattern many of our users love.