Skip to content

Commit

Permalink
update tiny nnx
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 15, 2023
1 parent 2f60e22 commit f30ab47
Showing 1 changed file with 48 additions and 47 deletions.
95 changes: 48 additions & 47 deletions flax/experimental/nnx/docs/tiny_nnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@
},
{
"cell_type": "code",
"execution_count": 67,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import dataclasses\n",
"import hashlib\n",
"import typing as tp\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
"import dataclasses\n",
"\n",
"A = tp.TypeVar(\"A\")\n",
"M = tp.TypeVar(\"M\", bound=\"Module\")\n",
"Sharding = tp.Tuple[tp.Optional[str], ...]\n",
"Array = random.Array\n",
"Array = jax.Array\n",
"\n",
"\n",
"class Variable(tp.Generic[A]):\n",
Expand All @@ -53,13 +54,13 @@
" jax.tree_util.register_pytree_node(\n",
" cls,\n",
" lambda x: ((x.value,), (x.sharding,)),\n",
" lambda metadata, value: Variable(value[0], sharding=metadata[0]),\n",
" lambda metadata, value: cls(value[0], sharding=metadata[0]),\n",
" )\n",
"\n",
"\n",
"class State(dict[str, Variable[tp.Any]]):\n",
"\n",
" def filter(self, variable_type: tp.Type[Variable]) -> \"State\":\n",
" def extract(self, variable_type: tp.Type[Variable]) -> \"State\":\n",
" return State(\n",
" {\n",
" path: variable\n",
Expand Down Expand Up @@ -98,23 +99,23 @@
"\n",
" @staticmethod\n",
" def _build_module_recursive(\n",
" static.Union[\"GraphDef[M]\", int],\n",
" graphdef: tp.Union[\"GraphDef[M]\", int],\n",
" index_to_module: dict[int, \"Module\"],\n",
" ) -> M:\n",
" if isinstance(statict):\n",
" return index_to_module[static type: ignore\n",
" if isinstance(graphdef, int):\n",
" return index_to_module[graphdef] # type: ignore\n",
"\n",
" assert staticex not in index_to_module\n",
" assert graphdef.index not in index_to_module\n",
"\n",
" # add a dummy module to the index to avoid infinite recursion\n",
" module = object.__new__(statice)\n",
" index_to_module[staticex] = module\n",
" module = object.__new__(graphdef.type)\n",
" index_to_module[graphdef.index] = module\n",
"\n",
" submodules = {\n",
" name: GraphDef._build_module_recursive(submodule, index_to_module)\n",
" for name, submodule in staticmodules.items()\n",
" for name, submodule in graphdef.submodules.items()\n",
" }\n",
" vars(module).update(statictic_fields)\n",
" vars(module).update(graphdef.static_fields)\n",
" vars(module).update(submodules)\n",
" return module\n",
"\n",
Expand All @@ -133,11 +134,11 @@
"\n",
" def split(self: M) -> tp.Tuple[State, GraphDef[M]]:\n",
" state = State()\n",
" staticodule._partition_recursive(\n",
" graphdef = Module._partition_recursive(\n",
" module=self, module_id_to_index={}, path_parts=(), state=state\n",
" )\n",
" assert isinstance(staticaphDef)\n",
" return state, static\n",
" assert isinstance(graphdef, GraphDef)\n",
" return state, graphdef\n",
"\n",
" @staticmethod\n",
" def _partition_recursive(\n",
Expand Down Expand Up @@ -167,7 +168,7 @@
" # if value is a Variable, add to state\n",
" elif isinstance(value, Variable):\n",
" state[\"/\".join(value_path)] = value\n",
" else: # otherwise, add to static fields\n",
" else: # otherwise, add to graphdef fields\n",
" static_fields[name] = value\n",
"\n",
" return GraphDef(\n",
Expand All @@ -177,7 +178,7 @@
" static_fields=static_fields,\n",
" )\n",
"\n",
" def update_state(self, state: State) -> None:\n",
" def update(self, state: State) -> None:\n",
" for path, value in state.items():\n",
" path_parts = path.split(\"/\")\n",
" Module._set_value_at_path(self, path_parts, value)\n",
Expand Down Expand Up @@ -242,7 +243,7 @@
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -314,7 +315,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -340,27 +341,27 @@
"\n",
" # lift init\n",
" key = random.split(rngs.make_rng(), n_layers - 1)\n",
" staticaphDef[Block] = None # type: ignore\n",
" graphdef: GraphDef[Block] = None # type: ignore\n",
"\n",
" def init_fn(key):\n",
" nonlocal static\n",
" state, staticlock(\n",
" nonlocal graphdef\n",
" state, graphdef = Block(\n",
" hidden_size, hidden_size, rngs=Rngs(key)\n",
" ).split()\n",
" return state\n",
"\n",
" state = jax.vmap(init_fn)(key)\n",
" self.layers = staticge(state)\n",
" self.layers = graphdef.merge(state)\n",
" self.linear = Linear(hidden_size, hidden_size, rngs=rngs)\n",
"\n",
" def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:\n",
" # lift call\n",
" key: jax.Array = random.split(rngs.make_rng(), self.n_layers - 1) # type: ignore\n",
" state, staticelf.layers.split()\n",
" state, graphdef = self.layers.split()\n",
"\n",
" def scan_fn(x, inputs: tuple[jax.Array, State]):\n",
" key, state = inputs\n",
" x, (state, _) = staticly(state)(x, train=train, rngs=Rngs(key))\n",
" x, (state, _) = graphdef.apply(state)(x, train=train, rngs=Rngs(key))\n",
" return x, state\n",
"\n",
" x, state = jax.lax.scan(scan_fn, x, (key, state))\n",
Expand All @@ -371,22 +372,22 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"state = State({\n",
" 'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),\n",
" 'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n",
" 'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),\n",
" 'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n",
" 'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),\n",
" 'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),\n",
" 'linear/b': Variable(value=(10,), collection=params, sharding=None),\n",
" 'linear/w': Variable(value=(10, 10), collection=params, sharding=None)\n",
" 'layers/bn/bias': Param(value=(4, 10), sharding=None),\n",
" 'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),\n",
" 'layers/bn/scale': Param(value=(4, 10), sharding=None),\n",
" 'layers/bn/var': BatchStat(value=(4, 10), sharding=None),\n",
" 'layers/linear/b': Param(value=(4, 10), sharding=None),\n",
" 'layers/linear/w': Param(value=(4, 10, 10), sharding=None),\n",
" 'linear/b': Param(value=(10,), sharding=None),\n",
" 'linear/w': Param(value=(10, 10), sharding=None)\n",
"})\n",
"graphdef = GraphDef(type=<class '__main__.ScanMLP'>, index=0, submodules={'layers': GraphDef(type=<class '__main__.Block'>, index=1, submodules={'bn': GraphDef(type=<class '__main__.BatchNorm'>, index=2, submodules={}, static_fields={'mu': 0.95}), 'dropout': GraphDef(type=<class '__main__.Dropout'>, index=3, submodules={}, static_fields={'rate': 0.1}), 'linear': GraphDef(type=<class '__main__.Linear'>, index=4, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={}), 'linear': GraphDef(type=<class '__main__.Linear'>, index=5, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={'n_layers': 5})\n"
]
Expand All @@ -397,9 +398,9 @@
"x = jax.random.normal(random.key(0), (2, 10))\n",
"y = module(x, train=True, rngs=Rngs(random.key(1)))\n",
"\n",
"state, staticodule.split()\n",
"state, graphdef = module.split()\n",
"print(\"state =\", jax.tree_map(jnp.shape, state))\n",
"print(\"static stststatic"
"print(\"graphdef =\", graphdef)"
]
},
{
Expand All @@ -412,24 +413,24 @@
},
{
"cell_type": "code",
"execution_count": 89,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"params = State({\n",
" 'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),\n",
" 'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),\n",
" 'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),\n",
" 'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),\n",
" 'linear/b': Variable(value=(10,), collection=params, sharding=None),\n",
" 'linear/w': Variable(value=(10, 10), collection=params, sharding=None)\n",
" 'layers/bn/bias': Param(value=(4, 10), sharding=None),\n",
" 'layers/bn/scale': Param(value=(4, 10), sharding=None),\n",
" 'layers/linear/b': Param(value=(4, 10), sharding=None),\n",
" 'layers/linear/w': Param(value=(4, 10, 10), sharding=None),\n",
" 'linear/b': Param(value=(10,), sharding=None),\n",
" 'linear/w': Param(value=(10, 10), sharding=None)\n",
"})\n",
"batch_stats = State({\n",
" 'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n",
" 'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None)\n",
" 'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),\n",
" 'layers/bn/var': BatchStat(value=(4, 10), sharding=None)\n",
"})\n"
]
}
Expand Down Expand Up @@ -457,7 +458,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down

0 comments on commit f30ab47

Please sign in to comment.