Skip to content

Commit

Permalink
Some progress on build script; added multi-cuda install script.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Jul 26, 2022
1 parent 1e88edd commit 9268dc9
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 99 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ Features:
- Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation
- Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations
- CPU only build now available (Thank you, @mryab)

Deprecated:
- Pre-compiled release for CUDA 9.2, 10.0, 10.2 no longer available
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
COMPUTE_CAPABILITY := -gencode arch=compute_75,code=sm_75 # Volta

# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
CC_CUDA92 := -gencode arch=compute_30,code=sm_30

# Later versions of CUDA support the new architectures
CC_CUDA10x := -gencode arch=compute_30,code=sm_30
CC_CUDA10x += -gencode arch=compute_75,code=sm_75
CC_CUDA10x := -gencode arch=compute_75,code=sm_75

CC_CUDA110 := -gencode arch=compute_75,code=sm_75
CC_CUDA110 += -gencode arch=compute_80,code=sm_80
Expand All @@ -43,12 +44,12 @@ CC_CUDA11x += -gencode arch=compute_80,code=sm_80
CC_CUDA11x += -gencode arch=compute_86,code=sm_86

all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)

cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)

Expand Down
17 changes: 8 additions & 9 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2166,7 +2166,6 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
char local_data[ITEMS_PER_THREAD];
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
__shared__ typename BlockExchange::TempStorage temp_storage;

// we load row after row from the base_position
// Load data row by row
Expand Down Expand Up @@ -2446,7 +2445,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{

// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
Expand Down Expand Up @@ -2500,7 +2499,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
{
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];

__syncthreads();
}
Expand Down Expand Up @@ -2596,12 +2595,12 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
// TEMPLATE DEFINITIONS
//==============================================================

template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);

template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
Expand Down
2 changes: 1 addition & 1 deletion csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);


template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);

template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
Expand Down
22 changes: 21 additions & 1 deletion csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ int roundoff(int v, int d) {
}


#ifdef NO_CUBLASLT
#else
template<int ORDER> cublasLtOrder_t get_order()
{
switch(ORDER)
Expand All @@ -266,14 +268,19 @@ template<int ORDER> cublasLtOrder_t get_order()
case COL_AMPERE:
return CUBLASLT_ORDER_COL32_2R_4R4;
break;
default:
break;
}

return CUBLASLT_ORDER_ROW;
}

template cublasLtOrder_t get_order<ROW>();
template cublasLtOrder_t get_order<COL>();
template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();
#endif


template<int ORDER> int get_leading_dim(int dim1, int dim2)
Expand All @@ -297,6 +304,9 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
// 32*32 tiles
return 32*roundoff(dim1, 32);
break;
default:
return 0;
break;
}
}

Expand All @@ -306,7 +316,8 @@ template int get_leading_dim<COL32>(int dim1, int dim2);

template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{

#ifdef NO_CUBLASLT
#else
cublasLtOrder_t orderA = get_order<SRC>();
cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
Expand Down Expand Up @@ -345,6 +356,7 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
#endif
}

template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
Expand All @@ -358,6 +370,9 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl

template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_CUBLASLT
return 0;
#else
int has_error = 0;
cublasLtMatmulDesc_t matmulDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
Expand Down Expand Up @@ -412,6 +427,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
printf("error detected");

return has_error;
#endif
}

int fill_up_to_nearest_multiple(int value, int multiple)
Expand Down Expand Up @@ -523,6 +539,9 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{

#ifdef NO_CUBLASLT
#else

cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC;

Expand Down Expand Up @@ -569,6 +588,7 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
#endif
}

template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
Expand Down
77 changes: 77 additions & 0 deletions cuda_install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
URL92=https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux
URL100=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux
URL101=https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run
URL102=https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run
URL110=https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run
URL111=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
URL112=https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run
URL113=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run
URL114=https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run
URL115=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run
URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run
URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run


CUDA_VERSION=$1
BASE_PATH=$2

if [[ -n "$CUDA_VERSION" ]]; then
if [[ "$CUDA_VERSION" -eq "92" ]]; then
URL=$URL92
FOLDER=cuda-9.2
elif [[ "$CUDA_VERSION" -eq "100" ]]; then
URL=$URL100
FOLDER=cuda-10.0
elif [[ "$CUDA_VERSION" -eq "101" ]]; then
URL=$URL101
FOLDER=cuda-10.1
elif [[ "$CUDA_VERSION" -eq "102" ]]; then
URL=$URL102
FOLDER=cuda-10.2
elif [[ "$CUDA_VERSION" -eq "110" ]]; then
URL=$URL110
FOLDER=cuda-11.0
elif [[ "$CUDA_VERSION" -eq "111" ]]; then
URL=$URL111
FOLDER=cuda-11.1
elif [[ "$CUDA_VERSION" -eq "112" ]]; then
URL=$URL112
FOLDER=cuda-11.2
elif [[ "$CUDA_VERSION" -eq "113" ]]; then
URL=$URL113
FOLDER=cuda-11.3
elif [[ "$CUDA_VERSION" -eq "114" ]]; then
URL=$URL114
FOLDER=cuda-11.4
elif [[ "$CUDA_VERSION" -eq "115" ]]; then
URL=$URL115
FOLDER=cuda-11.5
elif [[ "$CUDA_VERSION" -eq "116" ]]; then
URL=$URL116
FOLDER=cuda-11.6
elif [[ "$CUDA_VERSION" -eq "117" ]]; then
URL=$URL117
FOLDER=cuda-11.7
else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
fi
else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
fi

FILE=$(basename $URL)

if [[ -n "$CUDA_VERSION" ]]; then
echo $URL
echo $FILE
wget $URL
bash $FILE --no-drm --no-man-page --override --installpath=~/local --librarypath=$BASE_PATH/lib --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc
echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc
source ~/.bashrc
else
echo ""
fi



38 changes: 0 additions & 38 deletions cuda_install_111.sh

This file was deleted.

Loading

0 comments on commit 9268dc9

Please sign in to comment.