diff --git a/sentence_transformers/callbacks/__init__.py b/sentence_transformers/callbacks/__init__.py new file mode 100644 index 000000000..2bb657802 --- /dev/null +++ b/sentence_transformers/callbacks/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from sentence_transformers.callbacks.self_guide_callbacks import SelfGuideWarmupCallback + +__all__ = ["SelfGuideWarmupCallback"] diff --git a/sentence_transformers/callbacks/self_guide_callbacks.py b/sentence_transformers/callbacks/self_guide_callbacks.py new file mode 100644 index 000000000..ad45ce734 --- /dev/null +++ b/sentence_transformers/callbacks/self_guide_callbacks.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import logging + +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +logger = logging.getLogger(__name__) + + +class SelfGuideWarmupCallback(TrainerCallback): + """Disables self-guide false-negative filtering during warmup, then enables it. + + This callback holds a reference to a loss module that has a ``self_guide_filtering_active`` + attribute and toggles it from ``False`` to ``True`` once the warmup phase is over. + + Args: + loss: A loss module with a ``self_guide_filtering_active`` boolean attribute. + warmup_ratio: Fraction of total training steps during which filtering is disabled. + Defaults to 0.2 (i.e., filtering starts after the first 20% of training). + """ + + def __init__(self, loss, warmup_ratio: float = 0.2) -> None: + super().__init__() + self.loss = loss + self.warmup_ratio = warmup_ratio + self.warmup_steps: int | None = None + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + self.warmup_steps = int(state.max_steps * self.warmup_ratio) + self.loss.self_guide_filtering_active = False + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + if self.warmup_steps is None: + return + should_be_active = state.global_step >= self.warmup_steps + if should_be_active != self.loss.self_guide_filtering_active: + self.loss.self_guide_filtering_active = should_be_active + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: dict | None = None, + **kwargs, + ) -> None: + if logs is not None: + logs["self_guide_filtering_active"] = self.loss.self_guide_filtering_active diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index 1d36b9696..63df0c27c 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -70,7 +70,7 @@ class CachedGISTEmbedLoss(nn.Module): def __init__( self, model: SentenceTransformer, - guide: SentenceTransformer, + guide: SentenceTransformer | None = None, temperature: float = 0.01, mini_batch_size: int = 32, show_progress_bar: bool = False, @@ -79,6 +79,7 @@ def __init__( contrast_anchors: bool = True, contrast_positives: bool = True, gather_across_devices: bool = False, + self_guide_warmup_ratio: float = 0.0, ) -> None: """ This loss is a combination of :class:`GISTEmbedLoss` and :class:`CachedMultipleNegativesRankingLoss`. @@ -100,7 +101,10 @@ def __init__( Args: model: SentenceTransformer model - guide: SentenceTransformer model to guide the in-batch negative sample selection. + guide: SentenceTransformer model to guide the in-batch negative sample selection. If None, the model + itself is used as the guide (self-guided mode), which is more efficient as it requires only a single + forward pass. This is useful for self-guided approaches where the student model's own similarity + scores are used with a margin (e.g., ``margin=-1.0``) to filter negatives. temperature: Temperature parameter to scale the cosine similarities. mini_batch_size: Mini-batch size for the forward pass, this denotes how much memory is actually used during training and evaluation. The larger the mini-batch size, the more memory efficient the training is, but @@ -110,6 +114,7 @@ def __init__( margin_strategy: Strategy used for false negative filtering. One of {"absolute", "relative"}. margin: The margin value for filtering negatives. Defaults to 0.0, together with the "absolute" strategy, this only removes negatives that are more similar to the query than the positive is to the query. + For self-guided mode, a negative margin (e.g., ``margin=-1.0``) is recommended to keep most negatives. contrast_anchors: If True, include anchor-anchor pairs in the loss computation, resulting in the embeddings of the anchors being pushed further apart. Defaults to True, following the original GISTEmbed paper. contrast_positives: If True, include positive-positive pairs in the loss computation, resulting in the embeddings @@ -118,6 +123,10 @@ def __init__( gather_across_devices: If True, gather the embeddings across all devices before computing the loss. Recommended when training on multiple GPUs, as it allows for larger batch sizes, but it may slow down training due to communication overhead, and can potentially lead to out-of-memory errors. + self_guide_warmup_ratio: Fraction of total training steps during which self-guide filtering is + disabled (warmup phase). Only relevant when using self-guided mode (``guide=None``). When set to a + value > 0, the :class:`~sentence_transformers.callbacks.SelfGuideWarmupCallback` is automatically + added by the trainer. Defaults to 0.0 (no warmup, filtering is always active). References: - Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://huggingface.co/papers/1705.00652 @@ -147,7 +156,27 @@ def __init__( - Equivalent to :class:`GISTEmbedLoss`, but with caching that allows for much higher batch sizes Example: - :: + Self-guided mode:: + + from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses + from datasets import Dataset + + model = SentenceTransformer("microsoft/mpnet-base") + train_dataset = Dataset.from_dict({ + "anchor": ["It's nice weather outside today.", "He drove to work."], + "positive": ["It's so sunny.", "He took the car to the office."], + }) + # Self-guided: no guide parameter, model guides itself + loss = losses.CachedGISTEmbedLoss(model, mini_batch_size=64, margin=-1.0) + + trainer = SentenceTransformerTrainer( + model=model, + train_dataset=train_dataset, + loss=loss, + ) + trainer.train() + + With a separate guide model:: from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses from datasets import Dataset @@ -180,24 +209,46 @@ def __init__( "Consider using GISTEmbedLoss instead." ) self.model = model - self.guide = guide self.temperature = temperature self.similarity_fct = nn.CosineSimilarity(dim=-1) - if not hasattr(model, "tokenizer") or not hasattr(guide, "tokenizer"): - raise ValueError("Both the training model and the guiding model must have a tokenizer attribute.") - if not isinstance(model.tokenizer, PreTrainedTokenizerBase) or not isinstance( - guide.tokenizer, PreTrainedTokenizerBase - ): - raise ValueError( - "Both the training model and the guiding model must use a PreTrainedTokenizer from transformers." - ) + + # Self-guided mode: if guide is None, use the model itself as the guide + # This is more efficient as it requires only a single forward pass + if guide is None: + self.guide = model + self.is_self_guided = True + else: + self.guide = guide + self.is_self_guided = model is guide + + # Validate tokenizer requirements + if self.is_self_guided: + if not hasattr(model, "tokenizer"): + raise ValueError("The training model must have a tokenizer attribute.") + if not isinstance(model.tokenizer, PreTrainedTokenizerBase): + raise ValueError("The training model must use a PreTrainedTokenizer from transformers.") + else: + if not hasattr(model, "tokenizer") or not hasattr(self.guide, "tokenizer"): + raise ValueError("Both the training model and the guiding model must have a tokenizer attribute.") + if not isinstance(model.tokenizer, PreTrainedTokenizerBase) or not isinstance( + self.guide.tokenizer, PreTrainedTokenizerBase + ): + raise ValueError( + "Both the training model and the guiding model must use a PreTrainedTokenizer from transformers." + ) + self.mini_batch_size = mini_batch_size self.cache: list[list[Tensor]] | None = None self.random_states: list[list[RandContext]] | None = None self.show_progress_bar = show_progress_bar - self.must_retokenize = ( - model.tokenizer.vocab != guide.tokenizer.vocab or guide.max_seq_length < model.max_seq_length - ) + + # No need to retokenize if self-guided (same model, same tokenizer) + if self.is_self_guided: + self.must_retokenize = False + else: + self.must_retokenize = ( + model.tokenizer.vocab != guide.tokenizer.vocab or guide.max_seq_length < model.max_seq_length + ) if self.must_retokenize: self.tokenizer = model.tokenizer if margin_strategy not in ("absolute", "relative"): @@ -209,6 +260,13 @@ def __init__( self.gather_across_devices = gather_across_devices self.cross_entropy_loss = nn.CrossEntropyLoss() + self.self_guide_warmup_ratio = self_guide_warmup_ratio + if self.is_self_guided and self_guide_warmup_ratio > 0: + # Filtering starts disabled; the SelfGuideWarmupCallback enables it after warmup. + self.self_guide_filtering_active: bool = False + else: + self.self_guide_filtering_active: bool = True + def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor: return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0)) @@ -232,16 +290,21 @@ def embed_minibatch( with grad_context(): random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None reps = self.model(sentence_feature_minibatch)["sentence_embedding"] # (mbsz, hdim) - with torch.no_grad(): - if self.must_retokenize: - decoded = self.tokenizer.batch_decode( - sentence_feature_minibatch["input_ids"], skip_special_tokens=True - ) - sentence_feature_minibatch = self.guide.tokenize(decoded) - sentence_feature_minibatch = { - key: value.to(self.guide.device) for key, value in sentence_feature_minibatch.items() - } - guide_reps = self.guide(sentence_feature_minibatch)["sentence_embedding"] + + # If self-guided, reuse student embeddings as guide embeddings (no need for second forward pass) + if self.is_self_guided: + guide_reps = reps.detach() + else: + with torch.no_grad(): + if self.must_retokenize: + decoded = self.tokenizer.batch_decode( + sentence_feature_minibatch["input_ids"], skip_special_tokens=True + ) + sentence_feature_minibatch = self.guide.tokenize(decoded) + sentence_feature_minibatch = { + key: value.to(self.guide.device) for key, value in sentence_feature_minibatch.items() + } + guide_reps = self.guide(sentence_feature_minibatch)["sentence_embedding"] return reps, guide_reps, random_state @@ -352,14 +415,20 @@ def mask_false_negatives(guided_sim_mat, sim_mat, positive_mask: Tensor | None = positive_mask = torch.eye(*guided_ap_sim.shape, dtype=torch.bool, device=guided_ap_sim.device) positive_mask = positive_mask.roll(begin) - # Apply false negative suppression to each similarity matrix using guided similarity as anchor - ap_sim = mask_false_negatives(guided_ap_sim, ap_sim, positive_mask=positive_mask) # anchor-positive + # Apply false negative suppression to each similarity matrix using guided similarity as anchor. + # For self-guided mode, respect the warmup: skip filtering while self_guide_filtering_active is False. + # For external guide mode, always filter (self_guide_filtering_active is always True). + apply_filtering = not self.is_self_guided or self.self_guide_filtering_active + + if apply_filtering: + ap_sim = mask_false_negatives(guided_ap_sim, ap_sim, positive_mask=positive_mask) # anchor-positive scores = [ap_sim] if self.contrast_anchors: aa_sim = self.sim_matrix(anchors[begin:end], anchors) guided_aa_sim = self.sim_matrix(anchors_guide[begin:end], anchors_guide) - aa_sim = mask_false_negatives(guided_aa_sim, aa_sim) # anchor-anchor + if apply_filtering: + aa_sim = mask_false_negatives(guided_aa_sim, aa_sim) # anchor-anchor scores.append(aa_sim) if self.contrast_positives: @@ -369,7 +438,8 @@ def mask_false_negatives(guided_sim_mat, sim_mat, positive_mask: Tensor | None = guided_pp_sim = self.sim_matrix( candidates_guide[0][offset + begin : min(offset + end, offset + batch_size)], candidates_guide[0] ) - pp_sim = mask_false_negatives(guided_pp_sim, pp_sim) # positive-positive + if apply_filtering: + pp_sim = mask_false_negatives(guided_pp_sim, pp_sim) # positive-positive scores.append(pp_sim) # If there are negatives (len(candidates) > 1), process them @@ -377,7 +447,8 @@ def mask_false_negatives(guided_sim_mat, sim_mat, positive_mask: Tensor | None = for i in range(1, len(candidates)): # Start from 1 since the first is the positive neg_sim = self.sim_matrix(anchors[begin:end], candidates[i]) guided_neg_sim = self.sim_matrix(anchors_guide[begin:end], candidates_guide[i]) - neg_sim = mask_false_negatives(guided_neg_sim, neg_sim) + if apply_filtering: + neg_sim = mask_false_negatives(guided_neg_sim, neg_sim) scores.append(neg_sim) # anchor-negative # Concatenate all scores into a single tensor diff --git a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py index 54e6f20c1..6b120e280 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py @@ -85,6 +85,10 @@ def __init__( show_progress_bar: bool = False, hardness_mode: Literal["in_batch_negatives", "hard_negatives", "all_negatives"] | None = None, hardness_strength: float = 0.0, + self_guide: bool = False, + self_guide_margin: float = 0.0, + self_guide_margin_strategy: Literal["absolute", "relative"] = "absolute", + self_guide_warmup_ratio: float = 0.0, ) -> None: """ Boosted version of :class:`MultipleNegativesRankingLoss` (https://huggingface.co/papers/1705.00652) by GradCache (https://huggingface.co/papers/2101.06983). @@ -156,6 +160,21 @@ def __init__( Must be non-negative. Ignored when ``hardness_mode`` is ``None``. + self_guide: If True, enable self-guided false-negative filtering. The model's own similarity scores + are used to detect and suppress likely false negatives in the in-batch negatives before computing + the loss. This can improve training quality when the batch contains semantically similar samples + that are not explicitly paired. Only applied to the ``"query_to_doc"`` direction. + self_guide_margin: Margin for false-negative detection threshold. With ``self_guide_margin_strategy="absolute"``, + a negative with similarity above ``positive_sim - margin`` is considered a false negative. With + ``"relative"``, the threshold is ``positive_sim * (1 - margin)``. Defaults to 0.0. + self_guide_margin_strategy: Strategy for applying the margin. One of ``"absolute"`` or ``"relative"``. + Defaults to ``"absolute"``. + self_guide_warmup_ratio: Fraction of total training steps during which self-guide filtering is + disabled (warmup phase). Requires the :class:`~sentence_transformers.callbacks.SelfGuideWarmupCallback` + to be added to the trainer, which is done automatically by + :class:`~sentence_transformers.SentenceTransformerTrainer`. Defaults to 0.0 (no warmup, filtering + is always active). + References: - Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://huggingface.co/papers/1705.00652 - Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://huggingface.co/papers/2101.06983 @@ -254,6 +273,17 @@ def __init__( "effect. Set hardness_strength to a positive value to enable hardness weighting." ) + self.self_guide = self_guide + if self_guide: + if self_guide_margin_strategy not in ("absolute", "relative"): + raise ValueError("self_guide_margin_strategy must be 'absolute' or 'relative'.") + self.self_guide_margin = self_guide_margin + self.self_guide_margin_strategy = self_guide_margin_strategy + self.self_guide_warmup_ratio = self_guide_warmup_ratio + # If warmup_ratio == 0, filtering is always active (no warmup needed). + # Otherwise, it starts disabled and the SelfGuideWarmupCallback enables it. + self.self_guide_filtering_active: bool = self_guide_warmup_ratio == 0.0 + self.cache: list[list[Tensor]] | None = None self.random_states: list[list[RandContext]] | None = None @@ -382,6 +412,21 @@ def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) same_query_doc_mask = identity[local_batch].repeat(1, num_docs).bool() sim_matrices["doc_to_doc"].masked_fill_(same_query_doc_mask, -torch.inf) + # Self-guide false-negative filtering: use the model's own similarity scores to detect + # and mask likely false negatives. Applied only to query_to_doc before temperature scaling. + if self.self_guide and self.self_guide_filtering_active: + guide_sim = sim_matrices["query_to_doc"].detach() + positive_sim = guide_sim[row_indices, local_batch].unsqueeze(1) # (mbs, 1) + + if self.self_guide_margin_strategy == "absolute": + fn_mask = guide_sim > (positive_sim - self.self_guide_margin) + else: + fn_mask = guide_sim > (positive_sim * (1 - self.self_guide_margin)) + + # Protect true positives from masking + fn_mask[row_indices, local_batch] = False + sim_matrices["query_to_doc"] = sim_matrices["query_to_doc"].masked_fill(fn_mask, -torch.inf) + # Compute hardness penalties on the unscaled (raw cosine) similarities (Lan et al. 2025, Eq. 5). # penalty = alpha * stop_grad(cos_sim), making harder negatives contribute more to the # softmax denominator. Computed before temperature scaling so no rescaling is needed. @@ -476,7 +521,7 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor return loss def get_config_dict(self) -> dict[str, Any]: - return { + config = { "scale": self.scale, "similarity_fct": self.similarity_fct.__name__, "mini_batch_size": self.mini_batch_size, @@ -486,6 +531,12 @@ def get_config_dict(self) -> dict[str, Any]: "hardness_mode": self.hardness_mode, "hardness_strength": self.hardness_strength, } + if self.self_guide: + config["self_guide"] = self.self_guide + config["self_guide_margin"] = self.self_guide_margin + config["self_guide_margin_strategy"] = self.self_guide_margin_strategy + config["self_guide_warmup_ratio"] = self.self_guide_warmup_ratio + return config @property def temperature(self) -> float: diff --git a/sentence_transformers/losses/GISTEmbedLoss.py b/sentence_transformers/losses/GISTEmbedLoss.py index dc2e06871..aee957673 100644 --- a/sentence_transformers/losses/GISTEmbedLoss.py +++ b/sentence_transformers/losses/GISTEmbedLoss.py @@ -16,13 +16,14 @@ class GISTEmbedLoss(nn.Module): def __init__( self, model: SentenceTransformer, - guide: SentenceTransformer, + guide: SentenceTransformer | None = None, temperature: float = 0.01, margin_strategy: Literal["absolute", "relative"] = "absolute", margin: float = 0.0, contrast_anchors: bool = True, contrast_positives: bool = True, gather_across_devices: bool = False, + self_guide_warmup_ratio: float = 0.0, ) -> None: """ This loss is used to train a SentenceTransformer model using the GISTEmbed algorithm. @@ -38,12 +39,16 @@ def __init__( Args: model: SentenceTransformer model based on a `transformers` model. - guide: SentenceTransformer model to guide the in-batch negative sample selection. + guide: SentenceTransformer model to guide the in-batch negative sample selection. If None, the model + itself is used as the guide (self-guided mode), which is more efficient as it requires only a single + forward pass. This is useful for self-guided approaches where the student model's own similarity + scores are used with a margin (e.g., ``margin=-1.0``) to filter negatives. temperature: Temperature parameter to scale the cosine similarities. Inverse of the ``scale`` parameter in :class:`MultipleNegativesRankingLoss`. margin_strategy: Strategy used for false negative filtering. One of {"absolute", "relative"}. margin: The margin value for filtering negatives. Defaults to 0.0, together with the "absolute" strategy, this only removes negatives that are more similar to the query than the positive is to the query. + For self-guided mode, a negative margin (e.g., ``margin=-1.0``) is recommended to keep most negatives. contrast_anchors: If True, include anchor-anchor pairs in the loss computation, resulting in the embeddings of the anchors being pushed further apart. Defaults to True, following the original GISTEmbed paper. contrast_positives: If True, include positive-positive pairs in the loss computation, resulting in the embeddings @@ -52,6 +57,10 @@ def __init__( gather_across_devices: If True, gather the embeddings across all devices before computing the loss. Recommended when training on multiple GPUs, as it allows for larger batch sizes, but it may slow down training due to communication overhead, and can potentially lead to out-of-memory errors. + self_guide_warmup_ratio: Fraction of total training steps during which self-guide filtering is + disabled (warmup phase). Only relevant when using self-guided mode (``guide=None``). When set to a + value > 0, the :class:`~sentence_transformers.callbacks.SelfGuideWarmupCallback` is automatically + added by the trainer. Defaults to 0.0 (no warmup, filtering is always active). References: - For further details, see: https://huggingface.co/papers/2402.16829 @@ -79,7 +88,27 @@ def __init__( a stronger training signal at the cost of some training overhead. Example: - :: + Self-guided mode:: + + from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses + from datasets import Dataset + + model = SentenceTransformer("microsoft/mpnet-base") + train_dataset = Dataset.from_dict({ + "anchor": ["It's nice weather outside today.", "He drove to work."], + "positive": ["It's so sunny.", "He took the car to the office."], + }) + # Self-guided: no guide parameter, model guides itself + loss = losses.GISTEmbedLoss(model, margin=-1.0) + + trainer = SentenceTransformerTrainer( + model=model, + train_dataset=train_dataset, + loss=loss, + ) + trainer.train() + + With a separate guide model:: from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses from datasets import Dataset @@ -101,20 +130,37 @@ def __init__( """ super().__init__() self.model = model - self.guide = guide self.temperature = temperature self.similarity_fct = nn.CosineSimilarity(dim=-1) - if not hasattr(model, "tokenizer") or not hasattr(guide, "tokenizer"): - raise ValueError("Both the training model and the guiding model must have a tokenizer attribute.") - if not isinstance(model.tokenizer, PreTrainedTokenizerBase) or not isinstance( - guide.tokenizer, PreTrainedTokenizerBase - ): - raise ValueError( - "Both the training model and the guiding model must use a PreTrainedTokenizer from transformers." + + # Self-guided mode: if guide is None, use the model itself as the guide + if guide is None: + self.guide = model + self.is_self_guided = True + else: + self.guide = guide + self.is_self_guided = model is guide + + # Validate tokenizer requirements + if self.is_self_guided: + if not hasattr(model, "tokenizer"): + raise ValueError("The training model must have a tokenizer attribute.") + if not isinstance(model.tokenizer, PreTrainedTokenizerBase): + raise ValueError("The training model must use a PreTrainedTokenizer from transformers.") + self.must_retokenize = False + else: + if not hasattr(model, "tokenizer") or not hasattr(self.guide, "tokenizer"): + raise ValueError("Both the training model and the guiding model must have a tokenizer attribute.") + if not isinstance(model.tokenizer, PreTrainedTokenizerBase) or not isinstance( + self.guide.tokenizer, PreTrainedTokenizerBase + ): + raise ValueError( + "Both the training model and the guiding model must use a PreTrainedTokenizer from transformers." + ) + self.must_retokenize = ( + model.tokenizer.get_vocab() != self.guide.tokenizer.get_vocab() + or self.guide.max_seq_length < model.max_seq_length ) - self.must_retokenize = ( - model.tokenizer.get_vocab() != guide.tokenizer.get_vocab() or guide.max_seq_length < model.max_seq_length - ) if self.must_retokenize: self.tokenizer = self.model.tokenizer @@ -133,26 +179,38 @@ def __init__( self.gather_across_devices = gather_across_devices self.cross_entropy_loss = nn.CrossEntropyLoss() + self.self_guide_warmup_ratio = self_guide_warmup_ratio + if self.is_self_guided and self_guide_warmup_ratio > 0: + # Filtering starts disabled; the SelfGuideWarmupCallback enables it after warmup. + self.self_guide_filtering_active: bool = False + else: + self.self_guide_filtering_active: bool = True + def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor: return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0)) def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] - with torch.no_grad(): - if self.must_retokenize: - decoded = [ - self.tokenizer.batch_decode(sentence_feature["input_ids"], skip_special_tokens=True) - for sentence_feature in sentence_features - ] - sentence_features = [self.guide.tokenize(sentences) for sentences in decoded] - sentence_features = [ - {key: value.to(self.guide.device) for key, value in sentence_feature.items()} - for sentence_feature in sentence_features - ] - guide_embeddings = [ - self.guide(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features - ] + # If self-guided, reuse student embeddings as guide embeddings (no second forward pass) + if self.is_self_guided: + guide_embeddings = [emb.detach() for emb in embeddings] + else: + with torch.no_grad(): + if self.must_retokenize: + decoded = [ + self.tokenizer.batch_decode(sentence_feature["input_ids"], skip_special_tokens=True) + for sentence_feature in sentence_features + ] + sentence_features = [self.guide.tokenize(sentences) for sentences in decoded] + sentence_features = [ + {key: value.to(self.guide.device) for key, value in sentence_feature.items()} + for sentence_feature in sentence_features + ] + + guide_embeddings = [ + self.guide(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features + ] negative = None negative_guide = None @@ -208,27 +266,35 @@ def mask_false_negatives(guided_sim_mat, sim_mat, positive_mask: Tensor | None = # Create a mask to protect true positive pairs in the anchor-positive matrix (i.e., diagonal elements) positive_mask = torch.eye(*guided_ap_sim.shape, dtype=torch.bool, device=guided_ap_sim.device) - # Apply false negative suppression to each similarity matrix using guided similarity as anchor - ap_sim = mask_false_negatives(guided_ap_sim, ap_sim, positive_mask=positive_mask) # anchor-positive + # Apply false negative suppression to each similarity matrix using guided similarity as anchor. + # For self-guided mode, respect the warmup: skip filtering while self_guide_filtering_active is False. + # For external guide mode, always filter (self_guide_filtering_active is always True). + apply_filtering = not self.is_self_guided or self.self_guide_filtering_active + + if apply_filtering: + ap_sim = mask_false_negatives(guided_ap_sim, ap_sim, positive_mask=positive_mask) # anchor-positive scores = [ap_sim] if self.contrast_anchors: aa_sim = self.sim_matrix(anchor, anchor) guided_aa_sim = self.sim_matrix(anchor_guide, anchor_guide) - aa_sim = mask_false_negatives(guided_aa_sim, aa_sim) # anchor-anchor + if apply_filtering: + aa_sim = mask_false_negatives(guided_aa_sim, aa_sim) # anchor-anchor scores.append(aa_sim) if self.contrast_positives: pp_sim = self.sim_matrix(positive[offset : offset + batch_size], positive) guided_pp_sim = self.sim_matrix(positive_guide[offset : offset + batch_size], positive_guide) - pp_sim = mask_false_negatives(guided_pp_sim, pp_sim) # positive-positive + if apply_filtering: + pp_sim = mask_false_negatives(guided_pp_sim, pp_sim) # positive-positive scores.append(pp_sim) # Handle the case where we have a negative sample if negative is not None: an_sim = self.sim_matrix(anchor, negative) guided_an_sim = self.sim_matrix(anchor_guide, negative_guide) - an_sim = mask_false_negatives(guided_an_sim, an_sim) # anchor-negative + if apply_filtering: + an_sim = mask_false_negatives(guided_an_sim, an_sim) # anchor-negative scores.append(an_sim) scores = torch.cat(scores, dim=1) / self.temperature diff --git a/sentence_transformers/losses/MultipleNegativesRankingLoss.py b/sentence_transformers/losses/MultipleNegativesRankingLoss.py index dd5f21883..063255cca 100644 --- a/sentence_transformers/losses/MultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/MultipleNegativesRankingLoss.py @@ -28,6 +28,10 @@ def __init__( partition_mode: Literal["joint", "per_direction"] = "joint", hardness_mode: Literal["in_batch_negatives", "hard_negatives", "all_negatives"] | None = None, hardness_strength: float = 0.0, + self_guide: bool = False, + self_guide_margin: float = 0.0, + self_guide_margin_strategy: Literal["absolute", "relative"] = "absolute", + self_guide_warmup_ratio: float = 0.0, ) -> None: """ Given a dataset of (anchor, positive) pairs, (anchor, positive, negative) triplets, or (anchor, positive, negative_1, ..., negative_n) @@ -105,6 +109,21 @@ def __init__( Must be non-negative. Ignored when ``hardness_mode`` is ``None``. + self_guide: If True, enable self-guided false-negative filtering. The model's own similarity scores + are used to detect and suppress likely false negatives in the in-batch negatives before computing + the loss. This can improve training quality when the batch contains semantically similar samples + that are not explicitly paired. Only applied to the ``"query_to_doc"`` direction. + self_guide_margin: Margin for false-negative detection threshold. With ``self_guide_margin_strategy="absolute"``, + a negative with similarity above ``positive_sim - margin`` is considered a false negative. With + ``"relative"``, the threshold is ``positive_sim * (1 - margin)``. Defaults to 0.0. + self_guide_margin_strategy: Strategy for applying the margin. One of ``"absolute"`` or ``"relative"``. + Defaults to ``"absolute"``. + self_guide_warmup_ratio: Fraction of total training steps during which self-guide filtering is + disabled (warmup phase). Requires the :class:`~sentence_transformers.callbacks.SelfGuideWarmupCallback` + to be added to the trainer, which is done automatically by + :class:`~sentence_transformers.SentenceTransformerTrainer`. Defaults to 0.0 (no warmup, filtering + is always active). + Requirements: 1. (anchor, positive) pairs, (anchor, positive, negative) triplets, or (anchor, positive, negative_1, ..., negative_n) n-tuples @@ -227,6 +246,17 @@ def __init__( "effect. Set hardness_strength to a positive value to enable hardness weighting." ) + self.self_guide = self_guide + if self_guide: + if self_guide_margin_strategy not in ("absolute", "relative"): + raise ValueError("self_guide_margin_strategy must be 'absolute' or 'relative'.") + self.self_guide_margin = self_guide_margin + self.self_guide_margin_strategy = self_guide_margin_strategy + self.self_guide_warmup_ratio = self_guide_warmup_ratio + # If warmup_ratio == 0, filtering is always active (no warmup needed). + # Otherwise, it starts disabled and the SelfGuideWarmupCallback enables it. + self.self_guide_filtering_active: bool = self_guide_warmup_ratio == 0.0 + def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: # Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives) embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] @@ -282,6 +312,22 @@ def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) same_query_doc_mask = same_query_doc_mask.repeat(1, len(docs)).bool() sim_matrices["doc_to_doc"].masked_fill_(same_query_doc_mask, -torch.inf) + # Self-guide false-negative filtering: use the model's own similarity scores to detect + # and mask likely false negatives (negatives that are more similar to the query than the + # positive). Applied only to the query_to_doc direction before temperature scaling. + if self.self_guide and self.self_guide_filtering_active: + guide_sim = sim_matrices["query_to_doc"].detach() + positive_sim = guide_sim[row_indices, local_indices].unsqueeze(1) # (bs, 1) + + if self.self_guide_margin_strategy == "absolute": + fn_mask = guide_sim > (positive_sim - self.self_guide_margin) + else: + fn_mask = guide_sim > (positive_sim * (1 - self.self_guide_margin)) + + # Protect true positives from masking + fn_mask[row_indices, local_indices] = False + sim_matrices["query_to_doc"] = sim_matrices["query_to_doc"].masked_fill(fn_mask, -torch.inf) + # Compute hardness penalties on the unscaled (raw cosine) similarities (Lan et al. 2025, Eq. 5). # penalty = alpha * stop_grad(cos_sim), making harder negatives contribute more to the # softmax denominator. Computed before temperature scaling so no rescaling is needed. @@ -337,7 +383,7 @@ def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) return loss def get_config_dict(self) -> dict[str, Any]: - return { + config = { "scale": self.scale, "similarity_fct": self.similarity_fct.__name__, "gather_across_devices": self.gather_across_devices, @@ -346,6 +392,12 @@ def get_config_dict(self) -> dict[str, Any]: "hardness_mode": self.hardness_mode, "hardness_strength": self.hardness_strength, } + if self.self_guide: + config["self_guide"] = self.self_guide + config["self_guide_margin"] = self.self_guide_margin + config["self_guide_margin_strategy"] = self.self_guide_margin_strategy + config["self_guide_warmup_ratio"] = self.self_guide_warmup_ratio + return config @property def temperature(self) -> float: diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 6e12c2343..d07ab519c 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -388,8 +388,23 @@ def prepare_loss( model: SentenceTransformer, ) -> torch.nn.Module: if isinstance(loss, torch.nn.Module): - return loss.to(model.device) - return loss(model).to(model.device) + loss = loss.to(model.device) + else: + loss = loss(model).to(model.device) + + # Auto-add SelfGuideWarmupCallback if the loss has self_guide_warmup_ratio > 0 + warmup_ratio = getattr(loss, "self_guide_warmup_ratio", 0.0) + if warmup_ratio and warmup_ratio > 0: + from sentence_transformers.callbacks import SelfGuideWarmupCallback + + has_callback = any( + isinstance(cb, SelfGuideWarmupCallback) and cb.loss is loss for cb in self.callback_handler.callbacks + ) + if not has_callback: + callback = SelfGuideWarmupCallback(loss=loss, warmup_ratio=warmup_ratio) + self.add_callback(callback) + + return loss def compute_loss( self,