diff --git a/src/pn.cu b/src/pn.cu index d32db9a..4c50acf 100644 --- a/src/pn.cu +++ b/src/pn.cu @@ -19,7 +19,6 @@ cudaError_t generate_permutation_polynomials(const generate_permutation_polynomi cudaMemPool_t pool = cfg.mem_pool; cudaStream_t stream = cfg.stream; unsigned int columns_count = cfg.columns_count; - assert(columns_count == 4); unsigned int log_rows_count = cfg.log_rows_count; const unsigned cells_count = columns_count << log_rows_count; const unsigned bits_count = log2_ceiling(columns_count) + log_rows_count; @@ -31,7 +30,16 @@ cudaError_t generate_permutation_polynomials(const generate_permutation_polynomi unsigned_ints sorted_values; HANDLE_CUDA_ERROR(allocate(unsorted_keys, cells_count, pool, stream)); - HANDLE_CUDA_ERROR(transpose<4>(unsorted_keys, cfg.indexes, log_rows_count, stream)); + switch (columns_count) { + case 3: + HANDLE_CUDA_ERROR(transpose<3>(unsorted_keys, cfg.indexes, log_rows_count, stream)); + break; + case 4: + HANDLE_CUDA_ERROR(transpose<4>(unsorted_keys, cfg.indexes, log_rows_count, stream)); + break; + default: + assert(columns_count == 3 || columns_count == 4); + } HANDLE_CUDA_ERROR(allocate(unsorted_values, cells_count, pool, stream)); HANDLE_CUDA_ERROR(fill_transposed_range(unsorted_values, columns_count, log_rows_count, stream)); HANDLE_CUDA_ERROR(allocate(sorted_keys, cells_count, pool, stream)); diff --git a/src/pn_kernels.cu b/src/pn_kernels.cu index 77d2fb4..5c51510 100644 --- a/src/pn_kernels.cu +++ b/src/pn_kernels.cu @@ -44,6 +44,7 @@ template cudaError_t transpose(unsigned *dst, const unsigne return cudaGetLastError(); } +template cudaError_t transpose<3>(unsigned *dst, const unsigned *src, unsigned log_rows_count, cudaStream_t stream); template cudaError_t transpose<4>(unsigned *dst, const unsigned *src, unsigned log_rows_count, cudaStream_t stream); #undef BLOCK_SIZE