Replies: 1 comment 2 replies
-
I guess newer examples aren't really necessary, since all that really changes is where we import And it does seem like as long as the params are sharded in the variable dictionary, we can still re-use flax modules? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I've seen the siren model parallelism example, but it uses an older version of jax.
I was wondering if there's any more recent examples that use the newer APIs?
And also, I was wondering whether we still need to re-define Flax modules with named axes like the siren colab does?
It kind of seems like we can use existing flax modules with parameter sharding in the initializers.
Beta Was this translation helpful? Give feedback.
All reactions