[feat] Add ContAccum support to GradCache MNRL#3612
[feat] Add ContAccum support to GradCache MNRL#3612KimJaehee0725 wants to merge 3 commits intohuggingface:mainfrom
Conversation
|
Hello! This is very cool, I was not aware of this paper. I'm currently assisting with a PR on
|
|
Hi Tom, Thanks a lot for the kind words! Glad you find the idea interesting. In the meantime, I’m happy to address any questions or adjust the implementation if needed. Best, |
|
|
||
|
|
||
| def test_push_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None: | ||
| def build_commit_info(**kwargs): |
There was a problem hiding this comment.
Apologies for the inconvenience here, this was caused by this: huggingface/huggingface_hub#3737
It's resolved and released now, so we're good either way.
(I'm going through the paper etc. now)
|
I ran some tests last night to experiment with this, with the following settings:
Exact GradCache scriptimport random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
"microsoft/mpnet-base",
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name="MPNet base trained on GooAQ triplets using CachedMultipleNegativesRankingLoss with GradCache",
),
)
# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
# train_dataset = train_dataset.add_column("negative", train_dataset["answer"]) # Dummy negatives for MNRBL
# train_dataset = train_dataset.add_column("negative_1", train_dataset["answer"]) # Dummy negatives for MNRBL
# train_dataset = train_dataset.add_column("negative_2", train_dataset["answer"]) # Dummy negatives for MNRBL
eval_dataset: Dataset = dataset_dict["test"]
# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)
# 5. (Optional) Specify training arguments
run_name = "mpnet-base-gooaq-cmnrl-1024bs-GradCache"
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=1024,
per_device_eval_batch_size=1024,
learning_rate=2e-5 * 4,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # CachedMultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=0.1,
save_strategy="steps",
save_steps=0.1,
save_total_limit=2,
logging_steps=0.05,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
{qid: dataset[qid]["answer"] for qid in queries}
# {qid: dataset[qid]["answer"] for qid in queries} |
# {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
corpus=corpus,
queries=queries,
relevant_docs=relevant_docs,
show_progress_bar=True,
name="gooaq-dev",
)
dev_evaluator(model)
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset.remove_columns("id"),
eval_dataset=eval_dataset.remove_columns("id"),
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)
# # 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")
# # 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)Exact GradCache + ContAccum scriptimport random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
"microsoft/mpnet-base",
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name="MPNet base trained on GooAQ triplets using CachedMultipleNegativesRankingLoss with ContAccum",
),
)
# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
# train_dataset = train_dataset.add_column("negative", train_dataset["answer"]) # Dummy negatives for MNRBL
# train_dataset = train_dataset.add_column("negative_1", train_dataset["answer"]) # Dummy negatives for MNRBL
# train_dataset = train_dataset.add_column("negative_2", train_dataset["answer"]) # Dummy negatives for MNRBL
eval_dataset: Dataset = dataset_dict["test"]
# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=32, use_cont_accum=True, cache_size=1028, prev_cache=True)
# 5. (Optional) Specify training arguments
run_name = "mpnet-base-gooaq-cmnrl-128bs-ContAccum"
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=128,
per_device_eval_batch_size=128,
learning_rate=2e-5 * 2,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # CachedMultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=0.1,
save_strategy="steps",
save_steps=0.1,
save_total_limit=2,
logging_steps=0.05,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
{qid: dataset[qid]["answer"] for qid in queries}
# {qid: dataset[qid]["answer"] for qid in queries} |
# {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
corpus=corpus,
queries=queries,
relevant_docs=relevant_docs,
show_progress_bar=True,
name="gooaq-dev",
)
dev_evaluator(model)
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset.remove_columns("id"),
eval_dataset=eval_dataset.remove_columns("id"),
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)
# # 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")
# # 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)In short, I wanted to compare the two methods for "extending the batch size" to an effective batch size of 1028. From my perspective, GradCache already kind of solves this issue, but I'm open to alternatives that speed up the training. My understanding is that these two setups should be somewhere between similar and equivalent. The GradCache model trained in 0.27 hours according to the model's Environmental Impact section and was uploaded here: https://huggingface.co/tomaarsen/mpnet-base-gooaq-cmnrl-1024bs-GradCache. It reached a 0.8652 NDCG@10 on the in-domain evaluation test. The GradCache + ContAccum model with I'm a bit struggling to see where this has a nice edge over pure GradCache currently. The memory overhead for caching makes sense if it e.g. results in speedups or allows some Do you have suggestions I can use to get more benefit out of ContAccum?
|
Summary
This PR adds Contrastive Accumulation (ContAccum) support to
CachedMultipleNegativesRankingLoss, enabling cached embeddings from previous steps to be reused as in-batch negatives while keeping the GradCache memory profile.Motivation
ContAccum ("A Gradient Accumulation Method for Dense Retriever under Memory Constraint") increases the effective in-batch negative pool without increasing GPU memory, by reusing cached embeddings across steps. This makes large-batch contrastive training more accessible in memory-constrained environments.
Changes
CachedMultipleNegativesRankingLoss(use_cont_accum,cache_size,prev_cache).on_optimizer_step()to reset caches after each optimizer step whenprev_cache=False.SentenceTransformerTrainervia a callback.Usage
Parameters
use_cont_accum: Enable ContAccum (reuse cached embeddings from previous steps as in-batch negatives).cache_size: Max number of cached embeddings per column; set to effective batch size to mimic grad accumulation.prev_cache: If True, keep caches across optimizer steps; if False, caches reset after each optimizer step.Notes
loss.on_optimizer_step()after each optimizer step whenprev_cache=False.model.train()), not evaluation.References