Skip to content

[feat] Add ContAccum support to GradCache MNRL#3612

Open
KimJaehee0725 wants to merge 3 commits intohuggingface:mainfrom
KimJaehee0725:contaccum-gradcache
Open

[feat] Add ContAccum support to GradCache MNRL#3612
KimJaehee0725 wants to merge 3 commits intohuggingface:mainfrom
KimJaehee0725:contaccum-gradcache

Conversation

@KimJaehee0725
Copy link

Summary

This PR adds Contrastive Accumulation (ContAccum) support to CachedMultipleNegativesRankingLoss, enabling cached embeddings from previous steps to be reused as in-batch negatives while keeping the GradCache memory profile.

Motivation

ContAccum ("A Gradient Accumulation Method for Dense Retriever under Memory Constraint") increases the effective in-batch negative pool without increasing GPU memory, by reusing cached embeddings across steps. This makes large-batch contrastive training more accessible in memory-constrained environments.

Changes

  • Add ContAccum parameters to CachedMultipleNegativesRankingLoss (use_cont_accum, cache_size, prev_cache).
  • Cache anchors and candidate embeddings and include them in loss computation during training.
  • Add on_optimizer_step() to reset caches after each optimizer step when prev_cache=False.
  • Hook cache reset in SentenceTransformerTrainer via a callback.
  • Add paper reference to the loss docstring.

Usage

loss = losses.CachedMultipleNegativesRankingLoss(
    model,
    mini_batch_size=32,
    use_cont_accum=True,
    cache_size=256,
    prev_cache=False,
)

Parameters

  • use_cont_accum: Enable ContAccum (reuse cached embeddings from previous steps as in-batch negatives).
  • cache_size: Max number of cached embeddings per column; set to effective batch size to mimic grad accumulation.
  • prev_cache: If True, keep caches across optimizer steps; if False, caches reset after each optimizer step.

Notes

  • If you use a custom training loop, call loss.on_optimizer_step() after each optimizer step when prev_cache=False.
  • Cache is only used during training (model.train()), not evaluation.

References

@tomaarsen
Copy link
Member

Hello!

This is very cool, I was not aware of this paper. I'm currently assisting with a PR on transformers for something that should help strengthen Sentence Transformers' multi-modality support, but afterwards I'll have a fresh look at this, as well as the paper. The premise looks quite solid.

  • Tom Aarsen

@KimJaehee0725
Copy link
Author

Hi Tom,

Thanks a lot for the kind words! Glad you find the idea interesting.
No rush at all — whenever you get a chance to take a look, I’d really appreciate any feedback.
(I fixed the linter error right before with current commit.)

In the meantime, I’m happy to address any questions or adjust the implementation if needed.

Best,
Jaehee



def test_push_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
def build_commit_info(**kwargs):
Copy link
Member

@tomaarsen tomaarsen Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the inconvenience here, this was caused by this: huggingface/huggingface_hub#3737

It's resolved and released now, so we're good either way.

(I'm going through the paper etc. now)

@tomaarsen
Copy link
Member

I ran some tests last night to experiment with this, with the following settings:

  • GradCache only:
    • CachedMultipleNegativesRankingLoss with mini_batch_size=64
    • SentenceTransformerTrainingArguments with per_device_train_batch_size=1024
  • GradCache with ContAccum:
    • CachedMultipleNegativesRankingLoss with mini_batch_size=64 (also tested 32), use_cont_accum=True, cache_size=1028, prev_cache=True (also tested False)
    • SentenceTransformerTrainingArguments with per_device_train_batch_size=128
Exact GradCache script
import random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "microsoft/mpnet-base",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base trained on GooAQ triplets using CachedMultipleNegativesRankingLoss with GradCache",
    ),
)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
# train_dataset = train_dataset.add_column("negative", train_dataset["answer"])  # Dummy negatives for MNRBL
# train_dataset = train_dataset.add_column("negative_1", train_dataset["answer"])  # Dummy negatives for MNRBL
# train_dataset = train_dataset.add_column("negative_2", train_dataset["answer"])  # Dummy negatives for MNRBL

eval_dataset: Dataset = dataset_dict["test"]

# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)

# 5. (Optional) Specify training arguments
run_name = "mpnet-base-gooaq-cmnrl-1024bs-GradCache"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=1024,
    per_device_eval_batch_size=1024,
    learning_rate=2e-5 * 4,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # CachedMultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=0.1,
    save_strategy="steps",
    save_steps=0.1,
    save_total_limit=2,
    logging_steps=0.05,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
    {qid: dataset[qid]["answer"] for qid in queries}
    # {qid: dataset[qid]["answer"] for qid in queries} |
    # {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
    corpus=corpus,
    queries=queries,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name="gooaq-dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.remove_columns("id"),
    eval_dataset=eval_dataset.remove_columns("id"),
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)

# # 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# # 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)
Exact GradCache + ContAccum script
import random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "microsoft/mpnet-base",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base trained on GooAQ triplets using CachedMultipleNegativesRankingLoss with ContAccum",
    ),
)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
# train_dataset = train_dataset.add_column("negative", train_dataset["answer"])  # Dummy negatives for MNRBL
# train_dataset = train_dataset.add_column("negative_1", train_dataset["answer"])  # Dummy negatives for MNRBL
# train_dataset = train_dataset.add_column("negative_2", train_dataset["answer"])  # Dummy negatives for MNRBL

eval_dataset: Dataset = dataset_dict["test"]

# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=32, use_cont_accum=True, cache_size=1028, prev_cache=True)

# 5. (Optional) Specify training arguments
run_name = "mpnet-base-gooaq-cmnrl-128bs-ContAccum"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=2e-5 * 2,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # CachedMultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=0.1,
    save_strategy="steps",
    save_steps=0.1,
    save_total_limit=2,
    logging_steps=0.05,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
    {qid: dataset[qid]["answer"] for qid in queries}
    # {qid: dataset[qid]["answer"] for qid in queries} |
    # {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
    corpus=corpus,
    queries=queries,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name="gooaq-dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.remove_columns("id"),
    eval_dataset=eval_dataset.remove_columns("id"),
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)

# # 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# # 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)

In short, I wanted to compare the two methods for "extending the batch size" to an effective batch size of 1028. From my perspective, GradCache already kind of solves this issue, but I'm open to alternatives that speed up the training. My understanding is that these two setups should be somewhere between similar and equivalent.

The GradCache model trained in 0.27 hours according to the model's Environmental Impact section and was uploaded here: https://huggingface.co/tomaarsen/mpnet-base-gooaq-cmnrl-1024bs-GradCache. It reached a 0.8652 NDCG@10 on the in-domain evaluation test.

The GradCache + ContAccum model with prev_cache=True would continuously increase VRAM usage over time, resulting in significant slowdowns, until I cancelled it when the expected training completion would hit ~11.5 hours.
With prev_cache=False, results seem more promising. The memory seems static now, but the cache will always get reset with only 1 element in it.
I reran the script with gradient_accumulation_steps=8 (is that the intended approach then?), then the cache size is 8 (shouldn't it be 8 * batch_size? Should I then set cache_size to something lower than 1028?). The gradient accumulation version completed here: https://huggingface.co/tomaarsen/mpnet-base-gooaq-cmnrl-128bs-ContAccum. It took 0.281 hours (a.k.a. roughly the same time) and scored a 0.8304 NDCG@10.

I'm a bit struggling to see where this has a nice edge over pure GradCache currently. The memory overhead for caching makes sense if it e.g. results in speedups or allows some batch_size setting that's out of reach with GradCache, but I think GradCache already allows for effectively any batch_size, and I haven't encountered any speedups.

Do you have suggestions I can use to get more benefit out of ContAccum?

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants