-
Problem you have encountered:
Steps to reproduce:class Count(nnx.Variable): pass
class Counter(nnx.Module):
def __init__(self):
self.count = Count(jnp.array(0))
def __call__(self):
self.count.value += 1
counter = Counter()
counter() # works fine
nnx.jit(counter)() # fails Logs, error messages, etc:
Any pointers to why it may be happening or what is the trace level in this case? System information
Disclaimer: I get the following (seemingly unrelated) warning at the start:
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hey! The issue is that because you are jitting def increment(counter: Counter):
counter()
increment(counter) # works fine
nnx.jit(increment)(counter) # should work |
Beta Was this translation helpful? Give feedback.
-
Huh! I didn't think about capturing the object. Thank you! As a mnemonic for myself: nnx.jit(counter)() # fails: JIT is applied to the instance method and thus the instance itself is captured
nnx.jit(Counter.__call__)(counter) # OK: JIT is applied to the class method, the instance is passed _after_ JIT and is NOT captured
nnx.jit(increment)(counter) # OK: same as above, but with a helper method |
Beta Was this translation helpful? Give feedback.
Hey! The issue is that because you are jitting
counter.__call__
,self
is being passed as a capture and we don't allow mutating objects outside the current transform's context. To fix this simply create a aux function that acceptscounter
as an input and jit that: