Skip to content

Commit

Permalink
Merge branch 'extract_outliers' into debug
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Aug 4, 2022
2 parents 6101a8f + bd51532 commit cc5b323
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 55 deletions.
58 changes: 38 additions & 20 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,30 +203,30 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# generate outlier index and subB
outlier_idx = torch.unique(coo_tensorA.colidx).long()
state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# do not use pool for 2nd FFN layer
state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
else:
state.idx = outlier_idx
state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()

if state.idx is not None:
# extract outliers
CA[:, state.idx] = 0
CAt[:, state.idx] = 0
subA = A[:, state.idx]
else:
subA = None
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()

#if state.idx is not None:
# # extract outliers
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
#else:
# subA = None
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None

C32A, SA = F.transform(CA, 'col32')

# 2. Quantize B
if state.has_fp16_weights:
Expand All @@ -241,6 +241,23 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
else:
has_grad = False

if coo_tensorA is not None and not state.has_fp16_weights:
# extract outliers

outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
#else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half()
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]

shapeB = state.SB[0]

if len(input_shape) == 3:
Expand All @@ -249,11 +266,12 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
output_shape = (input_shape[0], shapeB[0])

# 3. Matmul
C32A, SA = F.transform(CA, 'col32')
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)

# 4. Mixed-precision decomposition matmul
if state.threshold > 0.0 and coo_tensorA is not None and subA is not None:
if coo_tensorA is not None and subA is not None:
output += torch.matmul(subA, state.subB)

# 5. Save state
Expand Down
26 changes: 26 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,3 +1435,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
x *= SA[1]/127
x +=offset
return x.to(dtype)

def extract_outliers(A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ['col_turing', 'col_ampere']
assert A.device.type == 'cuda'

out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)

idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0])
cols = ct.c_int32(shapeA[1])
ptrA = get_ptr(A)
ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out)

if formatA == 'col_turing':
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == 'col_ampere':
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)

return out




78 changes: 72 additions & 6 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2591,16 +2591,82 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}
}

template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
{
int local_colidx = idx[blockIdx.x];

if(FORMAT==COL_TURING)
{
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
// columns are grouped in increments of 4, meaning that one has the following rows and columns
// rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
// cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...]

// each thread reads 1 element = 1 row
for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
{
int offset_per_col_tile = ((rowsA+7)/8)*32*8;
int tile_offset_rows = (row/8)*32*8;
int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
int offset = 0;
int subtile_col_idx = local_colidx%32;
int subtile_row_idx = row % 8;
if(row % 2 == 1)
offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2);
else
// even
offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2);

offset += tile_offset_rows + tile_offset_cols;

char val = A[offset];

int out_idx = (row*idx_size) + blockIdx.x;
out[out_idx] = val;
}
}
else if(FORMAT == COL_AMPERE)
{

for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
{
// we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
// within each tile.
int offset_per_col_tile = ((rowsA+31)/32)*32*32;
int tile_offset_rows = (row/32)*32*32;
int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
int subtile_col_idx = local_colidx%32;
int subtile_row_idx = row % 32;
// this magic is taken from the cublasLt doc (search for COL32)
int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
offset += tile_offset_cols + tile_offset_rows;

char val = A[offset];
int out_idx = (row*idx_size) + blockIdx.x;
out[out_idx] = val;
}
}
}

//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================

template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);

template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);

template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
Expand Down
2 changes: 2 additions & 0 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);

template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);

#endif


26 changes: 26 additions & 0 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,36 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}


template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols)
{
int threads = 256;
// we load 128 column values per warp
int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
int tiledRows = 0;

int num_blocks = idx_size;

if(FORMAT == COL_TURING)
{
tiledRows = fill_up_to_nearest_multiple(rows, 8);
}
else if(FORMAT == COL_AMPERE)
{
tiledRows = fill_up_to_nearest_multiple(rows, 32);
}

kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================

template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);

template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);

Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val

template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);

template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);

#endif
6 changes: 6 additions & 0 deletions csrc/pythonInterface.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRo
void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }

void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }

int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }

Expand Down Expand Up @@ -280,6 +283,9 @@ extern "C"
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }

void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }

#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
Expand Down
66 changes: 37 additions & 29 deletions deploy_from_slurm.sh
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
#!/bin/bash
BASE_PATH=$1

echo "MAKE SURE LD_LIBRARY_PATH IS EMPTY!"
echo $LD_LIBRARY_PATH

if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then
echo "Compilation unsuccessul!" 1>&2
exit 64
fi


module unload cuda
module unload gcc

#rm -rf dist build
#make clean
#make cleaneggs
#export CUDA_HOME=
#make cpuonly
#
#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
# # Control will enter here if $DIRECTORY doesn't exist.
# echo "Compilation unsuccessul!" 1>&2
# exit 64
#fi
#CUDA_VERSION=cpu python -m build
#python -m twine upload dist/* --verbose --repository testpypi
rm -rf dist build
make clean
make cleaneggs
export CUDA_HOME=
make cpuonly

if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
CUDA_VERSION=cpu python -m build
python -m twine upload dist/* --verbose --repository testpypi

rm -rf dist build
make clean
make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.0
make cuda110
make cuda110

if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
Expand Down Expand Up @@ -102,20 +111,20 @@ fi
CUDA_VERSION=115 python -m build
python -m twine upload dist/* --verbose --repository testpypi

#rm -rf dist build
#make clean
#make cleaneggs
#export CUDA_HOME=$BASE_PATH/cuda-11.6
#
#make cuda11x
#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
# # Control will enter here if $DIRECTORY doesn't exist.
# echo "Compilation unsuccessul!" 1>&2
# exit 64
#fi
#CUDA_VERSION=116 python -m build
#python -m twine upload dist/* --verbose --repository testpypi
#
rm -rf dist build
make clean
make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.6

make cuda11x
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
CUDA_VERSION=116 python -m build
python -m twine upload dist/* --verbose --repository testpypi

rm -rf dist build
make clean
make cleaneggs
Expand Down Expand Up @@ -257,5 +266,4 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
exit 64
fi
CUDA_VERSION=117-nomatmul python -m build
python -m twine upload dist/* --verbose
python -m twine upload dist/* --verbose --repository testpypi
Loading

0 comments on commit cc5b323

Please sign in to comment.