Skip to content

Commit 72bab3e

Browse files
authored
Merge branch 'release/25.12' into fix-nnd-recall-fp32
2 parents b7ec9cc + 923cce5 commit 72bab3e

File tree

4 files changed

+48
-22
lines changed

4 files changed

+48
-22
lines changed

c/include/cuvs/core/c_api.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
#include <cuda_runtime.h>
99
#include <dlpack/dlpack.h>
10-
#include <rapids_logger/log_levels.h>
1110
#include <stdbool.h>
1211
#include <stdint.h>
1312

@@ -48,13 +47,13 @@ void cuvsSetLastErrorText(const char* error);
4847
*
4948
*/
5049
typedef enum {
51-
CUVS_LOG_LEVEL_TRACE = RAPIDS_LOGGER_LOG_LEVEL_TRACE,
52-
CUVS_LOG_LEVEL_DEBUG = RAPIDS_LOGGER_LOG_LEVEL_DEBUG,
53-
CUVS_LOG_LEVEL_INFO = RAPIDS_LOGGER_LOG_LEVEL_INFO,
54-
CUVS_LOG_LEVEL_WARN = RAPIDS_LOGGER_LOG_LEVEL_WARN,
55-
CUVS_LOG_LEVEL_ERROR = RAPIDS_LOGGER_LOG_LEVEL_ERROR,
56-
CUVS_LOG_LEVEL_CRITICAL = RAPIDS_LOGGER_LOG_LEVEL_CRITICAL,
57-
CUVS_LOG_LEVEL_OFF = RAPIDS_LOGGER_LOG_LEVEL_OFF
50+
CUVS_LOG_LEVEL_TRACE = 0,
51+
CUVS_LOG_LEVEL_DEBUG = 1,
52+
CUVS_LOG_LEVEL_INFO = 2,
53+
CUVS_LOG_LEVEL_WARN = 3,
54+
CUVS_LOG_LEVEL_ERROR = 4,
55+
CUVS_LOG_LEVEL_CRITICAL = 5,
56+
CUVS_LOG_LEVEL_OFF = 6
5857
} cuvsLogLevel_t;
5958

6059
/** @brief Returns the current log level

c/src/core/c_api.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,32 @@ extern "C" const char* cuvsGetLastErrorText()
214214

215215
extern "C" void cuvsSetLogLevel(cuvsLogLevel_t verbosity)
216216
{
217-
raft::default_logger().set_level(static_cast<rapids_logger::level_enum>(verbosity));
217+
rapids_logger::level_enum level = rapids_logger::level_enum::trace;
218+
switch (verbosity) {
219+
case CUVS_LOG_LEVEL_TRACE:
220+
level = rapids_logger::level_enum::trace;
221+
break;
222+
case CUVS_LOG_LEVEL_DEBUG:
223+
level = rapids_logger::level_enum::debug;
224+
break;
225+
case CUVS_LOG_LEVEL_INFO:
226+
level = rapids_logger::level_enum::info;
227+
break;
228+
case CUVS_LOG_LEVEL_WARN:
229+
level = rapids_logger::level_enum::warn;
230+
break;
231+
case CUVS_LOG_LEVEL_ERROR:
232+
level = rapids_logger::level_enum::error;
233+
break;
234+
case CUVS_LOG_LEVEL_CRITICAL:
235+
level = rapids_logger::level_enum::critical;
236+
break;
237+
case CUVS_LOG_LEVEL_OFF:
238+
level = rapids_logger::level_enum::off;
239+
break;
240+
default: RAFT_FAIL("Unsupported cuvsLogLevel_t value provided");
241+
}
242+
raft::default_logger().set_level(level);
218243
}
219244

220245
extern "C" cuvsLogLevel_t cuvsGetLogLevel()

java/panama-bindings/generate-bindings.sh

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,6 @@ else
2020
exit 1
2121
fi
2222

23-
if [ -n "${RAPIDS_LOGGER_INCLUDE_DIR:-}" ]; then
24-
echo "Using user-defined RAPIDS_LOGGER_INCLUDE_DIR"
25-
elif [ -n "${CONDA_PREFIX:-}" ]; then
26-
RAPIDS_LOGGER_INCLUDE_DIR="${CONDA_PREFIX}/include"
27-
else
28-
echo "Couldn't find a suitable CUDA include directory."
29-
exit 1
30-
fi
31-
3223
PATH="$(pwd)/jextract-22/bin/:${PATH}"
3324
export PATH
3425

@@ -46,7 +37,6 @@ fi
4637
jextract \
4738
--include-dir "${REPODIR}"/java/internal/build/bindings/include/ \
4839
--include-dir "${CUDA_INCLUDE_DIR}" \
49-
--include-dir "${RAPIDS_LOGGER_INCLUDE_DIR}" \
5040
--output "${REPODIR}/java/cuvs-java/src/main/java22/" \
5141
--target-package ${TARGET_PACKAGE} \
5242
--library cuvs_c \

python/cuvs/cuvs/tests/test_serialization.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44

@@ -7,7 +7,7 @@
77
from pylibraft.common import device_ndarray
88

99
from cuvs.neighbors import brute_force, cagra, ivf_flat, ivf_pq
10-
from cuvs.tests.ann_utils import generate_data
10+
from cuvs.tests.ann_utils import calc_recall, generate_data
1111

1212

1313
@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.ubyte])
@@ -77,5 +77,17 @@ def run_save_load(ann_module, dtype):
7777
neighbors2 = neighbors_dev.copy_to_host()
7878
dist2 = distance_dev.copy_to_host()
7979

80-
assert np.all(neighbors == neighbors2)
8180
assert np.allclose(dist, dist2, rtol=1e-6)
81+
82+
# Sort the neighbors to avoid ordering issues
83+
sorted_neighbors = np.argsort(neighbors, axis=-1)
84+
sorted_neighbors2 = np.argsort(neighbors2, axis=-1)
85+
neighbors = np.take_along_axis(neighbors, sorted_neighbors, axis=-1)
86+
neighbors2 = np.take_along_axis(neighbors2, sorted_neighbors2, axis=-1)
87+
all_match = np.all(neighbors == neighbors2)
88+
# If the neighbors are not the same, there might be a cutoff between the k
89+
# and k+1 neighbors at the same distance.
90+
# Calculate that the recall is at least 99.8%
91+
if not all_match:
92+
recall = calc_recall(neighbors, neighbors2)
93+
assert recall >= 0.998

0 commit comments

Comments
 (0)