diff --git a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp index 87ebb9c4e..0e602cbdc 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp @@ -98,11 +98,11 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p pool->drain(); auto kernel = pipelineValue->kernel; auto pipeline = pipelineValue->pipeline; - // Allocate a new command. + // Allocate a new command. auto encoder = command_batch->startCommand(); encoder->setComputePipelineState(pipeline.get()); encoder->setThreadgroupMemoryLength(kernel->threadgroupMemoryAllocation, 0); - + // Bind the function arguments. encoder->useResource(tensors[0], MTL::ResourceUsageRead); encoder->useResource(tensors[1], MTL::ResourceUsageRead); @@ -146,17 +146,37 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::L).bufferIndex()); } } - + MTL::Size gridSize - (ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]), - hash.Hq, - attentionDesc.batchDimension); + (ceilDivide(int64_t(hash.R), kernel->blockDimensions[0]), 1, 1); MTL::Size groupSize (int64_t(kernel->threadgroupSize), 1, 1); - - // Dispatch the required number of threads. - encoder->dispatchThreadgroups(gridSize, groupSize); - + + const size_t bytesPerElement = attentionDesc.lowPrecisionInputs ? sizeof(uint16_t) : sizeof(float); + for (int i = 0; i < attentionDesc.batchDimension; i++) { + for (int j = 0; j < hash.Hq; j++) { + encoder->setBufferOffset(tensor_offsets[0] + bytesPerElement * (i * hash.R * hash.D * hash.Hq + j * hash.D), AttentionOperand(AttentionOperand::Q).bufferIndex()); + encoder->setBufferOffset(tensor_offsets[1] + bytesPerElement * (i * hash.C * hash.D * hash.Hk + j * hash.D), AttentionOperand(AttentionOperand::K).bufferIndex()); + encoder->setBufferOffset(tensor_offsets[2] + bytesPerElement * (i * hash.C * hash.D * hash.Hk + j * hash.D), AttentionOperand(AttentionOperand::V).bufferIndex()); + if (attentionDesc.lowPrecisionInputs) { + encoder->setBufferOffset(sizeof(float) * (i * hash.R * hash.D * hash.Hq + j * hash.D), AttentionOperand(AttentionOperand::O).bufferIndex()); + if (tensors[5]) { + encoder->setBufferOffset(tensor_offsets[5] + sizeof(float) * (i * hash.R * hash.Hq + j * hash.R), AttentionOperand(AttentionOperand::L).bufferIndex()); + } else { + encoder->setBufferOffset(sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension + sizeof(float) * (i * hash.R * hash.Hq + j * hash.R), AttentionOperand(AttentionOperand::L).bufferIndex()); + } + } else { + encoder->setBufferOffset(tensor_offsets[3] + sizeof(float) * (i * hash.R * hash.D * hash.Hq + j * hash.D), AttentionOperand(AttentionOperand::O).bufferIndex()); + if (tensors[5]) { + encoder->setBufferOffset(tensor_offsets[5] + sizeof(float) * (i * hash.R * hash.Hq + j * hash.R), AttentionOperand(AttentionOperand::L).bufferIndex()); + } else { + encoder->setBufferOffset(sizeof(float) * (i * hash.R * hash.Hq + j * hash.R), AttentionOperand(AttentionOperand::L).bufferIndex()); + } + } + // Dispatch the required number of threads. + encoder->dispatchThreadgroups(gridSize, groupSize); + } + } // Finish the command. command_batch->finishCommand(encoder); if (attentionDesc.lowPrecisionInputs) { diff --git a/lib/nnc/mfa/v2/AttentionKernel.cpp b/lib/nnc/mfa/v2/AttentionKernel.cpp index fe749b509..64cecbef0 100644 --- a/lib/nnc/mfa/v2/AttentionKernel.cpp +++ b/lib/nnc/mfa/v2/AttentionKernel.cpp @@ -520,7 +520,11 @@ std::string AttentionKernel::operandLocationWithHeadOffsetValue(AttentionOperand CodeWriter source; source.SetValue("OPERAND", operand.name()); if (operand.value == AttentionOperand::L || operand.value == AttentionOperand::D) { - source += "{{OPERAND}} + (gid.z * Hq + gid.y) * R\\"; + if (Hq > 1) { + source += "{{OPERAND}} + (gid.z * Hq + gid.y) * R\\"; + } else { + source += "{{OPERAND}} + gid.z * R\\"; + } } else if (Hq > 1) { source.SetValue("HEAD_DIMENSION", std::to_string(headDimension)); if (!transposed(operand)) {