diff --git a/lib/nnc/mfa/v2/AttentionDescriptor.cpp b/lib/nnc/mfa/v2/AttentionDescriptor.cpp index 850985bf8..6bda3c247 100644 --- a/lib/nnc/mfa/v2/AttentionDescriptor.cpp +++ b/lib/nnc/mfa/v2/AttentionDescriptor.cpp @@ -454,7 +454,7 @@ std::vector AttentionDescriptor::forwardMixed(MTL::Device if (device->supportsFamily(MTL::GPUFamily(1009))) { return { AttentionParameterRow(32, 16, 128, 16, { AttentionOperand::Q, AttentionOperand::O }), - AttentionParameterRow(96, 16, 128, 32, { AttentionOperand::Q, AttentionOperand::O }), + AttentionParameterRow(64, 16, 128, 32, { AttentionOperand::Q, AttentionOperand::O }), AttentionParameterRow(160, 32, 128, 32, { AttentionOperand::O }), AttentionParameterRow(224, 32, 128, 32, { AttentionOperand::Q }), AttentionParameterRow(384, 32, 128, 32, {})