diff --git a/tools/load_embeddings.py b/tools/load_embeddings.py index 21da0026..f3bcd83c 100644 --- a/tools/load_embeddings.py +++ b/tools/load_embeddings.py @@ -23,6 +23,10 @@ def insert(connection: Connection, embeddings: npt.NDArray): movie_id_rows = cursor.fetchall_sync() all_movie_ids = [row[0] for row in movie_id_rows] + total_batches = embeddings.shape[0] // BATCH_SIZE + if embeddings.shape[0] % BATCH_SIZE != 0: + total_batches += 1 + # Insert batches batch = 0 while batch * BATCH_SIZE < embeddings.shape[0]: @@ -37,6 +41,7 @@ def insert(connection: Connection, embeddings: npt.NDArray): for id, e in zip(movie_ids_batch, np_embeddings_batch) ] + print(f"Loading batch {batch} of {total_batches}...") cursor.executemany_sync( "INSERT INTO embeddings (movie_id, embedding) VALUES (?,?);", insert_batch )