how to assert two frozendicts' equality? #1766
Unanswered
franchesoni
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Assuming all your leaves are jax arrays, you could do this: import jax.numpy as jnp
import numpy as np
assert np.all(jax.tree_leaves(jax.tree_multimap(lambda x, y: jnp.all(jnp.equal(x, y)), paramsp, params))) Hope that helps! |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
does not work, how should I do this?
Beta Was this translation helpful? Give feedback.
All reactions