Skip to content

Commit d3c1430

Browse files
tidied up neural ODE example
1 parent 482de90 commit d3c1430

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

docs/examples/neural_ode.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@
198198
"def main(\n",
199199
" dataset_size=256,\n",
200200
" batch_size=32,\n",
201-
" lr_strategy=(3e-3, 3e-3),\n",
201+
" lr=3e-3,\n",
202202
" steps_strategy=(500, 500),\n",
203203
" length_strategy=(0.1, 1),\n",
204204
" width_size=64,\n",
@@ -214,6 +214,7 @@
214214
" _, length_size, data_size = ys.shape\n",
215215
"\n",
216216
" model = NeuralODE(data_size, width_size, depth, key=model_key)\n",
217+
" optim = optax.adabelief(lr)\n",
217218
"\n",
218219
" # Training loop like normal.\n",
219220
" #\n",
@@ -233,8 +234,7 @@
233234
" model = eqx.apply_updates(model, updates)\n",
234235
" return loss, model, opt_state\n",
235236
"\n",
236-
" for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):\n",
237-
" optim = optax.adabelief(lr)\n",
237+
" for steps, length in zip(steps_strategy, length_strategy):\n",
238238
" opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))\n",
239239
" _ts = ts[: int(length_size * length)]\n",
240240
" _ys = ys[:, : int(length_size * length)]\n",
@@ -326,9 +326,9 @@
326326
],
327327
"metadata": {
328328
"kernelspec": {
329-
"display_name": "jax0227",
329+
"display_name": "Python 3 (ipykernel)",
330330
"language": "python",
331-
"name": "jax0227"
331+
"name": "python3"
332332
},
333333
"language_info": {
334334
"codemirror_mode": {
@@ -340,7 +340,7 @@
340340
"name": "python",
341341
"nbconvert_exporter": "python",
342342
"pygments_lexer": "ipython3",
343-
"version": "3.9.7"
343+
"version": "3.11.8"
344344
}
345345
},
346346
"nbformat": 4,

0 commit comments

Comments
 (0)