Skip to content

Commit

Permalink
Fix a compilation issue on cmul.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 31, 2023
1 parent 4daeabd commit 75f4fd9
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions lib/nnc/cmd/blas/gpu/ccv_nnc_cmul_gpu_ref.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ __global__ void _ccv_nnc_cmul_kernel(const size_t count, const NUM1* const a, co
}

template<typename NUM1, typename NUM2, typename NUM3>
__global__ void _ccv_nnc_cmul_kernel_4d_0(const int astride2, const int astride1, const int astride0, const int bstride2, const int bstride1, const int bstride0, const int cstride2, const int cstride1, const int cstride0, const int dim1, const int dim0, const NUM1* const a, const NUM2* const b, NUM3* const c)
__global__ void _ccv_nnc_cmul_kernel_4d_0(const int astride2, const int astride1, const int astride0, const int bstride2, const int bstride1, const int bstride0, const int cstride2, const int cstride1, const int cstride0, const int dim2, const int dim1, const int dim0, const NUM1* const a, const NUM2* const b, NUM3* const c)
{
const int z = blockIdx.z * blockDim.z + threadIdx.z;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -161,13 +161,13 @@ static int _ccv_nnc_cmul_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
{
if (a->info.datatype == CCV_32F && c->info.datatype == CCV_32F)
{
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[1], cdim[0] / 2, a->data.f32, b->data.f32, c->data.f32);
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[2], cdim[1], cdim[0] / 2, a->data.f32, b->data.f32, c->data.f32);
} else if (a->info.datatype == CCV_32F && c->info.datatype == CCV_16F) {
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[1], cdim[0] / 2, a->data.f32, b->data.f32, (__half*)c->data.f16);
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[2], cdim[1], cdim[0] / 2, a->data.f32, b->data.f32, (__half*)c->data.f16);
} else if (a->info.datatype == CCV_16F && c->info.datatype == CCV_32F) {
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[1], cdim[0] / 2, (__half*)a->data.f16, (__half*)b->data.f16, c->data.f32);
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[2], cdim[1], cdim[0] / 2, (__half*)a->data.f16, (__half*)b->data.f16, c->data.f32);
} else if (a->info.datatype == CCV_16F && c->info.datatype == CCV_16F) {
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[1], cdim[0] / 2, (__half*)a->data.f16, (__half*)b->data.f16, (__half*)c->data.f16);
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[2], cdim[1], cdim[0] / 2, (__half*)a->data.f16, (__half*)b->data.f16, (__half*)c->data.f16);
}
} else if (nd == 3) {
if (a->info.datatype == CCV_32F && c->info.datatype == CCV_32F)
Expand Down

0 comments on commit 75f4fd9

Please sign in to comment.