Skip to content

Commit

Permalink
Fix cases simple data movement need to trip into MPSGraph.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jan 2, 2024
1 parent e6dd1ef commit 5575884
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,27 @@ static int _ccv_nnc_format_transform(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint
for (i = 0; i < output_size; i++)
{
const ccv_nnc_tensor_view_t* const a = (const ccv_nnc_tensor_view_t*)inputs[i];
const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)outputs[i];
// If this is just normal data transfer, do this.
if (a->info.format == b->info.format && CCV_IS_TENSOR_CONTIGUOUS(a) && CCV_IS_TENSOR_CONTIGUOUS(b) && ccv_nnc_tensor_count(a->info) == ccv_nnc_tensor_count(b->info) && CCV_GET_DATA_TYPE_SIZE(a->info.datatype) == CCV_GET_DATA_TYPE_SIZE(b->info.datatype))
{
const size_t size = (ssize_t)ccv_nnc_tensor_count(a->info) * CCV_GET_DATA_TYPE_SIZE(a->info.datatype);
if (size == 0)
continue;
const int device_a = CCV_TENSOR_GET_DEVICE_ID(a->info.type);
const int device_b = CCV_TENSOR_GET_DEVICE_ID(b->info.type);
assert(device_a == device_b);
id<MTLBuffer> buffer_a = mpgetbuffer((const ccv_nnc_tensor_t*)a);
id<MTLBuffer> buffer_b = mpgetbuffer((const ccv_nnc_tensor_t*)b);
const off_t offset_a = mpgetoffset((const ccv_nnc_tensor_t*)a);
const off_t offset_b = mpgetoffset((const ccv_nnc_tensor_t*)b);
@autoreleasepool {
id<MTLBlitCommandEncoder> encoder = [command_buffer blitCommandEncoder];
[encoder copyFromBuffer:buffer_a sourceOffset:offset_a toBuffer:buffer_b destinationOffset:offset_b size:size];
[encoder endEncoding];
}
continue;
}
ccv_nnc_tensor_view_t bt = ccv_nnc_get_tensor_view(outputs[i]);
MPSGraph *graph = [MPSGraph new];
graph.options = MPSGraphOptionsSynchronizeResults;
Expand Down

0 comments on commit 5575884

Please sign in to comment.