Skip to content

add sparse to gpu tests #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pipeline {
sh 'docker rmi -f bscwdc/dislib &> /dev/null || true'
sh 'docker build --pull --no-cache --tag bscwdc/dislib .'
sh '''#!/bin/bash
docker run $(bash <(curl -s https://codecov.io/env)) -d --name dislib bscwdc/dislib'''
docker run $(bash <(curl -s https://codecov.io/env)) -v /home/jenkins/.ssh:/root/.ssh:ro -d --name dislib bscwdc/dislib'''
}
}
stage('test') {
Expand Down
1 change: 1 addition & 0 deletions dislib/classification/knn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class KNeighborsClassifier(BaseEstimator):
but different labels, the results will depend on the ordering of the
training data.
https://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm

Examples
--------
>>> import dislib as ds
Expand Down
64 changes: 15 additions & 49 deletions dislib/cluster/kmeans/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,24 @@ def _decode_helper(obj):
@task(blocks={Type: COLLECTION_IN, Depth: 2}, returns=np.array)
def _partial_sum_gpu(blocks, centers):
import cupy as cp
from scipy.sparse import issparse

partials = np.zeros((centers.shape[0], 2), dtype=object)
arr = Array._merge_blocks(blocks).astype(np.float32)
arr_gpu = cp.asarray(arr)
centers_gpu = cp.asarray(centers).astype(cp.float32)
arr = Array._merge_blocks(blocks)

if issparse(arr):
arr = arr.todense()
if issparse(centers):
centers = centers.todense()

arr_gpu = cp.expand_dims(cp.asarray(arr), axis=1)
centers_gpu = cp.expand_dims(cp.asarray(centers), axis=0)

close_centers_gpu = cp.argmin(distance_gpu(arr_gpu, centers_gpu), axis=1)
arr_gpu, centers_gpu = None, None
diff = arr_gpu - centers_gpu
del arr_gpu, centers_gpu

dist_gpu = cp.linalg.norm(diff, axis=2)
close_centers_gpu = cp.argmin(dist_gpu, axis=1)

close_centers = cp.asnumpy(close_centers_gpu)

Expand Down Expand Up @@ -396,47 +406,3 @@ def _merge(*data):
def _predict(blocks, centers):
arr = Array._merge_blocks(blocks)
return pairwise_distances(arr, centers).argmin(axis=1).reshape(-1, 1)


def distance_gpu(a_gpu, b_gpu):
import cupy as cp

sq_sum_ker = get_sq_sum_kernel()

aa_gpu = cp.empty(a_gpu.shape[0], dtype=cp.float32)
bb_gpu = cp.empty(b_gpu.shape[0], dtype=cp.float32)

sq_sum_ker(a_gpu, aa_gpu, axis=1)
sq_sum_ker(b_gpu, bb_gpu, axis=1)

size = len(aa_gpu) * len(bb_gpu)
dist_gpu = cp.empty((len(aa_gpu), len(bb_gpu)), dtype=cp.float32)
add_mix_kernel(len(b_gpu))(aa_gpu, bb_gpu, dist_gpu, size=size)
aa_gpu, bb_gpu = None, None

dist_gpu += -2.0 * cp.dot(a_gpu, b_gpu.T)

return dist_gpu


def get_sq_sum_kernel():
import cupy as cp

return cp.ReductionKernel(
'T x', # input params
'T y', # output params
'x * x', # map
'a + b', # reduce
'y = a', # post-reduction map
'0', # identity value
'sqsum' # kernel name
)


def add_mix_kernel(y_len):
import cupy as cp

return cp.ElementwiseKernel(
'raw T x, raw T y', 'raw T z',
f'z[i] = x[i / {y_len}] + y[i % {y_len}]',
'add_mix')
22 changes: 20 additions & 2 deletions dislib/data/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,11 +1624,20 @@ def _matmul_with_transpose(a, b, transpose_a, transpose_b):
@task(returns=np.array)
def _add_gpu(block1, block2):
import cupy as cp
from scipy.sparse import issparse, csr_matrix

sparse = False
if issparse(block1):
block1 = block1.todense()
sparse = True
if issparse(block2):
block2 = block2.todense()
sparse = True

block1_gpu, block2_gpu = cp.asarray(block1), cp.asarray(block2)
res = cp.asnumpy(cp.add(block1_gpu, block2_gpu))
del block1_gpu, block2_gpu
return res
return res if not sparse else csr_matrix(res)


@constraint(computing_units="${ComputingUnits}")
Expand All @@ -1644,6 +1653,15 @@ def _add_cpu(block1, block2):
@task(returns=np.array)
def _matmul_gpu(a, b, transpose_a, transpose_b):
import cupy as cp
from scipy.sparse import issparse, csr_matrix

sparse = False
if issparse(a):
a = a.todense()
sparse = True
if issparse(b):
b = b.todense()
sparse = True

a_gpu, b_gpu = cp.asarray(a), cp.asarray(b)

Expand All @@ -1654,7 +1672,7 @@ def _matmul_gpu(a, b, transpose_a, transpose_b):

res = cp.asnumpy(cp.matmul(a_gpu, b_gpu))
del a_gpu, b_gpu
return res
return res if not sparse else csr_matrix(res)


def _multiply_block_groups(hblock, vblock, transpose_a=False,
Expand Down
4 changes: 2 additions & 2 deletions dislib/decomposition/qr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,9 @@ def _dot_task_gpu(a, b, transpose_result=False, transpose_a=False,

a_gpu, b_gpu = cp.asarray(a), cp.asarray(b)
if transpose_a:
a_gpu = np.transpose(a_gpu)
a_gpu = cp.transpose(a_gpu)
if transpose_b:
b_gpu = np.transpose(b_gpu)
b_gpu = cp.transpose(b_gpu)

dot_gpu = cp.dot(a_gpu, b_gpu)

Expand Down
2 changes: 1 addition & 1 deletion dislib/math/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _compute_u_block_sorted_gpu(a_block, index, bsize, sorting, u_block):
a_col_gpu = cp.asarray(Array._merge_blocks(a_block))
norm_gpu = cp.linalg.norm(a_col_gpu, axis=0)

zero_idx = cp.where(norm_gpu == 0)
zero_idx = cp.where(norm_gpu == 0)[0]
a_col_gpu[0, zero_idx] = 1
norm_gpu[zero_idx] = 1

Expand Down
73 changes: 20 additions & 53 deletions dislib/neighbors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def kneighbors(self, x, n_neighbors=None, return_distance=True):

if dislib.__gpu_available__:
for x_row in self._fit_data._iterator(axis=0):
q = _get_kneighbors_gpu(x_row._blocks,
q_row._blocks,
q = _get_kneighbors_gpu(q_row._blocks,
x_row._blocks,
n_neighbors,
offset)
queries.append(q)
offset += len(x_row._blocks)
offset += x_row.shape[0]
else:
for sknnstruct, n_samples in self._fit_data:
queries.append(_get_kneighbors(sknnstruct, q_row._blocks,
Expand Down Expand Up @@ -160,18 +160,31 @@ def _get_kneighbors(sknnstruct, q_blocks, n_neighbors, offset):
returns=tuple)
def _get_kneighbors_gpu(x_blocks, q_blocks, n_neighbors, offset):
import cupy as cp
from scipy.sparse import issparse

x_samples = Array._merge_blocks(x_blocks)
q_samples = Array._merge_blocks(q_blocks)

x_samples_gpu = cp.asarray(x_samples).astype(cp.float64)
q_samples_gpu = cp.asarray(q_samples).astype(cp.float64)
if issparse(x_samples):
x_samples = x_samples.todense()
if issparse(q_samples):
q_samples = q_samples.todense()

x_samples_gpu = cp.expand_dims(cp.asarray(x_samples), axis=1)
del x_samples
q_samples_gpu = cp.expand_dims(cp.asarray(q_samples), axis=0)
del q_samples

diff = x_samples_gpu - q_samples_gpu
dist_gpu = cp.linalg.norm(diff, axis=2)

dist_gpu = distance_gpu(q_samples_gpu, x_samples_gpu)
ind_gpu = cp.argsort(dist_gpu, axis=1)[:, :n_neighbors]
dist_gpu = cp.take_along_axis(dist_gpu, ind_gpu, axis=1)

return cp.asnumpy(dist_gpu), cp.asnumpy(ind_gpu) + offset
dist = cp.asnumpy(dist_gpu)
inds = cp.asnumpy(ind_gpu) + offset

return dist, inds


@constraint(computing_units="${ComputingUnits}")
Expand All @@ -190,49 +203,3 @@ def _merge_kqueries(k, *queries):
final_ind = np.take_along_axis(aggr_ind, final_ii, 1)

return final_dist, final_ind


def distance_gpu(a_gpu, b_gpu):
import cupy as cp

sq_sum_ker = get_sq_sum_kernel()

aa_gpu = cp.empty(a_gpu.shape[0], dtype=cp.float64)
bb_gpu = cp.empty(b_gpu.shape[0], dtype=cp.float64)

sq_sum_ker(a_gpu, aa_gpu, axis=1)
sq_sum_ker(b_gpu, bb_gpu, axis=1)

dist_shape = (len(aa_gpu), len(bb_gpu))
dist_gpu = cp.empty(dist_shape, dtype=cp.float64)

add_mix_kernel(len(b_gpu))(aa_gpu, bb_gpu, dist_gpu,
size=int(np.prod(dist_shape)))
aa_gpu, bb_gpu = None, None

dist_gpu += -2.0 * cp.dot(a_gpu, b_gpu.T)

return cp.sqrt(dist_gpu)


def get_sq_sum_kernel():
import cupy as cp

return cp.ReductionKernel(
'T x', # input params
'T y', # output params
'x * x', # map
'a + b', # reduce
'y = a', # post-reduction map
'0', # identity value
'sqsum' # kernel name
)


def add_mix_kernel(y_len):
import cupy as cp

return cp.ElementwiseKernel(
'raw T x, raw T y', 'raw T z',
f'z[i] = x[i / {y_len}] + y[i % {y_len}]',
'add_mix')
8 changes: 4 additions & 4 deletions run_ci_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ export PYTHONPATH=$PYTHONPATH:${root_path}
echo "Running flake8 style check"
./run_style.sh

echo "Running tests"
# Run the tests in ./tests with PyCOMPSs
./run_tests.sh

echo "Running code coverage"
./run_coverage.sh

# echo "Running tests"
# Run the tests in ./tests with PyCOMPSs
# ./run_tests.sh
22 changes: 17 additions & 5 deletions run_coverage.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
#!/bin/bash -e
#!/bin/bash

rm -rf .git
ssh [email protected] rm -rf /scratch/tmp/dislib-gpu-test

coverage3 run --data-file=cpu_cov --source dislib tests &
cpu_cov=$!

ssh [email protected] mkdir -p /scratch/tmp/dislib-gpu-test
scp -r . [email protected]:/scratch/tmp/dislib-gpu-test/
ssh [email protected] "cd /scratch/tmp/dislib-gpu-test;./run_gpu_cov.sh"
scp [email protected]:/scratch/tmp/dislib-gpu-test/gpu_cov .
ssh [email protected] rm -rf /scratch/tmp/dislib-gpu-test

wait $cpu_cov

coverage3 combine cpu_cov gpu_cov

# Run the coverage of the dislib using the tests in ./tests (sequential)
coverage3 run --source dislib tests
# Report coverage results to the CLI.
coverage3 report -m
# Upload coverage report to codecov.io
bash <(curl -s https://codecov.io/bash) -t 629589cf-e257-4262-8ec0-314dfd98f003
13 changes: 13 additions & 0 deletions run_gpu_cov.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash -e

module purge
module load bullxmpi/bullxmpi-1.2.9.1 COMPSs/TrunkCT
module load gcc/9.2.0 cuda/10.1 mkl/2018.1 ANACONDA/2021.05
module unload python

eval "$(conda shell.bash hook)"
conda activate cupy-cuda101

export PYTHONPATH=/apps/COMPSs/3.1/Bindings/python/3/

DISLIB_GPU_AVAILABLE=True python3 -m coverage run --data-file=gpu_cov --source dislib tests
2 changes: 1 addition & 1 deletion tests/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_kneighbors_sparse(self):
X, Y = csr_matrix(X), Y
x, y = ds.array(X, (50, 5)), ds.array(Y, (50, 1))

knn = KNeighborsClassifier(n_neighbors=3, weights='')
knn = KNeighborsClassifier(n_neighbors=3, weights='distance')
knn.fit(x, y)
ds_y_hat = knn.predict(x)

Expand Down