Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TMA with 64-bit indexing #3599

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Support TMA with 64-bit indexing #3599

wants to merge 10 commits into from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Dec 16, 2024

This PR removes the check that IndexType is Int32 when using TMA.

When TMA is enabled we have a coordinate array specifying the data pointer of the global memory array and the box coordinates. The global array might be larger enough to require an int64_t to linearly index its last element, in which case we would use int64_t indexing. However, since it is possibly a multidimensional array, it is still possible that the multidimensional coords of all TMA boxes can still be expressed with int32_t. These are the cases we want to cover for matmul since the input is 2D and we don't commonly expect each individual dimension to be larger than 2^31.

Currently I am just static casting the box dims to int32_t in memory.cu. We should additionally verify that none of these coords will be larger than the capacity of int32_t. We could handle that more broadly actually by writing a more sophisticated analysis that takes every indexing expression (including TMA box coords) in the Fusion and proves bounds for it. This is trickier than what we currently do which is just look at all TensorViews and determine the position of the furthest strided element within that tensor, but it might reveal that we use int64_t in many cases where it's unnecessary (such as matmuls).

Note that it's still possible to have a box coordinate that cannot be expressed with a 32-bit index, in which case we should not allow TMA.

TODO: we should check the above condition in the core heuristic where we have access to problem sizes.

Fixes #3595

@jacobhinkle
Copy link
Collaborator Author

!test

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@jacobhinkle jacobhinkle marked this pull request as ready for review December 17, 2024 14:14
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support TMA with Int64 indexing
2 participants