Proper use of flax.scan #2059
Unanswered
ozencgungor
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi there,
I am trying to implement a MLP-Mixer type model on nearest neighbor graphs (HEALPix to be specific, my "images" are on a sphere) using chebyshev convolutions. For clarity of discussion, suppose my input is of shape
(N, M, F)
whereN
is the batch dimension,M
is the number of pixels andF
is the channel dimension. After performing a chebyshev transformation of orderK
, I now have an array of shape(N, K, M, F)
. In code this looks like:After this I would simply act on the output of the vectorized chebyshev transform with a usual
nn.Conv(kernel_size=(K, 1))
and squeeze the leftover dimension out to get back an output of shape(N, M, F)
My images are masked, so to not act on the masked pixels, I carry around an array of pixel indices of shape
(M,)
telling me which pixels are valid so I mask before I act with thenn.Conv
.Now suppose I break up my images into N_s patches to get a shape
(N, N_s, K, M/N_s, F)
and I do the same on the array of indices to get an array of indices of shape(N_s, M/N_s)
. What I would like to do is that I would have aflax.scan
edMaskedConv
module that would effectively do a for i loop over theN_s
dimension.But I cannot for the love of me understand how
nn.scan
works. How would one go about implementing thefor i loop
described above usingnn.scan
so that one would have the option of not sharing parameters between different patches? If there was no masking, I could imagine reshaping to(N*N_s, K, M/N_s, F)
and acting with annn.Conv
but the masking forces me to implement some kind of looping over the patch dimension.I'm relatively new to
jax/flax
so any help would be greatly appreciated and thanks in advance.Beta Was this translation helpful? Give feedback.
All reactions