Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message for when nnx.Modules use jax or numpy arrays as leaf values #4492

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

RaghuSpaceRajan
Copy link

What does this PR do?

Fixes #4480

Checklist

  • [✔] This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).

@@ -446,7 +446,8 @@ def _graph_flatten(
if isinstance(value, (jax.Array, np.ndarray)):
path_str = '/'.join(map(str, (*path, key)))
raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
f'The variable at the path {path_str!r} with value = {value} is of type:' \
f'{type(value)}. Leaf values of this type are not supported for nnx.Modules.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f'{type(value)}. Leaf values of this type are not supported for nnx.Modules.'
f'{type(value)}. Leaf values of this type are not supported for nnx.Modules or nnx.Object in general.'

@cgarciae
Copy link
Collaborator

@RaghuSpaceRajan updated main today to fix the broken tests due to JAX's recent update. Can you please rebase?

cpgaffney1 and others added 4 commits January 24, 2025 12:25
…does nothing (it used to control a checkpointing behavior that has since been optimized away).

PiperOrigin-RevId: 718571228
PiperOrigin-RevId: 718899289
@RaghuSpaceRajan
Copy link
Author

@cgarciae Done.

@RaghuSpaceRajan
Copy link
Author

Actually, for some reason, it shows me as author for some of the rebased commits. Not sure if that's normal, it's the first time I have rebased across forks. Let me know if should just redo the thing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve error message when user mistakenly holds a jax Array in an nnx.Module
3 participants