Skip to content

Commit

Permalink
Merge pull request #4249 from IvyZX:ckpt-guide
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683162125
  • Loading branch information
Flax Authors committed Oct 7, 2024
2 parents 551942e + 742c926 commit 9a3a1fc
Show file tree
Hide file tree
Showing 6 changed files with 664 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
hooks:
- id: check-toml
- id: trailing-whitespace
exclude: ^docs*/.*\.md$
exclude: ^docs.*\.md$
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
Expand Down
382 changes: 382 additions & 0 deletions docs_nnx/guides/checkpointing.ipynb

Large diffs are not rendered by default.

201 changes: 201 additions & 0 deletions docs_nnx/guides/checkpointing.md

Large diffs are not rendered by default.

97 changes: 52 additions & 45 deletions docs_nnx/guides/surgery.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"This will throw error: <class 'KeyError'>: 'layer1'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ivyzheng/envs/py310/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1401: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
" warnings.warn(\n"
"This will throw error: <class 'ValueError'>: Dict key mismatch; expected keys: ['linear1', 'linear2']; dict: {'layer1': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}, 'layer2': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}}.\n"
]
}
],
Expand Down Expand Up @@ -267,45 +259,46 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'linear1': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},\n",
" 'kernel': {'raw_value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n",
"{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n",
" 'kernel': {'value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n",
" [ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],\n",
" [ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],\n",
" [-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)}},\n",
" 'linear2': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},\n",
" 'kernel': {'raw_value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],\n",
" 'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n",
" 'kernel': {'value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],\n",
" [ 0.41914317, 0.84359694, -0.47937787, -0.49135214],\n",
" [-0.46072108, 0.4630125 , 0.39276958, -0.9441406 ],\n",
" [-0.6690758 , -0.18474789, -0.57622856, 0.4821079 ]], dtype=float32)}}}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
" warnings.warn(\n"
]
}
],
"source": [
"def module_from_variables_dict(module_factory, variables, map_key_fn):\n",
" if map_key_fn is None:\n",
" map_key_fn = lambda path: path\n",
" mdl = nnx.eval_shape(module_factory)\n",
" graph_def, state = nnx.split(mdl)\n",
" state = state.flat_state()\n",
" for path, val in flax.traverse_util.flatten_dict(variables).items():\n",
" mapped_path = map_key_fn(path)\n",
" if mapped_path not in state:\n",
" raise ValueError(f\"{mapped_path} doesn't exist in {state.keys()}\")\n",
" state[mapped_path].value = val\n",
" state = nnx.State.from_flat_path(state)\n",
" return nnx.merge(graph_def, state)\n",
"\n",
"# Make your local change on the checkpoint.\n",
"raw = checkpointer.restore('/tmp/nnx-surgery-state')\n",
"pprint(raw)\n",
"raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']\n",
"del raw['linear1'], raw['linear2']\n",
"\n",
"restored_model = module_from_variables_dict(\n",
" lambda: nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))),\n",
" raw,\n",
" lambda path: path[:-1] if path[-1] == 'raw_value' else path\n",
")\n",
"def process_raw_dict(raw_state_dict):\n",
" flattened = nnx.traversals.flatten_mapping(raw_state_dict)\n",
" # Cut off the '.value' postfix on every leaf path.\n",
" flattened = {(path[:-1] if path[-1] == 'value' else path): value\n",
" for path, value in flattened.items()}\n",
" return nnx.traversals.unflatten_mapping(flattened)\n",
"\n",
"# Make your local change on the checkpoint dictionary.\n",
"raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')\n",
"pprint(raw_dict)\n",
"raw_dict['layer1'] = raw_dict.pop('linear1')\n",
"raw_dict['layer2'] = raw_dict.pop('linear2')\n",
"\n",
"# Fit it into the model state.\n",
"abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"graph_def, state = nnx.split(abs_model)\n",
"state.replace_by_pure_dict(process_raw_dict(raw_dict))\n",
"restored_model = nnx.merge(graph_def, state)\n",
"\n",
"np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))"
]
Expand Down Expand Up @@ -339,9 +332,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Number of jax arrays in memory at start: 34\n",
"Number of jax arrays in memory midway: 38 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)\n",
"Number of jax arrays in memory at end: 36 (2 discarded - only lora_a & lora_b are used in model)\n"
"Number of jax arrays in memory at start: 38\n",
"Number of jax arrays in memory midway: 42 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)\n",
"Number of jax arrays in memory at end: 40 (2 discarded - only lora_a & lora_b are used in model)\n"
]
}
],
Expand Down Expand Up @@ -379,8 +372,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Number of jax arrays in memory at start: 40\n",
"Number of jax arrays in memory at end: 42 (2 new created - lora_a and lora_b)\n"
"Number of jax arrays in memory at start: 44\n",
"Number of jax arrays in memory at end: 46 (2 new created - lora_a and lora_b)\n"
]
}
],
Expand All @@ -389,7 +382,7 @@
"old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"\n",
"# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!\n",
"@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)\n",
"@nnx.jit(donate_argnums=0)\n",
"def partial_init(old_state, rngs):\n",
" model = TwoLayerMLP(4, rngs=rngs)\n",
" # Create a new state.\n",
Expand All @@ -404,6 +397,20 @@
"print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'\n",
" ' (2 new created - lora_a and lora_b)')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -420,7 +427,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
53 changes: 27 additions & 26 deletions docs_nnx/guides/surgery.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,31 +152,24 @@ except Exception as e:
But you can load the parameter tree as a raw dictionary, make the renames, and generate a new state that is guaranteed to be compatible with your new model definition.

```{code-cell} ipython3
def module_from_variables_dict(module_factory, variables, map_key_fn):
if map_key_fn is None:
map_key_fn = lambda path: path
mdl = nnx.eval_shape(module_factory)
graph_def, state = nnx.split(mdl)
state = state.flat_state()
for path, val in flax.traverse_util.flatten_dict(variables).items():
mapped_path = map_key_fn(path)
if mapped_path not in state:
raise ValueError(f"{mapped_path} doesn't exist in {state.keys()}")
state[mapped_path].value = val
state = nnx.State.from_flat_path(state)
return nnx.merge(graph_def, state)
# Make your local change on the checkpoint.
raw = checkpointer.restore('/tmp/nnx-surgery-state')
pprint(raw)
raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']
del raw['linear1'], raw['linear2']
restored_model = module_from_variables_dict(
lambda: nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))),
raw,
lambda path: path[:-1] if path[-1] == 'raw_value' else path
)
def process_raw_dict(raw_state_dict):
flattened = nnx.traversals.flatten_mapping(raw_state_dict)
# Cut off the '.value' postfix on every leaf path.
flattened = {(path[:-1] if path[-1] == 'value' else path): value
for path, value in flattened.items()}
return nnx.traversals.unflatten_mapping(flattened)
# Make your local change on the checkpoint dictionary.
raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')
pprint(raw_dict)
raw_dict['layer1'] = raw_dict.pop('linear1')
raw_dict['layer2'] = raw_dict.pop('linear2')
# Fit it into the model state.
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graph_def, state = nnx.split(abs_model)
state.replace_by_pure_dict(process_raw_dict(raw_dict))
restored_model = nnx.merge(graph_def, state)
np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
```
Expand Down Expand Up @@ -218,7 +211,7 @@ Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))
# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!
@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)
@nnx.jit(donate_argnums=0)
def partial_init(old_state, rngs):
model = TwoLayerMLP(4, rngs=rngs)
# Create a new state.
Expand All @@ -233,3 +226,11 @@ good_model = partial_init(old_state, nnx.Rngs(42))
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
' (2 new created - lora_a and lora_b)')
```

```{code-cell} ipython3
```

```{code-cell} ipython3
```
1 change: 1 addition & 0 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def __call__(self, inputs: Array) -> Array:
(((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)
assert self.use_bias == (bias is not None)
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
return y
Expand Down

0 comments on commit 9a3a1fc

Please sign in to comment.