Skip to content

Commit

Permalink
Move some code around.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Aug 15, 2024
1 parent 6c65201 commit 5728976
Showing 1 changed file with 84 additions and 78 deletions.
162 changes: 84 additions & 78 deletions lib/nnc/mfa/v2/GEMMKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ GEMMKernel::GEMMKernel(GEMMKernelDescriptor descriptor, MTL::Device *const devic
}
}

#pragma mark - Source

std::string GEMMKernel::createSource() const noexcept {
CodeWriter source;

Expand Down Expand Up @@ -496,6 +498,8 @@ for (ushort m = 0; m < {{REGISTER_M}}; m += 8) {
*source += createMultiply;
}

#pragma mark - Caching

void GEMMKernel::createInitializeC(CodeWriter *source) const noexcept {
source->SetValue("REGISTER_M_8_REGISTER_N_8", std::to_string((registerM / 8) * (registerN / 8)));
*source += R"(
Expand Down Expand Up @@ -733,6 +737,86 @@ if ({{DIRECT_ACCESS_CONDITION}}) {
)";
}

void GEMMKernel::createStoreC(CodeWriter *source) const noexcept {
if (memoryPrecisions.C == GEMMOperandPrecision::BF16 && registerPrecisions.C == GEMMOperandPrecision::FP32) {
source->SetValue("STORE_FUNCTION_C", "store_bfloat");
} else {
source->SetValue("STORE_FUNCTION_C", "store");
}

*source += R"(
if ({{DIRECT_ACCESS_CONDITION}}) {
// Fast path for matrices that qualify.
uint2 C_offset(N_offset + offset_in_group.x,
M_offset + offset_in_group.y);
auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset(
C, {{LEADING_DIMENSION_C}}, C_offset);
// Write the accumulator to device memory.
#pragma clang loop unroll(full)
for (ushort m = 0; m < {{REGISTER_M}}; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < {{REGISTER_N}}; n += 8) {
ushort2 origin(n, m);
auto C = get_sram(C_sram, {{REGISTER_N}}, origin);
C->{{STORE_FUNCTION_C}}(C_dst, {{LEADING_DIMENSION_C}}, origin);
}
}
} else {
// Slow path for when memory must be handled more carefully.
auto C_block = (threadgroup {{MEMORY_NAME_C}}*)(threadgroup_block);
auto C_block_dst =
simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset(
C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, offset_in_group);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the accumulator to threadgroup memory.
#pragma clang loop unroll(full)
for (ushort m = 0; m < {{REGISTER_M}}; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < {{REGISTER_N}}; n += 8) {
ushort2 origin(n, m);
auto C = get_sram(C_sram, {{REGISTER_N}}, origin);
C->{{STORE_FUNCTION_C}}(
C_block_dst, {{LEADING_BLOCK_DIMENSIONS_C}}, origin);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Launch the async copy from threadgroup to device memory.
if (sidx == 0) {
uint2 C_offset(gid.x * N_group, gid.y * M_group);
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
min(uint(M_group), M - C_offset.y));
auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset(
C, {{LEADING_DIMENSION_C}}, C_offset);
// If we shift successfully, the garbage zone moves from the bottom right
// to the top left.
if ((M_shift != 0) || (N_shift != 0)) {
ushort2 C_block_shift(0, 0);
if ((M_shift != 0) && (C_offset.y >= M_edge)) {
C_block_shift.y = M_shift;
}
if ((N_shift != 0) && (C_offset.x >= N_edge)) {
C_block_shift.x = N_shift;
}
C_block = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset(
C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_block_shift);
}
simdgroup_event event;
event.async_copy(
C_dst, {{LEADING_DIMENSION_C}}, C_tile,
C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_tile);
}
}
)";
}

#pragma mark - Multiply

void GEMMKernel::createMultiplyIterations(CodeWriter *source) const noexcept {
if (preferAsyncLoad) {
source->SetValue("ASYNC_ITERATIONS_START", "0");
Expand Down Expand Up @@ -840,81 +924,3 @@ for (uint k = {{ASYNC_ITERATIONS_START}}; k < K; k += K_group) {
)";
}

void GEMMKernel::createStoreC(CodeWriter *source) const noexcept {
if (memoryPrecisions.C == GEMMOperandPrecision::BF16 && registerPrecisions.C == GEMMOperandPrecision::FP32) {
source->SetValue("STORE_FUNCTION_C", "store_bfloat");
} else {
source->SetValue("STORE_FUNCTION_C", "store");
}

*source += R"(
if ({{DIRECT_ACCESS_CONDITION}}) {
// Fast path for matrices that qualify.
uint2 C_offset(N_offset + offset_in_group.x,
M_offset + offset_in_group.y);
auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset(
C, {{LEADING_DIMENSION_C}}, C_offset);
// Write the accumulator to device memory.
#pragma clang loop unroll(full)
for (ushort m = 0; m < {{REGISTER_M}}; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < {{REGISTER_N}}; n += 8) {
ushort2 origin(n, m);
auto C = get_sram(C_sram, {{REGISTER_N}}, origin);
C->{{STORE_FUNCTION_C}}(C_dst, {{LEADING_DIMENSION_C}}, origin);
}
}
} else {
// Slow path for when memory must be handled more carefully.
auto C_block = (threadgroup {{MEMORY_NAME_C}}*)(threadgroup_block);
auto C_block_dst =
simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset(
C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, offset_in_group);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the accumulator to threadgroup memory.
#pragma clang loop unroll(full)
for (ushort m = 0; m < {{REGISTER_M}}; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < {{REGISTER_N}}; n += 8) {
ushort2 origin(n, m);
auto C = get_sram(C_sram, {{REGISTER_N}}, origin);
C->{{STORE_FUNCTION_C}}(
C_block_dst, {{LEADING_BLOCK_DIMENSIONS_C}}, origin);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Launch the async copy from threadgroup to device memory.
if (sidx == 0) {
uint2 C_offset(gid.x * N_group, gid.y * M_group);
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
min(uint(M_group), M - C_offset.y));
auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset(
C, {{LEADING_DIMENSION_C}}, C_offset);
// If we shift successfully, the garbage zone moves from the bottom right
// to the top left.
if ((M_shift != 0) || (N_shift != 0)) {
ushort2 C_block_shift(0, 0);
if ((M_shift != 0) && (C_offset.y >= M_edge)) {
C_block_shift.y = M_shift;
}
if ((N_shift != 0) && (C_offset.x >= N_edge)) {
C_block_shift.x = N_shift;
}
C_block = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset(
C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_block_shift);
}
simdgroup_event event;
event.async_copy(
C_dst, {{LEADING_DIMENSION_C}}, C_tile,
C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_tile);
}
}
)";
}

0 comments on commit 5728976

Please sign in to comment.