-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
wrong result in matmul kernel #3665
Comments
In the generated code, one section looks like this btype is marked as f16, but T5 is of dtype float.
|
Interesting. We should be rejecting this in the matmul scheduler since the input to the matmul pattern is not fp16 or bf16. |
What happens here is we pattern match this to a matmul when the broadcasts are immediately before the mul+sum. Switching this so that there is a cast in between the broadcast and the mul means we will just use the reduction scheduler, realizing the whole product instead. That's slow but accurate. But we're not recognizing the matmul as an SSS matmul instead of HSS in the first case, hence the error. I think ideally we should actually accept this pattern and just remove this cast during translation to matmul. |
That would be ideal. In #3644 , I'm moving cast closer to inputs, passing through meta operations like broadcast/squeeze/set/view. So that would just break the matmul scheduler unless we accept the new pattern. |
The issue emerged here
If we reorder the broadcast/cast before the multiply&sum, our codegen gives totally wrong result. Here's the repro.
The text was updated successfully, but these errors were encountered: