-
Hi, I'm working on a model which requires a lot of memory, so I'm trying to make a few memory optimizations. The code is structured as functions updating a field, which I've implemented as a simple flax dataclass consisting of the field and some other stuff, i.e. like this:
When I modify the field I use the replace method to update the field, as its immutable, i.e.
If I understood correctly, this code makes a copy of
This is guaranteed to by in-place after jitting, and actually works with flax dataclasses too:
but I suspect this isn't in-place, since when I run something like this
What's the recommended way of dealing with in-place updates and dataclasses? Make the dataclass unfrozen? Or is there something possible with updating the replace method with the new Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
When you In your example (where it seems you aren't jitting at all), in the |
Beta Was this translation helpful? Give feedback.
-
The dataclass itself is only something that exists at the Python level. JAX only has immutable arrays but XLA optimizes memory allocation and can do in-place updates. There as some things you should keep in mind though:
|
Beta Was this translation helpful? Give feedback.
The dataclass itself is only something that exists at the Python level. JAX only has immutable arrays but XLA optimizes memory allocation and can do in-place updates.
There as some things you should keep in mind though: