Replies: 6 comments 10 replies
-
Thanks for posting this update, I remembered you commented on this performance consideration somewhere months back and couldn't find it. Been doing this split/merge w/ standard JAX transforms in my code just in case (and to stay as pure JAX as possible). If a PR goes through to address this, I'll try switching to the NNX transforms! 👍 |
Beta Was this translation helpful? Give feedback.
-
Thanks for posting this! Is there a way to do update with metrics? Or only graphdef, state = nnx.split((model, optimizer, metrics))
...
nnx.update((model, optimizer, metrics), state) Cause now it raises an error:
And also it will be great to create page with speed up tips for NNX API! |
Beta Was this translation helpful? Give feedback.
-
@cgarciae As I mention above, I've been sticking to the split/merge + JAX transforms to future proof against any performance hits. However, I would consider switching to NNX transforms for my current dev if the expectation is that the Rust extension would definitively close the performance gap. Can you comment on the expected gains with flaxlib? |
Beta Was this translation helpful? Give feedback.
-
@cgarciae in your example, at the end, |
Beta Was this translation helpful? Give feedback.
-
Big fan of NNX! I personally think there are reasons other than performance to use split/merge and standard JAX transforms. It's "closer to the metal," if you will -- once you understand the split/merge API and JAX's core APIs, you're empowered to do pretty much anything, with a little more boilerplate (holding on to the graphdef) which is not too bad in my opinion (especially since y'all have done such a great job with the static typing!). You can mix NNX's mutable reference semantics with JAX's pure functional semantics to write both convenient and bug-free code. I worry that encouraging NNX transforms only, while sweeping split/merge under the rug, would be especially bad for newer JAX users. NNX transforms add a layer of abstraction that completely hides the underlying JAX abstractions, which may make it harder to pick up important concepts like tracing/staging out, PyTrees, sharding, etc. As a more experienced JAX user, I've definitely been finding split/merge with explicit state management more comfortable and legible. Another argument for encouraging this pattern is that, at least right now, you must understand split/merge and explicit state management to save and load checkpoints. I realize not everyone will agree with me! My vote would be to document both split/merge and NNX transforms side-by-side as equivalent ways of doing things, even after flaxlib is complete. That way, even if people do want to use NNX transforms to save on boilerplate, they can still acquire a mental model of what is happening under the hood. |
Beta Was this translation helpful? Give feedback.
-
fyi Pinning this discussion to google/flax/discussions/ @cgarciae #nnx |
Beta Was this translation helpful? Give feedback.
-
Currently
nnx.jit
traverses the object graph in Python. This is slow and primarily affects the small model regime, as the Python overhead starts to disappear as the model's width grows. To solve this in general, we will be developing a Rust extension calledflaxlib
(see first steps in #4196) to speedup some of the traversal logic ingraph.py
, similar to how JAX solved the same issue withjaxlib
for standard pytrees.Meanwhile, there is a pattern you can use to remove the python overhead using regular
jax.jit
+nnx.split
/nnx.merge
to stage out the traversal logic. Take this code that usesnnx.jit
as an example:To speed it up you can use
nnx.split
before starting the training loop to create agraphdef
andstate
for the NNX objects which are fast to traverse, and then callmerge
+split
inside thejax.jit
-decorated function so they only run once during tracing:After the training loop is done (or whenever need)
nnx.update
can be used to updatemodel
,optimizer
, andmetrics
to a newstate
.Beta Was this translation helpful? Give feedback.
All reactions