[feat] Add hardness-weighted contrastive learning to losses#3667
[feat] Add hardness-weighted contrastive learning to losses#3667tomaarsen merged 9 commits intohuggingface:mainfrom
feat] Add hardness-weighted contrastive learning to losses#3667Conversation
|
It might be better to work on the self-guide PR first (#3662), so that we can implement both self-guide and hardness weight to GIST-style losses. |
|
Great addition! |
|
This is super interesting! I'm definitely interested in adding hardness normalisation. The KaLM paper has a lot of interesting ideas as well, I believe they call it a focal-style reweighting. They also do some Contrastive distillation thing that I never 100% figured out. I'll try and review this soon!
|
|
(Heads up, I pulled main into this branch to update it) Can you run
|
|
I ran Thank you. |
|
I initially misread the EmbeddingGemma paper and implemented So I've updated |
Only add penalties to query_to_doc following the papers, fix penalty masking, update modes to hard_negatives, in_batch_negatives, and all_negatives. Add a warning if you're using a combination of parameters that allows the loss to go negative. Add focal weighting from the Kalm-Embedding-v2 paper I didn't update the (C)GIST as that would require more research on my part first.
|
I've made some changes, as I think there were a handful of issues. Specifically, I believe the papers only update the query -> doc similarities. Beyond that, I believe there were issues with the penalty masking, so I've resolved those. And I've updated the modes to hard_negatives, in_batch_negatives, and all_negatives in (C)MRNL. I also
I didn't update the (C)GIST as that would require more research on my part, as these are implemented a bit differently. I'm also okay with not adding the hardness options to (C)GIST at this time. I had Claude write me a script to perform hyperparameter optimization on all of the options now. I'd like to know whether these new additions (hardness, directions, partition_mode, etc.) help over standard InfoNCE/MNRL. I'm training a Static Embedding model on https://huggingface.co/datasets/tomaarsen/natural-questions-hard-negatives/viewer/triplet-5 to help me figure that out. Please have a look at my changes, I hope they make sense to you.
|
|
Hello ! It's also okay for me to add hardness options to MNRLs this time. Since I have some GPUs right now, I'll test all 5 variations with default Training script#!/bin/bash
# Run all 5 hardness weight experiments with DDP on GPUs 4,5,6,7
#
# Usage: bash run_all.sh
set -e
export CUDA_VISIBLE_DEVICES=4,5,6,7
NPROC=4
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
TRAIN_SCRIPT="$SCRIPT_DIR/train.py"
MASTER_PORT=29501
echo "============================================"
echo "[1/5] Baseline (no hardness weighting)"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
--hardness_mode none \
--hardness_strength 0.0
echo "============================================"
echo "[2/5] in_batch_negatives, strength=9"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
--hardness_mode in_batch_negatives \
--hardness_strength 9
echo "============================================"
echo "[3/5] hard_negatives, strength=5"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
--hardness_mode hard_negatives \
--hardness_strength 5
echo "============================================"
echo "[4/5] all_negatives, strength=5"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
--hardness_mode all_negatives \
--hardness_strength 5
echo "============================================"
echo "[5/5] focal, strength=0.5"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
--hardness_mode focal \
--hardness_strength 0.5
echo "============================================"
echo "All 5 experiments complete!"
echo "============================================"Training Python file"""
Training script for hardness-weighted CachedMultipleNegativesRankingLoss experiments.
Dataset: tomaarsen/natural-questions-hard-negatives (triplet-5)
Loss: CachedMultipleNegativesRankingLoss with various hardness_mode settings
Usage (single GPU):
python train.py --hardness_mode none
Usage (DDP):
CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 train.py --hardness_mode in_batch_negatives --hardness_strength 9
"""
import argparse
import logging
import os
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="microsoft/mpnet-base")
parser.add_argument("--hardness_mode", type=str, default="none", choices=["none", "in_batch_negatives", "hard_negatives", "all_negatives", "focal"])
parser.add_argument("--hardness_strength", type=float, default=0.0)
parser.add_argument("--per_device_train_batch_size", type=int, default=256)
parser.add_argument("--mini_batch_size", type=int, default=32)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--output_dir", type=str, default="models/hardness_weight")
parser.add_argument("--wandb_project", type=str, default="Hardness_Weight")
return parser.parse_args()
def main():
args = parse_args()
hardness_mode = None if args.hardness_mode == "none" else args.hardness_mode
if hardness_mode is None:
run_name = "baseline"
else:
run_name = f"{args.hardness_mode}_s{args.hardness_strength}"
logger.info(f"Run: {run_name} | hardness_mode={hardness_mode}, hardness_strength={args.hardness_strength}")
# 1. Load model
model = SentenceTransformer(args.model_name)
# 2. Load dataset
dataset = load_dataset("tomaarsen/natural-questions-hard-negatives", "triplet-5", split="train")
dataset = dataset.rename_columns({"query": "anchor", "answer": "positive"})
logger.info(f"Dataset size: {len(dataset)}")
# 3. Define loss
loss = CachedMultipleNegativesRankingLoss(
model,
mini_batch_size=args.mini_batch_size,
gather_across_devices=True,
hardness_mode=hardness_mode,
hardness_strength=args.hardness_strength,
)
# 4. Evaluator (subset of NanoBEIR for mid-training eval)
dev_evaluator = NanoBEIREvaluator(
dataset_names=["msmarco", "nq", "nfcorpus", "quoraretrieval"],
show_progress_bar=True,
batch_size=64,
)
# 5. Training arguments
os.environ["WANDB_PROJECT"] = args.wandb_project
training_args = SentenceTransformerTrainingArguments(
output_dir=f"{args.output_dir}/{run_name}",
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
learning_rate=args.learning_rate,
warmup_ratio=0.1,
bf16=True,
# batch_sampler=BatchSamplers.NO_DUPLICATES,
eval_on_start=True,
eval_strategy="steps",
eval_steps=0.1,
logging_steps=10,
logging_first_step=True,
save_strategy="no",
report_to="wandb",
run_name=run_name,
ddp_find_unused_parameters=True,
)
# 6. Train
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 7. Full NanoBEIR evaluation (all 13 subsets)
logger.info("Running full NanoBEIR evaluation on all subsets...")
full_evaluator = NanoBEIREvaluator(
show_progress_bar=True,
batch_size=64,
)
full_evaluator(model)
# 8. Save
model.save_pretrained(f"{args.output_dir}/{run_name}/final")
logger.info(f"Done: {run_name}")
if __name__ == "__main__":
main()Thank you once again, |
|
Here are the results on NanoBEIR (across all 13 subsets):
|
|
I ran a lot of benchmarks on my side as well via Then with these 2 changes, we're ready I think:
|
|
Also note that I just merged #3677, which shouldn't interfere with this PR too much, but it's still a merge conflict.
|
|
Thanks for the review! Applied the requested changes:
|
|
Oh, I forgot that Looks great otherwise!
|
|
Thanks a bunch, I think this should be good to go once the tests are green!
|
|
Thanks for all the deatils ! I'll now work on the self-guide mode. |
Hello, @tomaarsen !
Pull Request overview
MultipleNegativesRankingLoss,CachedMultipleNegativesRankingLoss,GISTEmbedLoss,CachedGISTEmbedLoss), inspired by Lan et al. 2025 (LLaVE).Details
This PR introduces a hardness-aware mechanism that up-weights harder negatives during contrastive learning, controlled by two new parameters:
hardness_alphaandhardness_mode. The idea is that not all negatives are equally informative — harder negatives (those with higher similarity to the query) should contribute more to the loss. (This is actually one of the issue I mentioned about two years ago. #2910)The feature is off by default (
hardness_alpha=0.0), so existing behavior is completely unchanged.Two operating modes
hardness_mode="all"(logit penalty on all negatives):alpha * sg(cos_sim(q, neg))to each negative logit before the softmax, while keeping the positive logit unchanged.detach()) on the penalty prevents shortcut learning.hardness_mode="hard_only"(per-sample loss weighting from explicit hard negatives):w_i = exp(alpha * sg(max_cos_sim(q_i, hard_negs)))based on how similar the query is to its hardest explicit negative. This is a method mentioned by EmbeddingGemma (Lee et al. 2025)loss = sum(w_i * L_i) / sum(w_i).Implementation details
(sim_matrix / self.scale).detach()before computing the penalty.(scores * self.temperature).detach(). The hardness penalty is applied after GIST's false-negative suppression, so samples already masked to-infby the guide model remain suppressed (-inf + penalty = -inf)."all"mode, positive positions are explicitly protected by zeroing out the penalty at diagonal/positive indices."hard_only"weights are precomputed once before the mini-batch loop, while"all"penalty is applied per mini-batch.Experiment setup
I've ran a experiment with the implemented losses.
microsoft/mpnet-baserlhn/remove-250K(100K training samples, triplet format: query, positive, hard negative)CachedMultipleNegativesRankingLosswithmini_batch_size=64NanoBEIREvaluatoron 6 datasets (MSMARCO, NFCorpus, FiQA2018, SciFact, NQ, HotpotQA), every 20% of trainingNanoBEIREvaluatoron all 13 NanoBEIR datasetsTraining test code
Results
NanoBEIR (all 13 datasets) — nDCG@10 per task on final model:
hard_onlyα=5.0allα=5.0hard_onlyα=2.0allα=2.0Key observations:
hard_onlywith alpha=2.0 is the best setting overall, achieving +0.66% mean nDCG@10 improvement on the full NanoBEIR benchmark (13 datasets) with zero additional training cost.hard_only-style weighting, but our experiments suggest that a more moderate alpha=2.0 generalizes better for this rlhn dataset and setting.hard_onlymode outperformsallmode, likely because per-sample weighting from explicit hard negatives provides a cleaner signal than uniform logit penalty across all in-batch negatives.hard_onlyα=2.0 wins or ties on 7 of 13 datasets (ClimateFEVER, HotpotQA, MSMARCO, NQ, QuoraRetrieval, ArguAna, SciFact), with particularly strong gains on NQ (+2.6pp) and MSMARCO (+1.3pp).allmode with alpha=5.0 matches baseline overall (0.5515 vs 0.5514) but shows different per-task trade-offs — stronger on FEVER (+1.8pp) and SciFact (+2.1pp), weaker on NQ (-2.8pp).Usage example
Files changed
sentence_transformers/losses/MultipleNegativesRankingLoss.py— addhardness_alphaandhardness_modeparameterssentence_transformers/losses/CachedMultipleNegativesRankingLoss.py— same, with mini-batch compatible implementationsentence_transformers/losses/GISTEmbedLoss.py— same, adapted for temperature-based scoringsentence_transformers/losses/CachedGISTEmbedLoss.py— same, with mini-batch compatible implementationTODO / Open questions
hardness_modeoptionshardness_alphadefault should be non-zero (e.g., 2.0) for triplet data