Skip to content

Behaviour in-place update of dataclasses #1758

Answered by jheek
GJBoth asked this question in Q&A
Discussion options

You must be logged in to vote

The dataclass itself is only something that exists at the Python level. JAX only has immutable arrays but XLA optimizes memory allocation and can do in-place updates.

There as some things you should keep in mind though:

  • XLA by default cannot reduce memory of input arguments of jit/pmap unless you "donate them" using the donate_argnums kwarg. So for example even x.at[idx].add(1) will not work if x is an argument to jit(f) unless you use donate_argnums
  • The dataclasses themselves don't have significant memory overhead but holding on to arrays unnecessarily (only outside of jit) can of course cause extra memory consumption. XLA doesn't care about when the Python code drops the reference for …

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
1 reply
@GJBoth
Comment options

Comment options

You must be logged in to vote
3 replies
@GJBoth
Comment options

@jheek
Comment options

jheek Jan 6, 2022
Maintainer

@GJBoth
Comment options

Answer selected by GJBoth
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants