Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,254 +1,158 @@
import argparse
import gzip
import json
import logging
import os
import pickle
import random
import sys
import tarfile
from datetime import datetime
from shutil import copyfile

import tqdm
from torch.utils.data import DataLoader, Dataset
from datasets import Dataset, load_dataset

from sentence_transformers import InputExample, LoggingHandler, SentenceTransformer, losses, models, util
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MarginMSELoss
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments

#### Just some code to print debug information to stdout
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout


parser = argparse.ArgumentParser()
parser.add_argument("--train_batch_size", default=64, type=int)
parser.add_argument("--max_seq_length", default=300, type=int)
parser.add_argument("--model_name", required=True)
parser.add_argument("--max_passages", default=0, type=int)
parser.add_argument("--epochs", default=30, type=int)
parser.add_argument("--pooling", default="mean")
parser.add_argument(
"--negs_to_use",
default=None,
help="From which systems should negatives be used? Multiple systems separated by comma. None = all",
)
parser.add_argument("--warmup_steps", default=1000, type=int)
parser.add_argument("--lr", default=2e-5, type=float)
parser.add_argument("--num_negs_per_system", default=5, type=int)
parser.add_argument("--use_pre_trained_model", default=False, action="store_true")
parser.add_argument("--use_all_queries", default=False, action="store_true")
args = parser.parse_args()

logging.info(str(args))
# Just some code to print debug information to stdout
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING) # Quiet httpx logs


# The model we want to fine-tune
train_batch_size = (
args.train_batch_size
) # Increasing the train batch size improves the model performance, but requires more GPU memory
model_name = args.model_name
max_passages = args.max_passages
max_seq_length = args.max_seq_length # Max length for passages. Increasing it, requires more GPU memory
train_batch_size = 64
max_seq_length = 300 # Max length for passages. Increasing it, requires more GPU memory
model_name = "microsoft/mpnet-base"
num_epochs = 1
max_steps = -1
lr = 2e-5

num_negs_per_system = (
args.num_negs_per_system
) # We used different systems to mine hard negatives. Number of hard negatives to add from each system
num_epochs = args.epochs # Number of epochs we want to train
# We used different systems to mine hard negatives. Number of hard negatives to add from each system
num_negs_per_system = 5
num_negatives = 5

# Load our embedding model
if args.use_pre_trained_model:
logging.info("use pretrained SBERT model")
model = SentenceTransformer(model_name)
model.max_seq_length = max_seq_length
else:
logging.info("Create new SBERT model")
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

model_save_path = f"output/train_bi-encoder-margin_mse-{model_name.replace('/', '-')}-batch_size_{train_batch_size}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"


# Write self to path
os.makedirs(model_save_path, exist_ok=True)

train_script_path = os.path.join(model_save_path, "train_script.py")
copyfile(__file__, train_script_path)
with open(train_script_path, "a") as fOut:
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))


### Now we read the MS Marco dataset
data_folder = "msmarco-data"

#### Read the corpus files, that contain all the passages. Store them in the corpus dict
corpus = {} # dict in the format: passage_id -> passage. Stores all existent passages
collection_filepath = os.path.join(data_folder, "collection.tsv")
if not os.path.exists(collection_filepath):
tar_filepath = os.path.join(data_folder, "collection.tar.gz")
if not os.path.exists(tar_filepath):
logging.info("Download collection.tar.gz")
util.http_get("https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz", tar_filepath)

with tarfile.open(tar_filepath, "r:gz") as tar:
tar.extractall(path=data_folder)

logging.info("Read corpus: collection.tsv")
with open(collection_filepath, encoding="utf8") as fIn:
for line in fIn:
pid, passage = line.strip().split("\t")
pid = int(pid)
corpus[pid] = passage


### Read the train queries, store in queries dict
queries = {} # dict in the format: query_id -> query. Stores all training queries
queries_filepath = os.path.join(data_folder, "queries.train.tsv")
if not os.path.exists(queries_filepath):
tar_filepath = os.path.join(data_folder, "queries.tar.gz")
if not os.path.exists(tar_filepath):
logging.info("Download queries.tar.gz")
util.http_get("https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz", tar_filepath)

with tarfile.open(tar_filepath, "r:gz") as tar:
tar.extractall(path=data_folder)


with open(queries_filepath, encoding="utf8") as fIn:
for line in fIn:
qid, query = line.strip().split("\t")
qid = int(qid)
queries[qid] = query


# Load a dict (qid, pid) -> ce_score that maps query-ids (qid) and paragraph-ids (pid)
# to the CrossEncoder score computed by the cross-encoder/ms-marco-MiniLM-L6-v2 model
ce_scores_file = os.path.join(data_folder, "cross-encoder-ms-marco-MiniLM-L6-v2-scores.pkl.gz")
if not os.path.exists(ce_scores_file):
logging.info("Download cross-encoder scores file")
util.http_get(
"https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz",
ce_scores_file,
)

logging.info("Using pretrained SBERT model")
model = SentenceTransformer(model_name)
model.max_seq_length = max_seq_length

# Map PID -> text
corpus = load_dataset("sentence-transformers/msmarco-corpus", "passage", split="train")
corpus_dict = dict(zip(corpus["pid"], corpus["text"]))

# Map QID -> query text
queries = load_dataset("sentence-transformers/msmarco-corpus", "query", split="train")
query_dict = dict(zip(queries["qid"], queries["text"]))

# Map QID -> {PID: CE score}
scores = load_dataset("sentence-transformers/msmarco-scores-ms-marco-MiniLM-L6-v2", "list", split="train")
ce_scores = {
qid: dict(zip(cids, sc)) for qid, cids, sc in zip(scores["query_id"], scores["corpus_id"], scores["score"])
}
logging.info("Load CrossEncoder scores dict")
with gzip.open(ce_scores_file, "rb") as fIn:
ce_scores = pickle.load(fIn)

# As training data we use hard-negatives that have been mined using various systems
hard_negatives_filepath = os.path.join(data_folder, "msmarco-hard-negatives.jsonl.gz")
if not os.path.exists(hard_negatives_filepath):
logging.info("Download cross-encoder scores file")
util.http_get(
"https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/msmarco-hard-negatives.jsonl.gz",
hard_negatives_filepath,
)


logging.info("Read hard negatives train file")
train_queries = {}
negs_to_use = None
with gzip.open(hard_negatives_filepath, "rt") as fIn:
for line in tqdm.tqdm(fIn):
if max_passages > 0 and len(train_queries) >= max_passages:
break
data = json.loads(line)

# Get the positive passage ids
pos_pids = data["pos"]

# Get the hard negatives
neg_pids = set()
if negs_to_use is None:
if args.negs_to_use is not None: # Use specific system for negatives
negs_to_use = args.negs_to_use.split(",")
else: # Use all systems
negs_to_use = list(data["neg"].keys())
logging.info("Using negatives from the following systems: {}".format(", ".join(negs_to_use)))

for system_name in negs_to_use:
if system_name not in data["neg"]:
# Datasets with 50 hard negatives mined per query using different models
SYSTEMS = {
"bm25": "sentence-transformers/msmarco-bm25",
"msmarco-distilbert-base-tas-b": "sentence-transformers/msmarco-msmarco-distilbert-base-tas-b",
"msmarco-distilbert-base-v3": "sentence-transformers/msmarco-msmarco-distilbert-base-v3",
"msmarco-MiniLM-L-6-v3": "sentence-transformers/msmarco-msmarco-MiniLM-L6-v3",
"distilbert-margin_mse-cls-dot-v2": "sentence-transformers/msmarco-distilbert-margin-mse-cls-dot-v2",
"distilbert-margin_mse-cls-dot-v1": "sentence-transformers/msmarco-distilbert-margin-mse-cls-dot-v1",
"distilbert-margin_mse-mean-dot-v1": "sentence-transformers/msmarco-distilbert-margin-mse-mean-dot-v1",
"mpnet-margin_mse-mean-v1": "sentence-transformers/msmarco-mpnet-margin-mse-mean-v1",
"co-condenser-margin_mse-cls-v1": "sentence-transformers/msmarco-co-condenser-margin-mse-cls-v1",
"distilbert-margin_mse-mnrl-mean-v1": "sentence-transformers/msmarco-distilbert-margin-mse-mnrl-mean-v1",
"distilbert-margin_mse-sym_mnrl-mean-v1": "sentence-transformers/msmarco-distilbert-margin-mse-sym-mnrl-mean-v1",
"distilbert-margin_mse-sym_mnrl-mean-v2": "sentence-transformers/msmarco-distilbert-margin-mse-sym-mnrl-mean-v2",
"co-condenser-margin_mse-sym_mnrl-mean-v1": "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
}

train_data = {}
for system_key, repo_id in SYSTEMS.items():
print(f"Loading {system_key}...")
dataset = load_dataset(repo_id, "triplet-50-ids", split="train")

for row in tqdm.tqdm(dataset, desc=f"Processing {system_key}"):
qid = row.pop("query")
pos_pid = row.pop("positive")
neg_pids = list(row.values()) # All remaining columns are negatives
existing_neg_pids = set(train_data[qid]["neg_pids"]) if qid in train_data else set()
pos_ce_score = ce_scores[qid][pos_pid]
valid_neg_pids = []
valid_neg_labels = []

for neg_pid in neg_pids:
if neg_pid in existing_neg_pids or neg_pid not in ce_scores[qid]:
continue

system_negs = data["neg"][system_name]
negs_added = 0
for pid in system_negs:
if pid not in neg_pids:
neg_pids.add(pid)
negs_added += 1
if negs_added >= num_negs_per_system:
break

if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0):
train_queries[data["qid"]] = {
"qid": data["qid"],
"query": queries[data["qid"]],
"pos": pos_pids,
"neg": neg_pids,
}

logging.info(f"Train queries: {len(train_queries)}")


# We create a custom MSMARCO dataset that returns triplets (query, positive, negative)
# on-the-fly based on the information from the mined-hard-negatives jsonl file.
class MSMARCODataset(Dataset):
def __init__(self, queries, corpus, ce_scores):
self.queries = queries
self.queries_ids = list(queries.keys())
self.corpus = corpus
self.ce_scores = ce_scores

for qid in self.queries:
self.queries[qid]["pos"] = list(self.queries[qid]["pos"])
self.queries[qid]["neg"] = list(self.queries[qid]["neg"])
random.shuffle(self.queries[qid]["neg"])

def __getitem__(self, item):
query = self.queries[self.queries_ids[item]]
query_text = query["query"]
qid = query["qid"]

if len(query["pos"]) > 0:
pos_id = query["pos"].pop(0) # Pop positive and add at end
pos_text = self.corpus[pos_id]
query["pos"].append(pos_id)
else: # We only have negatives, use two negs
pos_id = query["neg"].pop(0) # Pop negative and add at end
pos_text = self.corpus[pos_id]
query["neg"].append(pos_id)

# Get a negative passage
neg_id = query["neg"].pop(0) # Pop negative and add at end
neg_text = self.corpus[neg_id]
query["neg"].append(neg_id)

pos_score = self.ce_scores[qid][pos_id]
neg_score = self.ce_scores[qid][neg_id]

return InputExample(texts=[query_text, pos_text, neg_text], label=pos_score - neg_score)

def __len__(self):
return len(self.queries)


# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
train_dataset = MSMARCODataset(queries=train_queries, corpus=corpus, ce_scores=ce_scores)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, drop_last=True)
train_loss = losses.MarginMSELoss(model=model)
valid_neg_pids.append(neg_pid)
valid_neg_labels.append(pos_ce_score - ce_scores[qid][neg_pid])
existing_neg_pids.add(neg_pid)
if len(valid_neg_pids) >= num_negs_per_system:
break
if qid not in train_data:
train_data[qid] = {"qid": qid, "pid": pos_pid, "neg_pids": valid_neg_pids, "neg_labels": valid_neg_labels}
else:
train_data[qid]["neg_pids"].extend(valid_neg_pids)
train_data[qid]["neg_labels"].extend(valid_neg_labels)

train_data = {qid: data for qid, data in train_data.items() if data["neg_pids"]}
logging.info(f"Kept {len(train_data)} queries with negatives")

train_dataset = Dataset.from_list(list(train_data.values()))


def ids_to_text_transform(batch):
sampled = [
random.sample(list(zip(neg_pids, neg_labels)), num_negatives)
for neg_pids, neg_labels in zip(batch["neg_pids"], batch["neg_labels"])
]
neg_pid_lists, label_lists = zip(*[zip(*s) for s in sampled])
return {
"anchor": [query_dict[qid] for qid in batch["qid"]],
"positive": [corpus_dict[pid] for pid in batch["pid"]],
**{
f"negative_{idx}": [corpus_dict[pid] for pid in neg_ids] for idx, neg_ids in enumerate(zip(*neg_pid_lists))
},
"label": list(label_lists),
}


train_dataset.set_transform(ids_to_text_transform)

# Loss function
loss = MarginMSELoss(model)

# Prepare training arguments
short_model_name = model_name.split("/")[-1] if "/" in model_name else model_name
run_name = f"{short_model_name}-msmarco-margin-mse"
args = SentenceTransformerTrainingArguments(
output_dir=f"output/{run_name}",
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
warmup_ratio=0.1,
learning_rate=lr,
max_steps=max_steps,
save_strategy="steps",
save_steps=0.1,
logging_steps=0.01,
batch_sampler=BatchSamplers.NO_DUPLICATES,
)

# Train the model
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
warmup_steps=args.warmup_steps,
use_amp=True,
checkpoint_path=model_save_path,
checkpoint_save_steps=10000,
optimizer_params={"lr": args.lr},
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss,
)

# Train latest model
model.save(model_save_path)
trainer.train()

final_model_path = f"output/{run_name}/final"
model.save_pretrained(final_model_path)

# (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(f"{run_name}")
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\nTo upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_model_path!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
Loading