Skip to content

Commit

Permalink
Naive register <-> tmem load/store support (#3786)
Browse files Browse the repository at this point in the history
Extracted from #3755 to make code
review easy.

This PR adds a new unit test `TMemTest.GmemRegTMemRegGmemCopy` that
schedules a copy kernel gmem -> register -> tmem -> register -> gmem,
and update our system with the minimum required changes to make this
test pass.

The purpose of this PR is not to provide a good implementation of TMem
support, but just to provide the absolute minimal requirement for us to
start. Limitations are:
1. The index is hard coded zero, so this PR is not touching the
interesting topic of "how to schedule TMem tensor?"
2. The TMem is used without allocation. Using a memory that is not
allocated is clearly a wrong way to program, but as described in the
code comment, if a fusion only has one TMem TensorView, it is guaranteed
to work.

Generated code:
```CUDA
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T4) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (32 * ((nvfuser_index_t)blockIdx.x));
  bool b1;
  b1 = i0 < T0.logical_size[0LL];
  Array<float, 1, 1> T1;
  T1[0] = 0;
  if (b1) {
    T1[0]
       = T0[((T0.alloc_stride[0LL] * ((nvfuser_index_t)threadIdx.x)) + ((32 * T0.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
  }
  asm volatile(
    "tcgen05.st.sync.aligned.32x32b.x1.b32 [%0], {%1};\n"
    :
    :"r"(0U),
     "f"((*reinterpret_cast<Array<float, 1, 1>*>(&T1[0]))[0])
  );
  asm volatile("tcgen05.wait::st.sync.aligned;\n");
  Array<float, 1, 1> T3;
  asm(
    "tcgen05.ld.sync.aligned.32x32b.x1.b32 {%0}, [%1];\n"
    :"=f"((*reinterpret_cast<Array<float, 1, 1>*>(&T3[0]))[0])
    :"r"(0U)
  );
  asm volatile("tcgen05.wait::ld.sync.aligned;\n");
  if (b1) {
    T4[i0]
       = T3[0];
  }
}
```
  • Loading branch information
zasdfgbnm authored Jan 30, 2025
1 parent a3f6ba9 commit 149c163
Show file tree
Hide file tree
Showing 14 changed files with 258 additions and 18 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/device_lower/analysis/index_compute.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/predicate_elimination.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/sync_information.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/tensor_memory.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/thread_predicate.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/tma.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/trivial_broadcast.cpp
Expand Down
7 changes: 6 additions & 1 deletion csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
return;
}

if (ti->view()->getMemoryType() == MemoryType::Tensor) {
code_ << genInline(ti->index());
return;
}

if (ti->view()->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(ti->view()).hasBID()) {
code_ << "*(volatile " << ti->getDataType().value() << "*)&";
Expand Down Expand Up @@ -3186,7 +3191,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
break;
}
case MemoryType::Tensor: {
NVF_THROW("Not implemented yet");
// Do nothing for now. This behavior will change soon.
break;
}
default:
Expand Down
26 changes: 26 additions & 0 deletions csrc/device_lower/analysis/tensor_memory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include <device_lower/analysis/tensor_memory.h>
#include <fusion.h>
#include <ir/all_nodes.h>

namespace nvfuser {

TensorMemoryInfo computeTMemInfo(Fusion* fusion) {
bool found = false;
for (auto tv : fusion->allTvs()) {
if (tv->getMemoryType() == MemoryType::Tensor) {
NVF_ERROR(!found, "Only one tensor on TMem is supported");
found = true;
}
}
return {};
}

} // namespace nvfuser
65 changes: 65 additions & 0 deletions csrc/device_lower/analysis/tensor_memory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

namespace nvfuser {

class Fusion;

// Information used to lower tensor memory. So far, there is no information
// needed, the computeTMemInfo just check that there is only one tensor on TMem
// in the fusion. This limitation is described in the note below, and it is only
// for incremental development. This limitation will be removed soon in the
// future.
struct TensorMemoryInfo;
TensorMemoryInfo computeTMemInfo(Fusion* fusion);

// Note: [Tensor Memory Allocation]
//
// Tensor memory is a very special memory, so its allocation is also very
// different from other memory types.
//
// It is highly recommended to read the PTX documentation for tensor memory
// if you are not alreay familiar with it:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-memory
//
// The first thing to note is, TMem does not have virtualization. This means:
// We can not just allocate starting from address 0 like how we allocate shared
// memory, and rely on page table to translate the same virtual address of
// different CTA to different physical address. There is no virtual TMem
// address. All addresses are physical addresses.
//
// Because multiple CTAs can execute on the same SM simultaneously, there must
// be some handshaking mechanism for each CTA to know the region of TMem that it
// can use. This is done by using the PTX instruction tcgen05.alloc. To ensure
// safety, there is a mutex "I have the right to allocate TMem" in the
// hardware. At the beginning of each CTA, the CTA will try to acquire the mutex
// automatically. If it fails, the CTA will be blocked until the mutex is free.
// This means, only one CTA can allocate TMem at a time. Once the CTA has
// finished allocating TMem, it should release the mutex to relinquish the right
// to allocate. After the right to allocate is relinquished, this CTA can not
// allocate new TMem any more, but it can still access the TMem that it has
// allocated, and it can also free the TMem that it has allocated. Once one CTA
// relinquishes the right to allocate, the next CTA that is blocked will be
// unblocked and can acquire the mutex to allocate TMem.
//
// Currently, the TMem allocation is not supported in nvFuser. We currently only
// allow one TensorView to be on TMem, and because we never relinquish the right
// to allocate TMem, CTA will be serialized on SM. A new CTA can be scheduled on
// an SM only after the previous CTA on that SM has completely finished
// executing. Thanks to this serialization, we can just skip allocating and
// think that our only TMem TensorView own the entire TMem, because we are sure
// that there will not be another CTA using that address. As a result, we could
// just provide address 0 to our instructions that access TMem. In principle, it
// is clearly wrong to write to an address that is not allocated, but because we
// are sure that it will in practice work for the specific unit test that we are
// targeting, we just do it so we have incremental development.

struct TensorMemoryInfo {};

} // namespace nvfuser
3 changes: 3 additions & 0 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,9 @@ void GpuLower::analysis(Fusion* fusion) {

consumerToTMAInfo() = getConsumerToTMAInfoMap(fusion_);
dumpExprsIfEnabled(fusion_->exprs(), "getConsumerToTMAInfoMap");

tmemInfo() = computeTMemInfo(fusion_);
dumpExprsIfEnabled(fusion_->exprs(), "computeTMemInfo");
}

kir::Kernel* GpuLower::kernel() const {
Expand Down
12 changes: 12 additions & 0 deletions csrc/device_lower/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <device_lower/analysis/fused_reduction.h>
#include <device_lower/analysis/predicate_elimination.h>
#include <device_lower/analysis/sync_information.h>
#include <device_lower/analysis/tensor_memory.h>
#include <device_lower/analysis/thread_predicate.h>
#include <device_lower/analysis/tma.h>
#include <device_lower/analysis/trivial_broadcast.h>
Expand Down Expand Up @@ -268,6 +269,14 @@ class GpuLower : public NonCopyable {
return consumer_to_tma_info_;
}

const TensorMemoryInfo& tmemInfo() const {
return tmem_info_;
}

TensorMemoryInfo& tmemInfo() {
return tmem_info_;
}

// Register a boolean Val as a predicate to validate at the run time. Optional
// validation error messages can be given as args.
template <typename... Args>
Expand Down Expand Up @@ -365,6 +374,9 @@ class GpuLower : public NonCopyable {
// Keep track of the mbarrier used for each load/store operation
std::unordered_map<const Expr*, TensorView*> ldst_mbarrier_map_;

// Information about tensor memory usage
TensorMemoryInfo tmem_info_;

// Keep track of validations needed at runtime. For example, a pair of
//! "extent mod split_factor == 0" and an error message for divisibility check
//! for vectorization.
Expand Down
48 changes: 41 additions & 7 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2141,13 +2141,47 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
}

if (!ir_utils::isStMatrixOp(ldst)) {
in = lowerSrcIndex(
ldst->in(),
ldst->out(),
{},
ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst));
out =
lowerDstIndex(ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type);
bool is_ldst_tmem = ldst->opType() == LoadStoreOpType::LdTMem ||
ldst->opType() == LoadStoreOpType::StTMem;
if (is_ldst_tmem) {
// TODO: support other types
NVF_ERROR(
dataTypeSize(ldst->in()->dtype()) == 4,
"For now, we only support 32-bit types in tmem");
NVF_ERROR(
dataTypeSize(ldst->out()->dtype()) == 4,
"For now, we only support 32-bit types in tmem");
// TODO: hard code size 1 for now.
// According to the specification of tcgen05.{ld,st}, the register
// operand must be viewed as a vector of 32-bit elements.
// See:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-memory-and-register-load-store-instructions
as_type = ArrayType{std::make_shared<DataType>(ldst->in()->dtype()), 1};
}
if (auto tv = dynamic_cast<TensorView*>(ldst->in());
tv != nullptr && tv->getMemoryType() == MemoryType::Tensor) {
// TODO: hard coded index zero for now.
auto index = IrBuilder::create<Val>(0, DataType::UInt32);
in = IrBuilder::create<kir::TensorIndex>(
tv, index, DataType::TMemAddress);
} else {
in = lowerSrcIndex(
ldst->in(),
ldst->out(),
{},
ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst),
as_type);
}
if (auto tv = dynamic_cast<TensorView*>(ldst->out());
tv != nullptr && tv->getMemoryType() == MemoryType::Tensor) {
// TODO: hard coded index zero for now.
auto index = IrBuilder::create<Val>(0, DataType::UInt32);
out = IrBuilder::create<kir::TensorIndex>(
tv, index, DataType::TMemAddress);
} else {
out = lowerDstIndex(
ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type);
}
}
auto new_ldst =
IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in, ldst->cacheOp())
Expand Down
35 changes: 35 additions & 0 deletions csrc/device_lower/pass/inline_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,41 @@ class LowerToInlinePtx : public kir::ExprMutator {
IrBuilder::create<Val>(vec_size),
invertedPredicate(ldst->predicate())},
kir::Asm::Options{/*volatile=*/true}));
} else if (ldst->opType() == LoadStoreOpType::LdTMem) {
// TODO: support other types of ld/st
auto ptx = "tcgen05.ld.sync.aligned.32x32b.x1.b32";
registerReplace(
ldst,
IrBuilder::create<kir::Asm>(
ptx,
std::vector<Val*>{ldst->out()},
std::vector<Val*>{ldst->in()}));
auto wait_ptx = "tcgen05.wait::ld.sync.aligned";
registerInsertAfter(
ldst,
IrBuilder::create<kir::Asm>(
wait_ptx,
std::vector<Val*>{},
std::vector<Val*>{},
kir::Asm::Options{/*volatile=*/true}));
} else if (ldst->opType() == LoadStoreOpType::StTMem) {
// TODO: support other types of ld/st
auto ptx = "tcgen05.st.sync.aligned.32x32b.x1.b32";
registerReplace(
ldst,
IrBuilder::create<kir::Asm>(
ptx,
std::vector<Val*>{},
std::vector<Val*>{ldst->out(), ldst->in()},
kir::Asm::Options{/*volatile=*/true}));
auto wait_ptx = "tcgen05.wait::st.sync.aligned";
registerInsertAfter(
ldst,
IrBuilder::create<kir::Asm>(
wait_ptx,
std::vector<Val*>{},
std::vector<Val*>{},
kir::Asm::Options{/*volatile=*/true}));
}
}

Expand Down
4 changes: 3 additions & 1 deletion csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ TensorIndex::TensorIndex(
isPointerType(index->dtype()) || index->dtype() == DataType::Index ||
isStructType(index->dtype()) ||
index->dtype() ==
DataType::UInt64 /*For matrix descriptor for hopper MMA*/,
DataType::UInt64 /*For matrix descriptor for hopper MMA*/
|| index->dtype() ==
DataType::UInt32 /*Temporarily enabled for TMem tensor*/,
"Cannot index with a value other than an int/pointer/struct.");
}

Expand Down
8 changes: 6 additions & 2 deletions csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ static std::string data_type2string(DataType t) {
case DataType::UInt16:
return "uint16_t";
case DataType::UInt32:
case DataType::SMemAddress:
case DataType::TMemAddress:
return "uint32_t";
case DataType::UInt64:
return "uint64_t";
case DataType::SMemAddress:
return "unsigned";
case DataType::ComplexFloat:
return "std::complex<float>";
case DataType::ComplexDouble:
Expand Down Expand Up @@ -860,6 +860,10 @@ const char* load_store_type2string(LoadStoreOpType t) {
return "CpAsyncBulk";
case LoadStoreOpType::CpAsyncBulkTensorTile:
return "CpAsyncBulkTensorTile";
case LoadStoreOpType::LdTMem:
return "LdTMem";
case LoadStoreOpType::StTMem:
return "StTMem";
default:
NVF_THROW("Unexpected parallel type");
}
Expand Down
12 changes: 8 additions & 4 deletions csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ enum class PrimDataType {
ComplexFloat,
// Pointers
SMemAddress,
TMemAddress,
// Null
Null
};
Expand Down Expand Up @@ -196,6 +197,7 @@ struct DataType {
static constexpr PrimDataType ComplexFloat = PrimDataType::ComplexFloat;
static constexpr PrimDataType ComplexDouble = PrimDataType::ComplexDouble;
static constexpr PrimDataType SMemAddress = PrimDataType::SMemAddress;
static constexpr PrimDataType TMemAddress = PrimDataType::TMemAddress;
static constexpr PrimDataType Null = PrimDataType::Null;
};

Expand Down Expand Up @@ -297,7 +299,7 @@ inline bool isUnsignedIntegralType(DataType dtype) {
// Returns if the datatype is a pointer type
inline bool isPointerType(DataType dtype) {
return std::holds_alternative<PointerType>(dtype.type) ||
dtype == DataType::SMemAddress;
dtype == DataType::SMemAddress || dtype == DataType::TMemAddress;
}

// Returns if the datatype is an integer or pointer type
Expand Down Expand Up @@ -801,7 +803,9 @@ enum class LoadStoreOpType {
CpAsync,
CpAsyncBulk,
CpAsyncBulkTensorTile,
StMatrix
StMatrix,
LdTMem,
StTMem
};

// Used to label what part of the circular buffered iterdomain
Expand Down Expand Up @@ -1055,11 +1059,11 @@ constexpr inline size_t primDataTypeSize(PrimDataType type) {
case DataType::UInt16:
return sizeof(uint16_t);
case DataType::UInt32:
case DataType::SMemAddress:
case DataType::TMemAddress:
return sizeof(uint32_t);
case DataType::UInt64:
return sizeof(uint64_t);
case DataType::SMemAddress:
return sizeof(unsigned);
default:
NVF_THROW("Size undefined for data type.");
}
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/test_loop_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,15 +568,15 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2>
const unsigned smem_offset = 0;
NVFUSER_DEFINE_MAGIC_ZERO;
float* T4 = reinterpret_cast<float*>(array + smem_offset + 0LL);
unsigned i0;
uint32_t i0;
i0 = toSmem(T4);
float* ptr1;
ptr1 = T0.data + (4LL * T0.alloc_stride[0LL]);
#pragma unroll 4
for(nvfuser_index_t i2 = 0LL; i2 < 4LL; ++i2) {
float* ptr3;
ptr3 = T0.data + (T0.alloc_stride[0LL] * i2);
unsigned i4;
uint32_t i4;
i4 = i0 + (12LL * i2);
bool b5;
b5 = (i2 + nvfuser_zero) < T0.logical_size[0LL];
Expand Down Expand Up @@ -608,7 +608,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2>
ptr8 = ptr1 + (T0.alloc_stride[0LL] * i7);
nvfuser_index_t i9;
i9 = 4LL + i7;
unsigned i10;
uint32_t i10;
i10 = i0 + (12LL * (i9 % 5LL));
nvfuser_index_t i11;
i11 = 1LL + (3LL * (i7 % 5LL));
Expand Down
Loading

0 comments on commit 149c163

Please sign in to comment.