diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e269ae69..187f0e6a 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -222,6 +222,78 @@ class HierarchicalNSW : public AlgorithmInterface { return num_deleted_; } + std::unordered_map getNodeDegree(size_t label) const { + std::unordered_map degreePerLevel; + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + if (internalId < 0 || internalId >= cur_element_count) { + throw std::invalid_argument("Node ID does not exist."); + } + + else if (isMarkedDeleted(internalId)){ + throw std::invalid_argument("This node ID has been previosly deleted!"); + } + else{ + // Degree at level 0 + linklistsizeint* ll_0 = get_linklist0(internalId); + size_t degreeLevel0 = getListCount(ll_0); + degreePerLevel[0] = degreeLevel0; + // Check higher levels + int node_level = element_levels_[internalId]; + for (int level = 1; level <= node_level; ++level) { + linklistsizeint* ll = get_linklist(internalId, level); + size_t degree = getListCount(ll); + degreePerLevel[level] = degree; + } + } + return degreePerLevel; + } + + float getAverageDegreeAtLayer(int targetLevel) const { + + if (targetLevel < 0 || targetLevel > maxlevel_){ + throw std::invalid_argument("Layer not found!"); + } + + if (targetLevel == 0) { + return getAverageDegreeAtLayer0(); + } + size_t totalDegree = 0; + size_t nodesAtLevelCount = 0; + + for (size_t node_id = 0; node_id < cur_element_count; ++node_id) { + if (element_levels_[node_id] == targetLevel) { + nodesAtLevelCount++; + linklistsizeint* ll = get_linklist(node_id, targetLevel); + size_t degree = getListCount(ll); + totalDegree += degree; + } + } + + if (nodesAtLevelCount == 0) return 0.0; // Avoid division by zero + return static_cast(totalDegree) / nodesAtLevelCount; + } + + // --------------------- // + float getAverageDegreeAtLayer0() const { + size_t totalDegree = 0; + size_t nodesAtLevelCount = 0; + + for (size_t node_id = 0; node_id < cur_element_count; ++node_id) { + if (element_levels_[node_id] == 0) { + nodesAtLevelCount++; + linklistsizeint* ll = get_linklist0(node_id); + size_t degree = getListCount(ll); + totalDegree += degree; + } + } + if (nodesAtLevelCount == 0) return 0.0; // Avoid division by zero + return static_cast(totalDegree) / nodesAtLevelCount; + } + std::priority_queue, std::vector>, CompareByFirst> searchBaseLayer(tableint ep_id, const void *data_point, int layer) { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index dd09e80a..30adda99 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -720,6 +720,10 @@ class Index { size_t getCurrentCount() const { return appr_alg->cur_element_count; } + + int getMaxLayer() const { + return appr_alg->maxlevel_; + } }; template @@ -911,6 +915,21 @@ PYBIND11_PLUGIN(hnswlib) { py::module m("hnswlib"); py::class_>(m, "Index") + .def("get_max_layer", &Index::getMaxLayer) + .def("get_node_degree", [](const Index &index, hnswlib::labeltype label) { + auto degrees = index.appr_alg->getNodeDegree(label); // Call the getNodeDegree method from the C++ class + py::dict result; // Create a Python dictionary to hold the results + for (const auto °ree : degrees) { // Convert the C++ unordered_map to a Python dictionary + result[py::cast(degree.first)] = py::cast(degree.second); + } + return result; + }, py::arg("label"), "Retrieves the degree of a node in all levels where it exists.") + + .def("get_average_degree_at_layer", [](const Index &index, int layer) { + float averageDegree = index.appr_alg->getAverageDegreeAtLayer(layer); // Call the getAverageDegreeAtLayer method from the C++ class + return averageDegree; + }, py::arg("layer"), "Calculates the average node degree at the specified layer.") + .def(py::init(&Index::createFromParams), py::arg("params")) /* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */ .def(py::init(&Index::createFromIndex), py::arg("index"))