Lowering of normalization calls #3706
Unanswered
pratnali-aws
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello All,
This is a newbie question:
I have the following simple model that uses BatchNormalization:
and it is lowered to
I see that
nn.Conv
gets lowered toconvolution.23 = f32[1,64,64,32]{3,2,1,0} convolution(Arg_6.20, Arg_5.19), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
butnn.BatchNorm
is lowered into a sequence of ops.I get a sense that it is a translation of implementation here-https://flax.readthedocs.io/en/latest/_modules/flax/linen/normalization.html#BatchNorm.
Is there some way to avoid this? Especially, since hlo has a native
batch_norm_training
op?Beta Was this translation helpful? Give feedback.
All reactions