-
Notifications
You must be signed in to change notification settings - Fork 660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use (jax) pytrees inside of nnx modules? #4497
Comments
So to say, is this the recommended way? class SimpleModule(nnx.Module):
pytree: SimplePytree
def __init__(
self, N, pt, rngs: nnx.Rngs, visible_bias: bool = True, param_dtype=complex
):
self.linear = nnx.Linear(N, 1, param_dtype=param_dtype, rngs=rngs)
self.pytree = jax.tree.map(nnx.Variable, pt)
def __call__(self, x):
pt = jax.tree.map(lambda x:x.value, self.pytree)
return self.linear(pt * x) or is there some better approach? I'm not sure I love this because it breaks |
Hi @PhilipVinc.
Can you clarify what you mean by this? |
Well, it breaks any |
I see. That is probably a case we don't want to support. |
I have some objects which are jax pytrees, and would like to store them inside of an nnx module. In general, I would like to have a way to easily tag them (or better, the arrays they have inside) as
Params
or non trainableVariables
.However, this does not seem to work out of the box as I get the error that
ValueError: Arrays leaves are not supported, at 'pytree/0': 2.0
(see MWE below).Is there a way to support this? Is there an easy way to wrap/unwrap all fields of field into Params or variables?
raises
The text was updated successfully, but these errors were encountered: