Skip to content

Commit

Permalink
Merge branch 'main' into copyright
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbeale-IL authored Jan 9, 2025
2 parents ec8b290 + 78ea246 commit dc834d6
Show file tree
Hide file tree
Showing 24 changed files with 83 additions and 12 deletions.
3 changes: 3 additions & 0 deletions file_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import pickle

Expand Down
3 changes: 3 additions & 0 deletions plots/plot_gold-search-recall.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand Down
3 changes: 3 additions & 0 deletions plots/plot_ndoc-recall.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand Down
3 changes: 3 additions & 0 deletions plots/plot_noise_percentile.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import matplotlib.pyplot as plt
import plot_utils
Expand Down
3 changes: 3 additions & 0 deletions plots/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import numpy as np
import os
Expand Down
3 changes: 3 additions & 0 deletions preprocessing/alce/convert_alce_colbert.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from file_utils import load_json
import convert_alce_utils

Expand Down
3 changes: 3 additions & 0 deletions preprocessing/alce/convert_alce_dense.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from file_utils import load_json, save_json, save_jsonl
import convert_alce_utils

Expand Down
3 changes: 3 additions & 0 deletions preprocessing/alce/convert_alce_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from tqdm import tqdm
import pandas as pd

Expand Down
3 changes: 3 additions & 0 deletions preprocessing/convert_nq_dense.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import datasets
import json
Expand Down
3 changes: 3 additions & 0 deletions preprocessing/create_groundtruth_calibration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""
After exhaustive search with a flat search index has been run, you will have files that contain the nearest neighbors
for every single query.
Expand Down
3 changes: 3 additions & 0 deletions preprocessing/sample_retrieved_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import argparse
from tqdm import tqdm
Expand Down
3 changes: 3 additions & 0 deletions preprocessing/set_gold_recall.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import argparse
from tqdm import tqdm
Expand Down
3 changes: 3 additions & 0 deletions reader/compute_ci.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import numpy as np
import pathlib
Expand Down
3 changes: 3 additions & 0 deletions reader/eval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# Some of this code is based on prior work under the MIT License:
# Copyright (c) 2023 Princeton Natural Language Processing
# Copyright (c) Carnegie Mellon University
Expand Down
3 changes: 3 additions & 0 deletions reader/eval_per_query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import collections
from collections import Counter
Expand Down
3 changes: 3 additions & 0 deletions reader/plot_per_k.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import os
import logging
Expand Down
3 changes: 3 additions & 0 deletions reader/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# Some of this code is based on prior work under the MIT License:
# Copyright (c) 2023 Princeton Natural Language Processing

Expand Down
3 changes: 3 additions & 0 deletions reader/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# Some of this code is based on prior work under the MIT License:
# Copyright (c) 2023 Princeton Natural Language Processing

Expand Down
3 changes: 3 additions & 0 deletions retriever/eval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# Some of this code is based on prior work under the MIT License:
# Copyright (c) 2023 Princeton Natural Language Processing

Expand Down
3 changes: 3 additions & 0 deletions retriever/index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
import numpy as np
import os
Expand Down
3 changes: 3 additions & 0 deletions retriever/ret_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import torch
import os
Expand Down
26 changes: 14 additions & 12 deletions retriever/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# Some of this code is based on prior work under the MIT License:
# Copyright (c) 2023 Princeton Natural Language Processing

Expand Down Expand Up @@ -62,14 +65,14 @@ def dense_random_retrieval(

import gc

if not load_search_results:
if load_search_results is None:
index_path = os.path.join(INDEX_PATH, 'dense', embed_file.split(".fvecs")[0])
vec_file = os.path.join(VEC_PATH, embed_file)
# Optimal configuration is to set the number of threads to the batch size
index_kwargs.update({'num_threads': num_threads})

logger.info('Start indexing...')
search_index_og = index.dense_build_index(
search_index = index.dense_build_index(
index_path,
vec_file,
index_fn,
Expand All @@ -80,11 +83,11 @@ def dense_random_retrieval(
)
logger.info('Done indexing')

if embed_model_type == 'st':
import sentence_transformers as st
embed_model = st.SentenceTransformer(embed_model_name)
else:
raise NotImplementedError('Need to implement alternate type of embedding model')
if embed_model_type == 'st':
import sentence_transformers as st
embed_model = st.SentenceTransformer(embed_model_name)
else:
raise NotImplementedError('Need to implement alternate type of embedding model')

logger.info('Embedding and batching queries...')

Expand All @@ -99,20 +102,19 @@ def dense_random_retrieval(
logger.info(f"Batch size: {len(queries)}")
query_data = query_data_batches[batch_id]

if load_search_results:
batch_load_results = load_search_results.replace("*", str(batch_id))
k_neighbors, dist_neighbors = load_pickle(batch_load_results, logger)
else:
if load_search_results is None:
query_embs = embed_model.encode(queries)

search_index = search_index_og
logger.info(f"Start searching for {k} neighbors per query...")
k_neighbors, dist_neighbors = search_index.search(query_embs, k)
logger.info('Done searching')

# Save direct search outputs before writing to JSON
save_pickle([k_neighbors, dist_neighbors], f'{doc_dataset}_tmp_batch-{batch_id}.pkl', logger)
gc.collect()
elif load_search_results:
batch_load_results = load_search_results.replace("*", str(batch_id))
k_neighbors, dist_neighbors = load_pickle(batch_load_results, logger)

logger.info('Loading text corpus and document titles to associate with neighbors')
corpus = datasets.load_dataset("json", data_files=corpus_file)
Expand Down
3 changes: 3 additions & 0 deletions retriever/run_colbert.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import sys
import logging
Expand Down
3 changes: 3 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

class InvalidArgument(Exception):
"""raise when user input arguments are invalid"""
pass
Expand Down

0 comments on commit dc834d6

Please sign in to comment.