diff --git a/libs/community/langchain_community/vectorstores/surrealdb.py b/libs/community/langchain_community/vectorstores/surrealdb.py index 3157f48..2f5aaf9 100644 --- a/libs/community/langchain_community/vectorstores/surrealdb.py +++ b/libs/community/langchain_community/vectorstores/surrealdb.py @@ -1,15 +1,76 @@ +from __future__ import annotations + import asyncio -from typing import Any, Dict, Iterable, List, Optional, Tuple +from dataclasses import KW_ONLY, dataclass, field +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Optional, + Sequence, + Union, +) import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore +from surrealdb import ( + AsyncHttpSurrealConnection, + AsyncWsSurrealConnection, + BlockingHttpSurrealConnection, + BlockingWsSurrealConnection, + RecordID, +) from langchain_community.vectorstores.utils import maximal_marginal_relevance +if TYPE_CHECKING: + from surrealdb import AsyncSurreal, Surreal + DEFAULT_K = 4 # Number of Documents to return. +type SurrealConnection = Union[ + BlockingWsSurrealConnection, BlockingHttpSurrealConnection +] +type SurrealAsyncConnection = Union[ + AsyncWsSurrealConnection, AsyncHttpSurrealConnection +] + +GET_BY_ID_QUERY = """ + SELECT * + FROM type::table($table) + WHERE id IN array::combine([$table], $ids) + .map(|$v| type::thing($v[0], $v[1])) +""" + +# # Development commands: +# +# ```sh +# surreal start -u root -p root -l debug +# make integration_tests TEST_FILE=tests/integration_tests/vectorstores/test_surrealdb.py # noqa: E501 +# make format +# make lint +# ``` + + +@dataclass +class SurrealDocument: + _: KW_ONLY + id: RecordID = field(hash=False) + text: str + embedding: list[float] + similarity: float | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def into(self) -> Document: + return Document( + id=self.id.id, + page_content=self.text, + metadata=self.metadata, + ) + class SurrealDBStore(VectorStore): """ @@ -18,14 +79,9 @@ class SurrealDBStore(VectorStore): To use, you should have the ``surrealdb`` python package installed. Args: - embedding_function: Embedding function to use. - dburl: SurrealDB connection url - ns: surrealdb namespace for the vector store. (default: "langchain") - db: surrealdb database for the vector store. (default: "database") - collection: surrealdb collection for the vector store. - (default: "documents") - - (optional) db_user and db_pass: surrealdb credentials + embedding: The embedding function or model to use for generating embeddings. + table: SurrealDB table for the vector store (default: "documents"). + connection: SurrealDB connection Example: .. code-block:: python @@ -34,542 +90,540 @@ class SurrealDBStore(VectorStore): from langchain_community.embeddings import HuggingFaceEmbeddings model_name = "sentence-transformers/all-mpnet-base-v2" - embedding_function = HuggingFaceEmbeddings(model_name=model_name) - dburl = "ws://localhost:8000/rpc" - ns = "langchain" - db = "docstore" - collection = "documents" - db_user = "root" - db_pass = "root" - - sdb = SurrealDBStore.from_texts( - texts=texts, - embedding=embedding_function, - dburl, - ns, db, collection, - db_user=db_user, db_pass=db_pass) + embedding = HuggingFaceEmbeddings(model_name=model_name) + + conn = Surreal("ws://localhost:8000/rpc") + conn.signin({"username": "root", "password": "root"}) + conn.use("langchain", "test") + + connection = SurrealDBStore.from_texts( + texts=texts, + embedding=embedding, + connection=conn + ) """ def __init__( self, - embedding_function: Embeddings, - **kwargs: Any, + embedding: Embeddings, + connection: SurrealConnection | None, + table: str = "documents", + index_name: str = "documents_vector_index", + embedding_dimension: int | None = None, + async_connection: SurrealAsyncConnection | None = None, ) -> None: - try: - from surrealdb import Surreal - except ImportError as e: - raise ImportError( - """Cannot import from surrealdb. - please install with `pip install surrealdb`.""" - ) from e - - self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc") - - if self.dburl[0:2] == "ws": - self.sdb = Surreal(self.dburl) + self.embedding = embedding + self.table = table + self.index_name = index_name + self.connection = connection + self.async_connection = async_connection + if embedding_dimension is not None: + self.embedding_dimension = embedding_dimension else: - raise ValueError("Only websocket connections are supported at this time.") - - self.ns = kwargs.pop("ns", "langchain") - self.db = kwargs.pop("db", "database") - self.collection = kwargs.pop("collection", "documents") - self.embedding_function = embedding_function - self.kwargs = kwargs + self.embedding_dimension = len(self.embedding.embed_query("foo")) + self._ensure_index() - async def initialize(self) -> None: - """ - Initialize connection to surrealdb database - and authenticate if credentials are provided + def _ensure_index(self) -> None: + query = f""" + DEFINE INDEX IF NOT EXISTS {self.index_name} + ON TABLE {self.table} + FIELDS embedding + MTREE DIMENSION {self.embedding_dimension} DIST COSINE TYPE F32 + CONCURRENTLY; """ - await self.sdb.connect() - if "db_user" in self.kwargs and "db_pass" in self.kwargs: - user = self.kwargs.get("db_user") - password = self.kwargs.get("db_pass") - await self.sdb.signin({"user": user, "pass": password}) - await self.sdb.use(self.ns, self.db) + if self.async_connection is not None: + loop = asyncio.get_event_loop() + loop.create_task(self.async_connection.query(query)) + elif self.connection is not None: + self.connection.query(query) + else: + raise ValueError("No connection provided") - @property - def embeddings(self) -> Optional[Embeddings]: - return ( - self.embedding_function - if isinstance(self.embedding_function, Embeddings) - else None + def _build_text_data( + self, + text: str, + embedding: list[float], + metadata: dict | None, + with_id: str | None, + ) -> tuple[RecordID | None, dict]: + preferred_id = None + data = {"text": text, "embedding": embedding, "metadata": {}} + if metadata is not None: + data["metadata"] = metadata + preferred_id = metadata.get("id") + if with_id is not None: + preferred_id = with_id + record_id = ( + RecordID(self.table, preferred_id) if preferred_id is not None else None ) - - async def aadd_texts( + return record_id, data + + @staticmethod + def _parse_documents(ids: Sequence[str], results: list[dict]) -> list[Document]: + docs = {} + for x in results: + doc = SurrealDocument(**x).into() + docs[doc.id] = doc + # sort docs in the same order as the passed in IDs + result: list[Document] = [] + for key in ids: + d = docs.get(str(key)) + if d is not None: + result.append(d) + return result + + def _build_search_query( self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> List[str]: - """Add list of text along with embeddings to the vector store asynchronously + embedding: list[float], + k: int, + score_threshold: float, + custom_filter: dict[str, str] | None, + ) -> tuple[str, dict]: + args = { + "table": self.table, + "embedding": embedding, + "k": k, + "score_threshold": score_threshold, + } - Args: - texts (Iterable[str]): collection of text to add to the database + # build additional filter criteria + custom_filter_str = "" + if custom_filter: + for key in custom_filter: + # check value type + if type(custom_filter[key]) in [str, bool]: + filter_value = f"'{custom_filter[key]}'" + else: + filter_value = f"{custom_filter[key]}" + + custom_filter_str += f"and metadata.{key} = {filter_value} " - Returns: - List of ids for the newly inserted documents + query = f""" + SELECT + id, + text, + metadata, + embedding, + similarity + FROM ( + SELECT + id, + text, + metadata, + embedding, + vector::similarity::cosine(embedding, $embedding) as similarity + FROM type::table($table) + WHERE embedding <|{k}|> $embedding + {custom_filter_str} + ) + WHERE similarity >= $score_threshold + ORDER BY similarity DESC """ - embeddings = self.embedding_function.embed_documents(list(texts)) - ids = [] - for idx, text in enumerate(texts): - data = {"text": text, "embedding": embeddings[idx]} - if metadatas is not None and idx < len(metadatas): - data["metadata"] = metadatas[idx] # type: ignore[assignment] - else: - data["metadata"] = [] - record = await self.sdb.create( - self.collection, - data, + + return query, args + + @staticmethod + def _parse_results( + results: list[dict], + ) -> list[tuple[Document, float, list[float]]]: + parsed = [] + for raw in results: + parsed.append( + ( + SurrealDocument(**raw).into(), + raw["similarity"], + raw["embedding"], + ), ) - ids.append(record[0]["id"]) - return ids + return parsed - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> List[str]: - """Add list of text along with embeddings to the vector store + @staticmethod + def _filter_documents_from_result( + search_result: list[tuple[Document, float, list[float]]], + embedding: list[float], + k: int = DEFAULT_K, + lambda_mult: float = 0.5, + ) -> list[Document]: + # extract only document from result + docs = [sub[0] for sub in search_result] + # extract only embedding from result + embeddings = [sub[-1] for sub in search_result] - Args: - texts (Iterable[str]): collection of text to add to the database + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + embeddings, + k=k, + lambda_mult=lambda_mult, + ) - Returns: - List of ids for the newly inserted documents - """ + return [docs[i] for i in mmr_selected] - async def _add_texts( - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> List[str]: - await self.initialize() - return await self.aadd_texts(texts, metadatas, **kwargs) + def _aux( + self, + embedding: list[float], + *, + k: int = DEFAULT_K, + score_threshold: float = -1, + custom_filter: dict[str, str] | None = None, + ) -> list[tuple[Document, float, list[float]]]: + if self.connection is None: + raise ValueError("No connection provided") + query, args = self._build_search_query( + embedding, k, score_threshold, custom_filter + ) + results = self.connection.query(query, args) + return self._parse_results(results) - return asyncio.run(_add_texts(texts, metadatas, **kwargs)) + # ========================================================================= + # == Extended methods + # ========================================================================= - async def adelete( + @property + def embeddings(self) -> Embeddings | None: + return self.embedding if isinstance(self.embedding, Embeddings) else None + + def add_texts( self, - ids: Optional[List[str]] = None, + texts: Iterable[str], + metadatas: list[dict] | None = None, + *, + ids: list[str] | None = None, **kwargs: Any, - ) -> Optional[bool]: - """Delete by document ID asynchronously. - - Args: - ids: List of ids to delete. - **kwargs: Other keyword arguments that subclasses might use. - - Returns: - Optional[bool]: True if deletion is successful, - False otherwise. - """ + ) -> list[str]: + if self.connection is None: + raise ValueError("No connection provided") + embeddings = self.embedding.embed_documents(list(texts)) + result_ids = [] + for idx, text in enumerate(texts): + record_id, data = self._build_text_data( + text, + embeddings[idx], + metadatas[idx] if metadatas is not None else None, + ids[idx] if ids is not None else None, + ) + if record_id is not None: + inserted = self.connection.upsert(record_id, data) + else: + inserted = self.connection.insert(self.table, data) + if isinstance(inserted, list): + for record in inserted: + result_ids.append(record["id"].id) + else: + result_ids.append(inserted["id"].id) + return result_ids - if ids is None: - await self.sdb.delete(self.collection) - return True - else: - if isinstance(ids, str): - await self.sdb.delete(ids) - return True + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: list[dict] | None = None, + *, + ids: list[str] | None = None, + **kwargs: Any, + ) -> list[str]: + if self.async_connection is None: + raise ValueError("No async connection provided") + embeddings = self.embedding.embed_documents(list(texts)) + result_ids = [] + coroutines = [] + for idx, text in enumerate(texts): + record_id, data = self._build_text_data( + text, + embeddings[idx], + metadatas[idx] if metadatas is not None else None, + ids[idx] if ids is not None else None, + ) + if record_id is not None: + coroutines.append(self.async_connection.upsert(record_id, data)) else: - if isinstance(ids, list) and len(ids) > 0: - _ = [await self.sdb.delete(id) for id in ids] - return True - return False + coroutines.append(self.async_connection.insert(self.table, data)) + results = await asyncio.gather(*coroutines) + for inserted in results: + if isinstance(inserted, list): + for record in inserted: + result_ids.append(record["id"].id) + elif isinstance(inserted, dict): + result_ids.append(inserted["id"].id) + return result_ids def delete( self, - ids: Optional[List[str]] = None, + ids: list[str] | None = None, **kwargs: Any, - ) -> Optional[bool]: - """Delete by document ID. - - Args: - ids: List of ids to delete. - **kwargs: Other keyword arguments that subclasses might use. - - Returns: - Optional[bool]: True if deletion is successful, - False otherwise. - """ - - async def _delete(ids: Optional[List[str]], **kwargs: Any) -> Optional[bool]: - await self.initialize() - return await self.adelete(ids=ids, **kwargs) + ) -> bool | None: + if self.connection is None: + raise ValueError("No connection provided") + try: + if ids is not None: + for id in ids: + self.connection.delete(RecordID(self.table, id)) + else: + self.connection.delete(self.table) + except Exception as _e: + return False + return True - return asyncio.run(_delete(ids, **kwargs)) + async def adelete( + self, ids: Optional[list[str]] = None, **kwargs: Any + ) -> Optional[bool]: + if self.async_connection is None: + raise ValueError("No async connection provided") + try: + if ids is not None: + coroutines = [ + self.async_connection.delete(RecordID(self.table, id)) for id in ids + ] + await asyncio.gather(*coroutines) + else: + await self.async_connection.delete(self.table) + except Exception as _e: + return False + return True + + def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: + if self.connection is None: + raise ValueError("No connection provided") + query_results = self.connection.query( + GET_BY_ID_QUERY, + {"table": self.table, "ids": ids}, + ) + return self._parse_documents(ids, query_results) + + async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]: + if self.async_connection is None: + raise ValueError("No async connection provided") + query_results = await self.async_connection.query( + GET_BY_ID_QUERY, + {"table": self.table, "ids": ids}, + ) + return self._parse_documents(ids, query_results) - async def _asimilarity_search_by_vector_with_score( + def similarity_search( self, - embedding: List[float], - k: int = DEFAULT_K, + query: str, + k: int = 4, *, - filter: Optional[Dict[str, str]] = None, + custom_filter: dict[str, str] | None = None, **kwargs: Any, - ) -> List[Tuple[Document, float, Any]]: - """Run similarity search for query embedding asynchronously - and return documents and scores - - Args: - embedding (List[float]): Query embedding. - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar along with scores - """ - args = { - "collection": self.collection, - "embedding": embedding, - "k": k, - "score_threshold": kwargs.get("score_threshold", 0), - } - - # build additional filter criteria - custom_filter = "" - if filter: - for key in filter: - # check value type - if type(filter[key]) in [str, bool]: - filter_value = f"'{filter[key]}'" - else: - filter_value = f"{filter[key]}" - - custom_filter += f"and metadata.{key} = {filter_value} " - - query = f""" - select - id, - text, - metadata, - embedding, - vector::similarity::cosine(embedding, $embedding) as similarity - from ⟨{args["collection"]}⟩ - where vector::similarity::cosine(embedding, $embedding) >= $score_threshold - {custom_filter} - order by similarity desc LIMIT $k; - """ - results = await self.sdb.query(query, args) - - if len(results) == 0: - return [] - - result = results[0] - - if result["status"] != "OK": - from surrealdb.ws import SurrealException - - err = result.get("result", "Unknown Error") - raise SurrealException(err) + ) -> list[Document]: + query_embedding = self.embedding.embed_query(query) + return self.similarity_search_by_vector( + query_embedding, k, custom_filter=custom_filter, **kwargs + ) + def similarity_search_with_score( + self, + query: str, + *, + k: int = DEFAULT_K, + score_threshold: float = -1, + custom_filter: dict[str, str] | None = None, + ) -> list[tuple[Document, float]]: + embedding = self.embedding.embed_query(query) return [ - ( - Document( - page_content=doc["text"], - metadata={"id": doc["id"], **(doc.get("metadata") or {})}, - ), - doc["similarity"], - doc["embedding"], + (d, s) + for d, s, _ in self._aux( + embedding, + k=k, + score_threshold=score_threshold, + custom_filter=custom_filter, ) - for doc in result["result"] ] - async def asimilarity_search_with_relevance_scores( + def similarity_search_by_vector( self, - query: str, + embedding: list[float], k: int = DEFAULT_K, *, - filter: Optional[Dict[str, str]] = None, + custom_filter: dict[str, str] | None = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Run similarity search asynchronously and return relevance scores - - Args: - query (str): Query - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar along with relevance scores - """ - query_embedding = self.embedding_function.embed_query(query) + ) -> list[Document]: return [ - (document, similarity) - for document, similarity, _ in ( - await self._asimilarity_search_by_vector_with_score( - query_embedding, k, filter=filter, **kwargs - ) + document + for document, _, _ in self._aux( + embedding=embedding, k=k, custom_filter=custom_filter ) ] - def similarity_search_with_relevance_scores( + def max_marginal_relevance_search( self, query: str, k: int = DEFAULT_K, + fetch_k: int = 20, + lambda_mult: float = 0.5, + *, + custom_filter: dict[str, str] | None = None, + **kwargs: Any, + ) -> list[Document]: + embedding = self.embedding.embed_query(query) + docs = self.max_marginal_relevance_search_by_vector( + embedding, k, fetch_k, lambda_mult, custom_filter=custom_filter, **kwargs + ) + return docs + + def max_marginal_relevance_search_by_vector( + self, + embedding: list[float], + k: int = DEFAULT_K, + fetch_k: int = 20, + lambda_mult: float = 0.5, *, - filter: Optional[Dict[str, str]] = None, + custom_filter: dict[str, str] | None = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Run similarity search synchronously and return relevance scores + ) -> list[Document]: + result = self._similarity_search_by_vector_with_score( + embedding, fetch_k, custom_filter=custom_filter, **kwargs + ) + return self._filter_documents_from_result(result, embedding, k, lambda_mult) - Args: - query (str): Query - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + @classmethod + def from_texts( + cls, + texts: list[str], + embedding: Embeddings, + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + connection: Surreal = None, + **kwargs: Any, + ) -> "SurrealDBStore": + store = SurrealDBStore(embedding, connection) + store.add_texts(texts, metadatas) + return store - Returns: - List of Documents most similar along with relevance scores - """ + @classmethod + async def afrom_texts( + cls, + texts: list[str], + embedding: Embeddings, + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + connection: AsyncSurreal = None, + **kwargs: Any, + ) -> "SurrealDBStore": + store = SurrealDBStore(embedding, None, async_connection=connection) + await store.aadd_texts(texts, metadatas) + return store - async def _similarity_search_with_relevance_scores() -> List[ - Tuple[Document, float] - ]: - await self.initialize() - return await self.asimilarity_search_with_relevance_scores( - query, k, filter=filter, **kwargs - ) + # TODO: implement + def _select_relevance_score_fn(self) -> Callable[[float], float]: + raise NotImplementedError - return asyncio.run(_similarity_search_with_relevance_scores()) + # ========================================================================= + # ========================================================================= + # ========================================================================= - async def asimilarity_search_with_score( + def _similarity_search_by_vector_with_score( + self, + embedding: list[float], + k: int = DEFAULT_K, + score_threshold: float = -1, + custom_filter: dict[str, str] | None = None, + ) -> list[tuple[Document, float, list[float]]]: + if self.connection is None: + raise ValueError("No connection provided") + query, args = self._build_search_query( + embedding, k, score_threshold, custom_filter + ) + results = self.connection.query(query, args) + return self._parse_results(results) + + async def _asimilarity_search_by_vector_with_score( + self, + embedding: list[float], + k: int = DEFAULT_K, + *, + custom_filter: dict[str, str] | None = None, + score_threshold: float = -1, + ) -> list[tuple[Document, float, list[float]]]: + if self.async_connection is None: + raise ValueError("No async connection provided") + query, args = self._build_search_query( + embedding, k, score_threshold, custom_filter + ) + results = await self.async_connection.query(query, args) + return self._parse_results(results) + + async def asimilarity_search_with_relevance_scores( self, query: str, k: int = DEFAULT_K, *, - filter: Optional[Dict[str, str]] = None, + custom_filter: dict[str, str] | None = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Run similarity search asynchronously and return distance scores - - Args: - query (str): Query - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar along with relevance distance scores - """ - query_embedding = self.embedding_function.embed_query(query) + ) -> list[tuple[Document, float]]: + query_embedding = self.embedding.embed_query(query) + # TODO: improve using asyncio.gather return [ (document, similarity) for document, similarity, _ in ( await self._asimilarity_search_by_vector_with_score( - query_embedding, k, filter=filter, **kwargs + query_embedding, k, custom_filter=custom_filter, **kwargs ) ) ] - def similarity_search_with_score( + def similarity_search_with_relevance_scores( self, query: str, k: int = DEFAULT_K, *, - filter: Optional[Dict[str, str]] = None, + custom_filter: dict[str, str] | None = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Run similarity search synchronously and return distance scores - - Args: - query (str): Query - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar along with relevance distance scores - """ - - async def _similarity_search_with_score() -> List[Tuple[Document, float]]: - await self.initialize() - return await self.asimilarity_search_with_score( - query, k, filter=filter, **kwargs + ) -> list[tuple[Document, float]]: + query_embedding = self.embedding.embed_query(query) + return [ + (document, similarity) + for document, similarity, _ in ( + self._similarity_search_by_vector_with_score( + query_embedding, k, custom_filter=custom_filter, **kwargs + ) ) - - return asyncio.run(_similarity_search_with_score()) + ] async def asimilarity_search_by_vector( self, - embedding: List[float], + embedding: list[float], k: int = DEFAULT_K, *, - filter: Optional[Dict[str, str]] = None, + custom_filter: dict[str, str] | None = None, **kwargs: Any, - ) -> List[Document]: - """Run similarity search on query embedding asynchronously - - Args: - embedding (List[float]): Query embedding - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar to the query - """ + ) -> list[Document]: + # TODO: improve using asyncio.gather return [ document for document, _, _ in await self._asimilarity_search_by_vector_with_score( - embedding, k, filter=filter, **kwargs + embedding, k, custom_filter=custom_filter, **kwargs ) ] - def similarity_search_by_vector( - self, - embedding: List[float], - k: int = DEFAULT_K, - *, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Document]: - """Run similarity search on query embedding - - Args: - embedding (List[float]): Query embedding - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar to the query - """ - - async def _similarity_search_by_vector() -> List[Document]: - await self.initialize() - return await self.asimilarity_search_by_vector( - embedding, k, filter=filter, **kwargs - ) - - return asyncio.run(_similarity_search_by_vector()) - async def asimilarity_search( self, query: str, k: int = DEFAULT_K, *, - filter: Optional[Dict[str, str]] = None, + custom_filter: dict[str, str] | None = None, **kwargs: Any, - ) -> List[Document]: - """Run similarity search on query asynchronously - - Args: - query (str): Query - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar to the query - """ - query_embedding = self.embedding_function.embed_query(query) + ) -> list[Document]: + query_embedding = self.embedding.embed_query(query) return await self.asimilarity_search_by_vector( - query_embedding, k, filter=filter, **kwargs + query_embedding, k, custom_filter=custom_filter, **kwargs ) - def similarity_search( - self, - query: str, - k: int = DEFAULT_K, - *, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Document]: - """Run similarity search on query - - Args: - query (str): Query - k (int): Number of results to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents most similar to the query - """ - - async def _similarity_search() -> List[Document]: - await self.initialize() - return await self.asimilarity_search(query, k, filter=filter, **kwargs) - - return asyncio.run(_similarity_search()) - async def amax_marginal_relevance_search_by_vector( self, - embedding: List[float], + embedding: list[float], k: int = DEFAULT_K, fetch_k: int = 20, lambda_mult: float = 0.5, *, - filter: Optional[Dict[str, str]] = None, + custom_filter: dict[str, str] | None = None, **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - + ) -> list[Document]: result = await self._asimilarity_search_by_vector_with_score( - embedding, fetch_k, filter=filter, **kwargs + embedding, fetch_k, custom_filter=custom_filter, **kwargs ) - - # extract only document from result - docs = [sub[0] for sub in result] - # extract only embedding from result - embeddings = [sub[-1] for sub in result] - - mmr_selected = maximal_marginal_relevance( - np.array(embedding, dtype=np.float32), - embeddings, - k=k, - lambda_mult=lambda_mult, - ) - - return [docs[i] for i in mmr_selected] - - def max_marginal_relevance_search_by_vector( - self, - embedding: List[float], - k: int = DEFAULT_K, - fetch_k: int = 20, - lambda_mult: float = 0.5, - *, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - - async def _max_marginal_relevance_search_by_vector() -> List[Document]: - await self.initialize() - return await self.amax_marginal_relevance_search_by_vector( - embedding, k, fetch_k, lambda_mult, filter=filter, **kwargs - ) - - return asyncio.run(_max_marginal_relevance_search_by_vector()) + return self._filter_documents_from_result(result, embedding, k, lambda_mult) async def amax_marginal_relevance_search( self, @@ -578,126 +632,11 @@ async def amax_marginal_relevance_search( fetch_k: int = 20, lambda_mult: float = 0.5, *, - filter: Optional[Dict[str, str]] = None, + custom_filter: dict[str, str] | None = None, **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - - embedding = self.embedding_function.embed_query(query) + ) -> list[Document]: + embedding = self.embedding.embed_query(query) docs = await self.amax_marginal_relevance_search_by_vector( - embedding, k, fetch_k, lambda_mult, filter=filter, **kwargs + embedding, k, fetch_k, lambda_mult, custom_filter=custom_filter, **kwargs ) return docs - - def max_marginal_relevance_search( - self, - query: str, - k: int = DEFAULT_K, - fetch_k: int = 20, - lambda_mult: float = 0.5, - *, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - - Returns: - List of Documents selected by maximal marginal relevance. - """ - - async def _max_marginal_relevance_search() -> List[Document]: - await self.initialize() - return await self.amax_marginal_relevance_search( - query, k, fetch_k, lambda_mult, filter=filter, **kwargs - ) - - return asyncio.run(_max_marginal_relevance_search()) - - @classmethod - async def afrom_texts( - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> "SurrealDBStore": - """Create SurrealDBStore from list of text asynchronously - - Args: - texts (List[str]): list of text to vectorize and store - embedding (Optional[Embeddings]): Embedding function. - dburl (str): SurrealDB connection url - (default: "ws://localhost:8000/rpc") - ns (str): surrealdb namespace for the vector store. - (default: "langchain") - db (str): surrealdb database for the vector store. - (default: "database") - collection (str): surrealdb collection for the vector store. - (default: "documents") - - (optional) db_user and db_pass: surrealdb credentials - - Returns: - SurrealDBStore object initialized and ready for use.""" - - sdb = cls(embedding, **kwargs) - await sdb.initialize() - await sdb.aadd_texts(texts, metadatas, **kwargs) - return sdb - - @classmethod - def from_texts( - cls, - texts: List[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> "SurrealDBStore": - """Create SurrealDBStore from list of text - - Args: - texts (List[str]): list of text to vectorize and store - embedding (Optional[Embeddings]): Embedding function. - dburl (str): SurrealDB connection url - ns (str): surrealdb namespace for the vector store. - (default: "langchain") - db (str): surrealdb database for the vector store. - (default: "database") - collection (str): surrealdb collection for the vector store. - (default: "documents") - - (optional) db_user and db_pass: surrealdb credentials - - Returns: - SurrealDBStore object initialized and ready for use.""" - sdb = asyncio.run(cls.afrom_texts(texts, embedding, metadatas, **kwargs)) - return sdb diff --git a/libs/community/tests/integration_tests/vectorstores/test_surrealdb.py b/libs/community/tests/integration_tests/vectorstores/test_surrealdb.py new file mode 100644 index 0000000..35b38c2 --- /dev/null +++ b/libs/community/tests/integration_tests/vectorstores/test_surrealdb.py @@ -0,0 +1,66 @@ +from typing import Generator + +import pytest +from langchain_tests.integration_tests.vectorstores import VectorStoreIntegrationTests + +from langchain_community.vectorstores.surrealdb import SurrealDBStore + + +class TestSurrealDB(VectorStoreIntegrationTests): + @property + def has_async(self) -> bool: + return False + + @pytest.fixture + def vectorstore(self) -> Generator[SurrealDBStore, None, None]: # type: ignore[override] + try: + from surrealdb import Surreal + except ImportError as e: + raise ImportError( + """Cannot import from surrealdb. + please install with `pip install surrealdb`.""" + ) from e + + conn = Surreal("ws://localhost:8000/rpc") + conn.signin({"username": "root", "password": "root"}) + conn.use("langchain", "test") + store = SurrealDBStore(self.get_embeddings(), conn) + store.delete() + try: + yield store + finally: + store.delete() + + +# FIXME: async test throws "got Future attached to a different +# loop" error +# class TestSurrealDBAsync(VectorStoreIntegrationTests): +# @property +# def has_sync(self) -> bool: +# return False +# +# @pytest.fixture +# def vectorstore(self) -> Generator[SurrealDBStore, None, None]: +# try: +# from surrealdb import AsyncSurreal +# except ImportError as e: +# raise ImportError( +# """Cannot import from surrealdb. +# please install with `pip install surrealdb`.""" +# ) from e +# +# async def _connect() -> AsyncSurreal: +# conn = AsyncSurreal("ws://localhost:8000/rpc") +# await conn.signin({"username": "root", "password": "root"}) +# await conn.use("langchain", "test") +# return conn +# +# async_conn = asyncio.run(_connect()) +# store = SurrealDBStore( +# self.get_embeddings(), None, async_connection=async_conn +# ) +# asyncio.run(store.adelete()) +# try: +# yield store +# finally: +# asyncio.run(store.adelete())