diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 0bcfe11a733..25a03b6bb23 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -29,6 +29,7 @@ #include #include +#include "ir/builder.h" namespace nvfuser { @@ -2672,6 +2673,13 @@ std::pair Index::getCpAsyncBulkGmemIndex( const TensorIndexer& indexer = GpuLower::current()->tensorIndexer(); auto indices_inner_to_outer = indexer.getIndexFor(ldst, !is_load, ids_to_index, loops); + // These are the box coordinates of the TMA box, which must be of type + // int32_t. Possible overflow in each of these dims should be checked + // elsewhere. + for (size_t i : c10::irange(indices_inner_to_outer.size())) { + indices_inner_to_outer[i] = SimplifyingIrBuilder::maybeCastExpr( + DataType::Int32, indices_inner_to_outer[i]); + } int64_t dim = (int64_t)tma_info.dims().size(); auto coordinate = IrBuilder::arrayExpr(indices_inner_to_outer); diff --git a/csrc/runtime/executor.cpp b/csrc/runtime/executor.cpp index 04f86b1edd0..7e39865360f 100644 --- a/csrc/runtime/executor.cpp +++ b/csrc/runtime/executor.cpp @@ -390,12 +390,6 @@ void KernelExecutor::compile( !(compile_params.index_type.value() == PrimDataType::Int32 && arg_index_type == PrimDataType::Int), "Compilation with int32 is requested but int64 is required for the arguments"); - NVF_ERROR( - !has_cp_async_bulk || - (compile_params.index_type.value() == PrimDataType::Int32), - "Compilation with int64 is requested but int32 is required because ", - "of TMA operations."); - } else if (arg_index_type == PrimDataType::Int) { // If the given compile option doesn't specify the index type, and // the arguments require 64-bit indexing, we need to use 64-bit @@ -403,15 +397,31 @@ void KernelExecutor::compile( // it's safe to use 32-bit for the whole kernel, so unless it's // specified through CompileParams, we do not use 32-bit indexing. compile_params.index_type = arg_index_type; - NVF_ERROR( - !has_cp_async_bulk, - "Compilation with int64 is required based on input arguments, but ", - "int32 is required because of TMA operations."); } else if (has_cp_async_bulk) { // TMA operations require 32-bit indexing. compile_params.index_type = PrimDataType::Int32; } + if (has_cp_async_bulk && compile_params.index_type == PrimDataType::Int) { + // Check that each dimension of the inputs can be indexed with a 32-bit int + // There might be some gmem tensors that are linearly indexed (i.e. not read + // or written using TMA). This analysis only checks the extents of fusion + // inputs since it is assumed that if TMA is used for an output tensor it + // will share dims with some inputs. + for (const std::shared_ptr& arg : args) { + if (arg->is()) { + const at::Tensor& tensor = arg->as(); + for (const size_t dim_i : c10::irange(tensor.ndimension())) { + int64_t size = tensor.size((int64_t)dim_i); + NVF_CHECK( + size < (1LL << 31), + "TMA enabled with 64-bit indexing. Expected all input dims to be < 2^31 but found ", + size); + } + } + } + } + c10::DeviceGuard dg(options_.device); NVF_ERROR( diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 991fe732b72..c2f63d7e8e6 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -648,6 +648,40 @@ TEST_F(TMAIndexingTest, Load2DTensorWith1DTMA) { testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); } +// This test is the same as above but checks that we can use 64-bit indexing +TEST_F(TMAIndexingTest, Int64Indexing) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + tv1->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + + for (auto tv : {tv1, tv2}) { + tv->merge(0); + tv->split(0, 32); + tv->axis(0)->parallelize(ParallelType::BIDx); + } + tv1->axis(1)->parallelize(ParallelType::Bulk); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({30000, 50000}, options); + CompileParams int64_cparams = matmul_cparams; + int64_cparams.index_type = DataType::Int; + KernelExecutor ke; + ke.compile(&fusion, {t0}, {}, int64_cparams); + + auto cg_outputs = ke.run({t0}); + testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); +} + TEST_F(TMAIndexingTest, Load1DTensorWith2DTMA) { Fusion fusion; FusionGuard fg(&fusion);