Skip to content

[BUG] Accumulator initialization for grouped GEMM (for SM100) #2674

@alexsamardzic

Description

@alexsamardzic

Which component has the problem?

CUTLASS C++

Bug Report

cc: @mihir-awatramani

I believe there is a problem with grouped GEMM calculations when k_tile_count happens to be 0 here, i.e. when given group size along K is 0.

To reproduce, technically following code should do (using grouped GEMM implementation from PyTorch):

import torch

device="cuda"
dtype_AB = torch.bfloat16
dtype_offset = torch.int32

m = 16
n = 16

offs = torch.tensor([0, 8, 16, 16, 32], device=device, dtype=dtype_offset)
ngroups = offs.shape[0]
k = offs[-1]
A = torch.randn(m, k, device=device, dtype=dtype_AB)
B = torch.randn(n, k, device=device, dtype=dtype_AB)

out = torch._grouped_mm(A, B.transpose(-2, -1), offs=offs)
assert torch.all(out[0, :, :] == 0)
assert torch.all(out[3, :, :] == 0)

The problem is that, for the first and fourth group, the input tensors will be 16x0 and 0x16, so the output should be 16x16 tensor of all 0s. However, as k_tile_count happens to be 0 here, it seems to me that the accumulator never get changed, and whatever values have been there in the accumulator will be written into the output tensor.

It's hard to reproduce the problem deterministically though, as it could be that accumulator contained all 0s, so the output values should turn out correct. The only reliable way to reproduce the problem we found was to run corresponding tests from PyTorch, in a very specific way:

python test/test_matmul_cuda.py -k test_grouped_gemm_compiled --repeat 2

Some tests will always fail here, but not always the same tests; the failures will always happen on the second run. So most probably it's how accumulators are allocated, and about occurring to use a storage for accumulator that happens to have non-zero values from previous use.

In any case, if I force the accumulator values to all 0s in case k_tile_count happens to be 0, by adding following dummy code, after this line:

        if (k_tile_count == 0) {
          auto [tiled_mma, tCrA, tCrB] = collective_mainloop.mma_init(tmem_storage, shared_storage.tensors.mainloop);
          tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
          cute::gemm(tiled_mma, tCrA(_,_,_0{},_0{}), tCrB(_,_,_0{},_0{}), accumulator);
        }

then the wrong results disappear from the test.

So, would it be possible to apply alike fix to CUTLASS? We haven't encountered the issue on SM90, but probably it would be good to check the grouped GEMM code there too, for handling this same case.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions