From 55758844dc1fae5454dca8a7afe222b4f46b6a58 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Tue, 2 Jan 2024 12:52:42 -0500 Subject: [PATCH] Fix cases simple data movement need to trip into MPSGraph. --- lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m b/lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m index b64c74479..083185444 100644 --- a/lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m +++ b/lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m @@ -215,6 +215,27 @@ static int _ccv_nnc_format_transform(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint for (i = 0; i < output_size; i++) { const ccv_nnc_tensor_view_t* const a = (const ccv_nnc_tensor_view_t*)inputs[i]; + const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)outputs[i]; + // If this is just normal data transfer, do this. + if (a->info.format == b->info.format && CCV_IS_TENSOR_CONTIGUOUS(a) && CCV_IS_TENSOR_CONTIGUOUS(b) && ccv_nnc_tensor_count(a->info) == ccv_nnc_tensor_count(b->info) && CCV_GET_DATA_TYPE_SIZE(a->info.datatype) == CCV_GET_DATA_TYPE_SIZE(b->info.datatype)) + { + const size_t size = (ssize_t)ccv_nnc_tensor_count(a->info) * CCV_GET_DATA_TYPE_SIZE(a->info.datatype); + if (size == 0) + continue; + const int device_a = CCV_TENSOR_GET_DEVICE_ID(a->info.type); + const int device_b = CCV_TENSOR_GET_DEVICE_ID(b->info.type); + assert(device_a == device_b); + id buffer_a = mpgetbuffer((const ccv_nnc_tensor_t*)a); + id buffer_b = mpgetbuffer((const ccv_nnc_tensor_t*)b); + const off_t offset_a = mpgetoffset((const ccv_nnc_tensor_t*)a); + const off_t offset_b = mpgetoffset((const ccv_nnc_tensor_t*)b); + @autoreleasepool { + id encoder = [command_buffer blitCommandEncoder]; + [encoder copyFromBuffer:buffer_a sourceOffset:offset_a toBuffer:buffer_b destinationOffset:offset_b size:size]; + [encoder endEncoding]; + } + continue; + } ccv_nnc_tensor_view_t bt = ccv_nnc_get_tensor_view(outputs[i]); MPSGraph *graph = [MPSGraph new]; graph.options = MPSGraphOptionsSynchronizeResults;