How to count the parameters in a Flax model? #1224
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
Original question by @ameya98. |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Apr 8, 2021
Replies: 1 comment
-
Answer by @jheek: sum(p.size for p in jax.tree_leaves(params)) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Answer by @jheek: