Skip to content

Commit

Permalink
switch to nested State representation
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 23, 2023
1 parent 70214f4 commit a8340bc
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 190 deletions.
120 changes: 78 additions & 42 deletions flax/experimental/nnx/examples/00_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand Down Expand Up @@ -48,32 +55,29 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"State({\n",
" 'b': Param(\n",
" value=Array([0., 0.], dtype=float32)\n",
" ),\n",
" 'w': Param(\n",
" value=Array([[0.31696808, 0.55285215],\n",
" [0.31418085, 0.7399571 ]], dtype=float32)\n",
" )\n",
" 'w': Array([[0.31696808, 0.55285215],\n",
" [0.31418085, 0.7399571 ]], dtype=float32),\n",
" 'b': Array([0., 0.], dtype=float32)\n",
"})\n",
"GraphDef(\n",
" type=Linear,\n",
" index=0,\n",
" subgraphs=(),\n",
" static_fields=(('din', 2), ('dout', 2)),\n",
" variables=(('b', Param(\n",
" variables=(('w', Param(\n",
" value=Empty\n",
" )), ('w', Param(\n",
" )), ('b', Param(\n",
" value=Empty\n",
" ))),\n",
" submodules=()\n",
" metadata=<class '__main__.Linear'>\n",
")\n"
]
}
Expand All @@ -87,7 +91,40 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"State({\n",
" 'linear': {\n",
" 'w': Array([[0.31696808, 0.55285215],\n",
" [0.31418085, 0.7399571 ]], dtype=float32),\n",
" 'b': Array([0., 0.], dtype=float32)\n",
" }\n",
"})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class Nested(nnx.Module):\n",
" def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n",
" self.linear = Linear(din, dout, rngs=rngs)\n",
" \n",
"module = Nested(2, 2, rngs=nnx.Rngs(0))\n",
"\n",
"state, static = module.split()\n",
"state"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -129,34 +166,31 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"State({\n",
" 'b': Param(\n",
" value=Array([0., 0.], dtype=float32)\n",
" ),\n",
" 'w': Param(\n",
" value=Array([[0.31696808, 0.55285215],\n",
" [0.31418085, 0.7399571 ]], dtype=float32)\n",
" )\n",
" 'w': Array([[0.31696808, 0.55285215],\n",
" [0.31418085, 0.7399571 ]], dtype=float32),\n",
" 'b': Array([0., 0.], dtype=float32)\n",
"})\n",
"GraphDef(\n",
" type=Linear,\n",
" index=0,\n",
" subgraphs=(\n",
" ('submodule', 0)\n",
" ),\n",
" static_fields=(('din', 2), ('dout', 2)),\n",
" variables=(('b', Param(\n",
" variables=(('w', Param(\n",
" value=Empty\n",
" )), ('w', Param(\n",
" )), ('b', Param(\n",
" value=Empty\n",
" ))),\n",
" submodules=(\n",
" ('submodule', 0)\n",
" )\n",
" metadata=<class '__main__.Linear'>\n",
")\n"
]
}
Expand All @@ -170,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -179,7 +213,7 @@
"True"
]
},
"execution_count": 12,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -192,7 +226,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -235,27 +269,22 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"State({\n",
" 'y': Intermediate(\n",
" value=Array([[0.63114893, 1.2928092 ],\n",
" [0.63114893, 1.2928092 ]], dtype=float32)\n",
" )\n",
" 'y': Array([[0.63114893, 1.2928092 ],\n",
" [0.63114893, 1.2928092 ]], dtype=float32)\n",
"})\n",
"State({\n",
" 'b': Param(\n",
" value=Array([0., 0.], dtype=float32)\n",
" ),\n",
" 'w': Param(\n",
" value=Array([[0.31696808, 0.55285215],\n",
" [0.31418085, 0.7399571 ]], dtype=float32)\n",
" )\n",
" 'w': Array([[0.31696808, 0.55285215],\n",
" [0.31418085, 0.7399571 ]], dtype=float32),\n",
" 'b': Array([0., 0.], dtype=float32),\n",
" 'y': Empty\n",
"})\n"
]
}
Expand All @@ -267,6 +296,13 @@
"print(intermediates)\n",
"print(state)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -280,7 +316,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit a8340bc

Please sign in to comment.