diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 8cc272f8eb..c0100517d6 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -446,7 +446,8 @@ def _graph_flatten( if isinstance(value, (jax.Array, np.ndarray)): path_str = '/'.join(map(str, (*path, key))) raise ValueError( - f'Arrays leaves are not supported, at {path_str!r}: {value}' + f'The variable at the path {path_str!r} with value = {value} is of type:' \ + f'{type(value)}. Leaf values of this type are not supported for nnx.Modules.' ) # static_fields.append((key, value)) attributes.append(StaticAttribute(key, value))