Skip to content

NNX + mutable state + JIT = "Cannot mutate <module name> from a different trace level" #3998

Closed Answered by cgarciae
dfdx asked this question in Q&A
Discussion options

You must be logged in to vote

Hey! The issue is that because you are jitting counter.__call__, self is being passed as a capture and we don't allow mutating objects outside the current transform's context. To fix this simply create a aux function that accepts counter as an input and jit that:

def increment(counter: Counter):
  counter()
  
increment(counter)                   # works fine
nnx.jit(increment)(counter)      # should work

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by dfdx
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #3997 on June 15, 2024 09:08.