-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from ittia-research/dev
add wiki_dpr retriever for DSPy compile
- Loading branch information
Showing
9 changed files
with
231 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 | ||
RUN apt update && apt install python3 python3-pip git -y | ||
RUN pip install --no-cache-dir colbert-ai[torch,faiss-gpu] | ||
RUN pip install --no-cache-dir "numpy<2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
FROM python:3.11-bookworm | ||
RUN apt update && apt install python3 python3-pip git -y | ||
RUN pip install --no-cache-dir colbert-ai[torch,faiss-cpu] | ||
RUN pip install --no-cache-dir "numpy<2" | ||
|
||
WORKDIR /app | ||
|
||
CP prepare_files.py . | ||
CP server.py . | ||
CP start.sh . | ||
|
||
CMD ['/app/start.sh'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 | ||
RUN apt update && apt install python3 python3-pip git -y | ||
RUN pip install --no-cache-dir colbert-ai[torch,faiss-gpu] | ||
RUN pip install --no-cache-dir "numpy<2" | ||
|
||
WORKDIR /app | ||
|
||
CP prepare_files.py . | ||
CP server.py . | ||
CP start.sh . | ||
|
||
CMD ['/app/start.sh'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
## About | ||
- Dataset: https://github.com/facebookresearch/DPR/blob/main/dpr/data/download_data.py | ||
- direct downlaod link: `https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz` | ||
- Generate index for the ColBERTv2 model | ||
- Downlaod the generated index: https://huggingface.co/datasets/ittia/wiki_dpr | ||
- Start a retrieve server | ||
|
||
## How-to | ||
### Indexing | ||
1. Run container via Dockerfile.indexing; | ||
2. Add the .tsv dataset to `/data/datasets/wiki/psgs_w100.tsv`; | ||
3. Run `python3 indexing.py`. | ||
|
||
### Serve | ||
1. Run the GPU or CPU container via docker-compose.yml based on your hardware; | ||
2. Add required files: .tsv dataset, model checkpoint, index; | ||
* Default locations of these files are within: `prepare_files.py` | ||
* You may add existing files or download from HuggingFace via `python3 prepare_files.py` | ||
3. Start the server: `python3 server.py`. | ||
|
||
Test the server: `curl "http://localhost:8893/api/search?query=Who%20won%20the%202022%20FIFA%20world%20cup&k=3"` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
services: | ||
wiki-dpr-serve-gpu: | ||
build: | ||
dockerfile: Dockerfile.serve.gpu | ||
ports: | ||
- 8893:8893 | ||
volumes: | ||
- /data:/data | ||
deploy: | ||
resources: | ||
reservations: | ||
devices: | ||
- driver: nvidia | ||
count: all | ||
capabilities: [gpu] | ||
restart: unless-stopped | ||
|
||
wiki-dpr-serve-cpu: | ||
build: | ||
dockerfile: Dockerfile.serve.cpu | ||
ports: | ||
- 8893:8893 | ||
volumes: | ||
- /data:/data | ||
restart: unless-stopped | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from colbert.infra import Run, RunConfig, ColBERTConfig | ||
from colbert import Indexer | ||
|
||
GPU_NUMBER = 1 | ||
PROJECT_NAME = "wiki" | ||
TSV_PATH = "/data/datasets/wiki/psgs_w100.tsv" | ||
CHECKPOINT_PATH = "/data/checkpoint/colbertv2.0" | ||
|
||
if __name__=='__main__': | ||
with Run().context(RunConfig(nranks=GPU_NUMBER, experiment=PROJECT_NAME)): | ||
|
||
config = ColBERTConfig( | ||
nbits=2, | ||
doc_maxlen=220, | ||
) | ||
indexer = Indexer(checkpoint=CHECKPOINT_PATH, config=config) | ||
indexer.index(name=PROJECT_NAME, collection=TSV_PATH, overwrite=True) | ||
|
||
indexer.get_index() # get the absolute path of the index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os, subprocess | ||
import shutil | ||
from huggingface_hub import snapshot_download | ||
from tenacity import retry, stop_after_attempt, wait_fixed | ||
|
||
# Define the repository and the subdirectory path | ||
repo_id = "ittia/wiki_dpr" | ||
repo_type = "dataset" | ||
dataset_folder = "/data/datasets/wiki" | ||
dir_map = [ | ||
{ | ||
"repo_dir": "checkpoints/colbertv2.0", | ||
"local_dir": "/data/checkpoints/colbertv2.0", | ||
}, | ||
{ | ||
"repo_dir": "datasets", | ||
"local_dir": dataset_folder, | ||
}, | ||
{ | ||
"repo_dir": "indexes/wiki", | ||
"local_dir": "/data/indexes/wiki", | ||
}, | ||
] | ||
revision = "main" | ||
|
||
import os | ||
|
||
def check_exists(folder_path): | ||
# Check if the folder exists | ||
if os.path.exists(folder_path) and os.path.isdir(folder_path): | ||
# Check if the folder is not empty | ||
if [f for f in os.listdir(folder_path) if not f.startswith('.')]: # Don't count items starts with `.` | ||
return True | ||
return False | ||
|
||
def move_files_subfolders(source_folder, destination_folder): | ||
# Ensure the destination folder exists | ||
os.makedirs(destination_folder, exist_ok=True) | ||
|
||
# Iterate over all files and folders in the source folder | ||
for item in os.listdir(source_folder): | ||
source_path = os.path.join(source_folder, item) | ||
destination_path = os.path.join(destination_folder, item) | ||
|
||
# Move the item | ||
shutil.move(source_path, destination_path) | ||
|
||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2), reraise=True) | ||
def download_hf_folder(repo_dir, local_dir): | ||
downlaod_dir = os.path.join(local_dir, '.downlaod') | ||
|
||
os.makedirs(downlaod_dir, exist_ok=True) | ||
|
||
snapshot_download( | ||
repo_id=repo_id, | ||
repo_type=repo_type, | ||
revision=revision, | ||
allow_patterns=f"{repo_dir}/*", | ||
local_dir=downlaod_dir | ||
) | ||
|
||
return downlaod_dir | ||
|
||
for map in dir_map: | ||
repo_dir = map['repo_dir'] | ||
local_dir = map['local_dir'] | ||
|
||
if check_exists(local_dir): | ||
print(f"local dir '{local_dir}' exists and not empty, skip download") | ||
continue | ||
|
||
downlaod_dir = download_hf_folder(repo_dir, local_dir) | ||
_source_dir = os.path.join(downlaod_dir, repo_dir) | ||
move_files_subfolders(_source_dir, local_dir) | ||
|
||
print(f"Downloaded: {repo_dir} to {local_dir}") | ||
|
||
# extract dataset | ||
_file_path = os.path.join(dataset_folder, "psgs_w100.tsv.gz") | ||
if os.path.isfile(_file_path): | ||
try: | ||
subprocess.run(['gunzip', _file_path], check=True) | ||
print(f"File {_file_path} extracted and replaced successfully.") | ||
except subprocess.CalledProcessError as e: | ||
print(f"An error occurred: {e}") | ||
|
||
print("All folders have been downloaded and processed.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# source: https://raw.githubusercontent.com/stanford-futuredata/ColBERT/main/server.py | ||
|
||
from flask import Flask, render_template, request | ||
from functools import lru_cache | ||
import math | ||
import os | ||
|
||
from colbert.infra import Run, RunConfig, ColBERTConfig | ||
from colbert import Searcher | ||
|
||
INDEX_NAME = os.getenv("INDEX_NAME") | ||
INDEX_ROOT = os.getenv("INDEX_ROOT") | ||
app = Flask(__name__) | ||
|
||
searcher = Searcher(index=INDEX_NAME, index_root=INDEX_ROOT) | ||
counter = {"api" : 0} | ||
|
||
@lru_cache(maxsize=1000000) | ||
def api_search_query(query, k): | ||
print(f"Query={query}") | ||
if k == None: k = 10 | ||
k = min(int(k), 100) | ||
pids, ranks, scores = searcher.search(query, k=100) | ||
pids, ranks, scores = pids[:k], ranks[:k], scores[:k] | ||
passages = [searcher.collection[pid] for pid in pids] | ||
probs = [math.exp(score) for score in scores] | ||
probs = [prob / sum(probs) for prob in probs] | ||
topk = [] | ||
for pid, rank, score, prob in zip(pids, ranks, scores, probs): | ||
text = searcher.collection[pid] | ||
d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob} | ||
topk.append(d) | ||
topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) | ||
return {"query" : query, "topk": topk} | ||
|
||
@app.route("/api/search", methods=["GET"]) | ||
def api_search(): | ||
if request.method == "GET": | ||
counter["api"] += 1 | ||
print("API request count:", counter["api"]) | ||
return api_search_query(request.args.get("query"), request.args.get("k")) | ||
else: | ||
return ('', 405) | ||
|
||
if __name__ == "__main__": | ||
app.run("0.0.0.0", int(os.getenv("PORT", 8893))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#/bin/bash | ||
|
||
python ./prepare_files.py | ||
python ./server.py |