Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sentence_transformers/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from sentence_transformers.callbacks.self_guide_callbacks import SelfGuideWarmupCallback

__all__ = ["SelfGuideWarmupCallback"]
61 changes: 61 additions & 0 deletions sentence_transformers/callbacks/self_guide_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

import logging

from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments

logger = logging.getLogger(__name__)


class SelfGuideWarmupCallback(TrainerCallback):
"""Disables self-guide false-negative filtering during warmup, then enables it.

This callback holds a reference to a loss module that has a ``self_guide_filtering_active``
attribute and toggles it from ``False`` to ``True`` once the warmup phase is over.

Args:
loss: A loss module with a ``self_guide_filtering_active`` boolean attribute.
warmup_ratio: Fraction of total training steps during which filtering is disabled.
Defaults to 0.2 (i.e., filtering starts after the first 20% of training).
"""

def __init__(self, loss, warmup_ratio: float = 0.2) -> None:
super().__init__()
self.loss = loss
self.warmup_ratio = warmup_ratio
self.warmup_steps: int | None = None

def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
self.warmup_steps = int(state.max_steps * self.warmup_ratio)
self.loss.self_guide_filtering_active = False

def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
if self.warmup_steps is None:
return
should_be_active = state.global_step >= self.warmup_steps
if should_be_active != self.loss.self_guide_filtering_active:
self.loss.self_guide_filtering_active = should_be_active

def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: dict | None = None,
**kwargs,
) -> None:
if logs is not None:
logs["self_guide_filtering_active"] = self.loss.self_guide_filtering_active
131 changes: 101 additions & 30 deletions sentence_transformers/losses/CachedGISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class CachedGISTEmbedLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
guide: SentenceTransformer,
guide: SentenceTransformer | None = None,
temperature: float = 0.01,
mini_batch_size: int = 32,
show_progress_bar: bool = False,
Expand All @@ -79,6 +79,7 @@ def __init__(
contrast_anchors: bool = True,
contrast_positives: bool = True,
gather_across_devices: bool = False,
self_guide_warmup_ratio: float = 0.0,
) -> None:
"""
This loss is a combination of :class:`GISTEmbedLoss` and :class:`CachedMultipleNegativesRankingLoss`.
Expand All @@ -100,7 +101,10 @@ def __init__(

Args:
model: SentenceTransformer model
guide: SentenceTransformer model to guide the in-batch negative sample selection.
guide: SentenceTransformer model to guide the in-batch negative sample selection. If None, the model
itself is used as the guide (self-guided mode), which is more efficient as it requires only a single
forward pass. This is useful for self-guided approaches where the student model's own similarity
scores are used with a margin (e.g., ``margin=-1.0``) to filter negatives.
temperature: Temperature parameter to scale the cosine similarities.
mini_batch_size: Mini-batch size for the forward pass, this denotes how much memory is actually used during
training and evaluation. The larger the mini-batch size, the more memory efficient the training is, but
Expand All @@ -110,6 +114,7 @@ def __init__(
margin_strategy: Strategy used for false negative filtering. One of {"absolute", "relative"}.
margin: The margin value for filtering negatives. Defaults to 0.0, together with the "absolute" strategy,
this only removes negatives that are more similar to the query than the positive is to the query.
For self-guided mode, a negative margin (e.g., ``margin=-1.0``) is recommended to keep most negatives.
contrast_anchors: If True, include anchor-anchor pairs in the loss computation, resulting in the embeddings
of the anchors being pushed further apart. Defaults to True, following the original GISTEmbed paper.
contrast_positives: If True, include positive-positive pairs in the loss computation, resulting in the embeddings
Expand All @@ -118,6 +123,10 @@ def __init__(
gather_across_devices: If True, gather the embeddings across all devices before computing the loss.
Recommended when training on multiple GPUs, as it allows for larger batch sizes, but it may slow down
training due to communication overhead, and can potentially lead to out-of-memory errors.
self_guide_warmup_ratio: Fraction of total training steps during which self-guide filtering is
disabled (warmup phase). Only relevant when using self-guided mode (``guide=None``). When set to a
value > 0, the :class:`~sentence_transformers.callbacks.SelfGuideWarmupCallback` is automatically
added by the trainer. Defaults to 0.0 (no warmup, filtering is always active).

References:
- Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://huggingface.co/papers/1705.00652
Expand Down Expand Up @@ -147,7 +156,27 @@ def __init__(
- Equivalent to :class:`GISTEmbedLoss`, but with caching that allows for much higher batch sizes

Example:
::
Self-guided mode::

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset

model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
# Self-guided: no guide parameter, model guides itself
loss = losses.CachedGISTEmbedLoss(model, mini_batch_size=64, margin=-1.0)

trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()

With a separate guide model::

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset
Expand Down Expand Up @@ -180,24 +209,46 @@ def __init__(
"Consider using GISTEmbedLoss instead."
)
self.model = model
self.guide = guide
self.temperature = temperature
self.similarity_fct = nn.CosineSimilarity(dim=-1)
if not hasattr(model, "tokenizer") or not hasattr(guide, "tokenizer"):
raise ValueError("Both the training model and the guiding model must have a tokenizer attribute.")
if not isinstance(model.tokenizer, PreTrainedTokenizerBase) or not isinstance(
guide.tokenizer, PreTrainedTokenizerBase
):
raise ValueError(
"Both the training model and the guiding model must use a PreTrainedTokenizer from transformers."
)

# Self-guided mode: if guide is None, use the model itself as the guide
# This is more efficient as it requires only a single forward pass
if guide is None:
self.guide = model
self.is_self_guided = True
else:
self.guide = guide
self.is_self_guided = model is guide

# Validate tokenizer requirements
if self.is_self_guided:
if not hasattr(model, "tokenizer"):
raise ValueError("The training model must have a tokenizer attribute.")
if not isinstance(model.tokenizer, PreTrainedTokenizerBase):
raise ValueError("The training model must use a PreTrainedTokenizer from transformers.")
else:
if not hasattr(model, "tokenizer") or not hasattr(self.guide, "tokenizer"):
raise ValueError("Both the training model and the guiding model must have a tokenizer attribute.")
if not isinstance(model.tokenizer, PreTrainedTokenizerBase) or not isinstance(
self.guide.tokenizer, PreTrainedTokenizerBase
):
raise ValueError(
"Both the training model and the guiding model must use a PreTrainedTokenizer from transformers."
)

self.mini_batch_size = mini_batch_size
self.cache: list[list[Tensor]] | None = None
self.random_states: list[list[RandContext]] | None = None
self.show_progress_bar = show_progress_bar
self.must_retokenize = (
model.tokenizer.vocab != guide.tokenizer.vocab or guide.max_seq_length < model.max_seq_length
)

# No need to retokenize if self-guided (same model, same tokenizer)
if self.is_self_guided:
self.must_retokenize = False
else:
self.must_retokenize = (
model.tokenizer.vocab != guide.tokenizer.vocab or guide.max_seq_length < model.max_seq_length
)
if self.must_retokenize:
self.tokenizer = model.tokenizer
if margin_strategy not in ("absolute", "relative"):
Expand All @@ -209,6 +260,13 @@ def __init__(
self.gather_across_devices = gather_across_devices
self.cross_entropy_loss = nn.CrossEntropyLoss()

self.self_guide_warmup_ratio = self_guide_warmup_ratio
if self.is_self_guided and self_guide_warmup_ratio > 0:
# Filtering starts disabled; the SelfGuideWarmupCallback enables it after warmup.
self.self_guide_filtering_active: bool = False
else:
self.self_guide_filtering_active: bool = True

def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor:
return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0))

Expand All @@ -232,16 +290,21 @@ def embed_minibatch(
with grad_context():
random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None
reps = self.model(sentence_feature_minibatch)["sentence_embedding"] # (mbsz, hdim)
with torch.no_grad():
if self.must_retokenize:
decoded = self.tokenizer.batch_decode(
sentence_feature_minibatch["input_ids"], skip_special_tokens=True
)
sentence_feature_minibatch = self.guide.tokenize(decoded)
sentence_feature_minibatch = {
key: value.to(self.guide.device) for key, value in sentence_feature_minibatch.items()
}
guide_reps = self.guide(sentence_feature_minibatch)["sentence_embedding"]

# If self-guided, reuse student embeddings as guide embeddings (no need for second forward pass)
if self.is_self_guided:
guide_reps = reps.detach()
else:
with torch.no_grad():
if self.must_retokenize:
decoded = self.tokenizer.batch_decode(
sentence_feature_minibatch["input_ids"], skip_special_tokens=True
)
sentence_feature_minibatch = self.guide.tokenize(decoded)
sentence_feature_minibatch = {
key: value.to(self.guide.device) for key, value in sentence_feature_minibatch.items()
}
guide_reps = self.guide(sentence_feature_minibatch)["sentence_embedding"]

return reps, guide_reps, random_state

Expand Down Expand Up @@ -352,14 +415,20 @@ def mask_false_negatives(guided_sim_mat, sim_mat, positive_mask: Tensor | None =
positive_mask = torch.eye(*guided_ap_sim.shape, dtype=torch.bool, device=guided_ap_sim.device)
positive_mask = positive_mask.roll(begin)

# Apply false negative suppression to each similarity matrix using guided similarity as anchor
ap_sim = mask_false_negatives(guided_ap_sim, ap_sim, positive_mask=positive_mask) # anchor-positive
# Apply false negative suppression to each similarity matrix using guided similarity as anchor.
# For self-guided mode, respect the warmup: skip filtering while self_guide_filtering_active is False.
# For external guide mode, always filter (self_guide_filtering_active is always True).
apply_filtering = not self.is_self_guided or self.self_guide_filtering_active

if apply_filtering:
ap_sim = mask_false_negatives(guided_ap_sim, ap_sim, positive_mask=positive_mask) # anchor-positive
scores = [ap_sim]

if self.contrast_anchors:
aa_sim = self.sim_matrix(anchors[begin:end], anchors)
guided_aa_sim = self.sim_matrix(anchors_guide[begin:end], anchors_guide)
aa_sim = mask_false_negatives(guided_aa_sim, aa_sim) # anchor-anchor
if apply_filtering:
aa_sim = mask_false_negatives(guided_aa_sim, aa_sim) # anchor-anchor
scores.append(aa_sim)

if self.contrast_positives:
Expand All @@ -369,15 +438,17 @@ def mask_false_negatives(guided_sim_mat, sim_mat, positive_mask: Tensor | None =
guided_pp_sim = self.sim_matrix(
candidates_guide[0][offset + begin : min(offset + end, offset + batch_size)], candidates_guide[0]
)
pp_sim = mask_false_negatives(guided_pp_sim, pp_sim) # positive-positive
if apply_filtering:
pp_sim = mask_false_negatives(guided_pp_sim, pp_sim) # positive-positive
scores.append(pp_sim)

# If there are negatives (len(candidates) > 1), process them
if len(candidates) > 1:
for i in range(1, len(candidates)): # Start from 1 since the first is the positive
neg_sim = self.sim_matrix(anchors[begin:end], candidates[i])
guided_neg_sim = self.sim_matrix(anchors_guide[begin:end], candidates_guide[i])
neg_sim = mask_false_negatives(guided_neg_sim, neg_sim)
if apply_filtering:
neg_sim = mask_false_negatives(guided_neg_sim, neg_sim)
scores.append(neg_sim) # anchor-negative

# Concatenate all scores into a single tensor
Expand Down
Loading