Skip to content

Commit 35ee147

Browse files
committed
fix knn failure on windows due to mismatch in types and precision
1 parent 45cb63f commit 35ee147

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

dpbench/benchmarks/default/knn/knn_initialize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def _gen_data_x(ip_size, data_dim, seed, dtype):
2424

2525
def _gen_data_y(ip_size, classes_num, seed):
2626
default_rng.seed(seed)
27-
data = default_rng.randint(classes_num, size=ip_size)
27+
data = default_rng.randint(
28+
classes_num, size=ip_size, dtype=types_dict["int"]
29+
)
2830
return data
2931

3032
def _gen_train_data(train_size, data_dim, classes_num, seed_train, dtype):

dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_kernel.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88
template <typename FpTy, typename IntTy> class theKernel;
99

10-
template <typename FpTy> struct neighbors
10+
template <typename FpTy, typename IntTy> struct neighbors
1111
{
1212
FpTy dist;
13-
size_t label;
13+
IntTy label;
1414
};
1515

1616
template <typename FpTy, typename IntTy>
1717
sycl::event knn_impl(sycl::queue q,
1818
FpTy *d_train,
19-
size_t *d_train_labels,
19+
IntTy *d_train_labels,
2020
FpTy *d_test,
2121
size_t k,
2222
size_t classes_num,
@@ -33,7 +33,7 @@ sycl::event knn_impl(sycl::queue q,
3333

3434
// here k has to be 5 in order to match with numpy no. of
3535
// neighbors
36-
struct neighbors<FpTy> queue_neighbors[5];
36+
struct neighbors<FpTy, IntTy> queue_neighbors[5];
3737

3838
// count distances
3939
for (size_t j = 0; j < k; ++j) {
@@ -54,7 +54,7 @@ sycl::event knn_impl(sycl::queue q,
5454
for (size_t j = 0; j < k; ++j) {
5555
// push queue
5656
FpTy new_distance = queue_neighbors[j].dist;
57-
FpTy new_neighbor_label = queue_neighbors[j].label;
57+
IntTy new_neighbor_label = queue_neighbors[j].label;
5858
size_t index = j;
5959
while (index > 0 &&
6060
new_distance < queue_neighbors[index - 1].dist)
@@ -83,7 +83,7 @@ sycl::event knn_impl(sycl::queue q,
8383

8484
// push queue
8585
FpTy new_distance = queue_neighbors[k - 1].dist;
86-
FpTy new_neighbor_label = queue_neighbors[k - 1].label;
86+
IntTy new_neighbor_label = queue_neighbors[k - 1].label;
8787
size_t index = k - 1;
8888

8989
while (index > 0 &&

dpbench/benchmarks/default/knn/knn_sycl_native_ext/knn_sycl/_knn_sycl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void knn_sync(dpctl::tensor::usm_ndarray x_train,
4343
if (typenum == UAR_FLOAT) {
4444
sycl::event res_ev = knn_impl<float, unsigned int>(
4545
x_train.get_queue(), x_train.get_data<float>(),
46-
y_train.get_data<size_t>(), x_test.get_data<float>(), k,
46+
y_train.get_data<unsigned int>(), x_test.get_data<float>(), k,
4747
classes_num, train_size, test_size,
4848
predictions.get_data<unsigned int>(),
4949
votes_to_classes.get_data<float>(), data_dim);

0 commit comments

Comments
 (0)