Skip to content

Commit 923cce5

Browse files
authored
Fix BruteForce serialize test (#1568)
`test_serialization.py::test_save_load_brute_force` can sometimes fail due to the ordering of the neighbors. To fix this we can sort the neighbors before comparison, and if there's a small discrepancy due to two neighbors being at the same distance on the `k` cutoff we check that the recall is over 99.8%. I ran this test on a loop for 60k iterations without it failing. Authors: - Micka (https://github.com/lowener) Approvers: - Gil Forsyth (https://github.com/gforsyth) - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) URL: #1568
1 parent 19dd922 commit 923cce5

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

python/cuvs/cuvs/tests/test_serialization.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44

@@ -7,7 +7,7 @@
77
from pylibraft.common import device_ndarray
88

99
from cuvs.neighbors import brute_force, cagra, ivf_flat, ivf_pq
10-
from cuvs.tests.ann_utils import generate_data
10+
from cuvs.tests.ann_utils import calc_recall, generate_data
1111

1212

1313
@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.ubyte])
@@ -77,5 +77,17 @@ def run_save_load(ann_module, dtype):
7777
neighbors2 = neighbors_dev.copy_to_host()
7878
dist2 = distance_dev.copy_to_host()
7979

80-
assert np.all(neighbors == neighbors2)
8180
assert np.allclose(dist, dist2, rtol=1e-6)
81+
82+
# Sort the neighbors to avoid ordering issues
83+
sorted_neighbors = np.argsort(neighbors, axis=-1)
84+
sorted_neighbors2 = np.argsort(neighbors2, axis=-1)
85+
neighbors = np.take_along_axis(neighbors, sorted_neighbors, axis=-1)
86+
neighbors2 = np.take_along_axis(neighbors2, sorted_neighbors2, axis=-1)
87+
all_match = np.all(neighbors == neighbors2)
88+
# If the neighbors are not the same, there might be a cutoff between the k
89+
# and k+1 neighbors at the same distance.
90+
# Calculate that the recall is at least 99.8%
91+
if not all_match:
92+
recall = calc_recall(neighbors, neighbors2)
93+
assert recall >= 0.998

0 commit comments

Comments
 (0)