Skip to content

[feat] Add hardness-weighted contrastive learning to losses#3667

Merged
tomaarsen merged 9 commits intohuggingface:mainfrom
yjoonjang:feat/add_hardness_weight
Mar 2, 2026
Merged

[feat] Add hardness-weighted contrastive learning to losses#3667
tomaarsen merged 9 commits intohuggingface:mainfrom
yjoonjang:feat/add_hardness_weight

Conversation

@yjoonjang
Copy link
Contributor

@yjoonjang yjoonjang commented Feb 19, 2026

Hello, @tomaarsen !

Pull Request overview

  • Add optional hardness weighting to all four contrastive losses (MultipleNegativesRankingLoss, CachedMultipleNegativesRankingLoss, GISTEmbedLoss, CachedGISTEmbedLoss), inspired by Lan et al. 2025 (LLaVE).

Details

This PR introduces a hardness-aware mechanism that up-weights harder negatives during contrastive learning, controlled by two new parameters: hardness_alpha and hardness_mode. The idea is that not all negatives are equally informative — harder negatives (those with higher similarity to the query) should contribute more to the loss. (This is actually one of the issue I mentioned about two years ago. #2910)

The feature is off by default (hardness_alpha=0.0), so existing behavior is completely unchanged.

Two operating modes

hardness_mode="all" (logit penalty on all negatives):

  • Implements Equation 5 from Lan et al. 2025: adds alpha * sg(cos_sim(q, neg)) to each negative logit before the softmax, while keeping the positive logit unchanged.
  • This amplifies the contribution of harder negatives across all negatives (both in-batch and explicit hard negatives).
  • Works with any data format including pairs-only (anchor, positive).
  • Stop-gradient (detach()) on the penalty prevents shortcut learning.

hardness_mode="hard_only" (per-sample loss weighting from explicit hard negatives):

  • Computes a per-sample weight w_i = exp(alpha * sg(max_cos_sim(q_i, hard_negs))) based on how similar the query is to its hardest explicit negative. This is a method mentioned by EmbeddingGemma (Lee et al. 2025)
  • Applies weighted averaging: loss = sum(w_i * L_i) / sum(w_i).
  • Only active when explicit hard negatives are present (triplets or n-tuples); falls back to standard mean loss for pairs-only (query-positive only) data.

Implementation details

  • For MNRL-style losses, the raw cosine similarity is recovered via (sim_matrix / self.scale).detach() before computing the penalty.
  • For GISTEmbed-style losses, similarity is recovered via (scores * self.temperature).detach(). The hardness penalty is applied after GIST's false-negative suppression, so samples already masked to -inf by the guide model remain suppressed (-inf + penalty = -inf).
  • In the "all" mode, positive positions are explicitly protected by zeroing out the penalty at diagonal/positive indices.
  • For Cached variants, "hard_only" weights are precomputed once before the mini-batch loop, while "all" penalty is applied per mini-batch.

Experiment setup

I've ran a experiment with the implemented losses.

  • Model: microsoft/mpnet-base
  • Dataset: rlhn/remove-250K (100K training samples, triplet format: query, positive, hard negative)
  • Loss: CachedMultipleNegativesRankingLoss with mini_batch_size=64
  • Training: 1 epoch, batch size 128, lr 2e-5, warmup 10%, bf16 mixed precision
  • Evaluation (during training): NanoBEIREvaluator on 6 datasets (MSMARCO, NFCorpus, FiQA2018, SciFact, NQ, HotpotQA), every 20% of training
  • Evaluation (after training): NanoBEIREvaluator on all 13 NanoBEIR datasets
  • wandb: https://wandb.ai/yjoonjang/hardness_weight_test
Training test code
"""
CachedMNRL + Hardness Weight comparison test.

Compares standard CachedMultipleNegativesRankingLoss against:
  - hardness_mode="hard_only": per-sample loss weighting by explicit hard negative similarity
  - hardness_mode="all": logit penalty on all negatives (Lan et al. 2025, Eq. 5)

Dataset: rlhn/remove-250K (100K samples, triplet format: query, positive, negative)
Eval: NanoBEIREvaluator (nDCG@10 on NanoBEIR datasets)
"""

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from sentence_transformers.evaluation import NanoBEIREvaluator

model = SentenceTransformer("microsoft/mpnet-base")

# Baseline
loss_baseline = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)

# hard_only mode (alpha=2.0) — best performer
loss_hard_only = losses.CachedMultipleNegativesRankingLoss(
    model,
    mini_batch_size=64,
    hardness_alpha=2.0,
    hardness_mode="hard_only",
)

# all mode (alpha=2.0) — logit penalty on all negatives
loss_all = losses.CachedMultipleNegativesRankingLoss(
    model,
    mini_batch_size=64,
    hardness_alpha=2.0,
    hardness_mode="all",
)

Results

NanoBEIR (all 13 datasets) — nDCG@10 per task on final model:

Dataset Baseline hard_only α=5.0 all α=5.0 hard_only α=2.0 all α=2.0
ClimateFEVER 0.3370 0.3392 0.3494 0.3529 0.3346
DBPedia 0.5213 0.5240 0.5184 0.5289 0.5357
FEVER 0.8410 0.8441 0.8588 0.8476 0.8524
FiQA2018 0.4477 0.4529 0.4361 0.4463 0.4330
HotpotQA 0.6748 0.6649 0.6698 0.6763 0.6748
MSMARCO 0.5505 0.5651 0.5488 0.5632 0.5610
NFCorpus 0.2743 0.2585 0.2761 0.2759 0.2758
NQ 0.5567 0.5835 0.5283 0.5827 0.5507
QuoraRetrieval 0.9092 0.9186 0.9079 0.9205 0.9127
SCIDOCS 0.3761 0.3750 0.3647 0.3679 0.3788
ArguAna 0.5725 0.5876 0.5845 0.5961 0.5869
SciFact 0.5994 0.6003 0.6207 0.6029 0.5883
Touche2020 0.5079 0.4814 0.5064 0.4925 0.5070
Mean 0.5514 0.5535 0.5515 0.5580 0.5532

Key observations:

  • hard_only with alpha=2.0 is the best setting overall, achieving +0.66% mean nDCG@10 improvement on the full NanoBEIR benchmark (13 datasets) with zero additional training cost.
  • Lower alpha (2.0) consistently outperforms higher alpha (5.0) for both modes. Note that EmbeddingGemma (Lee et al. 2025) uses alpha=5.0 in their hard_only-style weighting, but our experiments suggest that a more moderate alpha=2.0 generalizes better for this rlhn dataset and setting.
  • hard_only mode outperforms all mode, likely because per-sample weighting from explicit hard negatives provides a cleaner signal than uniform logit penalty across all in-batch negatives.
  • hard_only α=2.0 wins or ties on 7 of 13 datasets (ClimateFEVER, HotpotQA, MSMARCO, NQ, QuoraRetrieval, ArguAna, SciFact), with particularly strong gains on NQ (+2.6pp) and MSMARCO (+1.3pp).
  • all mode with alpha=5.0 matches baseline overall (0.5515 vs 0.5514) but shows different per-task trade-offs — stronger on FEVER (+1.8pp) and SciFact (+2.1pp), weaker on NQ (-2.8pp).

Usage example

from sentence_transformers import SentenceTransformer, losses

model = SentenceTransformer("microsoft/mpnet-base")

# Recommended: hard_only mode with alpha=2.0
loss = losses.MultipleNegativesRankingLoss(
    model,
    hardness_alpha=2.0,
    hardness_mode="hard_only",
)

# Or with CachedMultipleNegativesRankingLoss
loss = losses.CachedMultipleNegativesRankingLoss(
    model,
    mini_batch_size=32,
    hardness_alpha=2.0,
    hardness_mode="hard_only",
)

# Or with GISTEmbedLoss
guide_model = SentenceTransformer("all-MiniLM-L6-v2")
loss = losses.GISTEmbedLoss(
    model,
    guide=guide_model,
    hardness_alpha=2.0,
    hardness_mode="hard_only",
)

Files changed

  • sentence_transformers/losses/MultipleNegativesRankingLoss.py — add hardness_alpha and hardness_mode parameters
  • sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py — same, with mini-batch compatible implementation
  • sentence_transformers/losses/GISTEmbedLoss.py — same, adapted for temperature-based scoring
  • sentence_transformers/losses/CachedGISTEmbedLoss.py — same, with mini-batch compatible implementation

TODO / Open questions

  • Add unit tests for both hardness_mode options
  • Add documentation page for the hardness weighting feature
  • Consider whether hardness_alpha default should be non-zero (e.g., 2.0) for triplet data
  • (maybe) Run experiments on larger models / more datasets to validate generalization

@yjoonjang
Copy link
Contributor Author

It might be better to work on the self-guide PR first (#3662), so that we can implement both self-guide and hardness weight to GIST-style losses.

@Samoed
Copy link
Contributor

Samoed commented Feb 19, 2026

Great addition! Kalm-v2 https://arxiv.org/abs/2506.20923 also using hard weighting, but I'm not sure if it same as in LLaVE or not

@tomaarsen
Copy link
Member

This is super interesting! I'm definitely interested in adding hardness normalisation. The KaLM paper has a lot of interesting ideas as well, I believe they call it a focal-style reweighting.

They also do some Contrastive distillation thing that I never 100% figured out.

I'll try and review this soon!

  • Tom Aarsen

@tomaarsen
Copy link
Member

tomaarsen commented Feb 20, 2026

(Heads up, I pulled main into this branch to update it)

Can you run pre-commit run --all perhaps? I have it pre-commit install-ed so it runs before every commit for this repository (README docs)

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

I ran pre-commit run --all an pushed them again!

Thank you.

@yjoonjang
Copy link
Contributor Author

I initially misread the EmbeddingGemma paper and implemented hard_only as applying a weight to each sample's loss based on the maximum similarity to any single hard negative (i.e., w_i = exp(α · max_j cos_sim(q_i, neg_j))). Looking more carefully at the paper, the intended behavior is to apply α · stop_grad(cos_sim) as a per-logit penalty to every explicit hard negative column individually in the score matrix. This is identical in structure to "all" mode, but restricted only to the explicit hard negative columns and leaving in-batch negative columns untouched.

So I've updated "hard_only" to match this: instead of reweighting the per-sample loss, we now zero out the penalty for in-batch columns ([:, :world_batch_size]) and only add the penalty to the explicit hard
negative columns ([:, world_batch_size:]). This is consistent with how EmbeddingGemma actually uses it.

@tomaarsen tomaarsen mentioned this pull request Feb 26, 2026
2 tasks
Only add penalties to query_to_doc following the papers, fix penalty masking, update modes to hard_negatives, in_batch_negatives, and all_negatives.

Add a warning if you're using a combination of parameters that allows the loss to go negative.

Add focal weighting from the Kalm-Embedding-v2 paper

I didn't update the (C)GIST as that would require more research on my part first.
@tomaarsen
Copy link
Member

I've made some changes, as I think there were a handful of issues. Specifically, I believe the papers only update the query -> doc similarities. Beyond that, I believe there were issues with the penalty masking, so I've resolved those. And I've updated the modes to hard_negatives, in_batch_negatives, and all_negatives in (C)MRNL.

I also

  • Added a warning if you're using a combination of parameters that allows the loss to go negative.
  • Added focal weighting from the Kalm-Embedding-v2 paper

I didn't update the (C)GIST as that would require more research on my part, as these are implemented a bit differently. I'm also okay with not adding the hardness options to (C)GIST at this time.

I had Claude write me a script to perform hyperparameter optimization on all of the options now. I'd like to know whether these new additions (hardness, directions, partition_mode, etc.) help over standard InfoNCE/MNRL. I'm training a Static Embedding model on https://huggingface.co/datasets/tomaarsen/natural-questions-hard-negatives/viewer/triplet-5 to help me figure that out.

Please have a look at my changes, I hope they make sense to you.

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

yjoonjang commented Feb 26, 2026

Hello !
This implementation looks great !
I especially like the addition of focal weighting from KaLM-v2 and the cleaner penalty masking logic.

It's also okay for me to add hardness options to MNRLs this time.
We could further add this to GISTs after all other PRs.

Since I have some GPUs right now, I'll test all 5 variations with default
configs (baseline (original CMNRL), in_batch_negatives_9, hard_negatives_5, all_negatives_5, focal_0.5) on tomaarsen/natural-questions-hard-negatives (triplet-5) with NanoBEIR evaluation.

Training script
#!/bin/bash
# Run all 5 hardness weight experiments with DDP on GPUs 4,5,6,7
#
# Usage: bash run_all.sh

set -e

export CUDA_VISIBLE_DEVICES=4,5,6,7
NPROC=4
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
TRAIN_SCRIPT="$SCRIPT_DIR/train.py"
MASTER_PORT=29501

echo "============================================"
echo "[1/5] Baseline (no hardness weighting)"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
    --hardness_mode none \
    --hardness_strength 0.0

echo "============================================"
echo "[2/5] in_batch_negatives, strength=9"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
    --hardness_mode in_batch_negatives \
    --hardness_strength 9

echo "============================================"
echo "[3/5] hard_negatives, strength=5"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
    --hardness_mode hard_negatives \
    --hardness_strength 5

echo "============================================"
echo "[4/5] all_negatives, strength=5"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
    --hardness_mode all_negatives \
    --hardness_strength 5

echo "============================================"
echo "[5/5] focal, strength=0.5"
echo "============================================"
torchrun --nproc_per_node=$NPROC --master_port=$MASTER_PORT "$TRAIN_SCRIPT" \
    --hardness_mode focal \
    --hardness_strength 0.5

echo "============================================"
echo "All 5 experiments complete!"
echo "============================================"
Training Python file
"""
Training script for hardness-weighted CachedMultipleNegativesRankingLoss experiments.

Dataset: tomaarsen/natural-questions-hard-negatives (triplet-5)
Loss: CachedMultipleNegativesRankingLoss with various hardness_mode settings

Usage (single GPU):
    python train.py --hardness_mode none

Usage (DDP):
    CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 train.py --hardness_mode in_batch_negatives --hardness_strength 9
"""

import argparse
import logging
import os

from datasets import load_dataset

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="microsoft/mpnet-base")
    parser.add_argument("--hardness_mode", type=str, default="none", choices=["none", "in_batch_negatives", "hard_negatives", "all_negatives", "focal"])
    parser.add_argument("--hardness_strength", type=float, default=0.0)
    parser.add_argument("--per_device_train_batch_size", type=int, default=256)
    parser.add_argument("--mini_batch_size", type=int, default=32)
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument("--learning_rate", type=float, default=2e-5)
    parser.add_argument("--output_dir", type=str, default="models/hardness_weight")
    parser.add_argument("--wandb_project", type=str, default="Hardness_Weight")
    return parser.parse_args()


def main():
    args = parse_args()

    hardness_mode = None if args.hardness_mode == "none" else args.hardness_mode

    if hardness_mode is None:
        run_name = "baseline"
    else:
        run_name = f"{args.hardness_mode}_s{args.hardness_strength}"
    logger.info(f"Run: {run_name} | hardness_mode={hardness_mode}, hardness_strength={args.hardness_strength}")

    # 1. Load model
    model = SentenceTransformer(args.model_name)

    # 2. Load dataset
    dataset = load_dataset("tomaarsen/natural-questions-hard-negatives", "triplet-5", split="train")
    dataset = dataset.rename_columns({"query": "anchor", "answer": "positive"})
    logger.info(f"Dataset size: {len(dataset)}")

    # 3. Define loss
    loss = CachedMultipleNegativesRankingLoss(
        model,
        mini_batch_size=args.mini_batch_size,
        gather_across_devices=True,
        hardness_mode=hardness_mode,
        hardness_strength=args.hardness_strength,
    )

    # 4. Evaluator (subset of NanoBEIR for mid-training eval)
    dev_evaluator = NanoBEIREvaluator(
        dataset_names=["msmarco", "nq", "nfcorpus", "quoraretrieval"],
        show_progress_bar=True,
        batch_size=64,
    )

    # 5. Training arguments
    os.environ["WANDB_PROJECT"] = args.wandb_project
    training_args = SentenceTransformerTrainingArguments(
        output_dir=f"{args.output_dir}/{run_name}",
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        learning_rate=args.learning_rate,
        warmup_ratio=0.1,
        bf16=True,
        # batch_sampler=BatchSamplers.NO_DUPLICATES,
		eval_on_start=True,
        eval_strategy="steps",
        eval_steps=0.1,
        logging_steps=10,
        logging_first_step=True,
        save_strategy="no",
        report_to="wandb",
        run_name=run_name,
        ddp_find_unused_parameters=True,
    )

    # 6. Train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        loss=loss,
        evaluator=dev_evaluator,
    )
    trainer.train()

    # 7. Full NanoBEIR evaluation (all 13 subsets)
    logger.info("Running full NanoBEIR evaluation on all subsets...")
    full_evaluator = NanoBEIREvaluator(
        show_progress_bar=True,
        batch_size=64,
    )
    full_evaluator(model)

    # 8. Save
    model.save_pretrained(f"{args.output_dir}/{run_name}/final")
    logger.info(f"Done: {run_name}")


if __name__ == "__main__":
    main()

Thank you once again,
I'll report the results here when its done !!

@yjoonjang
Copy link
Contributor Author

Here are the results on NanoBEIR (across all 13 subsets):

Recall@10 NDCG@10
Baseline (original CMNRL) 56.40 49.84
focal_s0.5 56.11 49.56
all_negatives_s5.0 56.48 49.80
hard_negatives_s5.0 56.53 49.87
in_batch_negatives_s9.0 55.88 49.77

hard_negatives_s5.0 shows slight improvements (+0.13 Recall@10, +0.03 NDCG@10), but the differences are marginal and likely within noise.

@tomaarsen
Copy link
Member

tomaarsen commented Feb 27, 2026

I ran a lot of benchmarks on my side as well via optuna using the static embedding model I previously mentioned. The differences were tiny, but had roughly this order: all_negatives > in_batch_negatives > hard_negatives = main > focal. I think I would be happy to remove focal again, I think it's best to avoid adding techniques that we can't get to outperform the baseline.

Then with these 2 changes, we're ready I think:

  • Remove focal
  • Undo changes for (C)GIST: this should also fix the tests again.
  • Tom Aarsen

@tomaarsen
Copy link
Member

Also note that I just merged #3677, which shouldn't interfere with this PR too much, but it's still a merge conflict.

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

Thanks for the review! Applied the requested changes:

  1. Removed focal mode from (C)MNRL
  2. Reverted GISTEmbedLoss and CachedGISTEmbedLoss
  3. Merged upstream/main to resolve conflicts from [loss] Disallow query_to_query/doc_to_doc with partition_mode="per_direction" due to negative loss #3677

@tomaarsen
Copy link
Member

tomaarsen commented Feb 27, 2026

Oh, I forgot that SparseMultipleNegativesRankingLoss subclasses MultipleNegativesRankingLoss. Can we perhaps also extend SparseMultipleNegativesRankingLoss + then update the currently failing tests/sparse_encoder/test_model_card.py (which fails because the SparseMultipleNegativesRankingLoss now outputs new values in its get_config_dict().

Looks great otherwise!

  • Tom Aarsen

@tomaarsen tomaarsen enabled auto-merge (squash) March 2, 2026 08:48
@tomaarsen
Copy link
Member

Thanks a bunch, I think this should be good to go once the tests are green!

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 7f180b4 into huggingface:main Mar 2, 2026
17 checks passed
@yjoonjang
Copy link
Contributor Author

Thanks for all the deatils !

I'll now work on the self-guide mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants