Skip to content

Commit c309bb8

Browse files
committed
review + fix for window test
1 parent 1079b1f commit c309bb8

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

cpp/open3d/core/nns/KnnSearchOps.cu

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,28 @@ __global__ void MarkSelectedIndices(
5050
// points.
5151
template <typename T>
5252
__global__ void ApplyMaskToDistances(
53-
T* distances, // Shape: (num_queries_tile, tile_cols)
53+
T* distances, // Shape: (num_queries_tile, distance_row_stride)
5454
const uint8_t* mask, // Shape: (num_queries, num_points)
5555
int num_queries_tile,
56-
int tile_cols,
57-
int query_offset, // Starting query index in the mask
58-
int point_offset, // Starting point index in the mask
59-
int num_points) { // Total number of points
56+
int64_t distance_row_stride, // Stride between rows
57+
int num_points_tile, // Actual number of valid points in this tile
58+
int query_offset, // Starting query index in the mask
59+
int point_offset, // Starting point index in the mask
60+
int num_points) { // Total number of points
6061
int query_local = blockIdx.y;
6162
int point_local = blockIdx.x * blockDim.x + threadIdx.x;
6263

63-
if (query_local >= num_queries_tile || point_local >= tile_cols) return;
64+
if (query_local >= num_queries_tile || point_local >= num_points_tile)
65+
return;
6466

6567
int query_global = query_offset + query_local;
6668
int point_global = point_offset + point_local;
6769

6870
if (point_global >= num_points) return;
6971

7072
if (mask[query_global * num_points + point_global]) {
71-
distances[query_local * tile_cols + point_local] =
72-
static_cast<T>(INFINITY);
73+
distances[query_local * distance_row_stride + point_local] =
74+
static_cast<T>(std::numeric_limits<float>::max());
7375
}
7476
}
7577

@@ -165,8 +167,10 @@ void KnnSearchCUDASinglePass(const Tensor& points,
165167
int num_cols = utility::DivUp(num_points, tile_cols);
166168

167169
// Get pointers from allocator for use in runBlockSelectPair
168-
TIndex* indices_ptr = const_cast<TIndex*>(output_allocator.IndicesPtr());
169-
T* distances_ptr = const_cast<T*>(output_allocator.DistancesPtr());
170+
TIndex* indices_ptr = static_cast<TIndex*>(
171+
output_allocator.NeighborsIndex_().GetDataPtr());
172+
T* distances_ptr =
173+
static_cast<T*>(output_allocator.NeighborsDistance_().GetDataPtr());
170174

171175
// Allocate temporary memory space.
172176
Tensor temp_distances =
@@ -338,15 +342,16 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
338342

339343
// Apply mask: set already-selected distances to infinity
340344
if (total_found > 0) {
341-
int64_t temp_stride = temp_distances_view.GetStride(0);
345+
int64_t distance_row_stride =
346+
temp_distances_view.GetStride(0);
342347
int block_size = 256;
343348
dim3 block(block_size);
344349
dim3 grid(utility::DivUp(num_points_j, block_size),
345350
num_queries_i);
346351
ApplyMaskToDistances<T><<<grid, block, 0, cur_stream>>>(
347352
temp_distances_view.GetDataPtr<T>(),
348353
mask.GetDataPtr<uint8_t>(), num_queries_i,
349-
static_cast<int>(temp_stride), i, j,
354+
distance_row_stride, num_points_j, i, j,
350355
num_points);
351356
}
352357

@@ -357,6 +362,8 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
357362
cur_stream, temp_distances_view, point_norms_j,
358363
chunk_out_distances, chunk_out_indices, chunk_k,
359364
1, tile_cols);
365+
chunk_out_distances.Add_(
366+
query_norms_i.View({num_queries_i, 1}));
360367
} else {
361368
// Multi-tile case: output to buffer
362369
Tensor buf_distances_col_view =
@@ -397,10 +404,14 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
397404

398405
// Copy to final output
399406
TIndex* indices_ptr =
400-
const_cast<TIndex*>(output_allocator.IndicesPtr()) +
407+
static_cast<TIndex*>(
408+
output_allocator.NeighborsIndex_()
409+
.GetDataPtr()) +
401410
(i * knn + total_found);
402411
T* distances_ptr =
403-
const_cast<T*>(output_allocator.DistancesPtr()) +
412+
static_cast<T*>(
413+
output_allocator.NeighborsDistance_()
414+
.GetDataPtr()) +
404415
(i * knn + total_found);
405416

406417
for (int q = 0; q < num_queries_i; ++q) {
@@ -439,10 +450,6 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
439450
output_allocator.NeighborsDistance_().View(
440451
{num_queries, knn});
441452

442-
// Add query norms
443-
chunk_out_distances.Add_(
444-
query_norms_i.View({num_queries_i, 1}));
445-
446453
for (int q = 0; q < num_queries_i; ++q) {
447454
int global_query_idx = i + q;
448455

@@ -457,15 +464,15 @@ void KnnSearchCUDAMultiPass(const Tensor& points,
457464
.Slice(1, total_found,
458465
total_found + chunk_k)
459466
.Flatten();
460-
dst_dist.AsRvalue() = src_dist.To(device);
467+
dst_dist.AsRvalue() = src_dist;
461468

462469
Tensor dst_idx = out_indices_full
463470
.Slice(0, global_query_idx,
464471
global_query_idx + 1)
465472
.Slice(1, total_found,
466473
total_found + chunk_k)
467474
.Flatten();
468-
dst_idx.AsRvalue() = src_idx.To(device);
475+
dst_idx.AsRvalue() = src_idx;
469476
}
470477

471478
// Update mask for next pass

0 commit comments

Comments
 (0)