@@ -50,26 +50,28 @@ __global__ void MarkSelectedIndices(
5050// points.
5151template <typename T>
5252__global__ void ApplyMaskToDistances (
53- T* distances, // Shape: (num_queries_tile, tile_cols )
53+ T* distances, // Shape: (num_queries_tile, distance_row_stride )
5454 const uint8_t * mask, // Shape: (num_queries, num_points)
5555 int num_queries_tile,
56- int tile_cols,
57- int query_offset, // Starting query index in the mask
58- int point_offset, // Starting point index in the mask
59- int num_points) { // Total number of points
56+ int64_t distance_row_stride, // Stride between rows
57+ int num_points_tile, // Actual number of valid points in this tile
58+ int query_offset, // Starting query index in the mask
59+ int point_offset, // Starting point index in the mask
60+ int num_points) { // Total number of points
6061 int query_local = blockIdx .y ;
6162 int point_local = blockIdx .x * blockDim .x + threadIdx .x ;
6263
63- if (query_local >= num_queries_tile || point_local >= tile_cols) return ;
64+ if (query_local >= num_queries_tile || point_local >= num_points_tile)
65+ return ;
6466
6567 int query_global = query_offset + query_local;
6668 int point_global = point_offset + point_local;
6769
6870 if (point_global >= num_points) return ;
6971
7072 if (mask[query_global * num_points + point_global]) {
71- distances[query_local * tile_cols + point_local] =
72- static_cast <T>(INFINITY );
73+ distances[query_local * distance_row_stride + point_local] =
74+ static_cast <T>(std::numeric_limits< float >:: max () );
7375 }
7476}
7577
@@ -165,8 +167,10 @@ void KnnSearchCUDASinglePass(const Tensor& points,
165167 int num_cols = utility::DivUp (num_points, tile_cols);
166168
167169 // Get pointers from allocator for use in runBlockSelectPair
168- TIndex* indices_ptr = const_cast <TIndex*>(output_allocator.IndicesPtr ());
169- T* distances_ptr = const_cast <T*>(output_allocator.DistancesPtr ());
170+ TIndex* indices_ptr = static_cast <TIndex*>(
171+ output_allocator.NeighborsIndex_ ().GetDataPtr ());
172+ T* distances_ptr =
173+ static_cast <T*>(output_allocator.NeighborsDistance_ ().GetDataPtr ());
170174
171175 // Allocate temporary memory space.
172176 Tensor temp_distances =
@@ -338,15 +342,16 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
338342
339343 // Apply mask: set already-selected distances to infinity
340344 if (total_found > 0 ) {
341- int64_t temp_stride = temp_distances_view.GetStride (0 );
345+ int64_t distance_row_stride =
346+ temp_distances_view.GetStride (0 );
342347 int block_size = 256 ;
343348 dim3 block (block_size);
344349 dim3 grid (utility::DivUp (num_points_j, block_size),
345350 num_queries_i);
346351 ApplyMaskToDistances<T><<<grid, block, 0 , cur_stream>>> (
347352 temp_distances_view.GetDataPtr <T>(),
348353 mask.GetDataPtr <uint8_t >(), num_queries_i,
349- static_cast < int >(temp_stride) , i, j,
354+ distance_row_stride, num_points_j , i, j,
350355 num_points);
351356 }
352357
@@ -357,6 +362,8 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
357362 cur_stream, temp_distances_view, point_norms_j,
358363 chunk_out_distances, chunk_out_indices, chunk_k,
359364 1 , tile_cols);
365+ chunk_out_distances.Add_ (
366+ query_norms_i.View ({num_queries_i, 1 }));
360367 } else {
361368 // Multi-tile case: output to buffer
362369 Tensor buf_distances_col_view =
@@ -397,10 +404,14 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
397404
398405 // Copy to final output
399406 TIndex* indices_ptr =
400- const_cast <TIndex*>(output_allocator.IndicesPtr ()) +
407+ static_cast <TIndex*>(
408+ output_allocator.NeighborsIndex_ ()
409+ .GetDataPtr ()) +
401410 (i * knn + total_found);
402411 T* distances_ptr =
403- const_cast <T*>(output_allocator.DistancesPtr ()) +
412+ static_cast <T*>(
413+ output_allocator.NeighborsDistance_ ()
414+ .GetDataPtr ()) +
404415 (i * knn + total_found);
405416
406417 for (int q = 0 ; q < num_queries_i; ++q) {
@@ -439,10 +450,6 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
439450 output_allocator.NeighborsDistance_ ().View (
440451 {num_queries, knn});
441452
442- // Add query norms
443- chunk_out_distances.Add_ (
444- query_norms_i.View ({num_queries_i, 1 }));
445-
446453 for (int q = 0 ; q < num_queries_i; ++q) {
447454 int global_query_idx = i + q;
448455
@@ -457,15 +464,15 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
457464 .Slice (1 , total_found,
458465 total_found + chunk_k)
459466 .Flatten ();
460- dst_dist.AsRvalue () = src_dist. To (device) ;
467+ dst_dist.AsRvalue () = src_dist;
461468
462469 Tensor dst_idx = out_indices_full
463470 .Slice (0 , global_query_idx,
464471 global_query_idx + 1 )
465472 .Slice (1 , total_found,
466473 total_found + chunk_k)
467474 .Flatten ();
468- dst_idx.AsRvalue () = src_idx. To (device) ;
475+ dst_idx.AsRvalue () = src_idx;
469476 }
470477
471478 // Update mask for next pass
0 commit comments