Skip to content

Commit

Permalink
Improve error message for when nnx.Modules use jax or numpy arrays as…
Browse files Browse the repository at this point in the history
… leaf values
  • Loading branch information
RaghuSpaceRajan authored Jan 20, 2025
1 parent e4418e2 commit 7bb4dc4
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ def _graph_flatten(
if isinstance(value, (jax.Array, np.ndarray)):
path_str = '/'.join(map(str, (*path, key)))
raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
f'The variable at the path {path_str!r} with value = {value} is of type:' \
f'{type(value)}. Leaf values of this type are not supported for nnx.Modules.'
)
# static_fields.append((key, value))
attributes.append(StaticAttribute(key, value))
Expand Down

0 comments on commit 7bb4dc4

Please sign in to comment.