Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NNX #3218

Merged
merged 1 commit into from
Nov 9, 2023
Merged

Add NNX #3218

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ build/
docs/**/tmp

# used by direnv
.envrc
.envrc

# custom
/tmp-files
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@

# -- Options for myst ----------------------------------------------
# uncomment line below to avoid running notebooks during development
# nb_execution_mode = 'off'
nb_execution_mode = 'off'
# Notebook cell execution timeout; defaults to 30.
nb_execution_timeout = 100
# List of patterns, relative to source directory, that match notebook
Expand All @@ -133,6 +133,7 @@
nb_execution_excludepatterns = [
'quick_start.ipynb', # <-- times out
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
'flax/experimental/nnx', # exclude nnx
]
# raise exceptions on execution so CI can catch errors
nb_execution_allow_errors = False
Expand Down
6 changes: 6 additions & 0 deletions docs/experimental/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@


.. toctree::
:maxdepth: 2

nnx
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,5 @@ Notable examples in Flax include:
developer_notes/index
philosophy
contributing
experimental
api_reference/index
92 changes: 68 additions & 24 deletions flax/core/flax_functional_engine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"import functools\n",
"import jax\n",
"from jax import numpy as jnp, random, lax\n",
"import numpy as np\n"
"import numpy as np"
]
},
{
Expand Down Expand Up @@ -78,26 +78,34 @@
}
],
"source": [
"def dense(scope: Scope, inputs: Array, features: int, bias: bool = True,\n",
" kernel_init=nn.linear.default_kernel_init,\n",
" bias_init=nn.initializers.zeros_init()):\n",
"def dense(\n",
" scope: Scope,\n",
" inputs: Array,\n",
" features: int,\n",
" bias: bool = True,\n",
" kernel_init=nn.linear.default_kernel_init,\n",
" bias_init=nn.initializers.zeros_init(),\n",
"):\n",
" kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))\n",
" y = jnp.dot(inputs, kernel)\n",
" if bias:\n",
" y += scope.param('bias', bias_init, (features,))\n",
" return y\n",
"\n",
"\n",
"model_fn = functools.partial(dense, features=3)\n",
"\n",
"x = jnp.ones((1, 2))\n",
"y, params = init(model_fn)(random.key(0), x)\n",
"print(params)\n",
"\n",
"\n",
"def mlp(scope: Scope, inputs: Array, features: int):\n",
" hidden = scope.child(dense, 'hidden')(inputs, features)\n",
" hidden = nn.relu(hidden)\n",
" return dense(scope.push('out'), hidden, 1)\n",
"\n",
"\n",
"init(mlp)(random.key(0), x, features=3)"
]
},
Expand Down Expand Up @@ -138,16 +146,31 @@
" def attend(self, query):\n",
" return jnp.dot(query, self.table.T)\n",
"\n",
"\n",
"# all the embedding module does is provide a convenient initializers\n",
"\n",
"def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=nn.linear.default_embed_init) -> Embedding:\n",
"\n",
"def embedding(\n",
" scope: Scope,\n",
" num_embeddings: int,\n",
" features: int,\n",
" init_fn=nn.linear.default_embed_init,\n",
") -> Embedding:\n",
" table = scope.param('table', init_fn, (num_embeddings, features))\n",
" return Embedding(table)\n",
"\n",
"\n",
"embedding, _ = init(embedding)(random.key(0), num_embeddings=2, features=3)\n",
"print(embedding.table)\n",
"print(embedding.lookup(1))\n",
"print(embedding.attend(jnp.ones((1, 3,))))"
"print(\n",
" embedding.attend(\n",
" jnp.ones((\n",
" 1,\n",
" 3,\n",
" ))\n",
" )\n",
")"
]
},
{
Expand Down Expand Up @@ -177,11 +200,16 @@
}
],
"source": [
"def lstm(scope, carry, inputs,\n",
" gate_fn=nn.activation.sigmoid, activation_fn=nn.activation.tanh,\n",
" kernel_init=nn.linear.default_kernel_init,\n",
" recurrent_kernel_init=nn.initializers.orthogonal(),\n",
" bias_init=nn.initializers.zeros_init()):\n",
"def lstm(\n",
" scope,\n",
" carry,\n",
" inputs,\n",
" gate_fn=nn.activation.sigmoid,\n",
" activation_fn=nn.activation.tanh,\n",
" kernel_init=nn.linear.default_kernel_init,\n",
" recurrent_kernel_init=nn.initializers.orthogonal(),\n",
" bias_init=nn.initializers.zeros_init(),\n",
"):\n",
" r\"\"\"A long short-term memory (LSTM) cell.\n",
"\n",
" the mathematical definition of the cell is as follows\n",
Expand Down Expand Up @@ -217,11 +245,15 @@
" hidden_features = h.shape[-1]\n",
" # input and recurrent layers are summed so only one needs a bias.\n",
" dense_h = lambda name: scope.child(dense, name)(\n",
" h, features=hidden_features, bias=True,\n",
" kernel_init=recurrent_kernel_init, bias_init=bias_init)\n",
" h,\n",
" features=hidden_features,\n",
" bias=True,\n",
" kernel_init=recurrent_kernel_init,\n",
" bias_init=bias_init,\n",
" )\n",
" dense_i = lambda name: scope.child(dense, name)(\n",
" inputs, features=hidden_features, bias=False,\n",
" kernel_init=kernel_init)\n",
" inputs, features=hidden_features, bias=False, kernel_init=kernel_init\n",
" )\n",
" i = gate_fn(dense_i(name='ii') + dense_h(name='hi'))\n",
" f = gate_fn(dense_i(name='if') + dense_h(name='hf'))\n",
" g = activation_fn(dense_i(name='ig') + dense_h(name='hg'))\n",
Expand All @@ -230,10 +262,12 @@
" new_h = o * activation_fn(new_c)\n",
" return (new_c, new_h), new_h\n",
"\n",
"\n",
"def lstm_init_carry(batch_dims, size, init_fn=jnp.zeros):\n",
" shape = batch_dims + (size,)\n",
" return init_fn(shape), init_fn(shape)\n",
"\n",
"\n",
"x = jnp.ones((1, 2))\n",
"carry = lstm_init_carry((1,), 3)\n",
"y, variables = init(lstm)(random.key(0), carry, x)\n",
Expand All @@ -259,23 +293,33 @@
"source": [
"def simple_scan(scope: Scope, xs):\n",
" init_carry = lstm_init_carry(xs.shape[:1], xs.shape[-1])\n",
"# cell = scope.child(lstm, 'cell')\n",
"# ys = []\n",
"# for i in range(xs.shape[1]):\n",
"# x = xs[:, i]\n",
"# init_carry, y = cell(init_carry, x)\n",
"# ys.append(y)\n",
"# return init_carry, ys\n",
" lstm_scan = lift.scan(lstm, in_axes=1, out_axes=1, variable_broadcast='params', split_rngs={'params': False})\n",
" # cell = scope.child(lstm, 'cell')\n",
" # ys = []\n",
" # for i in range(xs.shape[1]):\n",
" # x = xs[:, i]\n",
" # init_carry, y = cell(init_carry, x)\n",
" # ys.append(y)\n",
" # return init_carry, ys\n",
" lstm_scan = lift.scan(\n",
" lstm,\n",
" in_axes=1,\n",
" out_axes=1,\n",
" variable_broadcast='params',\n",
" split_rngs={'params': False},\n",
" )\n",
" return lstm_scan(scope, init_carry, xs)\n",
"\n",
"\n",
"key1, key2 = random.split(random.key(0), 2)\n",
"xs = random.uniform(key1, (1, 5, 2))\n",
"\n",
"\n",
"y, init_variables = init(simple_scan)(key2, xs)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n"
"print(\n",
" 'initialized parameter shapes:\\n',\n",
" jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)),\n",
")"
]
},
{
Expand Down
Empty file added flax/experimental/__init__.py
Empty file.
133 changes: 133 additions & 0 deletions flax/experimental/nnx/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# project specific
.vscode
/tmp
Loading
Loading