Skip to content

Commit

Permalink
Skip cudnn for data movement in data transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jan 2, 2024
1 parent 5575884 commit c55f36f
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions lib/nnc/cmd/util/gpu/ccv_nnc_util_gpu_cudnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@ static int _ccv_nnc_format_transform(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint
{
if (inputs[i]->info.dim[0] == 0 || outputs[i]->info.dim[0] == 0)
continue;
if (inputs[i]->info.format == outputs[i]->info.format && CCV_IS_TENSOR_CONTIGUOUS(inputs[i]) && CCV_IS_TENSOR_CONTIGUOUS(outputs[i]) && ccv_nnc_tensor_count(inputs[i]->info) == ccv_nnc_tensor_count(outputs[i]->info) && CCV_GET_DATA_TYPE_SIZE(inputs[i]->info.datatype) == CCV_GET_DATA_TYPE_SIZE(outputs[i]->info.datatype))
{
const ccv_nnc_tensor_t* const a = (ccv_nnc_tensor_t*)inputs[i];
ccv_nnc_tensor_t* const b = (ccv_nnc_tensor_t*)outputs[i];
const size_t size = (ssize_t)ccv_nnc_tensor_count(a->info) * CCV_GET_DATA_TYPE_SIZE(a->info.datatype);
const int device_a = CCV_TENSOR_GET_DEVICE_ID(a->info.type);
const int device_b = CCV_TENSOR_GET_DEVICE_ID(b->info.type);
if (stream_context)
{
cudaStream_t stream = ccv_nnc_stream_context_get_stream(stream_context);
if (device_a == device_b)
CUDA_ENFORCE(cudaMemcpyAsync(b->data.u8, a->data.u8, size, cudaMemcpyDeviceToDevice, stream));
else
CUDA_ENFORCE(cudaMemcpyPeerAsync(b->data.u8, device_b, a->data.u8, device_a, size, stream));
} else {
if (device_a == device_b)
CUDA_ENFORCE(cudaMemcpy(b->data.u8, a->data.u8, size, cudaMemcpyDeviceToDevice));
else
CUDA_ENFORCE(cudaMemcpyPeer(b->data.u8, device_b, a->data.u8, device_a, size));
}
continue;
}
const ccv_nnc_cudnn_tensor_view_descriptor_t a = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[i]);
const ccv_nnc_cudnn_tensor_view_descriptor_t b = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)outputs[i]);
assert(inputs[i]->info.datatype == outputs[i]->info.datatype);
Expand Down

0 comments on commit c55f36f

Please sign in to comment.