Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrong result in matmul kernel #3665

Open
jjsjann123 opened this issue Jan 2, 2025 · 4 comments
Open

wrong result in matmul kernel #3665

jjsjann123 opened this issue Jan 2, 2025 · 4 comments
Assignees

Comments

@jjsjann123
Copy link
Collaborator

The issue emerged here

If we reorder the broadcast/cast before the multiply&sum, our codegen gives totally wrong result. Here's the repro.

import torch
from nvfuser import FusionDefinition, DataType
    
DEBUG=False 
    
def nvfuser_fusion_id38(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    
    if DEBUG:
        T2 = fd.ops.broadcast(T0, [False, True, False])
        T3 = fd.ops.cast(T2, DataType.Float)
    
        T4 = fd.ops.broadcast(T1, [True, False, False])
        T5 = fd.ops.cast(T4, DataType.Float)
    else: 
        T2 = fd.ops.cast(T0, DataType.Float)
        T3 = fd.ops.broadcast(T2, [False, True, False])
    
        T4 = fd.ops.cast(T1, DataType.Float)
        T5 = fd.ops.broadcast(T4, [True, False, False])
    
    T6 = fd.ops.mul(T3, T5)
    T7 = fd.ops.sum(T6, [-1])
    fd.add_output(T7)
    
with FusionDefinition() as fd:
    nvfuser_fusion_id38(fd)
    
inputs = [
    torch.testing.make_tensor((32, 128), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((64, 128), dtype=torch.bfloat16, device='cuda:0'),
]
o = fd.execute(inputs)
    
o_ref = inputs[0].float() @ inputs[1].transpose(0, 1).float()
print(o[0] - o_ref)
assert o[0].allclose(o_ref, 1e-3, 1e-3)
@jjsjann123
Copy link
Collaborator Author

In the generated code, one section looks like this

btype is marked as f16, but T5 is of dtype float.

    float T5[32];
    // ...
       asm(
          "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
          :"=f"((*reinterpret_cast<Array<Array<float, 4, 1>, 2, 1>*>(&((*reinterpret_cast<Array<float, 8, 1>*>(&T10[(i113 + i114)])))))[0LL][0]),
           "=f"((*reinterpret_cast<Array<Array<float, 4, 1>, 2, 1>*>(&((*reinterpret_cast<Array<float, 8, 1>*>(&T10[(i113 + i114)])))))[0LL][1]),
           "=f"((*reinterpret_cast<Array<Array<float, 4, 1>, 2, 1>*>(&((*reinterpret_cast<Array<float, 8, 1>*>(&T10[(i113 + i114)])))))[0LL][2]),
           "=f"((*reinterpret_cast<Array<Array<float, 4, 1>, 2, 1>*>(&((*reinterpret_cast<Array<float, 8, 1>*>(&T10[(i113 + i114)])))))[0LL][3])
          :"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T3[i112]))[0]),
           "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T3[i112]))[1]),
           "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T3[i112]))[2]),
           "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T3[i112]))[3]),
           "r"((*reinterpret_cast<Array<Array<uint32_t, 2, 1>, 2, 1>*>(&((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T5[i114])))))[0LL][0]),
           "r"((*reinterpret_cast<Array<Array<uint32_t, 2, 1>, 2, 1>*>(&((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T5[i114])))))[0LL][1]),
           "f"((*reinterpret_cast<Array<Array<float, 4, 1>, 2, 1>*>(&((*reinterpret_cast<Array<float, 8, 1>*>(&T10[(i113 + i114)])))))[0LL][0]),
           "f"((*reinterpret_cast<Array<Array<float, 4, 1>, 2, 1>*>(&((*reinterpret_cast<Array<float, 8, 1>*>(&T10[(i113 + i114)])))))[0LL][1]),
           "f"((*reinterpret_cast<Array<Array<float, 4, 1>, 2, 1>*>(&((*reinterpret_cast<Array<float, 8, 1>*>(&T10[(i113 + i114)])))))[0LL][2]),
           "f"((*reinterpret_cast<Array<Array<float, 4, 1>, 2, 1>*>(&((*reinterpret_cast<Array<float, 8, 1>*>(&T10[(i113 + i114)])))))[0LL][3])
        );

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Jan 2, 2025

Interesting. We should be rejecting this in the matmul scheduler since the input to the matmul pattern is not fp16 or bf16.

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Jan 2, 2025

What happens here is we pattern match this to a matmul when the broadcasts are immediately before the mul+sum. Switching this so that there is a cast in between the broadcast and the mul means we will just use the reduction scheduler, realizing the whole product instead. That's slow but accurate. But we're not recognizing the matmul as an SSS matmul instead of HSS in the first case, hence the error.

I think ideally we should actually accept this pattern and just remove this cast during translation to matmul.

@jjsjann123
Copy link
Collaborator Author

I think ideally we should actually accept this pattern and just remove this cast during translation to matmul.

That would be ideal.

In #3644 , I'm moving cast closer to inputs, passing through meta operations like broadcast/squeeze/set/view. So that would just break the matmul scheduler unless we accept the new pattern.
I can temporarily remove broadcast from that list, so I'm messing up the matmul stuff, we'll add it back when the scheduler is updated to handle the variations.

@jacobhinkle jacobhinkle self-assigned this Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants