-
Notifications
You must be signed in to change notification settings - Fork 2.8k
replace ms-marco datasets and migrate examples
#3649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
98073fa
c2afcc7
f0d8729
5f280f3
5c60476
645f190
e7c024e
afea249
dd8f96d
1df3ef8
dd4efc6
a798c53
6d655a1
9b1e974
a876dee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,254 +1,156 @@ | ||
| import argparse | ||
| import gzip | ||
| import json | ||
| import logging | ||
| import os | ||
| import pickle | ||
| import random | ||
| import sys | ||
| import tarfile | ||
| from datetime import datetime | ||
| from shutil import copyfile | ||
|
|
||
| import tqdm | ||
| from torch.utils.data import DataLoader, Dataset | ||
| from datasets import Dataset, load_dataset | ||
| from huggingface_hub import hf_hub_download | ||
|
|
||
| from sentence_transformers import InputExample, LoggingHandler, SentenceTransformer, losses, models, util | ||
| from sentence_transformers import LoggingHandler, SentenceTransformer | ||
| from sentence_transformers.losses import MarginMSELoss | ||
| from sentence_transformers.models import Pooling, Transformer | ||
| from sentence_transformers.trainer import SentenceTransformerTrainer | ||
| from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments | ||
|
|
||
| #### Just some code to print debug information to stdout | ||
| # Just some code to print debug information to stdout | ||
| logging.basicConfig( | ||
| format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] | ||
| ) | ||
| #### /print debug information to stdout | ||
|
|
||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--train_batch_size", default=64, type=int) | ||
| parser.add_argument("--max_seq_length", default=300, type=int) | ||
| parser.add_argument("--model_name", required=True) | ||
| parser.add_argument("--max_passages", default=0, type=int) | ||
| parser.add_argument("--epochs", default=30, type=int) | ||
| parser.add_argument("--pooling", default="mean") | ||
| parser.add_argument( | ||
| "--negs_to_use", | ||
| default=None, | ||
| help="From which systems should negatives be used? Multiple systems separated by comma. None = all", | ||
| ) | ||
| parser.add_argument("--warmup_steps", default=1000, type=int) | ||
| parser.add_argument("--lr", default=2e-5, type=float) | ||
| parser.add_argument("--num_negs_per_system", default=5, type=int) | ||
| parser.add_argument("--use_pre_trained_model", default=False, action="store_true") | ||
| parser.add_argument("--use_all_queries", default=False, action="store_true") | ||
| args = parser.parse_args() | ||
|
|
||
| logging.info(str(args)) | ||
|
|
||
|
|
||
| # The model we want to fine-tune | ||
| train_batch_size = ( | ||
| args.train_batch_size | ||
| ) # Increasing the train batch size improves the model performance, but requires more GPU memory | ||
| model_name = args.model_name | ||
| max_passages = args.max_passages | ||
| max_seq_length = args.max_seq_length # Max length for passages. Increasing it, requires more GPU memory | ||
| train_batch_size = 64 | ||
| max_seq_length = 300 # Max length for passages. Increasing it, requires more GPU memory | ||
| model_name = "microsoft/mpnet-base" | ||
| max_passages = 0 | ||
| num_epochs = 1 | ||
| max_steps = 1e-7 | ||
| pooling_mode = "mean" | ||
| negs_to_use = None | ||
| lr = 2e-5 | ||
| # We used different systems to mine hard negatives. Number of hard negatives to add from each system | ||
| num_negs_per_system = 5 | ||
| use_pretrained_model = False | ||
| use_all_queries = False | ||
|
|
||
| num_negs_per_system = ( | ||
| args.num_negs_per_system | ||
| ) # We used different systems to mine hard negatives. Number of hard negatives to add from each system | ||
| num_epochs = args.epochs # Number of epochs we want to train | ||
|
|
||
| # Load our embedding model | ||
| if args.use_pre_trained_model: | ||
| if use_pretrained_model: | ||
| logging.info("use pretrained SBERT model") | ||
| model = SentenceTransformer(model_name) | ||
| model.max_seq_length = max_seq_length | ||
| else: | ||
| logging.info("Create new SBERT model") | ||
| word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) | ||
| pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling) | ||
| word_embedding_model = Transformer(model_name, max_seq_length=max_seq_length) | ||
| pooling_model = Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode) | ||
| model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) | ||
|
|
||
| model_save_path = f"output/train_bi-encoder-margin_mse-{model_name.replace('/', '-')}-batch_size_{train_batch_size}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" | ||
| os.makedirs(model_save_path, exist_ok=True) | ||
| corpus = load_dataset("sentence-transformers/msmarco-corpus", "passage", split="train") | ||
|
|
||
| corpus_dict = dict(zip(corpus["pid"], corpus["text"])) | ||
| queries = load_dataset("omkar334/msmarcoranking-queries", split="train") | ||
|
|
||
| # Write self to path | ||
| os.makedirs(model_save_path, exist_ok=True) | ||
| query_dict = dict(zip(queries["qid"], queries["text"])) | ||
| scores = load_dataset("sentence-transformers/msmarco-scores-ms-marco-MiniLM-L6-v2", "list", split="train") | ||
|
|
||
| train_script_path = os.path.join(model_save_path, "train_script.py") | ||
| copyfile(__file__, train_script_path) | ||
| with open(train_script_path, "a") as fOut: | ||
| fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) | ||
|
|
||
|
|
||
| ### Now we read the MS Marco dataset | ||
| data_folder = "msmarco-data" | ||
|
|
||
| #### Read the corpus files, that contain all the passages. Store them in the corpus dict | ||
| corpus = {} # dict in the format: passage_id -> passage. Stores all existent passages | ||
| collection_filepath = os.path.join(data_folder, "collection.tsv") | ||
| if not os.path.exists(collection_filepath): | ||
| tar_filepath = os.path.join(data_folder, "collection.tar.gz") | ||
| if not os.path.exists(tar_filepath): | ||
| logging.info("Download collection.tar.gz") | ||
| util.http_get("https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz", tar_filepath) | ||
|
|
||
| with tarfile.open(tar_filepath, "r:gz") as tar: | ||
| tar.extractall(path=data_folder) | ||
|
|
||
| logging.info("Read corpus: collection.tsv") | ||
| with open(collection_filepath, encoding="utf8") as fIn: | ||
| for line in fIn: | ||
| pid, passage = line.strip().split("\t") | ||
| pid = int(pid) | ||
| corpus[pid] = passage | ||
|
|
||
|
|
||
| ### Read the train queries, store in queries dict | ||
| queries = {} # dict in the format: query_id -> query. Stores all training queries | ||
| queries_filepath = os.path.join(data_folder, "queries.train.tsv") | ||
| if not os.path.exists(queries_filepath): | ||
| tar_filepath = os.path.join(data_folder, "queries.tar.gz") | ||
| if not os.path.exists(tar_filepath): | ||
| logging.info("Download queries.tar.gz") | ||
| util.http_get("https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz", tar_filepath) | ||
|
|
||
| with tarfile.open(tar_filepath, "r:gz") as tar: | ||
| tar.extractall(path=data_folder) | ||
|
|
||
|
|
||
| with open(queries_filepath, encoding="utf8") as fIn: | ||
| for line in fIn: | ||
| qid, query = line.strip().split("\t") | ||
| qid = int(qid) | ||
| queries[qid] = query | ||
|
|
||
|
|
||
| # Load a dict (qid, pid) -> ce_score that maps query-ids (qid) and paragraph-ids (pid) | ||
| # to the CrossEncoder score computed by the cross-encoder/ms-marco-MiniLM-L6-v2 model | ||
| ce_scores_file = os.path.join(data_folder, "cross-encoder-ms-marco-MiniLM-L6-v2-scores.pkl.gz") | ||
| if not os.path.exists(ce_scores_file): | ||
| logging.info("Download cross-encoder scores file") | ||
| util.http_get( | ||
| "https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz", | ||
| ce_scores_file, | ||
| ce_scores = { | ||
| qid: dict(zip(cids, sc)) | ||
| for qid, cids, sc in zip( | ||
| scores["query_id"], | ||
| scores["corpus_id"], | ||
| scores["score"], | ||
| ) | ||
|
|
||
| } | ||
| logging.info("Load CrossEncoder scores dict") | ||
| with gzip.open(ce_scores_file, "rb") as fIn: | ||
| ce_scores = pickle.load(fIn) | ||
|
|
||
| # As training data we use hard-negatives that have been mined using various systems | ||
| hard_negatives_filepath = os.path.join(data_folder, "msmarco-hard-negatives.jsonl.gz") | ||
| if not os.path.exists(hard_negatives_filepath): | ||
| logging.info("Download cross-encoder scores file") | ||
| util.http_get( | ||
| "https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/msmarco-hard-negatives.jsonl.gz", | ||
| hard_negatives_filepath, | ||
| ) | ||
| hard_negatives_filepath = hf_hub_download( | ||
| repo_id="sentence-transformers/msmarco-hard-negatives", | ||
| filename="msmarco-hard-negatives.jsonl.gz", | ||
| repo_type="dataset", | ||
| ) | ||
|
||
|
|
||
|
|
||
| logging.info("Read hard negatives train file") | ||
| train_queries = {} | ||
| negs_to_use = None | ||
| with gzip.open(hard_negatives_filepath, "rt") as fIn: | ||
| for line in tqdm.tqdm(fIn): | ||
| if max_passages > 0 and len(train_queries) >= max_passages: | ||
| break | ||
| data = json.loads(line) | ||
|
|
||
| # Get the positive passage ids | ||
| pos_pids = data["pos"] | ||
|
|
||
| # Get the hard negatives | ||
| neg_pids = set() | ||
| if negs_to_use is None: | ||
| if args.negs_to_use is not None: # Use specific system for negatives | ||
| negs_to_use = args.negs_to_use.split(",") | ||
| else: # Use all systems | ||
| negs_to_use = list(data["neg"].keys()) | ||
| logging.info("Using negatives from the following systems: {}".format(", ".join(negs_to_use))) | ||
|
|
||
| for system_name in negs_to_use: | ||
| if system_name not in data["neg"]: | ||
| def build_samples(): | ||
| with gzip.open(hard_negatives_filepath, "rt") as f: | ||
| for line in f: | ||
| data = json.loads(line) | ||
|
|
||
| pos_pids = data.get("pos", []) | ||
| neg_systems = data.get("neg", {}) | ||
|
|
||
| # --- skip bad rows (required) --- | ||
| if not pos_pids or not neg_systems: | ||
| continue | ||
|
|
||
| system_negs = data["neg"][system_name] | ||
| negs_added = 0 | ||
| for pid in system_negs: | ||
| if pid not in neg_pids: | ||
| neg_pids.add(pid) | ||
| negs_added += 1 | ||
| if negs_added >= num_negs_per_system: | ||
| break | ||
|
|
||
| if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0): | ||
| train_queries[data["qid"]] = { | ||
| "qid": data["qid"], | ||
| "query": queries[data["qid"]], | ||
| "pos": pos_pids, | ||
| "neg": neg_pids, | ||
| } | ||
| qid = data["qid"] | ||
| query = query_dict.get(qid) | ||
| if query is None: | ||
| continue | ||
|
|
||
| logging.info(f"Train queries: {len(train_queries)}") | ||
| pos = pos_pids[0] | ||
|
|
||
| negs = [] | ||
| for system_negs in neg_systems.values(): | ||
| negs.extend(system_negs[:num_negs_per_system]) | ||
|
|
||
| # We create a custom MSMARCO dataset that returns triplets (query, positive, negative) | ||
| # on-the-fly based on the information from the mined-hard-negatives jsonl file. | ||
| class MSMARCODataset(Dataset): | ||
| def __init__(self, queries, corpus, ce_scores): | ||
| self.queries = queries | ||
| self.queries_ids = list(queries.keys()) | ||
| self.corpus = corpus | ||
| self.ce_scores = ce_scores | ||
| if not negs: | ||
| continue | ||
|
|
||
| for qid in self.queries: | ||
| self.queries[qid]["pos"] = list(self.queries[qid]["pos"]) | ||
| self.queries[qid]["neg"] = list(self.queries[qid]["neg"]) | ||
| random.shuffle(self.queries[qid]["neg"]) | ||
| neg = random.choice(negs) | ||
|
||
|
|
||
| def __getitem__(self, item): | ||
| query = self.queries[self.queries_ids[item]] | ||
| query_text = query["query"] | ||
| qid = query["qid"] | ||
| yield { | ||
| "anchor": query, | ||
| "positive": corpus_dict[pos], | ||
| "negative": corpus_dict[neg], | ||
| "label": ce_scores[qid][pos] - ce_scores[qid][neg], | ||
| } | ||
|
|
||
| if len(query["pos"]) > 0: | ||
| pos_id = query["pos"].pop(0) # Pop positive and add at end | ||
| pos_text = self.corpus[pos_id] | ||
| query["pos"].append(pos_id) | ||
| else: # We only have negatives, use two negs | ||
| pos_id = query["neg"].pop(0) # Pop negative and add at end | ||
| pos_text = self.corpus[pos_id] | ||
| query["neg"].append(pos_id) | ||
|
|
||
| # Get a negative passage | ||
| neg_id = query["neg"].pop(0) # Pop negative and add at end | ||
| neg_text = self.corpus[neg_id] | ||
| query["neg"].append(neg_id) | ||
| train_dataset = Dataset.from_generator(build_samples) | ||
|
||
|
|
||
| pos_score = self.ce_scores[qid][pos_id] | ||
| neg_score = self.ce_scores[qid][neg_id] | ||
| logging.info(f"Training samples: {len(train_dataset)}") | ||
|
|
||
| return InputExample(texts=[query_text, pos_text, neg_text], label=pos_score - neg_score) | ||
|
|
||
| def __len__(self): | ||
| return len(self.queries) | ||
| # Loss function | ||
| train_loss = MarginMSELoss(model) | ||
|
|
||
|
|
||
| # For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training. | ||
| train_dataset = MSMARCODataset(queries=train_queries, corpus=corpus, ce_scores=ce_scores) | ||
| train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, drop_last=True) | ||
| train_loss = losses.MarginMSELoss(model=model) | ||
| # Prepare training arguments | ||
| args = SentenceTransformerTrainingArguments( | ||
| output_dir=model_save_path, | ||
| num_train_epochs=num_epochs, | ||
| per_device_train_batch_size=train_batch_size, | ||
| warmup_ratio=0.1, | ||
| learning_rate=lr, | ||
| save_strategy="steps", | ||
| save_steps=0.001, | ||
| logging_steps=0.01, | ||
| batch_sampler=BatchSamplers.NO_DUPLICATES, | ||
| ) | ||
|
|
||
| # Train the model | ||
| model.fit( | ||
| train_objectives=[(train_dataloader, train_loss)], | ||
| epochs=num_epochs, | ||
| warmup_steps=args.warmup_steps, | ||
| use_amp=True, | ||
| checkpoint_path=model_save_path, | ||
| checkpoint_save_steps=10000, | ||
| optimizer_params={"lr": args.lr}, | ||
| trainer = SentenceTransformerTrainer( | ||
| model=model, | ||
| args=args, | ||
| train_dataset=train_dataset, | ||
| loss=train_loss, | ||
| ) | ||
|
|
||
| # Train latest model | ||
| model.save(model_save_path) | ||
| trainer.train() | ||
|
|
||
| model.save_pretrained(model_save_path) | ||
|
|
||
| # (Optional) save the model to the Hugging Face Hub! | ||
| # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first | ||
| model_name = model_name if "/" not in model_name else model_name.split("/")[-1] | ||
| try: | ||
| model.push_to_hub(f"{model_name}-bi-encoder-margin-mse") | ||
| except Exception: | ||
| logging.error( | ||
| f"Error uploading model to the Hugging Face Hub:\nTo upload it manually, you can run " | ||
| f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({model_save_path!r})` " | ||
| f"and saving it using `model.push_to_hub('{model_name}-bi-encoder-margin-mse')`." | ||
| ) | ||


Uh oh!
There was an error while loading. Please reload this page.