From b2a519d8c56078fe3475004afffcae30682d8e07 Mon Sep 17 00:00:00 2001 From: PrashantDixit-dev Date: Sat, 19 Oct 2024 17:30:59 +0530 Subject: [PATCH] lancedb integration --- Makefile | 4 + requirements/ingest/lancedb.in | 3 + setup.py | 1 + unstructured/ingest/connector/lancedb.py | 119 ++++++++++++++++++ .../ingest/v2/examples/examples_lancedb.py | 52 ++++++++ 5 files changed, 179 insertions(+) create mode 100644 requirements/ingest/lancedb.in create mode 100644 unstructured/ingest/connector/lancedb.py create mode 100644 unstructured/ingest/v2/examples/examples_lancedb.py diff --git a/Makefile b/Makefile index d9a3e1803f..82b293663b 100644 --- a/Makefile +++ b/Makefile @@ -233,6 +233,10 @@ install-ingest-pinecone: install-ingest-qdrant: ${PYTHON} -m pip install -r requirements/ingest/qdrant.txt +.PHONY: install-ingest-lancedb +install-ingest-lancedb: + ${PYTHON} -m pip install -r requirements/ingest/lancedb.txt + .PHONY: install-ingest-chroma install-ingest-chroma: ${PYTHON} -m pip install -r requirements/ingest/chroma.txt diff --git a/requirements/ingest/lancedb.in b/requirements/ingest/lancedb.in new file mode 100644 index 0000000000..a73d461ee0 --- /dev/null +++ b/requirements/ingest/lancedb.in @@ -0,0 +1,3 @@ +-c ../deps/constraints.txt +-c ../base.txt +lancedb \ No newline at end of file diff --git a/setup.py b/setup.py index 89813f7c17..d5bac1d036 100644 --- a/setup.py +++ b/setup.py @@ -150,6 +150,7 @@ def load_requirements(file_list: Optional[Union[str, List[str]]] = None) -> List "opensearch": load_requirements("requirements/ingest/opensearch.in"), "outlook": load_requirements("requirements/ingest/outlook.in"), "pinecone": load_requirements("requirements/ingest/pinecone.in"), + "lancedb": load_requirements("requirements/ingest/lancedb.in"), "postgres": load_requirements("requirements/ingest/postgres.in"), "qdrant": load_requirements("requirements/ingest/qdrant.in"), "reddit": load_requirements("requirements/ingest/reddit.in"), diff --git a/unstructured/ingest/connector/lancedb.py b/unstructured/ingest/connector/lancedb.py new file mode 100644 index 0000000000..a4ace754c4 --- /dev/null +++ b/unstructured/ingest/connector/lancedb.py @@ -0,0 +1,119 @@ +import copy +import json +import multiprocessing as mp +import typing as t +import uuid +from dataclasses import dataclass + +from unstructured.ingest.enhanced_dataclass import enhanced_field +from unstructured.ingest.enhanced_dataclass.core import _asdict +from unstructured.ingest.error import DestinationConnectionError, WriteError +from unstructured.ingest.interfaces import ( + AccessConfig, + BaseConnectorConfig, + BaseDestinationConnector, + ConfigSessionHandleMixin, + IngestDocSessionHandleMixin, + WriteConfig, +) +from unstructured.ingest.logger import logger +from unstructured.ingest.utils.data_prep import batch_generator +from unstructured.staging.base import flatten_dict +from unstructured.utils import requires_dependencies + +if t.TYPE_CHECKING: + import lancedb + +@dataclass +class LanceDBAccessConfig(AccessConfig): + uri: str = enhanced_field(sensitive=True) + +@dataclass +class SimpleLanceDBConfig(ConfigSessionHandleMixin, BaseConnectorConfig): + table_name: str + access_config: LanceDBAccessConfig + +@dataclass +class LanceDBWriteConfig(WriteConfig): + batch_size: int = 50 + num_processes: int = 1 + +@dataclass +class LanceDBDestinationConnector(IngestDocSessionHandleMixin, BaseDestinationConnector): + write_config: LanceDBWriteConfig + connector_config: SimpleLanceDBConfig + _table: t.Optional["lancedb.Table"] = None + + def to_dict(self, **kwargs): + self_cp = copy.copy(self) + if hasattr(self_cp, "_table"): + setattr(self_cp, "_table", None) + return _asdict(self_cp, **kwargs) + + @property + def lancedb_table(self): + if self._table is None: + self._table = self.create_table() + return self._table + + def initialize(self): + pass + + @requires_dependencies(["lancedb"], extras="lancedb") + def create_table(self) -> "lancedb.Table": + import lancedb + + db = lancedb.connect(self.connector_config.access_config.uri) + table = db.open_table(self.connector_config.table_name) + logger.debug(f"Connected to table: {table}") + return table + + @DestinationConnectionError.wrap + def check_connection(self): + _ = self.lancedb_table + + @DestinationConnectionError.wrap + @requires_dependencies(["lancedb"], extras="lancedb") + def add_batch(self, batch): + table = self.lancedb_table + try: + table.add(batch) + except Exception as error: + raise WriteError(f"LanceDB error: {error}") from error + logger.debug(f"Added {len(batch)} records to the table") + + def write_dict(self, *args, elements_dict: t.List[t.Dict[str, t.Any]], **kwargs) -> None: + logger.info( + f"Adding {len(elements_dict)} elements to destination " + f"table {self.connector_config.table_name}", + ) + + lancedb_batch_size = self.write_config.batch_size + + logger.info(f"using {self.write_config.num_processes} processes to upload") + if self.write_config.num_processes == 1: + for chunk in batch_generator(elements_dict, lancedb_batch_size): + self.add_batch(chunk) + + else: + with mp.Pool( + processes=self.write_config.num_processes, + ) as pool: + pool.map( + self.add_batch, list(batch_generator(elements_dict, lancedb_batch_size)) + ) + + def normalize_dict(self, element_dict: dict) -> dict: + flattened = flatten_dict( + element_dict, + separator="_", + flatten_lists=True, + remove_none=True, + ) + return { + "id": str(uuid.uuid4()), + "vector": flattened.pop("embeddings", None), + "text": flattened.pop("text", None), + "metadata": json.dumps(flattened), + **flattened, + } \ No newline at end of file diff --git a/unstructured/ingest/v2/examples/examples_lancedb.py b/unstructured/ingest/v2/examples/examples_lancedb.py new file mode 100644 index 0000000000..42e988c449 --- /dev/null +++ b/unstructured/ingest/v2/examples/examples_lancedb.py @@ -0,0 +1,52 @@ +import os +from pathlib import Path +from unstructured.ingest.v2.interfaces import ProcessorConfig +from unstructured.ingest.v2.logger import logger +from unstructured.ingest.v2.pipeline.pipeline import Pipeline +from unstructured.ingest.v2.processes.chunker import ChunkerConfig +from unstructured.ingest.v2.processes.connectors.local import ( + LocalConnectionConfig, + LocalDownloaderConfig, + LocalIndexerConfig, +) +from unstructured.ingest.v2.processes.embedder import EmbedderConfig +from unstructured.ingest.v2.processes.partitioner import PartitionerConfig + +# Import the LanceDB-specific classes (assuming they've been created) +from unstructured.ingest.v2.processes.connectors.lancedb import ( + LanceDBConnectionConfig, + LanceDBUploaderConfig, + LanceDBUploadStagerConfig, +) + +base_path = Path(__file__).parent.parent.parent.parent.parent +docs_path = base_path / "example-docs" +work_dir = base_path / "tmp_ingest" +output_path = work_dir / "output" +download_path = work_dir / "download" + +if __name__ == "__main__": + logger.info(f"Writing all content in: {work_dir.resolve()}") + + Pipeline.from_configs( + context=ProcessorConfig(work_dir=str(work_dir.resolve())), + indexer_config=LocalIndexerConfig( + input_path=str(docs_path.resolve()) + "/book-war-and-peace-1p.txt" + ), + downloader_config=LocalDownloaderConfig(download_dir=download_path), + source_connection_config=LocalConnectionConfig(), + partitioner_config=PartitionerConfig(strategy="fast"), + chunker_config=ChunkerConfig(chunking_strategy="by_title"), + embedder_config=EmbedderConfig(embedding_provider="langchain-huggingface"), + destination_connection_config=LanceDBConnectionConfig( + # You'll need to set LANCEDB_URI environment variable to run this example + uri=os.getenv("LANCEDB_URI", "data"), + table_name=os.getenv( + "LANCEDB_TABLE", + default="your table name here. e.g. my-table," + "or define in environment variable LANCEDB_TABLE", + ), + ), + stager_config=LanceDBUploadStagerConfig(), + uploader_config=LanceDBUploaderConfig(batch_size=10, num_of_processes=2), + ).run() \ No newline at end of file