diff --git a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m index 48d02d2f7..c65bda63c 100644 --- a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m +++ b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m @@ -175,10 +175,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); const int is_mfa_gemv = !is_batched && ((a_rows == 1 && is_transpose_w && (w_rows % 4) == 0) || (!is_transpose_a && w_cols == 1 && (a_cols % 4) == 0)); - // v1 only supports the same precision of accumulator as the tensor. - int is_different_accumulator_precision = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F) || ((cmd.info.blas.flags & CCV_NNC_GEMM_16F) && a_datatype == CCV_32F); + int is_upcast = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F); const int is_mfa_supported = - ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || (!(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM) && !is_different_accumulator_precision)); + ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM)); size_t a_data_size = 0; if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX) @@ -364,11 +363,9 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = (is_transpose_a ? 1 : 0), .B_trans = (is_transpose_w ? 1 : 0), .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = (bias ? 1 : 0), + .register_float = (is_upcast ? 1 : 0), .batch_dims_a = { 0 }, .batch_dims_b = { 0 }, @@ -795,10 +792,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 1, .B_trans = (is_transpose_w ? 1 : 0), .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = 0, .batch_dims_a = { 0 }, @@ -834,10 +828,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 0, .B_trans = (is_transpose_w ? 0 : 1), .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = 0, .batch_dims_a = { 0 }, @@ -881,10 +872,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 1, .B_trans = (is_transpose_a ? 1 : 0), .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = 0, .batch_dims_a = { 0 }, @@ -920,10 +908,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = (is_transpose_a ? 0 : 1), .B_trans = 0, .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = 0, .batch_dims_a = { 0 }, diff --git a/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m b/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m index 114667feb..94bacfd3b 100644 --- a/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m +++ b/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m @@ -256,10 +256,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 0, .B_trans = 1, .D_trans = 0, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = (bias ? 1 : 0), .batch_dims_a = { 0 }, @@ -275,10 +272,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .A_trans = 0, .B_trans = 0, .D_trans = 1, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = is_batched, - .fused_activation_function = 0, .fused_bias = (bias ? 1 : 0), .batch_dims_a = { 0 }, diff --git a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m index d49ef0b62..f7436938c 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m +++ b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m @@ -316,10 +316,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c .A_trans = false, .B_trans = true, .D_trans = false, - .alpha = (float)1.0, - .beta = (float)0.0, .batched = 0, - .fused_activation_function = 0, .fused_bias = (bias ? 1 : 0), .batch_dims_a = { 0 }, diff --git a/lib/nnc/mfa/ccv_nnc_mfa.cpp b/lib/nnc/mfa/ccv_nnc_mfa.cpp index b50937e25..5912ef654 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.cpp @@ -90,12 +90,6 @@ void mfa::cache::prepare(mfa::co _mfa_cache_prepare(&map, context, hash); } -template <> -void mfa::cache::prepare(mfa::context* context, mfa::gemm::hash hash) -{ - _mfa_cache_prepare(&map, context, hash); -} - template <> void mfa::cache::prepare(mfa::context* context, mfa::normalization::hash hash) { diff --git a/lib/nnc/mfa/ccv_nnc_mfa.hpp b/lib/nnc/mfa/ccv_nnc_mfa.hpp index e1d12fef6..1ba6f733a 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa.hpp @@ -4,11 +4,11 @@ #include "nnc/ccv_nnc.h" #include "ccv_nnc_mfa_defines.hpp" #include "ccv_nnc_mfa_attention.hpp" -#include "ccv_nnc_mfa_gemm.hpp" #include "ccv_nnc_mfa_normalization.hpp" #include "ccv_nnc_mfa_depalettize.hpp" #include "ccv_nnc_mfa_adam.hpp" #include "ccv_nnc_mfa_cmul.hpp" +#include "ccv_nnc_mfa_gemm.hpp" #include "ccv_nnc_mfa_gemv.hpp" #include "ccv_nnc_mfa_cast.hpp" #include "ccv_nnc_mfa_add.hpp" @@ -49,7 +49,6 @@ class context { context(MTL::Device* device); cache attention_cache; - cache gemm_cache; cache normalization_cache; cache depalettize_cache; cache adam_cache; diff --git a/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp b/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp index 465d09bf8..8142aff5b 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_gemm.cpp @@ -13,7 +13,7 @@ using namespace ccv::nnc; void ccv_nnc_mfa_prepare_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t params) { - context->gemm_cache.prepare(context, mfa::gemm::hash(params)); + // No-op. } void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t params, MTL::CommandBatch* command_batch, MTL::Buffer** tensors, size_t* tensor_offsets) @@ -55,6 +55,8 @@ void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t pa break; } gemmDesc.transposeState = simd::uchar3 { params.A_trans, params.B_trans, params.D_trans }; + gemmDesc.registerPrecisionC = (params.register_float) ? std::optional(GEMMOperandPrecision::FP32) : std::nullopt; + gemmDesc.leadingDimensions = std::nullopt; gemmDesc.loadPreviousC = false; gemmDesc.useBias = params.fused_bias; if (params.batched) { @@ -71,7 +73,7 @@ void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t pa continue; } else if (operand == 3) { // Skip the D operand if unavailable. - if (!(params.fused_activation_function || params.fused_bias)) { + if (!params.fused_bias) { continue; } batch_dims = params.batch_dims_d; @@ -161,196 +163,3 @@ void ccv_nnc_mfa_encode_gemm(mfa::context* context, ccv_nnc_mfa_gemm_params_t pa command_batch->finishCommand(encoder); } -// MARK: - C++ - -mfa::gemm::hash::hash(ccv_nnc_mfa_gemm_params_t params) { - data_type = params.data_type; - M = params.M; - N = params.N; - K = params.K; - A_trans = params.A_trans; - B_trans = params.B_trans; - D_trans = params.D_trans; - alpha = params.alpha; - beta = params.beta; - batched = params.batched; - fused_activation_function = params.fused_activation_function; - fused_bias = params.fused_bias; -} - -bool mfa::gemm::hash::operator==(const mfa::gemm::hash& hash) const { - return - (data_type == hash.data_type) && - (M == hash.M) && - (N == hash.N) && - (K == hash.K) && - (A_trans == hash.A_trans) && - (B_trans == hash.B_trans) && - (D_trans == hash.D_trans) && - (alpha == hash.alpha) && - (beta == hash.beta) && - (batched == hash.batched) && - (fused_activation_function == hash.fused_activation_function) && - (fused_bias == hash.fused_bias); -} - -std::ostream& operator<<(std::ostream& os, const mfa::gemm::hash& hash) { - os << "mfa::gemm::hash {"; - os << " .data_type = " << hash.data_type << ','; - os << " .M = " << hash.M << ','; - os << " .N = " << hash.N << ','; - os << " .K = " << hash.K << ','; - os << " .A_trans = " << bool(hash.A_trans) << ','; - os << " .B_trans = " << bool(hash.B_trans) << ','; - os << " .D_trans = " << bool(hash.D_trans) << ','; - os << " .alpha = " << double(hash.alpha) << ','; - os << " .beta = " << double(hash.beta) << ','; - os << " .batched = " << bool(hash.batched) << ','; - os << " .fused_activation_function = " << bool(hash.fused_activation_function) << ','; - os << " .fused_bias = " << bool(hash.fused_bias) << " "; - os << "}"; - return os; -} - -std::size_t std::hash::operator()(const mfa::gemm::hash& hash) const noexcept { - std::size_t seed = 0; - using namespace mfa::hash; - combine_64(seed, hash.data_type); - combine_64(seed, pack_64(simd::uint2 { hash.M, hash.N })); - combine_64(seed, pack_64(simd::uint2 { hash.K, pack_32(simd::uchar4 { hash.A_trans, hash.B_trans, hash.D_trans, 0 }) })); - combine_64(seed, pack_64(simd::uint2 { *reinterpret_cast(&hash.alpha), *reinterpret_cast(&hash.beta) })); - combine_32(seed, pack_32(simd::uchar4 { hash.batched, hash.fused_activation_function, hash.fused_bias, 0 })); - return seed; -} - -mfa::gemm::pipeline::pipeline(mfa::context* context, mfa::gemm::hash hash) { - CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf)) - CCV_NNC_MFA_PRECONDITION(hash.alpha == 1.0) - CCV_NNC_MFA_PRECONDITION(hash.beta == 0.0) - CCV_NNC_MFA_PRECONDITION(hash.fused_activation_function == false) - - auto* pool = NS::AutoreleasePool::alloc()->init(); - - auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init()); - constants->setConstantValue(&hash.M, MTL::DataTypeUInt, NS::UInteger(0)); - constants->setConstantValue(&hash.N, MTL::DataTypeUInt, 1); - constants->setConstantValue(&hash.K, MTL::DataTypeUInt, 2); - constants->setConstantValue(&hash.A_trans, MTL::DataTypeBool, 10); - constants->setConstantValue(&hash.B_trans, MTL::DataTypeBool, 11); - constants->setConstantValue(&hash.D_trans, MTL::DataTypeBool, 13); - constants->setConstantValue(&hash.alpha, MTL::DataTypeFloat, 20); - constants->setConstantValue(&hash.beta, MTL::DataTypeFloat, 21); - constants->setConstantValue(&hash.batched, MTL::DataTypeBool, 100); - constants->setConstantValue(&hash.fused_activation_function, MTL::DataTypeBool, 101); - constants->setConstantValue(&hash.fused_bias, MTL::DataTypeBool, 50001); - simd::ulong4 garbage(0); - constants->setConstantValue(&garbage, MTL::DataTypeBool, 102); - constants->setConstantValue(&garbage, MTL::DataTypeBool, 103); - constants->setConstantValue(&garbage, MTL::DataTypeBool, 113); - constants->setConstantValue(&garbage, MTL::DataTypeBool, 50000); - - // Eventually, this may incorporate the batch size. - // BxMxN > 1,000,000 -> 48x48, only if M >= 88 and N >= 88 - // BxMxN > 4,000,000 -> 64x64, only if M >= 120 and N >= 120 - uint64_t C_elements = uint64_t(hash.M) * uint64_t(hash.N); - if (hash.batched) { - C_elements *= 2; - } - int is_half = (hash.data_type == MTL::DataTypeHalf); // SD v1 attention - int is_float = (hash.data_type == MTL::DataTypeFloat); // SD v2 attention - - uint16_t M_group = 32; - uint16_t N_group = 32; - uint16_t K_simd = 32; - if (C_elements > 1000 * 1000) { - M_group = 48; - N_group = 48; - } - - // If K_simd is perfectly equal to matrix K, the compiler can elide a large - // amount of logic in the kernel. - if (hash.K >= 33 && hash.K <= 40) { - K_simd = 40; // 1 * 40 - } else if (is_half && hash.K >= 73 && hash.K <= 80) { - K_simd = 40; // 2 * 40 - } else if (C_elements > 1000 * 1000) { - if (hash.K <= 24) { - K_simd = 24; // 1 * 24 - } else if (hash.K <= 32) { - K_simd = 32; // 1 * 32 - } else if (hash.K <= 48) { - K_simd = 24; - } else if (hash.K <= 64) { - K_simd = 32; - } else if (is_float) { - K_simd = 24; - } - } - - uint16_t M_splits = 2; - uint16_t N_splits = 2; - uint16_t M_simd = M_group / M_splits; - uint16_t N_simd = N_group / N_splits; - - constants->setConstantValue(&M_simd, MTL::DataTypeUShort, 200); - constants->setConstantValue(&N_simd, MTL::DataTypeUShort, 201); - constants->setConstantValue(&K_simd, MTL::DataTypeUShort, 202); - constants->setConstantValue(&M_splits, MTL::DataTypeUShort, 210); - constants->setConstantValue(&N_splits, MTL::DataTypeUShort, 211); - - std::string cpp_name; - uint16_t data_type_size = UINT16_MAX; - switch (hash.data_type) { - case MTL::DataTypeHalf: { - cpp_name = "hgemm"; - data_type_size = 2; - break; - } - case MTL::DataTypeFloat: { - cpp_name = "sgemm"; - data_type_size = 4; - break; - } - default: { - CCV_NNC_MFA_PRECONDITION(false) - break; - } - } - auto* swift_name = NS::String::string(cpp_name.c_str(), NS::UTF8StringEncoding); - - uint16_t A_block_bytes = M_group * K_simd * data_type_size; - uint16_t B_block_bytes = K_simd * N_group * data_type_size; - uint16_t C_block_bytes = M_group * N_group * data_type_size; - threadgroup_memory_length = A_block_bytes + B_block_bytes; - - if ((hash.M % 8 > 0) && (hash.N % 8 > 0)) { - if (C_block_bytes > threadgroup_memory_length) { - threadgroup_memory_length = C_block_bytes; - } - } - if (hash.fused_bias) { - uint16_t D_block_bytes = (hash.D_trans ? M_group : N_group) * data_type_size; - if (D_block_bytes > threadgroup_memory_length) { - threadgroup_memory_length = D_block_bytes; - } - } - - std::function ceil_divide = [](size_t original, uint16_t granularity) { - return (original + size_t(granularity) - 1) / size_t(granularity); - }; - grid_size = MTL::Size(ceil_divide(hash.N, N_group), ceil_divide(hash.M, M_group), 1); - group_size = MTL::Size(32 * M_splits * N_splits, 1, 1); - - NS::Error* error = nullptr; - auto function = NS::TransferPtr(context->library->newFunction(swift_name, constants.get(), &error)); - if (!function) { - CCV_NNC_MFA_CHECK_ERROR(error) - } - - pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error)); - if (!pso) { - CCV_NNC_MFA_CHECK_ERROR(error) - } - - pool->drain(); -} diff --git a/lib/nnc/mfa/ccv_nnc_mfa_gemm.hpp b/lib/nnc/mfa/ccv_nnc_mfa_gemm.hpp index a2f701394..92a6e1f3c 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_gemm.hpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_gemm.hpp @@ -9,11 +9,9 @@ typedef struct { uint8_t A_trans; uint8_t B_trans; uint8_t D_trans; - float alpha; - float beta; uint8_t batched; - uint8_t fused_activation_function; uint8_t fused_bias; + uint8_t register_float; // Fill these in the same order as the original shape, but null-terminated. // Both arrays must have the same length. @@ -25,56 +23,6 @@ typedef struct { #ifdef __cplusplus #include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp" #include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp" -#include - -namespace ccv { -namespace nnc { -namespace mfa { -namespace gemm { - -class hash { -public: - uint64_t data_type; - uint32_t M; - uint32_t N; - uint32_t K; - uint8_t A_trans; - uint8_t B_trans; - uint8_t D_trans; - float alpha; - float beta; - uint8_t batched; - uint8_t fused_activation_function; - uint8_t fused_bias; - - hash(ccv_nnc_mfa_gemm_params_t); - - bool operator==(const hash& rhs) const; -}; - -class pipeline { -public: - NS::SharedPtr pso; - - uint16_t threadgroup_memory_length; - MTL::Size grid_size; - MTL::Size group_size; - - pipeline(context* context, hash hash); -}; - -} // namespace gemm -} // namespace mfa -} // namespace nnc -} // namespace ccv - -std::ostream& operator<<(std::ostream& os, const ccv::nnc::mfa::gemm::hash& hash); - -template<> -struct std::hash -{ - std::size_t operator()(const ccv::nnc::mfa::gemm::hash& hash) const noexcept; -}; extern "C" { #endif // __cplusplus diff --git a/lib/nnc/mfa/v2/GEMMKernel.cpp b/lib/nnc/mfa/v2/GEMMKernel.cpp index 8baf0931a..ba96f69fd 100644 --- a/lib/nnc/mfa/v2/GEMMKernel.cpp +++ b/lib/nnc/mfa/v2/GEMMKernel.cpp @@ -357,8 +357,7 @@ source += R"( } std::string GEMMKernel::createConstants() const noexcept { - return R"( - + std::string constants = R"( // Dimensions of each matrix. // - Limitations to matrix size: // - 2^32 in each dimension (M/N/K). @@ -409,7 +408,13 @@ constant uint bias_batch_stride [[function_constant(18)]]; // Whether each matrix is transposed. constant bool A_trans = {{TRANSPOSE_STATE_A}}; constant bool B_trans = {{TRANSPOSE_STATE_B}}; +)"; + if (useBias) { + constants += R"( constant bool bias_trans = {{TRANSPOSE_STATE_BIAS}}; +)"; + } + constants += R"( // Define the memory layout of the matrix block. constant ushort M_group = {{BLOCK_DIMENSIONS_M}}; @@ -436,6 +441,7 @@ constant ushort M_shift = (M < M_group) ? 0 : {{REGISTER_M}} - M_remainder; constant ushort N_shift = (N < N_group) ? 0 : {{REGISTER_N}} - N_remainder; )"; + return constants; } void GEMMKernel::createUtilities(CodeWriter *const source) const noexcept { @@ -534,7 +540,7 @@ void GEMMKernel::createInitializeC(CodeWriter *source) const noexcept { } else { )"; if (useBias) { - if (preferAsyncLoad) { + if (true) { // TODO: figure why on M3 / M4 this is faster. preferAsyncLoad) { source->SetValue("DIRECT_BIAS_ACCESS_CONDITION", "false"); } else { source->SetValue("DIRECT_BIAS_ACCESS_CONDITION", "(M >= M_group) && (N >= N_group)");