|
198 | 198 | "def main(\n", |
199 | 199 | " dataset_size=256,\n", |
200 | 200 | " batch_size=32,\n", |
201 | | - " lr_strategy=(3e-3, 3e-3),\n", |
| 201 | + " lr=3e-3,\n", |
202 | 202 | " steps_strategy=(500, 500),\n", |
203 | 203 | " length_strategy=(0.1, 1),\n", |
204 | 204 | " width_size=64,\n", |
|
214 | 214 | " _, length_size, data_size = ys.shape\n", |
215 | 215 | "\n", |
216 | 216 | " model = NeuralODE(data_size, width_size, depth, key=model_key)\n", |
| 217 | + " optim = optax.adabelief(lr)\n", |
217 | 218 | "\n", |
218 | 219 | " # Training loop like normal.\n", |
219 | 220 | " #\n", |
|
233 | 234 | " model = eqx.apply_updates(model, updates)\n", |
234 | 235 | " return loss, model, opt_state\n", |
235 | 236 | "\n", |
236 | | - " for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):\n", |
237 | | - " optim = optax.adabelief(lr)\n", |
| 237 | + " for steps, length in zip(steps_strategy, length_strategy):\n", |
238 | 238 | " opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))\n", |
239 | 239 | " _ts = ts[: int(length_size * length)]\n", |
240 | 240 | " _ys = ys[:, : int(length_size * length)]\n", |
|
326 | 326 | ], |
327 | 327 | "metadata": { |
328 | 328 | "kernelspec": { |
329 | | - "display_name": "jax0227", |
| 329 | + "display_name": "Python 3 (ipykernel)", |
330 | 330 | "language": "python", |
331 | | - "name": "jax0227" |
| 331 | + "name": "python3" |
332 | 332 | }, |
333 | 333 | "language_info": { |
334 | 334 | "codemirror_mode": { |
|
340 | 340 | "name": "python", |
341 | 341 | "nbconvert_exporter": "python", |
342 | 342 | "pygments_lexer": "ipython3", |
343 | | - "version": "3.9.7" |
| 343 | + "version": "3.11.8" |
344 | 344 | } |
345 | 345 | }, |
346 | 346 | "nbformat": 4, |
|
0 commit comments