Skip to content

Commit

Permalink
Add script to load the embeddings (#384)
Browse files Browse the repository at this point in the history
* Add script to load the embeddings

Co-authored-by: Ferdi Kossmann <[email protected]>
Co-authored-by: Geoffrey Yu <[email protected]>

* Add progress indicator

* Fixes

* Add stub implementations

---------

Co-authored-by: Ferdi Kossmann <[email protected]>
  • Loading branch information
geoffxy and Ferdi Kossmann authored Nov 22, 2023
1 parent 5e4292b commit 303e785
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/brad/connection/cursor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Tuple, Optional, List, Iterator, AsyncIterator
from typing import Any, Tuple, Optional, List, Iterator, AsyncIterator, Iterable


Row = Tuple[Any, ...]
Expand Down Expand Up @@ -39,6 +39,9 @@ async def rollback(self) -> None:
def execute_sync(self, query: str) -> None:
raise NotImplementedError

def executemany_sync(self, query: str, batch: Iterable[Any]) -> None:
raise NotImplementedError

def fetchone_sync(self) -> Optional[Row]:
raise NotImplementedError

Expand Down
5 changes: 4 additions & 1 deletion src/brad/connection/odbc_cursor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Optional, List
from typing import Any, Optional, List, Iterable

from .cursor import Cursor, Row

Expand Down Expand Up @@ -32,6 +32,9 @@ async def rollback(self) -> None:
def execute_sync(self, query: str) -> None:
self._impl.execute(query)

def executemany_sync(self, query: str, batch: Iterable[Any]) -> None:
self._impl.executemany(query, batch)

def fetchone_sync(self) -> Optional[Row]:
return self._impl.fetchone()

Expand Down
5 changes: 4 additions & 1 deletion src/brad/connection/pyathena_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pyathena
import pyathena.connection
import pyathena.cursor
from typing import Optional, List
from typing import Any, Iterable, Optional, List

from .cursor import Cursor, Row

Expand Down Expand Up @@ -33,6 +33,9 @@ async def rollback(self) -> None:
def execute_sync(self, query: str) -> None:
self._impl.execute(query)

def executemany_sync(self, query: str, batch: Iterable[Any]) -> None:
raise RuntimeError("Not supported on Athena.")

def fetchone_sync(self) -> Optional[Row]:
return self._impl.fetchone() # type: ignore

Expand Down
5 changes: 4 additions & 1 deletion src/brad/connection/redshift_cursor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import redshift_connector
from typing import Optional, List
from typing import Any, Iterable, Optional, List

from .cursor import Cursor, Row

Expand Down Expand Up @@ -36,6 +36,9 @@ async def rollback(self) -> None:
def execute_sync(self, query: str) -> None:
self._impl.execute(query)

def executemany_sync(self, query: str, batch: Iterable[Any]) -> None:
raise RuntimeError("Not supported on Redshift.")

def fetchone_sync(self) -> Optional[Row]:
return self._impl.fetchone()

Expand Down
96 changes: 96 additions & 0 deletions tools/load_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import asyncio
import argparse
import numpy as np
import numpy.typing as npt

from brad.blueprint.manager import BlueprintManager
from brad.config.engine import Engine
from brad.config.file import ConfigFile
from brad.connection.connection import Connection
from brad.connection.factory import ConnectionFactory
from brad.asset_manager import AssetManager


# Killed if I insert data in 1 batch
BATCH_SIZE = 50_000


def insert(connection: Connection, embeddings: npt.NDArray):
cursor = connection.cursor_sync()

# Get the ids.
cursor.execute_sync("SELECT DISTINCT id FROM aka_title")
movie_id_rows = cursor.fetchall_sync()
all_movie_ids = [row[0] for row in movie_id_rows]

total_batches = embeddings.shape[0] // BATCH_SIZE
if embeddings.shape[0] % BATCH_SIZE != 0:
total_batches += 1

# Insert batches
batch = 0
while batch * BATCH_SIZE < embeddings.shape[0]:
np_embeddings_batch = embeddings[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE]
movie_ids_batch = all_movie_ids[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE]

insert_batch = [
(
id,
str(list(e)),
)
for id, e in zip(movie_ids_batch, np_embeddings_batch)
]

print(f"Loading batch {batch} of {total_batches}...")
cursor.executemany_sync(
"INSERT INTO embeddings (movie_id, embedding) VALUES (?,?);", insert_batch
)

batch += 1

cursor.commit_sync()


def inspect(connection: Connection):
cursor = connection.cursor_sync()
cursor.execute_sync("SELECT MAX(id) FROM embeddings;")

l = cursor.fetchall_sync()
for li in l:
print(li)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", type=str, required=True)
parser.add_argument("--schema-name", type=str, required=True)
parser.add_argument("--embeddings-file", type=str)
parser.add_argument("--load", action="store_true")
args = parser.parse_args()

if args.load:
embeddings = np.load(args.embeddings_file)
else:
embeddings = None

config = ConfigFile.load(args.config_file)
assets = AssetManager(config)
blueprint_mgr = BlueprintManager(config, assets, args.schema_name)
asyncio.run(blueprint_mgr.load())
aurora = ConnectionFactory.connect_to_sync(
Engine.Aurora,
args.schema_name,
config,
blueprint_mgr.get_directory(),
autocommit=False,
)

if args.load:
insert(aurora, embeddings)
inspect(aurora)

aurora.close_sync()


if __name__ == "__main__":
main()

0 comments on commit 303e785

Please sign in to comment.