Skip to content

Commit

Permalink
simplify / clarify code for random retrieval when no search has been …
Browse files Browse the repository at this point in the history
…previously performed
  • Loading branch information
vyaivo committed Nov 14, 2024
1 parent e2acae1 commit 78ea246
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions retriever/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ def dense_random_retrieval(

import gc

if not load_search_results:
if load_search_results is None:
index_path = os.path.join(INDEX_PATH, 'dense', embed_file.split(".fvecs")[0])
vec_file = os.path.join(VEC_PATH, embed_file)
# Optimal configuration is to set the number of threads to the batch size
index_kwargs.update({'num_threads': num_threads})

logger.info('Start indexing...')
search_index_og = index.dense_build_index(
search_index = index.dense_build_index(
index_path,
vec_file,
index_fn,
Expand All @@ -83,11 +83,11 @@ def dense_random_retrieval(
)
logger.info('Done indexing')

if embed_model_type == 'st':
import sentence_transformers as st
embed_model = st.SentenceTransformer(embed_model_name)
else:
raise NotImplementedError('Need to implement alternate type of embedding model')
if embed_model_type == 'st':
import sentence_transformers as st
embed_model = st.SentenceTransformer(embed_model_name)
else:
raise NotImplementedError('Need to implement alternate type of embedding model')

logger.info('Embedding and batching queries...')

Expand All @@ -102,20 +102,19 @@ def dense_random_retrieval(
logger.info(f"Batch size: {len(queries)}")
query_data = query_data_batches[batch_id]

if load_search_results:
batch_load_results = load_search_results.replace("*", str(batch_id))
k_neighbors, dist_neighbors = load_pickle(batch_load_results, logger)
else:
if load_search_results is None:
query_embs = embed_model.encode(queries)

search_index = search_index_og
logger.info(f"Start searching for {k} neighbors per query...")
k_neighbors, dist_neighbors = search_index.search(query_embs, k)
logger.info('Done searching')

# Save direct search outputs before writing to JSON
save_pickle([k_neighbors, dist_neighbors], f'{doc_dataset}_tmp_batch-{batch_id}.pkl', logger)
gc.collect()
elif load_search_results:
batch_load_results = load_search_results.replace("*", str(batch_id))
k_neighbors, dist_neighbors = load_pickle(batch_load_results, logger)

logger.info('Loading text corpus and document titles to associate with neighbors')
corpus = datasets.load_dataset("json", data_files=corpus_file)
Expand Down

0 comments on commit 78ea246

Please sign in to comment.