Explanation for weird pytree behaviour with grad? #1255
Answered
by
jheek
aniquetahir
asked this question in
Q&A
-
Consider the following code: from collections import namedtuple
import jax
import jax.numpy as np
Test = namedtuple('Test', ['a', 'b'])
test = Test(a=np.ones((10, 2)), b=np.ones((10,5)))
weird_function = jax.grad(lambda x: 0.)
other_function = jax.jit(lambda x: 0.)
weird_function(test)
other_function(test) Running What is happening here, exactly? How can I make |
Beta Was this translation helpful? Give feedback.
Answered by
jheek
Apr 19, 2021
Replies: 1 comment
-
jax.grad returns the gradient wrt to inputs while jit simply compiles a function and returns its output. Here the argument is a |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
jax.grad returns the gradient wrt to inputs while jit simply compiles a function and returns its output. Here the argument is a
Test
while the output is a scalar so that explains why the types are different.