diff --git a/Cargo.toml b/Cargo.toml index 4638678..26fedb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,8 +39,8 @@ thiserror = "2.0.9" tinyvec = { version = "1.9.0", features = ["rustc_1_55"] } tracing = "0.1.41" steppe = { version = "0.4", default-features = false } -pyo3 = { version = "0.25.1", optional = true } -pyo3-stub-gen = { version = "0.13.1", optional = true } +pyo3 = { version = "0.26.0", optional = true } +pyo3-stub-gen = { version = "0.22.1", optional = true } once_cell = { version = "1.21.3", optional = true } tempfile = { version = "3.21.0", optional = true } parking_lot = { version = "0.12.4", optional = true } diff --git a/hannoy.pyi b/hannoy.pyi index 08692a7..c9fb7bc 100644 --- a/hannoy.pyi +++ b/hannoy.pyi @@ -1,30 +1,38 @@ # This file is automatically generated by pyo3_stub_gen -# ruff: noqa: E501, F401 +# ruff: noqa: E501, F401, F403, F405 import builtins +import enum import os import pathlib import typing -from enum import Enum +__all__ = [ + "Database", + "Metric", + "Reader", + "Writer", +] +@typing.final class Database: r""" - An LMDB-backed vector database for vector search. + An LMDB-backed database for vector search. """ - def __new__(cls, path:builtins.str | os.PathLike | pathlib.Path, distance:Metric=..., name:typing.Optional[builtins.str]=None, env_size:typing.Optional[builtins.int]=None) -> Database: ... - def writer(self, dimensions:builtins.int, index:builtins.int=0, m:builtins.int=16, ef:builtins.int=96) -> Writer: + def __new__(cls, path: builtins.str | os.PathLike | pathlib.Path, distance: Metric = ..., name: typing.Optional[builtins.str] = None, env_size: typing.Optional[builtins.int] = None) -> Database: ... + def writer(self, dimensions: builtins.int, index: builtins.int = 0, m: builtins.int = 16, ef: builtins.int = 96) -> Writer: r""" Get a writer for a specific index and dimensions. """ - def reader(self, index:builtins.int=0) -> Reader: + def reader(self, index: builtins.int = 0) -> Reader: r""" - Get a reader for a specific index and dimensions + Open a reader for a specific index. """ @staticmethod def commit_rw_txn() -> builtins.bool: ... @staticmethod def abort_rw_txn() -> builtins.bool: ... +@typing.final class Reader: r""" A thread-local Database reader holding its own `RoTxn`. It is safe to spawn multiple readers in @@ -38,11 +46,17 @@ class Reader: reader.by_vec([1.0, 0.0], n = 1) ``` """ - def by_vec(self, query:typing.Sequence[builtins.float], n:builtins.int=10, ef_search:builtins.int=200) -> builtins.list[tuple[builtins.int, builtins.float]]: + def by_vec(self, query: typing.Sequence[builtins.float], n: builtins.int = 10, ef_search: builtins.int = 200) -> builtins.list[tuple[builtins.int, builtins.float]]: r""" Retrieve similar items from the db given a query. """ + def by_item(self, item: builtins.int, n: builtins.int = 10, ef_search: builtins.int = 200) -> typing.Optional[builtins.list[tuple[builtins.int, builtins.float]]]: + r""" + Retrieve similar items from the db given an item ID. + Returns `None` if the item is not in the database. + """ +@typing.final class Writer: r""" A struct for configuring the HNSW build and performing transactional insertions/deletions from @@ -60,15 +74,16 @@ class Writer: ``` """ def __enter__(self) -> Writer: ... - def __exit__(self, _exc_type:typing.Optional[type], _exc_value:typing.Optional[typing.Any], _traceback:typing.Optional[typing.Any]) -> None: ... - def add_item(self, item:builtins.int, vector:typing.Sequence[builtins.float]) -> None: + def __exit__(self, _exc_type: typing.Optional[type], _exc_value: typing.Optional[typing.Any], _traceback: typing.Optional[typing.Any]) -> None: ... + def add_item(self, item: builtins.int, vector: typing.Sequence[builtins.float]) -> None: r""" Store a vector associated with an item ID in the database. """ -class Metric(Enum): +@typing.final +class Metric(enum.Enum): r""" - Supported distance metrics in hannoy. + The supported distance metrics in hannoy. """ COSINE = ... EUCLIDEAN = ... diff --git a/pyproject.toml b/pyproject.toml index 2af3724..f635292 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,6 @@ [build-system] requires = ["maturin>=1.9.1,<2"] build-backend = "maturin" -description = "Python bindings for Hannoy; a KV-backed HNSW implementation in Rust using LMDB." -requires-python = ">=3.10" [project] name = "hannoy" diff --git a/src/python.rs b/src/python.rs index fa3d172..5387a9f 100644 --- a/src/python.rs +++ b/src/python.rs @@ -353,6 +353,21 @@ enum DynReader { Hamming(Reader), } + +macro_rules! hnsw_search { + ($reader:expr, |r| r . $($q:tt)*) => { + match $reader { + DynReader::Cosine(reader) => reader . $($q)*, + DynReader::Euclidean(reader) => reader . $($q)*, + DynReader::Manhattan(reader) => reader . $($q)*, + DynReader::BqCosine(reader) => reader . $($q)*, + DynReader::BqEuclidean(reader) => reader . $($q)*, + DynReader::BqManhattan(reader) => reader . $($q)*, + DynReader::Hamming(reader) => reader . $($q)*, + } + }; +} + /// A thread-local Database reader holding its own `RoTxn`. It is safe to spawn multiple readers in /// different threads. /// @@ -377,24 +392,18 @@ impl PyReader { #[pyo3(signature = (query, n=10, ef_search=200))] fn by_vec(&self, query: Vec, n: usize, ef_search: usize) -> PyResult> { let rtxn = &self.rtxn; - - macro_rules! hnsw_search { - ($read:expr, $q:expr) => { - $read.nns(n).ef_search(ef_search).by_vector(&rtxn, $q).map_err(h2py_err) - }; - } - - let found = match &self.dyn_reader { - DynReader::Cosine(reader) => hnsw_search!(reader, &query)?, - DynReader::Euclidean(reader) => hnsw_search!(reader, &query)?, - DynReader::Manhattan(reader) => hnsw_search!(reader, &query)?, - DynReader::BqCosine(reader) => hnsw_search!(reader, &query)?, - DynReader::BqEuclidean(reader) => hnsw_search!(reader, &query)?, - DynReader::BqManhattan(reader) => hnsw_search!(reader, &query)?, - DynReader::Hamming(reader) => hnsw_search!(reader, &query)?, - }; + let found = hnsw_search!(&self.dyn_reader, |r| r.nns(n).ef_search(ef_search).by_vector(&rtxn, &query)).map_err(h2py_err)?; Ok(found.into_nns()) } + + /// Retrieve similar items from the db given an item ID. + /// Returns `None` if the item is not in the database. + #[pyo3(signature = (item, n=10, ef_search=200))] + fn by_item(&self, item: ItemId, n: usize, ef_search: usize) -> PyResult>> { + let rtxn = &self.rtxn; + let found = hnsw_search!(&self.dyn_reader, |r| r.nns(n).ef_search(ef_search).by_item(&rtxn, item)).map_err(h2py_err)?; + Ok(found.map(|s| s.into_nns())) + } } fn h2py_err>(e: E) -> PyErr { diff --git a/tests/test_basic.py b/tests/test_basic.py index 7000b5d..3a6fdf0 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,8 +1,10 @@ from pathlib import Path -from typing import List +from typing import Literal + import pytest + import hannoy -from hannoy import Metric, Reader, Writer +from hannoy import Metric, Reader @pytest.fixture(scope="function", autouse=False) @@ -26,7 +28,6 @@ def test_read(db: hannoy.Database) -> None: query = [0.0, 1.0, 0.0] res = reader.by_vec(query, n=2) - print(res) assert len(res) == 2 (item_id, dist) = res[0] @@ -34,10 +35,21 @@ def test_read(db: hannoy.Database) -> None: assert dist == 0.0 +def test_read_by_item(db: hannoy.Database) -> None: + reader: Reader = db.reader(0) + + res = reader.by_item(1, n=2) + assert res is not None + assert len(res) == 2 + + assert {item_id for item_id, _ in res} == {0, 2} + assert not any(d == 0 for _, d in res) + + def test_multithreaded_reads(db) -> None: import threading - def _read(db: hannoy.Database, query: List[float]): + def _read(db: hannoy.Database, query: list[float]): reader = db.reader(0) t_id = threading.get_ident() print(f"nns from thread {t_id}: {reader.by_vec(query, 1)}")