Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Jan 29, 2025
1 parent 2fe5579 commit e97be30
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 91 deletions.
4 changes: 1 addition & 3 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3193,9 +3193,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}
case MemoryType::Tensor: {
indent() << "TMemTensor " << genVariableName(tv) << "("
<< genInline(alloc->baseAddress()) << ", "
<< genInline(alloc->laneOffset()) << ", "
<< genInline(alloc->colOffset()) << ");\n";
<< genInline(alloc->address()) << ");\n";
break;
}
default:
Expand Down
4 changes: 1 addition & 3 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,7 @@ class AllocationInserter : public kir::ExprMutator {
GpuLower::current()->tmemInfo().allocation_address;
auto address_ti = IrBuilder::create<kir::TensorIndex>(
allocation_address, allocation_address->fusion()->zeroVal());
alloc_expr->setBaseAddress(address_ti);
alloc_expr->setLaneOffset(allocation_address->fusion()->zeroVal());
alloc_expr->setColOffset(allocation_address->fusion()->zeroVal());
alloc_expr->setAddress(address_ti);
}

return alloc_expr;
Expand Down
2 changes: 0 additions & 2 deletions csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,6 @@ Allocate::Allocate(
addAttribute(alias);
// Always initialize smem/tmem addresses to nullptr
addAttribute(nullptr);
addAttribute(nullptr);
addAttribute(nullptr);

for (auto s : shape) {
addAttribute(s);
Expand Down
75 changes: 7 additions & 68 deletions csrc/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,9 @@ class Allocate final : public Expr {

//! Size of each dimension
std::vector<Val*> shape() const {
constexpr int64_t num_attributes_before_shape = 8;
std::vector<Val*> result;
result.reserve(attributes().size() - num_attributes_before_shape);
for (auto i = attributes().begin() + num_attributes_before_shape;
i != attributes().end();
++i) {
result.reserve(attributes().size() - 6);
for (auto i = attributes().begin() + 6; i != attributes().end(); ++i) {
result.emplace_back((*i)->as<Val>());
}
return result;
Expand Down Expand Up @@ -368,8 +365,9 @@ class Allocate final : public Expr {
// aligned address in bytes.
void setAddress(Val* addr) {
NVF_CHECK(
memoryType() == MemoryType::Shared,
"Allocation address may only be set for shared memory allocations. Memory type is ",
memoryType() == MemoryType::Shared ||
memoryType() == MemoryType::Tensor,
"Allocation address may only be set for shared/tensor memory allocations. Memory type is ",
memoryType());
NVF_CHECK(
address() == nullptr,
Expand All @@ -378,76 +376,17 @@ class Allocate final : public Expr {
attributes_[5] = addr;
}

void setBaseAddress(Val* addr) {
NVF_CHECK(
memoryType() == MemoryType::Tensor,
"Allocation base address may only be set for tensor memory allocations. Memory type is ",
memoryType());
NVF_CHECK(
baseAddress() == nullptr,
"Attempted to set base address twice for allocation ",
toString());
attributes_[5] = addr;
}

void setLaneOffset(Val* lane_offset) {
NVF_CHECK(
memoryType() == MemoryType::Tensor,
"Lane offset may only be set for tensor memory allocations. Memory type is ",
memoryType());
NVF_CHECK(
laneOffset() == nullptr,
"Attempted to set lane offset twice for allocation ",
toString());
attributes_[6] = lane_offset;
}

void setColOffset(Val* col_offset) {
NVF_CHECK(
memoryType() == MemoryType::Tensor,
"Column offset may only be set for tensor memory allocations. Memory type is ",
memoryType());
NVF_CHECK(
colOffset() == nullptr,
"Attempted to set column offset twice for allocation ",
toString());
attributes_[7] = col_offset;
}

// This is an integer scalar describing the byte address within the dynamic
// shared memory array for a shared memory allocation. For memory types other
// than Shared, or before allocation, this function might return nullptr.
Val* address() const {
NVF_CHECK(
memoryType() == MemoryType::Shared,
memoryType() == MemoryType::Shared ||
memoryType() == MemoryType::Tensor,
"Allocation address may only be set for shared memory allocations. Memory type is ",
memoryType());
return attributeVal(5);
}

Val* baseAddress() const {
NVF_CHECK(
memoryType() == MemoryType::Tensor,
"Base address may only be set for tensor memory allocations. Memory type is ",
memoryType());
return attributeVal(5);
}

Val* laneOffset() const {
NVF_CHECK(
memoryType() == MemoryType::Tensor,
"Lane offset may only be set for tensor memory allocations. Memory type is ",
memoryType());
return attributeVal(6);
}

Val* colOffset() const {
NVF_CHECK(
memoryType() == MemoryType::Tensor,
"Column offset may only be set for tensor memory allocations. Memory type is ",
memoryType());
return attributeVal(7);
}
};

// Allocate tensor memory tcgen05.alloc
Expand Down
17 changes: 2 additions & 15 deletions runtime/tensor_memory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
// manipulate tensor memory addresses. Example usage:
// TMemTensor T0(0x12345678):
// -> address (lane=0x1234, col=0x5678):
// TMemTensor T1(0x12345678, 32, 32):
// -> address (lane=0x1234+32, col=0x5678+32)
// TMemTensor T2 = T1 + {64, 64}:
// -> address (lane=T1.lane+64, col=T1.col+64)
// TMemTensor T1 = T0 + {64, 64}:
// -> address (lane=T0.lane+64, col=T0.col+64)
struct TMemTensor {
uint32_t raw_address;

Expand All @@ -31,24 +29,13 @@ struct TMemTensor {

TMemTensor(uint32_t raw_address) : raw_address(raw_address) {}

TMemTensor(uint32_t base_address, uint16_t lane_offset, uint16_t col_offset)
: raw_address(add(base_address, lane_offset, col_offset)) {}

operator uint32_t() const {
return raw_address;
}

uint32_t operator+(Array<uint16_t, 2> offset) const {
return add(raw_address, offset[0], offset[1]);
}

uint16_t lane() const {
return raw_address >> 16;
}

uint16_t col() const {
return raw_address & 0xFFFF;
}
};

static_assert(
Expand Down

0 comments on commit e97be30

Please sign in to comment.