From cf0783240c7df15af7ced7f1a18cf8c1e0d47d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20R=C3=B8d?= Date: Wed, 1 Jan 2025 18:54:45 +0100 Subject: [PATCH] Add specific model typing for nnx.Optimizer --- flax/nnx/training/optimizer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 4b85d5a3d4..6339f0392c 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -13,6 +13,8 @@ # limitations under the License. from __future__ import annotations +import typing as tp + import jax import jax.numpy as jnp import optax @@ -23,6 +25,8 @@ from flax.nnx.object import Object from flax.nnx.variablelib import Variable, VariableState +M = tp.TypeVar('M', bound=nnx.Module) + # TODO: add tests and docstrings @@ -101,7 +105,7 @@ def optimizer_update_variables(x, update): return jax.tree.map(optimizer_update_variables, opt_state, updates) -class Optimizer(Object): +class Optimizer(Object, tp.Generic[M]): """Simple train state for the common case with a single Optax optimizer. Example usage:: @@ -168,7 +172,7 @@ class Optimizer(Object): def __init__( self, - model: nnx.Module, + model: M, tx: optax.GradientTransformation, wrt: filterlib.Filter = nnx.Param, ):