-
Hello @cgarciae , I am closely watching the progress of AFAIK you can do something like this with the from jax.experimental.attrs import jax_getattr, jax_setattr, register
import jax.numpy as jnp
class Foo:
def __init__(self):
self.a = jnp.array([1,2])
self.b = jnp.array([3,4])
register(Foo)
foo = Foo()
@jax.jit
def swap_ref(foo: Foo):
a, b = jax_getattr(foo, "a"), jax_getattr(foo, "b")
jax_setattr(foo, "a", b), jax_setattr(foo, "b", a)
print(foo.a, foo.b) # [1 2] [3 4]
swap_ref(foo)
print(foo.a, foo.b) # [3 4] [1 2] I know you can achieve the same functionality under |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hey @ASEM000! I had a local branch testing |
Beta Was this translation helpful? Give feedback.
Hey @ASEM000!
I had a local branch testing
attrs
support, its very interesting but its not clear that its ultimately the best approach as it has certain limitations: you cannot create new attrs (which we kinda need forsow
). I'm very excited for NNX's new graph update mechanism, it gives me way more confidence on recommending the use ofnnx.jit
as it now properly tracks all state updates, and its now a superset ofjax.jit
.