Skip to content

mnist example flatten #2533

Answered by cgarciae
deadsoul44 asked this question in Q&A
Oct 14, 2022 · 1 comments · 2 replies
Discussion options

You must be logged in to vote

Hey @deadsoul44, it works if you have more than 2 dimension e.g:

import jax.numpy as jnp

x = jnp.ones((8, 28, 28, 1))
y = x.reshape((x.shape[0], -1))
print(y.shape) # (8, 784)

else you the shape won't change as you note but all this is correct behaviour.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@cgarciae
Comment options

@deadsoul44
Comment options

Answer selected by deadsoul44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #2532 on October 14, 2022 16:47.