Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to obtain internal statistics and parameters of an HNSW index #594

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 108 additions & 9 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@ namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;

struct Stats {
size_t nodes = 0;
size_t edges = 0;
size_t allocated_bytes = 0;
size_t max_edges = 0;
};

struct InternalParameters {
size_t max_elements = 0;
size_t num_deleted = 0;
size_t M = 0;
size_t maxM = 0;
size_t maxM0 = 0;
size_t ef_construction = 0;
size_t ef = 0;
double mult = 0;
size_t maxlevel = 0;

size_t size_data_per_element = 0;
size_t size_links_per_element = 0;
size_t size_links_level0 = 0;

size_t bytes_per_vector = 0;
};

template<typename dist_t>
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
public:
Expand Down Expand Up @@ -51,6 +76,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
char **linkLists_{nullptr};
std::vector<int> element_levels_; // keeps level of each element

// Size of each vector in bytes.
size_t data_size_{0};

DISTFUNC<dist_t> fstdistfunc_;
Expand Down Expand Up @@ -92,7 +118,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
size_t M = 16,
size_t ef_construction = 200,
size_t random_seed = 100,
bool allow_replace_deleted = false)
bool allow_replace_deleted = false,
size_t ef = 10)
: label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
link_list_locks_(max_elements),
element_levels_(max_elements),
Expand All @@ -112,7 +139,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
maxM_ = M_;
maxM0_ = M_ * 2;
ef_construction_ = std::max(ef_construction, M_);
ef_ = 10;
ef_ = ef;

level_generator_.seed(random_seed);
update_probability_generator_.seed(random_seed + 1);
Expand Down Expand Up @@ -322,7 +349,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
if (bare_bone_search ||
if (bare_bone_search ||
(!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) {
char* ep_data = getDataByInternalId(ep_id);
dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
Expand Down Expand Up @@ -403,7 +430,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
_MM_HINT_T0); ////////////////////////
#endif

if (bare_bone_search ||
if (bare_bone_search ||
(!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
top_candidates.emplace(dist, candidate_id);
if (!bare_bone_search && stop_condition) {
Expand Down Expand Up @@ -682,7 +709,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
return size;
}

void saveIndex(const std::string &location) {
void saveIndex(const std::string &location) override {
std::ofstream output(location, std::ios::binary);
std::streampos position;

Expand Down Expand Up @@ -826,7 +853,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::vector<data_t> getDataByLabel(labeltype label) const {
// lock all operations with element by label
std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));

std::unique_lock <std::mutex> lock_table(label_lookup_lock);
auto search = label_lookup_.find(label);
if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
Expand Down Expand Up @@ -888,7 +915,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

/*
* Removes the deleted mark of the node, does NOT really change the current graph.
*
*
* Note: the method is not safe to use when replacement of deleted elements is enabled,
* because elements marked as deleted can be completely removed by addPoint
*/
Expand Down Expand Up @@ -951,7 +978,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
* Adds point. Updates the point if it is already in the index.
* If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
*/
void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) {
void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) override {
if ((allow_replace_deleted_ == false) && (replace_deleted == true)) {
throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
}
Expand Down Expand Up @@ -1268,7 +1295,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const override {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

Expand Down Expand Up @@ -1377,6 +1404,27 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
return result;
}

int getMaxLevel() const override {
return maxlevel_;
}

InternalParameters getInternalParameters() {
InternalParameters params;
params.max_elements = max_elements_;
params.size_data_per_element = size_data_per_element_;
params.size_links_per_element = size_links_per_element_;
params.num_deleted = num_deleted_;
params.M = M_;
params.maxM = maxM_;
params.maxM0 = maxM0_;
params.ef_construction = ef_construction_;
params.ef = ef_;
params.mult = mult_;
params.maxlevel = maxlevel_;
params.size_links_level0 = size_links_level0_;
params.bytes_per_vector = data_size_;
return params;
}

void checkIntegrity() {
int connections_checked = 0;
Expand Down Expand Up @@ -1408,5 +1456,56 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
std::cout << "integrity ok, checked " << connections_checked << " connections\n";
}

// Populate index statistics in the Stats array.
Stats getStats(Stats* stats_per_level, int max_level) const {
if (max_level < 0) {
return {};
}
max_level = std::min(maxlevel_, max_level);
size_t node_head_bytes = size_links_level0_ + sizeof(labeltype); // Node header size

// Iterate through all the elements
auto num_elements = cur_element_count.load(std::memory_order_seq_cst);
for (size_t i = 0; i != num_elements; ++i) {
tableint internal_id = static_cast<size_t>(i);
auto node_level = element_levels_[internal_id];
if (node_level < 0) {
// This should not happen in practice.
continue;
}

// Base level (0)
stats_per_level[0].nodes++;
stats_per_level[0].edges += getListCount(get_linklist_at_level(internal_id, 0));
stats_per_level[0].allocated_bytes += node_head_bytes;

size_t max_level_for_node = std::min(
static_cast<size_t>(node_level), static_cast<size_t>(max_level));
for (size_t l = 1; l <= max_level_for_node; ++l) {
stats_per_level[l].nodes++;
stats_per_level[l].edges += getListCount(get_linklist_at_level(internal_id, l));
stats_per_level[l].allocated_bytes += size_links_per_element_;
}
}

// Compute max_edges based on the node count at each level
stats_per_level[0].max_edges = stats_per_level[0].nodes * maxM0_;
for (int l = 1; l <= max_level; ++l) {
stats_per_level[l].max_edges = stats_per_level[l].nodes * maxM_;
}

// Aggregate stats across all levels
Stats result{};
for (auto l = 0; l <= max_level; ++l) {
result.nodes += stats_per_level[l].nodes;
result.edges += stats_per_level[l].edges;
result.allocated_bytes += stats_per_level[l].allocated_bytes;
result.max_edges += stats_per_level[l].max_edges;
}

return result;
}

};
} // namespace hnswlib
3 changes: 3 additions & 0 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class AlgorithmInterface {
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;

virtual void saveIndex(const std::string &location) = 0;

virtual int getMaxLevel() const = 0;

virtual ~AlgorithmInterface(){
}
};
Expand Down