mnist example flatten #2533
Answered
by
cgarciae
deadsoul44
asked this question in
Q&A
-
Line 47 in 13fbace The line above claims to flatten the array but it does nothing.
Forgive me if I am missing something very obvious. |
Beta Was this translation helpful? Give feedback.
Answered by
cgarciae
Oct 14, 2022
Replies: 1 comment 2 replies
-
Hey @deadsoul44, it works if you have more than 2 dimension e.g: import jax.numpy as jnp
x = jnp.ones((8, 28, 28, 1))
y = x.reshape((x.shape[0], -1))
print(y.shape) # (8, 784) else you the shape won't change as you note but all this is correct behaviour. |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
deadsoul44
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey @deadsoul44, it works if you have more than 2 dimension e.g:
else you the shape won't change as you note but all this is correct behaviour.