Skip to content

Commit

Permalink
improve demo p3
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 3, 2023
1 parent 1fbd63e commit c535237
Showing 1 changed file with 76 additions and 71 deletions.
147 changes: 76 additions & 71 deletions flax/experimental/nnx/docs/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,16 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand All @@ -57,9 +64,9 @@
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7efed4734ee0>,\n",
" bias_init=<function zeros at 0x7eff772a0310>,\n",
" dot_general=<function dot_general at 0x7eff779594c0>\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7facfdf27e50>,\n",
" bias_init=<function zeros at 0x7fada0a0b280>,\n",
" dot_general=<function dot_general at 0x7fada1141430>\n",
" )\n",
")\n"
]
Expand Down Expand Up @@ -97,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -128,7 +135,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -149,7 +156,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -169,9 +176,9 @@
"\n",
"y = model(jnp.ones((2, 28, 28, 1)))\n",
"\n",
"for i, layer in enumerate(model):\n",
" if isinstance(layer, nnx.Conv):\n",
" model[i] = nnx.Linear(layer.in_features, layer.out_features, rngs=rngs)\n",
"for i, model in enumerate(model):\n",
" if isinstance(model, nnx.Conv):\n",
" model[i] = nnx.Linear(model.in_features, model.out_features, rngs=rngs)\n",
"\n",
"y = model(jnp.ones((2, 28, 28, 1)))"
]
Expand All @@ -187,12 +194,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Interacting with JAX is easy"
"### Interacting with JAX is easy\n",
"\n",
"While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations. NNX has two very simple APIs to interact with JAX: `split` and `merge`.\n",
"\n",
"The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `ModuleDef` object that contains the static structure of the Module."
]
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -206,28 +217,7 @@
" [ 0.38802508, 0.5655534 , 0.4870657 , 0.2267774 ],\n",
" [-0.9015767 , 0.24465278, -0.5844707 , 0.18421966],\n",
" [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n",
"})\n",
"static = ModuleDef(\n",
" type=CounterLinear,\n",
" index=0,\n",
" static_fields=(),\n",
" variables=(('count', Count(\n",
" value=Empty\n",
" )),),\n",
" submodules=(\n",
" ('linear', ModuleDef(\n",
" type=Linear,\n",
" index=1,\n",
" static_fields=(('bias_init', <function zeros at 0x7eff772a0310>), ('dot_general', <function dot_general at 0x7eff779594c0>), ('dtype', None), ('in_features', 4), ('kernel_init', <function variance_scaling.<locals>.init at 0x7efed4734ee0>), ('out_features', 4), ('param_dtype', <class 'jax.numpy.float32'>), ('precision', None), ('use_bias', True)),\n",
" variables=(('bias', Param(\n",
" value=Empty\n",
" )), ('kernel', Param(\n",
" value=Empty\n",
" ))),\n",
" submodules=()\n",
" ))\n",
" )\n",
")\n"
"})\n"
]
}
],
Expand All @@ -236,108 +226,123 @@
"\n",
"state, static = model.split()\n",
"\n",
"print(f'{state = }')\n",
"print(f'{static = }')"
"print(f'{state = }')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ModuleDef.merge` method allows you to take a `ModuleDef` and one or more `State` objects and merge them back into a `Module` object. \n",
"\n",
"Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example:"
]
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y.shape = (2, 4)\n",
"state[\"counter\"] = Array(1, dtype=int32)\n"
"state[\"count\"] = Array(1, dtype=int32)\n"
]
}
],
"source": [
"@jax.jit\n",
"def forward(\n",
" state: nnx.State, static: nnx.ModuleDef[CounterLinear], x: jax.Array\n",
") -> tuple[jax.Array, nnx.State]:\n",
"def forward(state: nnx.State, x: jax.Array):\n",
" model = static.merge(state)\n",
" y = model(x)\n",
" state, _ = model.split()\n",
" return y, state\n",
"\n",
"\n",
"x = jnp.ones((2, 4))\n",
"y, state = forward(state, static, x)\n",
"y, state = forward(state, x)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{state[\"count\"] = }')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Custom lifted Modules\n",
"\n",
"By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior. One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes this so easy that its realistic to implement custom lifted Modules for specific use cases.\n",
"\n",
"As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple:"
]
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y.shape = (8, 4)\n",
"vmodel.layers.counter = Array(2, dtype=int32)\n",
"ensemble.models.count = Array(1, dtype=int32)\n",
"state = State({\n",
" 'layers/counter': (),\n",
" 'layers/linear/bias': (8, 4),\n",
" 'layers/linear/kernel': (8, 4, 4)\n",
" 'models/count': (),\n",
" 'models/linear/bias': (8, 4),\n",
" 'models/linear/kernel': (8, 4, 4)\n",
"})\n"
]
}
],
"source": [
"class MLP(nnx.Module):\n",
" def __init__(self, din, dout, *, nbatch, rngs: nnx.Rngs):\n",
"class LinearEnsemble(nnx.Module):\n",
" def __init__(self, din, dout, *, num_models, rngs: nnx.Rngs):\n",
" # get raw rng seeds\n",
" rng_keys, rngs_def = rngs.split()\n",
" vmapped_keys = jax.tree_map(lambda k: random.split(k, nbatch), rng_keys)\n",
" vmapped_keys = jax.tree_map(lambda k: random.split(k, num_models), rng_keys)\n",
"\n",
" # define pure init fn and vmap\n",
" def vmap_init(key):\n",
" return CounterLinear(din, dout, rngs=rngs_def.merge(key)).split(\n",
" nnx.Param, Count\n",
" )\n",
"\n",
" params, counters, static = jax.vmap(\n",
" params, counts, static = jax.vmap(\n",
" vmap_init, in_axes=(0,), out_axes=(0, None, None)\n",
" )(vmapped_keys)\n",
" # update wrapped submodule reference\n",
" self.layers = static.merge(params, counters)\n",
" self.models = static.merge(params, counts)\n",
"\n",
" def __call__(self, x):\n",
" # get module values, define pure fn\n",
" params, counters, static = self.layers.split(nnx.Param, Count)\n",
" params, counts, static = self.models.split(nnx.Param, Count)\n",
"\n",
" def vmap_apply(x, params, counters, static):\n",
" layer = static.merge(params, counters)\n",
" y = layer(x)\n",
" params, counters, static = layer.split(nnx.Param, Count)\n",
" return y, params, counters, static\n",
" def vmap_apply(x, params, counts, static):\n",
" model = static.merge(params, counts)\n",
" y = model(x)\n",
" params, counts, static = model.split(nnx.Param, Count)\n",
" return y, params, counts, static\n",
"\n",
" # vmap and call\n",
" y, params, counters, static = jax.vmap(\n",
" vmap_apply, in_axes=(0, 0, None, None), out_axes=(0, 0, None, None)\n",
" )(x, params, counters, static)\n",
" y, params, counts, static = jax.vmap(\n",
" vmap_apply, in_axes=(None, 0, None, None), out_axes=(0, 0, None, None)\n",
" )(x, params, counts, static)\n",
" # update wrapped module\n",
" self.layers.update(params, counters, static)\n",
" self.models.update(params, counts, static) # use `update` to integrate the new state\n",
" return y\n",
"\n",
"x = jnp.ones((4,))\n",
"ensemble = LinearEnsemble(4, 4, num_models=8, rngs=nnx.Rngs(0))\n",
"\n",
"x = jnp.ones((8, 4))\n",
"vmodel = MLP(4, 4, nbatch=8, rngs=nnx.Rngs(0))\n",
"y = vmodel(x)\n",
"y = vmodel(x) # call twice to increment count\n",
"\n",
"# forward pass\n",
"y = ensemble(x)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{vmodel.layers.count = }')\n",
"print(f'state = {jax.tree_map(jnp.shape, vmodel.get_state())}')"
"print(f'{ensemble.models.count = }')\n",
"print(f'state = {jax.tree_map(jnp.shape, ensemble.get_state())}')"
]
},
{
Expand All @@ -350,7 +355,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit c535237

Please sign in to comment.