Skip to content

Commit

Permalink
Remove CoreCount based logic to select block size.
Browse files Browse the repository at this point in the history
Make BF16 can do correctness check.

Add loadPreviousC into GEMMDescriptor.
  • Loading branch information
liuliu committed Aug 15, 2024
1 parent 0ab02c0 commit 6c65201
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 38 deletions.
29 changes: 28 additions & 1 deletion bin/nnc/laplacian_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ std::pair<int, int> profileProblemSize(GEMMDescriptor descriptor)
void* t = A_storage;
A_storage = A;
A = (float*)t;
} else if (descriptor.memoryPrecisions.A == GEMMOperandPrecision::BF16) {
A_storage = (uint16_t*)ccmalloc(sizeof(uint16_t) * problemSize * problemSize);
for (int i = 0; i < problemSize * problemSize; i++)
((uint16_t*)A_storage)[i] = ((uint16_t*)A)[i * 2 + 1];
void* t = A_storage;
A_storage = A;
A = (float*)t;
}
void* B_storage = nullptr;
if (descriptor.memoryPrecisions.B == GEMMOperandPrecision::FP16)
Expand All @@ -76,6 +83,13 @@ std::pair<int, int> profileProblemSize(GEMMDescriptor descriptor)
void* t = B_storage;
B_storage = B;
B = (float*)t;
} else if (descriptor.memoryPrecisions.B == GEMMOperandPrecision::BF16) {
B_storage = (uint16_t*)ccmalloc(sizeof(uint16_t) * problemSize * problemSize);
for (int i = 0; i < problemSize * problemSize; i++)
((uint16_t*)B_storage)[i] = ((uint16_t*)B)[i * 2 + 1];
void* t = B_storage;
B_storage = B;
B = (float*)t;
}
void* bias_storage = nullptr;
if (descriptor.memoryPrecisions.bias == GEMMOperandPrecision::FP16)
Expand All @@ -85,6 +99,13 @@ std::pair<int, int> profileProblemSize(GEMMDescriptor descriptor)
void* t = bias_storage;
bias_storage = bias;
bias = (float*)t;
} else if (descriptor.memoryPrecisions.bias == GEMMOperandPrecision::BF16) {
bias_storage = (uint16_t*)ccmalloc(sizeof(uint16_t) * problemSize);
for (int i = 0; i < problemSize; i++)
((uint16_t*)bias_storage)[i] = ((uint16_t*)bias)[i * 2 + 1];
void* t = bias_storage;
bias_storage = bias;
bias = (float*)t;
}

// Since the Laplacian is symmetric, we swap roles of the matrices to test
Expand Down Expand Up @@ -258,6 +279,12 @@ std::pair<int, int> profileProblemSize(GEMMDescriptor descriptor)
ccv_half_precision_to_float(&value, &entry32, 1);
break;
}
case GEMMOperandPrecision::BF16: {
uint16_t value[2];
value[0] = 0;
value[1] = ((uint16_t*)raw)[address];
entry32 = *(float*)value;
}
}
C[address] = entry32;
}
Expand Down Expand Up @@ -483,7 +510,7 @@ int main(int argc, char** argv)
for (int j = 0; j < sizeof(transposeStates) / (sizeof(bool) * 2); j++)
{
TestDescriptor testDescriptor = TestDescriptor();
testDescriptor.precision = GEMMOperandPrecision::FP16;
testDescriptor.precision = GEMMOperandPrecision::BF16;
testDescriptor.problemSize = problemSize;
testDescriptor.transposeState[0] = transposeStates[j * 2];
testDescriptor.transposeState[1] = transposeStates[j * 2 + 1];
Expand Down
5 changes: 3 additions & 2 deletions lib/nnc/mfa/v2/GEMMDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ std::size_t std::hash<GEMMDescriptor>::operator()(const GEMMDescriptor& hash) co
combine_32(seed, hash.leadingDimensions.value()[2]);
}
combine_64(seed, pack_64(simd::ushort4 { hash.memoryPrecisions.A.value, hash.memoryPrecisions.B.value, hash.memoryPrecisions.C.value, hash.memoryPrecisions.bias.value }));
combine_32(seed, pack_32(simd::uchar4 { hash.transposeState[0], hash.transposeState[1], hash.transposeState[2], hash.useBias }));
combine_32(seed, pack_32(simd::uchar4 { hash.transposeState[0], hash.transposeState[1], hash.transposeState[2], 0 }));
combine_32(seed, pack_32(simd::uchar4 { hash.loadPreviousC, hash.useBias, 0, 0 }));
if (hash.registerPrecisionC.has_value()) {
combine_32(seed, pack_32(simd::ushort2 { hash.registerPrecisionC.value().value, 0 }));
}
Expand Down Expand Up @@ -103,7 +104,7 @@ std::pair<GEMMKernelDescriptor, PipelineValue<GEMMKernel> *> GEMMDescriptor::fin
constants->setConstantValue(&leadingDimensionB, MTL::DataTypeUInt, 6);
constants->setConstantValue(&leadingDimensionC, MTL::DataTypeUInt, 7);

bool loadPreviousC = false;
bool loadPreviousC = this->loadPreviousC;
constants->setConstantValue(&loadPreviousC, MTL::DataTypeBool, 10);

NS::String* swiftName = NS::String::string("gemm", NS::UTF8StringEncoding);
Expand Down
2 changes: 2 additions & 0 deletions lib/nnc/mfa/v2/GEMMDescriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ struct GEMMDescriptor {

simd::uchar3 transposeState;

bool loadPreviousC;

bool useBias;

bool operator==(const GEMMDescriptor& rhs) const;
Expand Down
61 changes: 26 additions & 35 deletions lib/nnc/mfa/v2/GEMMKernelDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,45 +72,36 @@ std::pair<simd::ushort3, std::optional<simd::ushort3>> GEMMKernelDescriptor::get

// Branch on whether the allocation is large / target occupancy is low.
if (useLargeAllocation) {
auto idealGroups = coreCount * 6;
if (actualGroups <= idealGroups) {
return std::make_pair(simd::ushort3 { 32, 32, 32 }, std::nullopt);
} else {
auto blockDimensions = simd::ushort3 { 48, 48, 24 };
// Remove CoreCount based block size logic, per https://github.com/philipturner/ccv/commit/e8b0682b4344410eb43cdafb9a9c721ba7fdb726
auto blockDimensions = simd::ushort3 { 48, 48, 24 };

// This is verified to be optimal for:
// - (memA, memB, memC) = (FP32, FP32, FP32)
// - (memA, memB, memC) = (FP16, FP16, FP32)
// - (memA, memB, memC) = (FP16, FP32, FP32)
// - (memA, memB, memC) = (FP16, FP32, FP16)
if (!transposeState[0] && !transposeState[1]) {
return std::make_pair(blockDimensions, simd::ushort3 { 24, 48, 48 });
} else if (!transposeState[0] && transposeState[1]) {
if (memoryPrecisions.B == GEMMOperandPrecision::FP32) {
return std::make_pair(blockDimensions, simd::ushort3 { 24, 28, 48 });
} else {
return std::make_pair(blockDimensions, simd::ushort3 { 24, 24, 48 });
}
} else if (transposeState[0] && !transposeState[1]) {
if (memoryPrecisions.A == GEMMOperandPrecision::FP32) {
return std::make_pair(blockDimensions, simd::ushort3 { 52, 48, 48 });
} else {
return std::make_pair(blockDimensions, simd::ushort3 { 56, 48, 48 });
}
// This is verified to be optimal for:
// - (memA, memB, memC) = (FP32, FP32, FP32)
// - (memA, memB, memC) = (FP16, FP16, FP32)
// - (memA, memB, memC) = (FP16, FP32, FP32)
// - (memA, memB, memC) = (FP16, FP32, FP16)
if (!transposeState[0] && !transposeState[1]) {
return std::make_pair(blockDimensions, simd::ushort3 { 24, 48, 48 });
} else if (!transposeState[0] && transposeState[1]) {
if (memoryPrecisions.B == GEMMOperandPrecision::FP32) {
return std::make_pair(blockDimensions, simd::ushort3 { 24, 28, 48 });
} else {
if (memoryPrecisions.A == GEMMOperandPrecision::FP32) {
return std::make_pair(blockDimensions, simd::ushort3 { 52, 24, 48 });
} else {
return std::make_pair(blockDimensions, simd::ushort3 { 56, 24, 48 });
}
return std::make_pair(blockDimensions, simd::ushort3 { 24, 24, 48 });
}
} else if (transposeState[0] && !transposeState[1]) {
if (memoryPrecisions.A == GEMMOperandPrecision::FP32) {
return std::make_pair(blockDimensions, simd::ushort3 { 52, 48, 48 });
} else {
return std::make_pair(blockDimensions, simd::ushort3 { 56, 48, 48 });
}
}
} else {
auto idealGroups = coreCount * 9;
if (actualGroups <= idealGroups) {
return std::make_pair(simd::ushort3 { 32, 32, 32 }, std::nullopt);
} else {
return std::make_pair(simd::ushort3 { 48, 48, 32 }, std::nullopt);
if (memoryPrecisions.A == GEMMOperandPrecision::FP32) {
return std::make_pair(blockDimensions, simd::ushort3 { 52, 24, 48 });
} else {
return std::make_pair(blockDimensions, simd::ushort3 { 56, 24, 48 });
}
}
} else {
return std::make_pair(simd::ushort3 { 48, 48, 32 }, std::nullopt);
}
}

0 comments on commit 6c65201

Please sign in to comment.