Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add batch_norm #362

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open

[Operator] Add batch_norm #362

wants to merge 18 commits into from

Conversation

2niuhe
Copy link
Contributor

@2niuhe 2niuhe commented Dec 13, 2024

PR Category

Operator

Type of Change

New Feature

Description

Implement batch_norm operator

  • forward
  • backward

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

@StrongSpoon StrongSpoon self-assigned this Dec 16, 2024
@StrongSpoon StrongSpoon self-requested a review December 16, 2024 05:36
@triton.heuristics(
{
"BLOCK_SIZE_BATCH": lambda args: next_power_of_2(args["batch_dim"]),
"BLOCK_SIZE_SPATIAL": BLOCK_SIZE_SPATIAL_heuristic,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are working on changing one-tile algorithm into loop tiling. please refer to max/min/lof_softmax and update the strategy. ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the input to batch_norm has the shape (batch, channel, spatial).

The implementation of batch_norm is already based on loop tiling, where tiling occurs along the spatial dimension while fully loading the batch dimension. This approach differs from the operators mentioned, such as max, min, which only require computations along a specified dimension without involving others. In the case of batch_norm, both the batch and spatial dimensions need to be fully loaded.

My previous consideration was that introducing tiling along the batch dimension would result in a nested loop structure, which may not be meaningful when batch is not large.

You might be suggesting that applying loop tiling on the batch dimension could indeed improve performance when the batch size is large, as it would allow for more contiguous memory access in the spatial dimension. For instance, if the batch size is 16,384 and the spatial size loaded per loop iteration is only 1, this could lead to inefficient memory access patterns. I will try implementing loop tiling on both the batch and spatial dimensions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

benchmark/test_norm_perf.py Show resolved Hide resolved
@2niuhe
Copy link
Contributor Author

2niuhe commented Dec 16, 2024

co-author: @zhangboyue https://github.com/2niuhe/FlagGems/tree/dev_batch_norm

(curr_input - mean) * (curr_input - prev_mean),
0.0,
)
var += tl.sum(deltas)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method cannot fully utilize vectorization/tensorization and leads to sequential computation.

BLOCK_SIZE_SPATIAL, spatial_dim - block_ind * BLOCK_SIZE_SPATIAL
)
curr_count = spatial_count * batch_dim
count += curr_count
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can set reasonable default value of tl.load to avoid keeping counter.

)

if affine:
weight_grad += tl.sum(curr_pre_lin * curr_output_grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

weight_grad = bias_grad = None

# Launches 1D grid where each program operates over one feature.
grid = lambda _: (feat_dim,)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grid hasn't to be a lambda expression. initialize it as a tuple.

@StrongSpoon
Copy link
Collaborator

please fix bug in accuracy test:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Code Contribution: 【Lv3】【Operator Development】batch_norm_backward
2 participants