diff --git a/src/brad/connection/cursor.py b/src/brad/connection/cursor.py index 6bb7b659..1174fa6c 100644 --- a/src/brad/connection/cursor.py +++ b/src/brad/connection/cursor.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, Optional, List, Iterator, AsyncIterator +from typing import Any, Tuple, Optional, List, Iterator, AsyncIterator, Iterable Row = Tuple[Any, ...] @@ -39,6 +39,9 @@ async def rollback(self) -> None: def execute_sync(self, query: str) -> None: raise NotImplementedError + def executemany_sync(self, query: str, batch: Iterable[Any]) -> None: + raise NotImplementedError + def fetchone_sync(self) -> Optional[Row]: raise NotImplementedError diff --git a/src/brad/connection/odbc_cursor.py b/src/brad/connection/odbc_cursor.py index d1ed6d97..1e67633e 100644 --- a/src/brad/connection/odbc_cursor.py +++ b/src/brad/connection/odbc_cursor.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Optional, List +from typing import Any, Optional, List, Iterable from .cursor import Cursor, Row @@ -32,6 +32,9 @@ async def rollback(self) -> None: def execute_sync(self, query: str) -> None: self._impl.execute(query) + def executemany_sync(self, query: str, batch: Iterable[Any]) -> None: + self._impl.executemany(query, batch) + def fetchone_sync(self) -> Optional[Row]: return self._impl.fetchone() diff --git a/src/brad/connection/pyathena_cursor.py b/src/brad/connection/pyathena_cursor.py index 622f6a72..d065d27a 100644 --- a/src/brad/connection/pyathena_cursor.py +++ b/src/brad/connection/pyathena_cursor.py @@ -2,7 +2,7 @@ import pyathena import pyathena.connection import pyathena.cursor -from typing import Optional, List +from typing import Any, Iterable, Optional, List from .cursor import Cursor, Row @@ -33,6 +33,9 @@ async def rollback(self) -> None: def execute_sync(self, query: str) -> None: self._impl.execute(query) + def executemany_sync(self, query: str, batch: Iterable[Any]) -> None: + raise RuntimeError("Not supported on Athena.") + def fetchone_sync(self) -> Optional[Row]: return self._impl.fetchone() # type: ignore diff --git a/src/brad/connection/redshift_cursor.py b/src/brad/connection/redshift_cursor.py index 1c88ccbc..270de76a 100644 --- a/src/brad/connection/redshift_cursor.py +++ b/src/brad/connection/redshift_cursor.py @@ -1,6 +1,6 @@ import asyncio import redshift_connector -from typing import Optional, List +from typing import Any, Iterable, Optional, List from .cursor import Cursor, Row @@ -36,6 +36,9 @@ async def rollback(self) -> None: def execute_sync(self, query: str) -> None: self._impl.execute(query) + def executemany_sync(self, query: str, batch: Iterable[Any]) -> None: + raise RuntimeError("Not supported on Redshift.") + def fetchone_sync(self) -> Optional[Row]: return self._impl.fetchone() diff --git a/tools/load_embeddings.py b/tools/load_embeddings.py new file mode 100644 index 00000000..ad2d158b --- /dev/null +++ b/tools/load_embeddings.py @@ -0,0 +1,96 @@ +import asyncio +import argparse +import numpy as np +import numpy.typing as npt + +from brad.blueprint.manager import BlueprintManager +from brad.config.engine import Engine +from brad.config.file import ConfigFile +from brad.connection.connection import Connection +from brad.connection.factory import ConnectionFactory +from brad.asset_manager import AssetManager + + +# Killed if I insert data in 1 batch +BATCH_SIZE = 50_000 + + +def insert(connection: Connection, embeddings: npt.NDArray): + cursor = connection.cursor_sync() + + # Get the ids. + cursor.execute_sync("SELECT DISTINCT id FROM aka_title") + movie_id_rows = cursor.fetchall_sync() + all_movie_ids = [row[0] for row in movie_id_rows] + + total_batches = embeddings.shape[0] // BATCH_SIZE + if embeddings.shape[0] % BATCH_SIZE != 0: + total_batches += 1 + + # Insert batches + batch = 0 + while batch * BATCH_SIZE < embeddings.shape[0]: + np_embeddings_batch = embeddings[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE] + movie_ids_batch = all_movie_ids[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE] + + insert_batch = [ + ( + id, + str(list(e)), + ) + for id, e in zip(movie_ids_batch, np_embeddings_batch) + ] + + print(f"Loading batch {batch} of {total_batches}...") + cursor.executemany_sync( + "INSERT INTO embeddings (movie_id, embedding) VALUES (?,?);", insert_batch + ) + + batch += 1 + + cursor.commit_sync() + + +def inspect(connection: Connection): + cursor = connection.cursor_sync() + cursor.execute_sync("SELECT MAX(id) FROM embeddings;") + + l = cursor.fetchall_sync() + for li in l: + print(li) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config-file", type=str, required=True) + parser.add_argument("--schema-name", type=str, required=True) + parser.add_argument("--embeddings-file", type=str) + parser.add_argument("--load", action="store_true") + args = parser.parse_args() + + if args.load: + embeddings = np.load(args.embeddings_file) + else: + embeddings = None + + config = ConfigFile.load(args.config_file) + assets = AssetManager(config) + blueprint_mgr = BlueprintManager(config, assets, args.schema_name) + asyncio.run(blueprint_mgr.load()) + aurora = ConnectionFactory.connect_to_sync( + Engine.Aurora, + args.schema_name, + config, + blueprint_mgr.get_directory(), + autocommit=False, + ) + + if args.load: + insert(aurora, embeddings) + inspect(aurora) + + aurora.close_sync() + + +if __name__ == "__main__": + main()