From cd446181a00b87e2d5119c8477d395e1e3716efd Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Tue, 23 Jan 2024 15:11:26 -0500 Subject: [PATCH] Add and validated test cases for CUDA ConvTranspose support. --- lib/nnc/ccv_nnc.h | 2 +- lib/nnc/cmd/ccv_nnc_cmd.inc | 4 + lib/nnc/cmd/config.mk | 2 +- lib/nnc/cmd/convolution/ccv_nnc_convolution.c | 4 +- .../gpu/ccv_nnc_conv_transpose_gpu_cudnn.cu | 201 ++++++++++++++++++ test/int/nnc/cudnn.tests.c | 170 +++++++++++++++ 6 files changed, 379 insertions(+), 4 deletions(-) create mode 100644 lib/nnc/cmd/convolution/gpu/ccv_nnc_conv_transpose_gpu_cudnn.cu diff --git a/lib/nnc/ccv_nnc.h b/lib/nnc/ccv_nnc.h index e5e33fe66..1e3da20c6 100644 --- a/lib/nnc/ccv_nnc.h +++ b/lib/nnc/ccv_nnc.h @@ -110,8 +110,8 @@ typedef struct { struct { int count; /**< [convolution_transpose.count] The number of filters for convolutional layer. */ int groups; /**< [convolution_transpose.groups] The number of groups for convolutional layer. */ - int output_padding; /**< [convolution_transpose.output_padding] The output padding to resolve ambiguity when treat this as inverse of convolution. */ int dilation[CCV_NNC_MAX_DIM_ALLOC]; /**< [convolution_transpose.dilation[]] The dilation factor for convolutional layer. Default to 1. */ + int output_padding; /**< [convolution_transpose.output_padding] The output padding to resolve ambiguity when treat this as inverse of convolution. */ } convolution_transpose; struct { int hidden_size; /**< [rnn.hidden_size] The number of features in the hidden state h. */ diff --git a/lib/nnc/cmd/ccv_nnc_cmd.inc b/lib/nnc/cmd/ccv_nnc_cmd.inc index fb0208958..9d85c1089 100644 --- a/lib/nnc/cmd/ccv_nnc_cmd.inc +++ b/lib/nnc/cmd/ccv_nnc_cmd.inc @@ -484,6 +484,8 @@ void _register_command_CCV_NNC_COMPRESSION_LSSC_FORWARD_backend_CCV_NNC_BACKEND_ void _register_command_CCV_NNC_COMPRESSION_LSSC_BACKWARD_backend_CCV_NNC_BACKEND_GPU_REF(ccv_nnc_cmd_backend_registry_t* const registry); void _register_command_CCV_NNC_CONVOLUTION_FORWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(ccv_nnc_cmd_backend_registry_t* const registry); void _register_command_CCV_NNC_CONVOLUTION_BACKWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(ccv_nnc_cmd_backend_registry_t* const registry); +void _register_command_CCV_NNC_CONVOLUTION_TRANSPOSE_FORWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(ccv_nnc_cmd_backend_registry_t* const registry); +void _register_command_CCV_NNC_CONVOLUTION_TRANSPOSE_BACKWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(ccv_nnc_cmd_backend_registry_t* const registry); void _register_command_CCV_NNC_DROPOUT_FORWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(ccv_nnc_cmd_backend_registry_t* const registry); void _register_command_CCV_NNC_DROPOUT_BACKWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(ccv_nnc_cmd_backend_registry_t* const registry); void _register_command_CCV_NNC_EWSUM_FORWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(ccv_nnc_cmd_backend_registry_t* const registry); @@ -950,6 +952,8 @@ static inline void _ccv_nnc_cmd_init(void) _register_command_CCV_NNC_COMPRESSION_LSSC_BACKWARD_backend_CCV_NNC_BACKEND_GPU_REF(&(init_map[13].backends[5])); _register_command_CCV_NNC_CONVOLUTION_FORWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(&(init_map[116].backends[3])); _register_command_CCV_NNC_CONVOLUTION_BACKWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(&(init_map[117].backends[3])); + _register_command_CCV_NNC_CONVOLUTION_TRANSPOSE_FORWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(&(init_map[48].backends[3])); + _register_command_CCV_NNC_CONVOLUTION_TRANSPOSE_BACKWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(&(init_map[49].backends[3])); _register_command_CCV_NNC_DROPOUT_FORWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(&(init_map[4].backends[3])); _register_command_CCV_NNC_DROPOUT_BACKWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(&(init_map[5].backends[3])); _register_command_CCV_NNC_EWSUM_FORWARD_backend_CCV_NNC_BACKEND_GPU_CUDNN(&(init_map[98].backends[3])); diff --git a/lib/nnc/cmd/config.mk b/lib/nnc/cmd/config.mk index b21612037..657832da0 100644 --- a/lib/nnc/cmd/config.mk +++ b/lib/nnc/cmd/config.mk @@ -1,3 +1,3 @@ CMD_SRCS := ./adam/ccv_nnc_adam_cpu_ref.c ./adam/ccv_nnc_adamw_cpu_ref.c ./blas/ccv_nnc_gemm_cpu_ref.c ./blas/ccv_nnc_gemm_cpu_opt.c ./blas/ccv_nnc_add_cpu_ref.c ./blas/ccv_nnc_mul_cpu_ref.c ./blas/ccv_nnc_cmul_cpu_ref.c ./compare/ccv_nnc_min_cpu_ref.c ./compare/ccv_nnc_max_cpu_ref.c ./compression/ccv_nnc_lssc_cpu_ref.c ./convolution/ccv_nnc_conv_cpu_ref.c ./convolution/ccv_nnc_conv_cpu_opt.c ./convolution/ccv_nnc_conv_transpose_cpu_ref.c ./dropout/ccv_nnc_dropout_cpu_ref.c ./ew/ccv_nnc_ew_cpu_ref.c ./gelu/ccv_nnc_gelu_cpu_ref.c ./histogram/ccv_nnc_histogram_cpu_ref.c ./index/ccv_nnc_index_select_cpu_ref.c ./isnan/ccv_nnc_reduce_isnan_cpu_ref.c ./lamb/ccv_nnc_lamb_cpu_ref.c ./leaky_relu/ccv_nnc_leaky_relu_cpu_ref.c ./loss/ccv_nnc_binary_crossentropy_cpu_ref.c ./loss/ccv_nnc_categorical_crossentropy_cpu_ref.c ./loss/ccv_nnc_mse_cpu_ref.c ./loss/ccv_nnc_smooth_l1_cpu_ref.c ./nms/ccv_nnc_nms_cpu_ref.c ./norm/ccv_nnc_batch_norm_cpu_ref.c ./norm/ccv_nnc_layer_norm_cpu_ref.c ./norm/ccv_nnc_group_norm_cpu_ref.c ./norm/ccv_nnc_rmsnorm_cpu_ref.c ./pool/ccv_nnc_max_pool_cpu_ref.c ./pool/ccv_nnc_avg_pool_cpu_ref.c ./rand/ccv_nnc_rand_uniform_cpu_ref.c ./rand/ccv_nnc_rand_normal_cpu_ref.c ./reduce/ccv_nnc_reduce_sum_cpu_ref.c ./reduce/ccv_nnc_reduce_mean_cpu_ref.c ./reduce/ccv_nnc_reduce_max_cpu_ref.c ./reduce/ccv_nnc_reduce_min_cpu_ref.c ./reduce/ccv_nnc_reduce_norm2_cpu_ref.c ./reduce/ccv_nnc_argmax_cpu_ref.c ./reduce/ccv_nnc_argmin_cpu_ref.c ./relu/ccv_nnc_relu_cpu_ref.c ./rmsprop/ccv_nnc_rmsprop_cpu_ref.c ./roi/ccv_nnc_roi_align_cpu_ref.c ./scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c ./sgd/ccv_nnc_sgd_cpu_ref.c ./sigmoid/ccv_nnc_sigmoid_cpu_ref.c ./sigmoid_loss/ccv_nnc_sigmoid_binary_crossentropy_cpu_ref.c ./softmax/ccv_nnc_softmax_cpu_ref.c ./softmax_loss/ccv_nnc_softmax_crossentropy_cpu_ref.c ./swish/ccv_nnc_swish_cpu_ref.c ./tanh/ccv_nnc_tanh_cpu_ref.c ./upsample/ccv_nnc_upsample_cpu_ref.c ./util/ccv_nnc_util_cpu_ref.c ./adam/ccv_nnc_adam.c ./blas/ccv_nnc_blas.c ./blas/cpu_opt/_ccv_nnc_gemm_cpu_opt.c ./blas/cpu_sys/_ccv_nnc_gemm_cpu_sys.c ./comm/ccv_nnc_comm.c ./compare/ccv_nnc_cmp.c ./compression/ccv_nnc_compression.c ./convolution/cpu_opt/_ccv_nnc_conv_cpu_4x4_3x3_winograd.c ./convolution/cpu_opt/_ccv_nnc_conv_cpu_fft.c ./convolution/cpu_opt/_ccv_nnc_conv_cpu_gemm.c ./convolution/cpu_opt/_ccv_nnc_conv_cpu_opt.c ./convolution/ccv_nnc_convolution.c ./dropout/ccv_nnc_dropout.c ./ew/ccv_nnc_ew.c ./gelu/ccv_nnc_gelu.c ./histogram/ccv_nnc_histogram.c ./index/ccv_nnc_index_select.c ./isnan/ccv_nnc_reduce_isnan.c ./lamb/ccv_nnc_lamb.c ./leaky_relu/ccv_nnc_leaky_relu.c ./loss/ccv_nnc_binary_crossentropy.c ./loss/ccv_nnc_categorical_crossentropy.c ./loss/ccv_nnc_mse.c ./loss/ccv_nnc_smooth_l1.c ./nms/ccv_nnc_nms.c ./norm/ccv_nnc_norm.c ./pool/ccv_nnc_pool.c ./rand/ccv_nnc_rand.c ./reduce/ccv_nnc_reduce.c ./relu/ccv_nnc_relu.c ./rmsprop/ccv_nnc_rmsprop.c ./rnn/ccv_nnc_lstm.c ./roi/ccv_nnc_roi_align.c ./scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention.c ./sgd/ccv_nnc_sgd.c ./sigmoid/ccv_nnc_sigmoid.c ./sigmoid_loss/ccv_nnc_sigmoid_binary_crossentropy.c ./softmax/ccv_nnc_softmax.c ./softmax_loss/ccv_nnc_softmax_crossentropy.c ./swish/ccv_nnc_swish.c ./tanh/ccv_nnc_tanh.c ./upsample/ccv_nnc_upsample.c ./util/ccv_nnc_util.c -CUDA_CMD_SRCS := ./adam/gpu/ccv_nnc_adam_gpu_ref.cu ./adam/gpu/ccv_nnc_adamw_gpu_ref.cu ./blas/gpu/ccv_nnc_gemm_gpu_cublas.cu ./blas/gpu/ccv_nnc_add_gpu_cudnn.cu ./blas/gpu/ccv_nnc_mul_gpu_cudnn.cu ./blas/gpu/ccv_nnc_cmul_gpu_ref.cu ./comm/gpu/ccv_nnc_comm_gpu_nccl.cu ./compare/gpu/ccv_nnc_min_gpu_ref.cu ./compare/gpu/ccv_nnc_max_gpu_ref.cu ./compression/gpu/ccv_nnc_lssc_gpu_ref.cu ./convolution/gpu/ccv_nnc_conv_gpu_cudnn.cu ./dropout/gpu/ccv_nnc_dropout_gpu_cudnn.cu ./ew/gpu/ccv_nnc_ew_gpu_cudnn.cu ./ew/gpu/ccv_nnc_ew_gpu_ref.cu ./gelu/gpu/ccv_nnc_gelu_gpu_ref.cu ./index/gpu/ccv_nnc_index_select_gpu_ref.cu ./isnan/gpu/ccv_nnc_reduce_isnan_gpu_cudnn.cu ./lamb/gpu/ccv_nnc_lamb_gpu_ref.cu ./leaky_relu/gpu/ccv_nnc_leaky_relu_gpu_ref.cu ./loss/gpu/ccv_nnc_binary_crossentropy_gpu_ref.cu ./loss/gpu/ccv_nnc_categorical_crossentropy_gpu_ref.cu ./loss/gpu/ccv_nnc_mse_gpu_ref.cu ./loss/gpu/ccv_nnc_smooth_l1_gpu_ref.cu ./nms/gpu/ccv_nnc_nms_gpu_ref.cu ./norm/gpu/ccv_nnc_batch_norm_gpu_cudnn.cu ./norm/gpu/ccv_nnc_layer_norm_gpu_cudnn.cu ./norm/gpu/ccv_nnc_group_norm_gpu_cudnn.cu ./norm/gpu/ccv_nnc_rmsnorm_gpu_cudnn.cu ./pool/gpu/ccv_nnc_max_pool_gpu_cudnn.cu ./pool/gpu/ccv_nnc_avg_pool_gpu_cudnn.cu ./rand/gpu/ccv_nnc_rand_uniform_gpu_ref.cu ./rand/gpu/ccv_nnc_rand_normal_gpu_ref.cu ./reduce/gpu/ccv_nnc_reduce_sum_gpu_cudnn.cu ./reduce/gpu/ccv_nnc_reduce_mean_gpu_cudnn.cu ./reduce/gpu/ccv_nnc_reduce_norm2_gpu_cudnn.cu ./reduce/gpu/ccv_nnc_argmax_gpu_ref.cu ./reduce/gpu/ccv_nnc_argmin_gpu_ref.cu ./relu/gpu/ccv_nnc_relu_gpu_cudnn.cu ./rmsprop/gpu/ccv_nnc_rmsprop_gpu_ref.cu ./rnn/gpu/ccv_nnc_lstm_gpu_cudnn.cu ./roi/gpu/ccv_nnc_roi_align_gpu_ref.cu ./scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu ./sgd/gpu/ccv_nnc_sgd_gpu_ref.cu ./sigmoid/gpu/ccv_nnc_sigmoid_gpu_cudnn.cu ./sigmoid_loss/gpu/ccv_nnc_sigmoid_binary_crossentropy_gpu_ref.cu ./softmax/gpu/ccv_nnc_softmax_gpu_cudnn.cu ./softmax_loss/gpu/ccv_nnc_softmax_crossentropy_gpu_cudnn.cu ./swish/gpu/ccv_nnc_swish_gpu_ref.cu ./tanh/gpu/ccv_nnc_tanh_gpu_cudnn.cu ./upsample/gpu/ccv_nnc_upsample_gpu_ref.cu ./util/gpu/ccv_nnc_util_gpu_cudnn.cu ./util/gpu/ccv_nnc_util_gpu_ref.cu +CUDA_CMD_SRCS := ./adam/gpu/ccv_nnc_adam_gpu_ref.cu ./adam/gpu/ccv_nnc_adamw_gpu_ref.cu ./blas/gpu/ccv_nnc_gemm_gpu_cublas.cu ./blas/gpu/ccv_nnc_add_gpu_cudnn.cu ./blas/gpu/ccv_nnc_mul_gpu_cudnn.cu ./blas/gpu/ccv_nnc_cmul_gpu_ref.cu ./comm/gpu/ccv_nnc_comm_gpu_nccl.cu ./compare/gpu/ccv_nnc_min_gpu_ref.cu ./compare/gpu/ccv_nnc_max_gpu_ref.cu ./compression/gpu/ccv_nnc_lssc_gpu_ref.cu ./convolution/gpu/ccv_nnc_conv_gpu_cudnn.cu ./convolution/gpu/ccv_nnc_conv_transpose_gpu_cudnn.cu ./dropout/gpu/ccv_nnc_dropout_gpu_cudnn.cu ./ew/gpu/ccv_nnc_ew_gpu_cudnn.cu ./ew/gpu/ccv_nnc_ew_gpu_ref.cu ./gelu/gpu/ccv_nnc_gelu_gpu_ref.cu ./index/gpu/ccv_nnc_index_select_gpu_ref.cu ./isnan/gpu/ccv_nnc_reduce_isnan_gpu_cudnn.cu ./lamb/gpu/ccv_nnc_lamb_gpu_ref.cu ./leaky_relu/gpu/ccv_nnc_leaky_relu_gpu_ref.cu ./loss/gpu/ccv_nnc_binary_crossentropy_gpu_ref.cu ./loss/gpu/ccv_nnc_categorical_crossentropy_gpu_ref.cu ./loss/gpu/ccv_nnc_mse_gpu_ref.cu ./loss/gpu/ccv_nnc_smooth_l1_gpu_ref.cu ./nms/gpu/ccv_nnc_nms_gpu_ref.cu ./norm/gpu/ccv_nnc_batch_norm_gpu_cudnn.cu ./norm/gpu/ccv_nnc_layer_norm_gpu_cudnn.cu ./norm/gpu/ccv_nnc_group_norm_gpu_cudnn.cu ./norm/gpu/ccv_nnc_rmsnorm_gpu_cudnn.cu ./pool/gpu/ccv_nnc_max_pool_gpu_cudnn.cu ./pool/gpu/ccv_nnc_avg_pool_gpu_cudnn.cu ./rand/gpu/ccv_nnc_rand_uniform_gpu_ref.cu ./rand/gpu/ccv_nnc_rand_normal_gpu_ref.cu ./reduce/gpu/ccv_nnc_reduce_sum_gpu_cudnn.cu ./reduce/gpu/ccv_nnc_reduce_mean_gpu_cudnn.cu ./reduce/gpu/ccv_nnc_reduce_norm2_gpu_cudnn.cu ./reduce/gpu/ccv_nnc_argmax_gpu_ref.cu ./reduce/gpu/ccv_nnc_argmin_gpu_ref.cu ./relu/gpu/ccv_nnc_relu_gpu_cudnn.cu ./rmsprop/gpu/ccv_nnc_rmsprop_gpu_ref.cu ./rnn/gpu/ccv_nnc_lstm_gpu_cudnn.cu ./roi/gpu/ccv_nnc_roi_align_gpu_ref.cu ./scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu ./sgd/gpu/ccv_nnc_sgd_gpu_ref.cu ./sigmoid/gpu/ccv_nnc_sigmoid_gpu_cudnn.cu ./sigmoid_loss/gpu/ccv_nnc_sigmoid_binary_crossentropy_gpu_ref.cu ./softmax/gpu/ccv_nnc_softmax_gpu_cudnn.cu ./softmax_loss/gpu/ccv_nnc_softmax_crossentropy_gpu_cudnn.cu ./swish/gpu/ccv_nnc_swish_gpu_ref.cu ./tanh/gpu/ccv_nnc_tanh_gpu_cudnn.cu ./upsample/gpu/ccv_nnc_upsample_gpu_ref.cu ./util/gpu/ccv_nnc_util_gpu_cudnn.cu ./util/gpu/ccv_nnc_util_gpu_ref.cu MPS_CMD_SRCS := ./adam/mps/ccv_nnc_adam_mps.m ./adam/mps/ccv_nnc_adamw_mps.m ./blas/mps/ccv_nnc_gemm_mps.m ./blas/mps/ccv_nnc_add_mps.m ./blas/mps/ccv_nnc_mul_mps.m ./blas/mps/ccv_nnc_cmul_mps.m ./convolution/mps/ccv_nnc_conv_mps.m ./ew/mps/ccv_nnc_ew_mps.m ./gelu/mps/ccv_nnc_gelu_mps.m ./index/mps/ccv_nnc_index_select_mps.m ./isnan/mps/ccv_nnc_reduce_isnan_mps.m ./leaky_relu/mps/ccv_nnc_leaky_relu_mps.m ./loss/mps/ccv_nnc_mse_mps.m ./norm/mps/ccv_nnc_layer_norm_mps.m ./norm/mps/ccv_nnc_group_norm_mps.m ./norm/mps/ccv_nnc_rmsnorm_mps.m ./pool/mps/ccv_nnc_max_pool_mps.m ./pool/mps/ccv_nnc_avg_pool_mps.m ./rand/mps/ccv_nnc_rand_uniform_mps.m ./rand/mps/ccv_nnc_rand_normal_mps.m ./reduce/mps/ccv_nnc_reduce_sum_mps.m ./reduce/mps/ccv_nnc_reduce_mean_mps.m ./reduce/mps/ccv_nnc_reduce_max_mps.m ./reduce/mps/ccv_nnc_reduce_min_mps.m ./reduce/mps/ccv_nnc_argmax_mps.m ./reduce/mps/ccv_nnc_argmin_mps.m ./relu/mps/ccv_nnc_relu_mps.m ./scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m ./sigmoid/mps/ccv_nnc_sigmoid_mps.m ./softmax/mps/ccv_nnc_softmax_mps.m ./swish/mps/ccv_nnc_swish_mps.m ./upsample/mps/ccv_nnc_upsample_mps.m ./util/mps/ccv_nnc_util_mps.m diff --git a/lib/nnc/cmd/convolution/ccv_nnc_convolution.c b/lib/nnc/cmd/convolution/ccv_nnc_convolution.c index 87b1625c8..a45dee8f5 100644 --- a/lib/nnc/cmd/convolution/ccv_nnc_convolution.c +++ b/lib/nnc/cmd/convolution/ccv_nnc_convolution.c @@ -102,14 +102,14 @@ static void _ccv_nnc_conv_transpose_tensor_auto_forw(const ccv_nnc_cmd_param_t c } REGISTER_COMMAND(CCV_NNC_CONVOLUTION_TRANSPOSE_FORWARD)(ccv_nnc_cmd_registry_t* const registry) - FIND_BACKEND(ccv_nnc_conv_transpose_cpu_ref.c) + FIND_BACKEND(ccv_nnc_conv_transpose_cpu_ref.c, gpu/ccv_nnc_conv_transpose_gpu_cudnn.cu) { registry->bitmask = _ccv_nnc_conv_forw_bitmask; registry->tensor_auto = _ccv_nnc_conv_transpose_tensor_auto_forw; } REGISTER_COMMAND(CCV_NNC_CONVOLUTION_TRANSPOSE_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) - FIND_BACKEND(ccv_nnc_conv_transpose_cpu_ref.c) + FIND_BACKEND(ccv_nnc_conv_transpose_cpu_ref.c, gpu/ccv_nnc_conv_transpose_gpu_cudnn.cu) { registry->bitmask = _ccv_nnc_conv_back_bitmask; registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_inputs; diff --git a/lib/nnc/cmd/convolution/gpu/ccv_nnc_conv_transpose_gpu_cudnn.cu b/lib/nnc/cmd/convolution/gpu/ccv_nnc_conv_transpose_gpu_cudnn.cu new file mode 100644 index 000000000..21dbdac76 --- /dev/null +++ b/lib/nnc/cmd/convolution/gpu/ccv_nnc_conv_transpose_gpu_cudnn.cu @@ -0,0 +1,201 @@ +extern "C" { +#include +#include +#include +#include +#include +} +#include + +#ifdef HAVE_CUDNN + +enum { + CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_0, // CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 + CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_1, // CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 + CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_FFT, // CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT + CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_FFT_TILING, // CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING + CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_WINOGRAD, // CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD + CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_WINOGRAD_NONFUSED, // CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED + CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_COUNT +}; + +static int _ccv_nnc_conv_transpose_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) +{ + assert(input_size >= 2); + assert(output_size == 1); + cudnnHandle_t cudnn = ccv_nnc_stream_context_get_cudnn(stream_context); + const ccv_nnc_cudnn_tensor_view_descriptor_t a = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[0]); + const ccv_nnc_cudnn_filter_descriptor_t w = ccv_nnc_cudnn_get_filter_descriptor(stream_context, (const ccv_nnc_tensor_t*)inputs[1]); + const ccv_nnc_cudnn_tensor_view_descriptor_t b = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)outputs[0]); + const int is_w_nhwc = inputs[1]->info.format == CCV_TENSOR_FORMAT_NHWC; + const int w_datatype = inputs[1]->info.datatype; + const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype); + cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution_transpose.groups); + + cudnnConvolutionBwdDataAlgo_t data_algo; + const int data_algorithm = cmd.algorithm < 0 ? -1 : cmd.algorithm; + switch (data_algorithm) + { + case CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_0: + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; + break; + case CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_1: + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + break; + case CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_FFT: + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT; + break; + case CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_FFT_TILING: + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING; + break; + case CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_WINOGRAD: + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD; + break; + case CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_WINOGRAD_NONFUSED: + data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED; + break; + default: // -1: Using preferences to find a suitable algorithm +#if CUDNN_VERSION >= 7000 + int data_algo_count; + cudnnConvolutionBwdDataAlgoPerf_t data_perf; + CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn, w.descriptor, a.descriptor, conv.descriptor, b.descriptor, 1, &data_algo_count, &data_perf)); + assert(data_algo_count > 0); + data_algo = data_perf.algo; +#else + CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataAlgorithm(cudnn, w.descriptor, a.descriptor, conv.descriptor, b.descriptor, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &data_algo)); +#endif + } + size_t workspace_size = 0; + CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn, w.descriptor, a.descriptor, conv.descriptor, b.descriptor, data_algo, &workspace_size)); + void* workspace = 0; + void* weight_data = w.data.u8; + if (CCV_GET_DATA_TYPE(inputs[2]->info.datatype) == CCV_QX) + { + ccv_nnc_tensor_param_t weight_params = inputs[2]->info; + const size_t count = ccv_nnc_tensor_count(weight_params); + const int palette_datatype = (weight_params.datatype & 0xff) << 12; + const int qbits = (weight_params.datatype & 0xf00) >> 8; + const int number_in_blocks = weight_params.reserved; + ccv_nnc_tensor_param_t depalettize_weight_params = weight_params; + depalettize_weight_params.datatype = palette_datatype; + depalettize_weight_params.reserved = 0; + const size_t data_size = ccv_nnc_tensor_data_size(depalettize_weight_params); + workspace_size = ((ssize_t)workspace_size + 1023) & -1024; // Somehow the workspace size is not padded. We need to pad it for weight_data to be aligned. + workspace = ccv_nnc_stream_context_get_workspace(stream_context, workspace_size + data_size, CCV_TENSOR_GPU_MEMORY); + weight_data = (uint8_t*)workspace + workspace_size; + ccv_nnc_compat_depalettize(w.data.u8, palette_datatype, ccv_nnc_tensor_data_size_without_padding(weight_params), qbits, number_in_blocks, weight_data, count, stream_context); + if (workspace_size == 0) + workspace = 0; + } else { + // TODO: If error, return OOM + if (workspace_size) + workspace = ccv_nnc_stream_context_get_workspace(stream_context, workspace_size, CCV_TENSOR_GPU_MEMORY); + } + static const float one = 1, zero = 0; + CUDNN_ENFORCE(cudnnConvolutionBackwardData(cudnn, &one, w.descriptor, weight_data, a.descriptor, a.data.u8, conv.descriptor, data_algo, workspace, workspace_size, &zero, b.descriptor, b.data.u8)); + if (input_size > 2 && inputs[2]) + { + const ccv_nnc_cudnn_tensor_view_descriptor_t bias = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[2]); + CUDNN_ENFORCE(cudnnAddTensor(cudnn, &one, bias.descriptor, bias.data.u8, &one, b.descriptor, b.data.u8)); + ccv_nnc_cudnn_deinit_tensor_view_descriptor(bias); + } + ccv_nnc_cudnn_deinit_tensor_view_descriptor(a); + ccv_nnc_cudnn_deinit_filter_descriptor(w); + ccv_nnc_cudnn_deinit_tensor_view_descriptor(b); + ccv_nnc_cudnn_deinit_convolution_descriptor(conv); + return CCV_NNC_EXEC_SUCCESS; +} + +static int _ccv_nnc_conv_transpose_forw_autotune(const ccv_nnc_cmd_t cmd, size_t max_workspace_size, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) +{ + assert(input_size >= 2); + assert(output_size == 1); + cudnnHandle_t cudnn = ccv_nnc_stream_context_get_cudnn(stream_context); + void* workmem = ccv_nnc_stream_context_get_workspace(stream_context, max_workspace_size, CCV_TENSOR_GPU_MEMORY); + if (max_workspace_size && !workmem) + return -1; + const ccv_nnc_cudnn_tensor_view_descriptor_t a = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[0]); + const ccv_nnc_cudnn_filter_descriptor_t w = ccv_nnc_cudnn_get_filter_descriptor(stream_context, (const ccv_nnc_tensor_t*)inputs[1]); + const ccv_nnc_cudnn_tensor_view_descriptor_t b = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)outputs[0]); + const int is_w_nhwc = inputs[1]->info.format == CCV_TENSOR_FORMAT_NHWC; + const int w_datatype = inputs[1]->info.datatype; + const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype); + cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution.groups); + int count = 0; + cudnnConvolutionBwdDataAlgoPerf_t data_perfs[CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_COUNT]; + void* weight_data = w.data.u8; + if (CCV_GET_DATA_TYPE(inputs[1]->info.datatype) == CCV_QX) + { + ccv_nnc_tensor_param_t weight_params = inputs[1]->info; + const int palette_datatype = (weight_params.datatype & 0xff) << 12; + ccv_nnc_tensor_param_t depalettize_weight_params = weight_params; + depalettize_weight_params.datatype = palette_datatype; + depalettize_weight_params.reserved = 0; + const size_t data_size = ccv_nnc_tensor_data_size(depalettize_weight_params); + max_workspace_size = ((ssize_t)max_workspace_size + 1023) & -1024; // Somehow the workspace size is not padded. We need to pad it for weight_data to be aligned. + workmem = ccv_nnc_stream_context_get_workspace(stream_context, max_workspace_size + data_size, CCV_TENSOR_GPU_MEMORY); + weight_data = (uint8_t*)workmem + max_workspace_size; + if (max_workspace_size == 0) + workmem = 0; + } + CUDNN_ENFORCE(cudnnFindConvolutionBackwardDataAlgorithmEx(cudnn, w.descriptor, weight_data, a.descriptor, a.data.u8, conv.descriptor, b.descriptor, b.data.u8, CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_COUNT, &count, data_perfs, workmem, max_workspace_size)); + int i; + cudnnConvolutionBwdDataAlgo_t data_algorithm = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + for(i = 0; i < count; i++) + if ((size_t)data_perfs[i].memory <= max_workspace_size && data_perfs[i].status == CUDNN_STATUS_SUCCESS) + { + data_algorithm = data_perfs[i].algo; + break; + } + ccv_nnc_cudnn_deinit_tensor_view_descriptor(a); + ccv_nnc_cudnn_deinit_filter_descriptor(w); + ccv_nnc_cudnn_deinit_tensor_view_descriptor(b); + ccv_nnc_cudnn_deinit_convolution_descriptor(conv); + switch (data_algorithm) + { + case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0: + return CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_0; + case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1: + return CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_1; + case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT: + return CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_FFT; + case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING: + return CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_FFT_TILING; + case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD: + return CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_WINOGRAD; + case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED: + return CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_WINOGRAD_NONFUSED; + case CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT: + break; + } + return -1; // Return the most efficient algorithm, return -1 if cannot find one. +} + +static int _ccv_nnc_conv_transpose_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context) +{ + return CCV_NNC_EXEC_INVALID; +} + +#endif + +REGISTER_COMMAND_BACKEND(CCV_NNC_CONVOLUTION_TRANSPOSE_FORWARD, CCV_NNC_BACKEND_GPU_CUDNN)(ccv_nnc_cmd_backend_registry_t* const registry) +{ +#ifdef HAVE_CUDNN + registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC; + registry->tensor_datatypes = CCV_32F | CCV_16F | CCV_QX; + registry->tensor_memory = CCV_TENSOR_GPU_MEMORY; + registry->algorithms = CCV_NNC_CMD_CUDNN_CONV_BWD_DATA_ALGO_COUNT; + registry->exec = _ccv_nnc_conv_transpose_forw; + registry->autotune = _ccv_nnc_conv_transpose_forw_autotune; +#endif +} + +REGISTER_COMMAND_BACKEND(CCV_NNC_CONVOLUTION_TRANSPOSE_BACKWARD, CCV_NNC_BACKEND_GPU_CUDNN)(ccv_nnc_cmd_backend_registry_t* const registry) +{ +#ifdef HAVE_CUDNN + registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC; + registry->tensor_datatypes = CCV_32F | CCV_16F | CCV_QX; + registry->tensor_memory = CCV_TENSOR_GPU_MEMORY; + registry->exec = _ccv_nnc_conv_transpose_back; +#endif +} diff --git a/test/int/nnc/cudnn.tests.c b/test/int/nnc/cudnn.tests.c index 24fa55dfb..f193426e5 100644 --- a/test/int/nnc/cudnn.tests.c +++ b/test/int/nnc/cudnn.tests.c @@ -4789,4 +4789,174 @@ TEST_CASE("broadcasting semantics for mul backward (no input grad) for a") ccv_nnc_tensor_free(gdb); } +TEST_CASE("cudnn forward convolution transpose") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CONVOLUTION_TRANSPOSE_FORWARD, CCV_NNC_BACKEND_GPU_CUDNN)); + ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + ccv_nnc_tensor_t* b = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_cmd_t cmd = CMD_CONVOLUTION_TRANSPOSE_FORWARD(1, INPUT_DIM, 0, KERNEL_SIZE, KERNEL_SIZE, OUTPUT_DIM); + cmd.backend = CCV_NNC_BACKEND_CPU_REF; + assert(cmd.backend >= 0); + ccv_nnc_hint_t hint = ccv_nnc_hint_auto(cmd.info, b->info, a->info); + ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* bias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, INPUT_DIM), 0); + // configure the inlets. + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 0); + int i; + for (i = 0; i < INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE * OUTPUT_DIM; i++) + w->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE); + for (i = 0; i < OUTPUT_SIZE * OUTPUT_SIZE * OUTPUT_DIM * ccv_max(1, BATCH_SIZE); i++) + a->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < INPUT_DIM; i++) + bias->data.f32[i] = (float)i / INPUT_DIM; + // Copy generated matrix values over to GPU. + ccv_nnc_tensor_t* ga = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + ccv_nnc_tensor_t* gw = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* gbias = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, INPUT_DIM), 0); + ccv_nnc_cmd_t move = CMD_DATA_TRANSFER_FORWARD(); + move.backend = CCV_NNC_BACKEND_GPU_REF; + assert(move.backend >= 0); + ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(ga, gw, gbias), 0); + ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(b), 0); + ccv_nnc_tensor_t* gc = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + + ccv_nnc_stream_context_t* stream_context = ccv_nnc_stream_context_new(CCV_STREAM_CONTEXT_GPU); + cmd.backend = CCV_NNC_BACKEND_GPU_CUDNN; + assert(cmd.backend >= 0); + cmd.algorithm = -1; + cmd = ccv_nnc_cmd_autotune(cmd, 1 * 1024 * 1024 * 1024, hint, 0, TENSOR_LIST(ga, gw, gbias), TENSOR_LIST(gc), stream_context); + assert(CCV_NNC_EXEC_SUCCESS == ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(ga, gw, gbias), TENSOR_LIST(gc), stream_context)); + ccv_nnc_stream_context_wait(stream_context); + ccv_nnc_stream_context_free(stream_context); + ccv_nnc_tensor_t* c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(gc), TENSOR_LIST(c), 0); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, b->data.f32, c->data.f32, BATCH_SIZE * INPUT_DIM * INPUT_SIZE * INPUT_SIZE, 1e-4, "output from cudnn should match from CPU"); + ccv_nnc_tensor_free(c); + ccv_nnc_tensor_free(gc); + ccv_nnc_tensor_free(bias); + ccv_nnc_tensor_free(w); + ccv_nnc_tensor_free(b); + ccv_nnc_tensor_free(a); + ccv_nnc_tensor_free(gbias); + ccv_nnc_tensor_free(gw); + ccv_nnc_tensor_free(ga); +} + +TEST_CASE("cudnn forward convolution transpose in nchw format") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CONVOLUTION_TRANSPOSE_FORWARD, CCV_NNC_BACKEND_GPU_CUDNN)); + ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, BATCH_SIZE, OUTPUT_DIM, OUTPUT_SIZE, OUTPUT_SIZE), 0); + ccv_nnc_tensor_t* b = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, BATCH_SIZE, INPUT_DIM, INPUT_SIZE, INPUT_SIZE), 0); + ccv_nnc_cmd_t cmd = CMD_CONVOLUTION_TRANSPOSE_FORWARD(1, INPUT_DIM, 0, KERNEL_SIZE, KERNEL_SIZE, OUTPUT_DIM); + cmd.backend = CCV_NNC_BACKEND_CPU_REF; + assert(cmd.backend >= 0); + ccv_nnc_hint_t hint = ccv_nnc_hint_auto(cmd.info, b->info, a->info); + ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, OUTPUT_DIM, INPUT_DIM, KERNEL_SIZE, KERNEL_SIZE), 0); + ccv_nnc_tensor_t* bias = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, INPUT_DIM), 0); + // configure the inlets. + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 0); + int i; + for (i = 0; i < INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE * OUTPUT_DIM; i++) + w->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE); + for (i = 0; i < OUTPUT_SIZE * OUTPUT_SIZE * OUTPUT_DIM * ccv_max(1, BATCH_SIZE); i++) + a->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < INPUT_DIM; i++) + bias->data.f32[i] = (float)i / INPUT_DIM; + // Copy generated matrix values over to GPU. + ccv_nnc_tensor_t* ga = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, BATCH_SIZE, OUTPUT_DIM, OUTPUT_SIZE, OUTPUT_SIZE), 0); + ccv_nnc_tensor_t* gw = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, OUTPUT_DIM, INPUT_DIM, KERNEL_SIZE, KERNEL_SIZE), 0); + ccv_nnc_tensor_t* gbias = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, INPUT_DIM), 0); + ccv_nnc_cmd_t move = CMD_DATA_TRANSFER_FORWARD(); + move.backend = CCV_NNC_BACKEND_GPU_REF; + assert(move.backend >= 0); + ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(ga, gw, gbias), 0); + ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(b), 0); + ccv_nnc_tensor_t* gc = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, BATCH_SIZE, INPUT_DIM, INPUT_SIZE, INPUT_SIZE), 0); + + ccv_nnc_cmd_t transform = CMD_FORMAT_TRANSFORM_FORWARD(); + transform.backend = CCV_NNC_BACKEND_GPU_CUDNN; + assert(transform.backend >= 0); + cmd.backend = CCV_NNC_BACKEND_GPU_CUDNN; + assert(cmd.backend >= 0); + cmd.algorithm = -1; + cmd = ccv_nnc_cmd_autotune(cmd, 1 * 1024 * 1024 * 1024, hint, 0, TENSOR_LIST(ga, gw, gbias), TENSOR_LIST(gc), 0); + assert(CCV_NNC_EXEC_SUCCESS == ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(ga, gw, gbias), TENSOR_LIST(gc), 0)); + ccv_nnc_tensor_t* c = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, BATCH_SIZE, INPUT_DIM, INPUT_SIZE, INPUT_SIZE), 0); + ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(gc), TENSOR_LIST(c), 0); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, b->data.f32, c->data.f32, BATCH_SIZE * INPUT_DIM * INPUT_SIZE * INPUT_SIZE, 1e-5, "output from cudnn should match from CPU"); + ccv_nnc_tensor_free(c); + ccv_nnc_tensor_free(gc); + ccv_nnc_tensor_free(bias); + ccv_nnc_tensor_free(w); + ccv_nnc_tensor_free(b); + ccv_nnc_tensor_free(a); + ccv_nnc_tensor_free(gbias); + ccv_nnc_tensor_free(gw); + ccv_nnc_tensor_free(ga); +} + +TEST_CASE("cudnn forward convolution transpose in half precision") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CONVOLUTION_TRANSPOSE_FORWARD, CCV_NNC_BACKEND_GPU_CUDNN)); + ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + ccv_nnc_tensor_t* b = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_cmd_t cmd = CMD_CONVOLUTION_TRANSPOSE_FORWARD(1, INPUT_DIM, 0, KERNEL_SIZE, KERNEL_SIZE, OUTPUT_DIM); + cmd.backend = CCV_NNC_BACKEND_CPU_REF; + assert(cmd.backend >= 0); + ccv_nnc_hint_t hint = ccv_nnc_hint_auto(cmd.info, b->info, a->info); + ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* bias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, INPUT_DIM), 0); + // configure the inlets. + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 0); + int i; + for (i = 0; i < INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE * OUTPUT_DIM; i++) + w->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE); + for (i = 0; i < OUTPUT_SIZE * OUTPUT_SIZE * OUTPUT_DIM * ccv_max(1, BATCH_SIZE); i++) + a->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < INPUT_DIM; i++) + bias->data.f32[i] = (float)i / INPUT_DIM; + ccv_nnc_tensor_t* a1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + ccv_nnc_tensor_t* w1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* bias1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, INPUT_DIM), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(a1, w1, bias1), 0); + // Copy generated matrix values over to GPU. + ccv_nnc_tensor_t* ga = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + ccv_nnc_tensor_t* gw = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* gbias = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, INPUT_DIM), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a1, w1, bias1), TENSOR_LIST(ga, gw, gbias), 0); + ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(b), 0); + ccv_nnc_tensor_t* gc = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + + ccv_nnc_stream_context_t* stream_context = ccv_nnc_stream_context_new(CCV_STREAM_CONTEXT_GPU); + + cmd.backend = CCV_NNC_BACKEND_GPU_CUDNN; + assert(cmd.backend >= 0); + cmd.algorithm = -1; + cmd = ccv_nnc_cmd_autotune(cmd, 512 * 1024 * 1024, hint, 0, TENSOR_LIST(ga, gw, gbias), TENSOR_LIST(gc), stream_context); + assert(CCV_NNC_EXEC_SUCCESS == ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(ga, gw, gbias), TENSOR_LIST(gc), stream_context)); + ccv_nnc_stream_context_wait(stream_context); + ccv_nnc_stream_context_free(stream_context); + ccv_nnc_tensor_t* c1 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gc), TENSOR_LIST(c1), 0); + ccv_nnc_tensor_t* c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(c1), TENSOR_LIST(c), 0); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, b->data.f32, c->data.f32, BATCH_SIZE * INPUT_DIM * INPUT_SIZE * INPUT_SIZE, 5e-3, "output from cudnn should match from CPU"); + ccv_nnc_tensor_free(c); + ccv_nnc_tensor_free(gc); + ccv_nnc_tensor_free(bias); + ccv_nnc_tensor_free(w); + ccv_nnc_tensor_free(b); + ccv_nnc_tensor_free(a); + ccv_nnc_tensor_free(c1); + ccv_nnc_tensor_free(bias1); + ccv_nnc_tensor_free(w1); + ccv_nnc_tensor_free(a1); + ccv_nnc_tensor_free(gbias); + ccv_nnc_tensor_free(gw); + ccv_nnc_tensor_free(ga); +} + #include "case_main.h"