Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TMA with 64-bit indexing #3599

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <transform_replay.h>

#include <memory>
#include "ir/builder.h"

namespace nvfuser {

Expand Down Expand Up @@ -2672,6 +2673,13 @@ std::pair<Val*, Val*> Index::getCpAsyncBulkGmemIndex(
const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
auto indices_inner_to_outer =
indexer.getIndexFor(ldst, !is_load, ids_to_index, loops);
// These are the box coordinates of the TMA box, which must be of type
// int32_t. Possible overflow in each of these dims should be checked
// elsewhere.
for (size_t i : c10::irange(indices_inner_to_outer.size())) {
indices_inner_to_outer[i] = SimplifyingIrBuilder::maybeCastExpr(
DataType::Int32, indices_inner_to_outer[i]);
}

int64_t dim = (int64_t)tma_info.dims().size();
auto coordinate = IrBuilder::arrayExpr(indices_inner_to_outer);
Expand Down
30 changes: 20 additions & 10 deletions csrc/runtime/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,28 +390,38 @@ void KernelExecutor::compile(
!(compile_params.index_type.value() == PrimDataType::Int32 &&
arg_index_type == PrimDataType::Int),
"Compilation with int32 is requested but int64 is required for the arguments");
NVF_ERROR(
!has_cp_async_bulk ||
(compile_params.index_type.value() == PrimDataType::Int32),
"Compilation with int64 is requested but int32 is required because ",
"of TMA operations.");

} else if (arg_index_type == PrimDataType::Int) {
// If the given compile option doesn't specify the index type, and
// the arguments require 64-bit indexing, we need to use 64-bit
// indexing. Note that if the arg type is 32-bit, it doesn't mean
// it's safe to use 32-bit for the whole kernel, so unless it's
// specified through CompileParams, we do not use 32-bit indexing.
compile_params.index_type = arg_index_type;
NVF_ERROR(
!has_cp_async_bulk,
"Compilation with int64 is required based on input arguments, but ",
"int32 is required because of TMA operations.");
} else if (has_cp_async_bulk) {
// TMA operations require 32-bit indexing.
compile_params.index_type = PrimDataType::Int32;
}

if (has_cp_async_bulk && compile_params.index_type == PrimDataType::Int) {
// Check that each dimension of the inputs can be indexed with a 32-bit int
// There might be some gmem tensors that are linearly indexed (i.e. not read
// or written using TMA). This analysis only checks the extents of fusion
// inputs since it is assumed that if TMA is used for an output tensor it
// will share dims with some inputs.
for (const std::shared_ptr<PolymorphicValue>& arg : args) {
if (arg->is<at::Tensor>()) {
const at::Tensor& tensor = arg->as<at::Tensor>();
for (const size_t dim_i : c10::irange(tensor.ndimension())) {
int64_t size = tensor.size((int64_t)dim_i);
NVF_CHECK(
size < (1LL << 31),
"TMA enabled with 64-bit indexing. Expected all input dims to be < 2^31 but found ",
size);
}
}
}
}

c10::DeviceGuard dg(options_.device);

NVF_ERROR(
Expand Down
34 changes: 34 additions & 0 deletions tests/cpp/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,40 @@ TEST_F(TMAIndexingTest, Load2DTensorWith1DTMA) {
testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
}

// This test is the same as above but checks that we can use 64-bit indexing
TEST_F(TMAIndexingTest, Int64Indexing) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(2);
fusion.addInput(tv0);
auto tv1 = set(tv0);
auto tv2 = set(tv1);
fusion.addOutput(tv2);

tv1->setMemoryType(MemoryType::Shared);
tv1->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

for (auto tv : {tv1, tv2}) {
tv->merge(0);
tv->split(0, 32);
tv->axis(0)->parallelize(ParallelType::BIDx);
}
tv1->axis(1)->parallelize(ParallelType::Bulk);
tv2->axis(1)->parallelize(ParallelType::TIDx);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({30000, 50000}, options);
CompileParams int64_cparams = matmul_cparams;
int64_cparams.index_type = DataType::Int;
KernelExecutor ke;
ke.compile(&fusion, {t0}, {}, int64_cparams);

auto cg_outputs = ke.run({t0});
testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
}

TEST_F(TMAIndexingTest, Load1DTensorWith2DTMA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down
Loading