-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add batch_norm #362
base: master
Are you sure you want to change the base?
Conversation
@triton.heuristics( | ||
{ | ||
"BLOCK_SIZE_BATCH": lambda args: next_power_of_2(args["batch_dim"]), | ||
"BLOCK_SIZE_SPATIAL": BLOCK_SIZE_SPATIAL_heuristic, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are working on changing one-tile algorithm into loop tiling. please refer to max/min/lof_softmax and update the strategy. ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming the input to batch_norm has the shape (batch, channel, spatial).
The implementation of batch_norm is already based on loop tiling, where tiling occurs along the spatial dimension while fully loading the batch dimension. This approach differs from the operators mentioned, such as max, min, which only require computations along a specified dimension without involving others. In the case of batch_norm, both the batch and spatial dimensions need to be fully loaded.
My previous consideration was that introducing tiling along the batch dimension would result in a nested loop structure, which may not be meaningful when batch is not large.
You might be suggesting that applying loop tiling on the batch dimension could indeed improve performance when the batch size is large, as it would allow for more contiguous memory access in the spatial dimension. For instance, if the batch size is 16,384 and the spatial size loaded per loop iteration is only 1, this could lead to inefficient memory access patterns. I will try implementing loop tiling on both the batch and spatial dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it
(curr_input - mean) * (curr_input - prev_mean), | ||
0.0, | ||
) | ||
var += tl.sum(deltas) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this method cannot fully utilize vectorization/tensorization and leads to sequential computation.
BLOCK_SIZE_SPATIAL, spatial_dim - block_ind * BLOCK_SIZE_SPATIAL | ||
) | ||
curr_count = spatial_count * batch_dim | ||
count += curr_count |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can set reasonable default value of tl.load to avoid keeping counter.
src/flag_gems/ops/batch_norm.py
Outdated
) | ||
|
||
if affine: | ||
weight_grad += tl.sum(curr_pre_lin * curr_output_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
src/flag_gems/ops/batch_norm.py
Outdated
weight_grad = bias_grad = None | ||
|
||
# Launches 1D grid where each program operates over one feature. | ||
grid = lambda _: (feat_dim,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grid hasn't to be a lambda expression. initialize it as a tuple.
please fix bug in accuracy test:) |
PR Category
Operator
Type of Change
New Feature
Description
Implement batch_norm operator
Issue
Progress