From 7bb4dc40f80d031ac5e16f79f3f1ea2c3dda6541 Mon Sep 17 00:00:00 2001 From: Raghu Rajan <15613406+RaghuSpaceRajan@users.noreply.github.com> Date: Mon, 20 Jan 2025 15:58:58 +0100 Subject: [PATCH] Improve error message for when nnx.Modules use jax or numpy arrays as leaf values --- flax/nnx/graph.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 8cc272f8eb..c0100517d6 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -446,7 +446,8 @@ def _graph_flatten( if isinstance(value, (jax.Array, np.ndarray)): path_str = '/'.join(map(str, (*path, key))) raise ValueError( - f'Arrays leaves are not supported, at {path_str!r}: {value}' + f'The variable at the path {path_str!r} with value = {value} is of type:' \ + f'{type(value)}. Leaf values of this type are not supported for nnx.Modules.' ) # static_fields.append((key, value)) attributes.append(StaticAttribute(key, value))