Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use (jax) pytrees inside of nnx modules? #4497

Open
PhilipVinc opened this issue Jan 22, 2025 · 4 comments
Open

How to use (jax) pytrees inside of nnx modules? #4497

PhilipVinc opened this issue Jan 22, 2025 · 4 comments

Comments

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Jan 22, 2025

I have some objects which are jax pytrees, and would like to store them inside of an nnx module. In general, I would like to have a way to easily tag them (or better, the arrays they have inside) as Params or non trainable Variables.

However, this does not seem to work out of the box as I get the error that ValueError: Arrays leaves are not supported, at 'pytree/0': 2.0 (see MWE below).

Is there a way to support this? Is there an easy way to wrap/unwrap all fields of field into Params or variables?

import jax
import jax.numpy as jnp
import flax.linen as nn
import flax.nnx as nnx
from typing import Any, Callable

# Define a JAX Pytree class
@jax.tree_util.register_pytree_node_class
class SimplePytree:
    def __init__(self, value: float):
        self.value = value

    def __mul__(self, other):
        return other * self.value

    # Register this class as a pytree
    def tree_flatten(self):
        return ([self.value], None)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

# Define the Flax nnx module
class SimpleModule(nnx.Module):
    pytree: SimplePytree

    def __init__(
        self, N, pt, rngs: nnx.Rngs, visible_bias: bool = True, param_dtype=complex
    ):
        self.linear = nnx.Linear(N, 1, param_dtype=param_dtype, rngs=rngs)
        self.pytree = pt

    def __call__(self, x):
        return self.linear(self.pytree * x)

# Instantiate the pytree and module
pytree = SimplePytree(jax.numpy.array(2.0))

net = SimpleModule(2, pytree, nnx.Rngs(0), visible_bias=True, param_dtype=complex)

x = jnp.ones((10, 2))

nnx.split(net)

raises

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/flax/nnx/graph.py:1290, in split(node, *filters)
   1219 def split(
   1220   node: A, *filters: filterlib.Filter
   1221 ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
   1222   """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
   1223   a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
   1224   contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
   (...)
   1288     filters are passed, a single ``State`` is returned.
   1289   """
-> 1290   graphdef, state = flatten(node)
   1291   states = _split_state(state, filters)
   1292   return graphdef, *states

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/flax/nnx/graph.py:404, in flatten(node, ref_index)
    402   ref_index = RefMap()
    403 flat_state: dict[PathParts, StateLeaf] = {}
--> 404 graphdef = _graph_flatten((), ref_index, flat_state, node)
    405 return graphdef, GraphState.from_flat_path(flat_state)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/flax/nnx/graph.py:436, in _graph_flatten(path, ref_index, flat_state, node)
    434 for key, value in values:
    435   if is_node(value):
--> 436     nodedef = _graph_flatten((*path, key), ref_index, flat_state, value)
    437     subgraphs.append((key, nodedef))
    438   elif isinstance(value, Variable):

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/flax/nnx/graph.py:451, in _graph_flatten(path, ref_index, flat_state, node)
    449     if isinstance(value, (jax.Array, np.ndarray)):
    450       path_str = '/'.join(map(str, (*path, key)))
--> 451       raise ValueError(
    452           f'Arrays leaves are not supported, at {path_str!r}: {value}'
    453       )
    454     static_fields.append((key, value))
    456 nodedef = NodeDef.create(
    457   type=node_impl.type,
    458   index=index,
   (...)
    464   index_mapping=None,
    465 )

ValueError: Arrays leaves are not supported, at 'pytree/0': 2.0
@PhilipVinc
Copy link
Contributor Author

PhilipVinc commented Jan 22, 2025

So to say, is this the recommended way?

class SimpleModule(nnx.Module):
    pytree: SimplePytree

    def __init__(
        self, N, pt, rngs: nnx.Rngs, visible_bias: bool = True, param_dtype=complex
    ):
        self.linear = nnx.Linear(N, 1, param_dtype=param_dtype, rngs=rngs)
        self.pytree = jax.tree.map(nnx.Variable, pt)

    def __call__(self, x):
        pt = jax.tree.map(lambda x:x.value, self.pytree)
        return self.linear(pt * x)

or is there some better approach?

I'm not sure I love this because it breaks SimpleModule.pytree which won't work anymore by default now.

@cgarciae
Copy link
Collaborator

Hi @PhilipVinc.

I'm not sure I love this because it breaks SimpleModule.pytree which won't work anymore by default now.

Can you clarify what you mean by this? Variable overloads all operators and implements __jax_array__ so wrapping Arrays tends to work. Can we do something to make your use case work?

@PhilipVinc
Copy link
Contributor Author

Well, it breaks any isinstance check for example?

@cgarciae
Copy link
Collaborator

I see. That is probably a case we don't want to support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants