Skip to content

[feat] Add EmbedDistillLoss#3665

Open
yjoonjang wants to merge 7 commits intohuggingface:mainfrom
yjoonjang:feat/add_embeddistill_loss
Open

[feat] Add EmbedDistillLoss#3665
yjoonjang wants to merge 7 commits intohuggingface:mainfrom
yjoonjang:feat/add_embeddistill_loss

Conversation

@yjoonjang
Copy link
Contributor

@yjoonjang yjoonjang commented Feb 18, 2026

Hello, @tomaarsen !

Pull Request overview

  • Introduce EmbedDistillLoss, an embedding-level knowledge distillation loss based on the EmbedDistill paper.

Details

This PR adds a new loss function EmbedDistillLoss that minimizes the distance between student and teacher model embeddings. Unlike score-based distillation (e.g., MarginMSELoss, DistillKLDivLoss), this approach directly aligns the embedding spaces, which provides a stronger learning signal for the student model following the geometric distillation approach from Kim et al., 2023.

Key features

Two operating modes:

  • On-the-fly mode: Teacher embeddings are computed during training (convenient, requires more GPU memory)
  • Pre-computed mode: Teacher embeddings are pre-computed and passed as labels (efficient for large-scale training)

Three distance metrics:

  • cosine (default) — cosine distance: 1 - cos_sim(student, teacher)
  • l2 — Euclidean distance: ||student - teacher||_2
  • mse — mean squared error

Learnable projection layer:

  • When student and teacher have different embedding dimensions, a linear projection ψ(z) = Wz + b maps student embeddings to the teacher dimension (It is proved that the direction of the projection (student -> teacher) better by Jina Embeddings v5 Paper)
  • save_projection() / load_projection() methods for persisting the learned projection weights

Automatic retokenization:

  • Handles tokenizer mismatches between student and teacher models (e.g., different vocabularies or max sequence lengths)

Why cosine as default?

The original EmbedDistill paper uses L2 distance, but recent work on embedding distillation such as Jina Embeddings v5 reports strong results with cosine-based alignment. My experiments below confirm that cosine distance significantly outperforms L2 and MSE for this task.

Experiment setup

  • Student: microsoft/mpnet-base (768d, with projection layer)
  • Teacher: BAAI/bge-m3 (1024d)
  • Dataset: sentence-transformers/all-nli (triplet), 100K training samples
  • Evaluation: TripletEvaluator on all-nli test set (cosine accuracy)
  • Training: 2 epoch, batch size 128, lr 2e-5, warmup 10%, bf16
  • wandb: https://wandb.ai/yjoonjang/embed_distill_test
Training test code details
"""
EmbedDistillLoss vs MultipleNegativesRankingLoss comparison test.

Usage:
	# Single GPU
	python test/test_embed_distill_loss.py

	# Multi GPU
	torchrun --nproc_per_node=2 test/test_embed_distill_loss.py
"""

import os
import time
import threading

import torch
import wandb
from datasets import load_dataset
from transformers.integrations import WandbCallback

from sentence_transformers import (
	SentenceTransformer,
	SentenceTransformerTrainer,
	SentenceTransformerTrainingArguments,
	losses,
)
import custom_ls
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.training_args import BatchSamplers


class WandbCallbackSafeFinish(WandbCallback):
	"""WandbCallback that finishes wandb run with a timeout to avoid hanging."""

	FINISH_TIMEOUT = 30

	def on_train_end(self, args, state, control, **kwargs):
		if self._wandb.run is not None:
			thread = threading.Thread(target=self._wandb.finish)
			thread.daemon = True
			thread.start()
			thread.join(timeout=self.FINISH_TIMEOUT)
			# Force-clear run reference so next trainer creates a fresh run
			if self._wandb.run is not None:
				self._wandb.run = None

# ──────────────────────────────────────────────
# Config
# ──────────────────────────────────────────────
os.environ["WANDB_PROJECT"] = "embed_distill_test"
STUDENT_MODEL = "microsoft/mpnet-base"
TEACHER_MODEL = "BAAI/bge-m3"
NUM_TRAIN_SAMPLES = 100_000
NUM_EPOCHS = 2
BATCH_SIZE = 128
LEARNING_RATE = 2e-5
OUTPUT_BASE = "./output2/embed_distill_test"

STUDENT_MODEL_KWARGS = {}
TEACHER_MODEL_KWARGS = {}


def get_training_args(output_dir: str, run_name: str) -> SentenceTransformerTrainingArguments:
	return SentenceTransformerTrainingArguments(
		output_dir=output_dir,
		num_train_epochs=NUM_EPOCHS,
		per_device_train_batch_size=BATCH_SIZE,
		per_device_eval_batch_size=64,
		learning_rate=LEARNING_RATE,
		warmup_ratio=0.1,
		fp16=False,
		bf16=True,
		batch_sampler=BatchSamplers.NO_DUPLICATES,
		eval_strategy="steps",
		eval_steps=0.2,
        eval_on_start=True,
		save_strategy="no",
		logging_first_step=True,
		logging_steps=100,
		report_to="wandb",
		run_name=run_name,
	)


def swap_wandb_callback(trainer: SentenceTransformerTrainer) -> None:
	"""Replace the default WandbCallback with one that finishes with a timeout."""
	trainer.remove_callback(WandbCallback)
	trainer.add_callback(WandbCallbackSafeFinish())


def main():
	# ──────────────────────────────────────────────
	# Data
	# ──────────────────────────────────────────────
	print("Loading dataset...")
	dataset = load_dataset("sentence-transformers/all-nli", "triplet")
	train_dataset = dataset["train"].select(range(NUM_TRAIN_SAMPLES))
	eval_dataset = dataset["dev"]
	test_dataset = dataset["test"]

	print(f"Train: {len(train_dataset)}, Dev: {len(eval_dataset)}, Test: {len(test_dataset)}")

	dev_evaluator = TripletEvaluator(
		anchors=eval_dataset["anchor"],
		positives=eval_dataset["positive"],
		negatives=eval_dataset["negative"],
		name="all-nli-dev",
	)
	test_evaluator = TripletEvaluator(
		anchors=test_dataset["anchor"],
		positives=test_dataset["positive"],
		negatives=test_dataset["negative"],
		name="all-nli-test",
	)

	results = {}

	# ──────────────────────────────────────────────
	# Experiment 1: MNRL Baseline
	# ──────────────────────────────────────────────
	print("\n" + "=" * 60)
	print("Experiment 1: MultipleNegativesRankingLoss (Baseline)")
	print("=" * 60)

	model_1 = SentenceTransformer(STUDENT_MODEL, model_kwargs=STUDENT_MODEL_KWARGS)
	loss_1 = losses.MultipleNegativesRankingLoss(model_1)

	trainer_1 = SentenceTransformerTrainer(
		model=model_1,
		args=get_training_args(f"{OUTPUT_BASE}/exp1_mnrl", run_name="MNRL"),
		train_dataset=train_dataset,
		eval_dataset=eval_dataset,
		loss=loss_1,
		evaluator=dev_evaluator,
	)
	swap_wandb_callback(trainer_1)

	start = time.time()
	trainer_1.train()
	time_1 = time.time() - start

	test_result_1 = test_evaluator(model_1)
	acc_1 = test_result_1["all-nli-test_cosine_accuracy"]
	results["MNRL"] = {"accuracy": acc_1, "time": time_1}
	print(f"\nTest accuracy: {acc_1:.4f}")
	print(f"Training time: {time_1:.1f}s")

	# ──────────────────────────────────────────────
	# Experiment 2: EmbedDistillLoss (on-the-fly, L2)
	# ──────────────────────────────────────────────
	print("\n" + "=" * 60)
	print("Experiment 2: EmbedDistillLoss (on-the-fly, L2)")
	print("=" * 60)

	model_2 = SentenceTransformer(STUDENT_MODEL, model_kwargs=STUDENT_MODEL_KWARGS)
	teacher_model = SentenceTransformer(TEACHER_MODEL, model_kwargs=TEACHER_MODEL_KWARGS)
	teacher_model.max_seq_length = 512
	loss_2 = custom_ls.EmbedDistillLoss(model_2, teacher_model=teacher_model, distance_metric="l2", add_projection_layer=True)

	trainer_2 = SentenceTransformerTrainer(
		model=model_2,
		args=get_training_args(f"{OUTPUT_BASE}/exp2_embeddistill_l2", run_name="EmbedDistill-L2"),
		train_dataset=train_dataset,
		eval_dataset=eval_dataset,
		loss=loss_2,
		evaluator=dev_evaluator,
	)
	swap_wandb_callback(trainer_2)

	start = time.time()
	trainer_2.train()
	time_2 = time.time() - start

	test_result_2 = test_evaluator(model_2)
	acc_2 = test_result_2["all-nli-test_cosine_accuracy"]
	results["EmbedDistill (L2)"] = {"accuracy": acc_2, "time": time_2}
	print(f"\nTest accuracy: {acc_2:.4f}")
	print(f"Training time: {time_2:.1f}s")

	# ──────────────────────────────────────────────
	# Experiment 3: EmbedDistillLoss (on-the-fly, cosine)
	# ──────────────────────────────────────────────
	print("\n" + "=" * 60)
	print("Experiment 3: EmbedDistillLoss (on-the-fly, cosine)")
	print("=" * 60)

	model_3 = SentenceTransformer(STUDENT_MODEL, model_kwargs=STUDENT_MODEL_KWARGS)
	teacher_model_3 = SentenceTransformer(TEACHER_MODEL, model_kwargs=TEACHER_MODEL_KWARGS)
	teacher_model_3.max_seq_length = 512
	loss_3 = custom_ls.EmbedDistillLoss(model_3, teacher_model=teacher_model_3, distance_metric="cosine", add_projection_layer=True)

	trainer_3 = SentenceTransformerTrainer(
		model=model_3,
		args=get_training_args(f"{OUTPUT_BASE}/exp3_embeddistill_cosine", run_name="EmbedDistill-Cosine"),
		train_dataset=train_dataset,
		eval_dataset=eval_dataset,
		loss=loss_3,
		evaluator=dev_evaluator,
	)
	swap_wandb_callback(trainer_3)

	start = time.time()
	trainer_3.train()
	time_3 = time.time() - start

	test_result_3 = test_evaluator(model_3)
	acc_3 = test_result_3["all-nli-test_cosine_accuracy"]
	results["EmbedDistill (cosine)"] = {"accuracy": acc_3, "time": time_3}
	print(f"\nTest accuracy: {acc_3:.4f}")
	print(f"Training time: {time_3:.1f}s")

	# ──────────────────────────────────────────────
	# Experiment 4: EmbedDistillLoss (on-the-fly, MSE)
	# ──────────────────────────────────────────────
	print("\n" + "=" * 60)
	print("Experiment 4: EmbedDistillLoss (on-the-fly, MSE)")
	print("=" * 60)

	model_4 = SentenceTransformer(STUDENT_MODEL, model_kwargs=STUDENT_MODEL_KWARGS)
	teacher_model_4 = SentenceTransformer(TEACHER_MODEL, model_kwargs=TEACHER_MODEL_KWARGS)
	teacher_model_4.max_seq_length = 512
	loss_4 = custom_ls.EmbedDistillLoss(model_4, teacher_model=teacher_model_4, distance_metric="mse", add_projection_layer=True)

	trainer_4 = SentenceTransformerTrainer(
		model=model_4,
		args=get_training_args(f"{OUTPUT_BASE}/exp4_embeddistill_mse", run_name="EmbedDistill-MSE"),
		train_dataset=train_dataset,
		eval_dataset=eval_dataset,
		loss=loss_4,
		evaluator=dev_evaluator,
	)
	swap_wandb_callback(trainer_4)

	start = time.time()
	trainer_4.train()
	time_4 = time.time() - start

	test_result_4 = test_evaluator(model_4)
	acc_4 = test_result_4["all-nli-test_cosine_accuracy"]
	results["EmbedDistill (MSE)"] = {"accuracy": acc_4, "time": time_4}
	print(f"\nTest accuracy: {acc_4:.4f}")
	print(f"Training time: {time_4:.1f}s")

	# ──────────────────────────────────────────────
	# Summary
	# ──────────────────────────────────────────────
	print("\n" + "=" * 60)
	print("RESULTS SUMMARY")
	print("=" * 60)
	print(f"{'Method':<25} {'Accuracy':>10} {'Time (s)':>10}")
	print("-" * 47)
	for name, res in results.items():
		print(f"{name:<25} {res['accuracy']:>10.4f} {res['time']:>10.1f}")
	print("-" * 47)

	# Teacher reference (use teacher from exp 2)
	teacher_result = test_evaluator(teacher_model)
	teacher_acc = teacher_result["all-nli-test_cosine_accuracy"]
	print(f"{'Teacher (reference)':<25} {teacher_acc:>10.4f} {'N/A':>10}")

	del teacher_model, teacher_model_4


if __name__ == "__main__":
	main()

Results

Method Cosine Accuracy Training Time
MNRL (baseline) 0.9021 1,047s
EmbedDistill (cosine) 0.8901 2,140s
EmbedDistill (L2) 0.6959 2,159s
EmbedDistill (MSE) 0.5928 2,140s
Teacher (reference) 0.9434 N/A

Key observations:

  • Cosine distance reaches 0.8901, very close to the MNRL baseline (0.9021) and retaining ~94% of teacher performance.
  • L2 and MSE perform significantly worse, likely due to scale sensitivity when projecting across different embedding dimensions.
  • Training takes ~2x longer than MNRL due to the additional teacher forward pass.

Usage example

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses

student_model = SentenceTransformer("microsoft/mpnet-base")
teacher_model = SentenceTransformer("BAAI/bge-m3")

loss = losses.EmbedDistillLoss(
    student_model,
    teacher_model=teacher_model,
    distance_metric="cosine",
    add_projection_layer=True,  # student: 768d → teacher: 1024d
)

trainer = SentenceTransformerTrainer(
    model=student_model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

# Optionally save the learned projection layer
loss.save_projection("output/projection.pt")

TODO / Open questions

  • Add unit tests

  • Should we add a training example script under examples/?

  • Youngjoon Jang

@yjoonjang
Copy link
Contributor Author

I also added KLDiv for the distance_metric (not sure distance_metric is the right expression), and these are the final results. (It's also updated in the previous wandb)

Method Cosine Accuracy Training Time
MNRL (baseline) 0.9021 1,047s
EmbedDistill (cosine) 0.8901 2,140s
EmbedDistill (L2) 0.6959 2,159s
EmbedDistill (MSE) 0.5928 2,140s
EmbedDistill (KLDiv) 0.5868 2,137s
Teacher (reference) 0.9434 N/A

@yjoonjang
Copy link
Contributor Author

yjoonjang commented Feb 18, 2026

The jina v5 paper states that

Figure 3 illustrates training progress for all three loss functions on the MTEB English v2 retrieval benchmark at nDCG@ 10. We observe clear differences in both convergence speed and final performance. While $L_{score}$ and $L_{NCE}$ provide a significantly faster initial increase in scores, they plateau relatively early, with score-based distillation showing very limited progress in later stages. In contrast, embedding-based distillation $(L_{distill})$ converges more slowly at the beginning, yet improves steadily and ultimately achieves the highest final retrieval performance in both data regimes. This suggests that while score-level matching is efficient for early alignment, directly aligning student and teacher embeddings provides a stronger and more sustained supervisory signal for long-term refinement.

This trend is also shown in the experiment above, making EmbedDistill with cosine more valuable. I think we can expect higher results with more datasets or steps.

image

@yjoonjang yjoonjang force-pushed the feat/add_embeddistill_loss branch from 067c8f2 to 6f09700 Compare February 22, 2026 14:51
@yjoonjang
Copy link
Contributor Author

yjoonjang commented Feb 22, 2026

Hi, I changed projection_save_path argument to pretrained_projection_path.

The old projection_save_path was confusing. It just stored a path at init, but you still had to call save_projection() manually. Unclear whether it's for saving or loading.

I replaced with pretrained_projection_path: if provided, it loads pre-trained projection weights at init automatically.
Also made path a required argument for save_projection(path) / load_projection(path) — no more hidden state.

# Save after training
loss.save_projection("proj.pt")

# Reuse in next training
loss = EmbedDistillLoss(..., pretrained_projection_path="proj.pt")

Would be great if you could take a look @tomaarsen !

@yjoonjang
Copy link
Contributor Author

yjoonjang commented Feb 26, 2026

Hi, I've encountered some problems while training with this EmbedDistillLoss. Modern embedding models (Qwen3-Embedding, pplx-embed, embeddinggemma, NV-Embed, etc.) require task-specific prompts to produce quality embeddings. For example, Qwen3-Embedding expects "Instruct: ...\nQuery: {text}" for queries, but no prompt for documents.

However, EmbedDistillLoss on-the-fly mode has no way to pass teacher-specific prompts.

Current Data Flow

Student prompts are handled by SentenceTransformerDataCollator via training_args.prompts. It prepends prompts to text before tokenization.

However, when EmbedDistillLoss retokenizes for the teacher (different tokenizers), it decodes student input_ids back to text and retokenizes with the teacher's tokenizer:

# EmbedDistillLoss.forward()
decoded = [self.tokenizer.batch_decode(sf["input_ids"], skip_special_tokens=True) ...]
sentence_features = [self.teacher_model.tokenize(sentences) for sentences in decoded]

Two issues:

  1. Teacher prompts are never applied: teacher encodes raw text without instructions.
  2. Student prompts leak into teacher input: decoded text includes student's prompt prefix (e.g., "query: What is X?"), which is passed to the teacher.

Note

This is not just an issue for EmbedDistillLoss. All losses using guide or teacher models (such as GISTEmbedLoss) have this problem.

Proposed Solutions

I think we have two options to solve this.

Option A: decode-and-replace

Add teacher_prompts dict to EmbedDistillLoss(+ other all losses). During retokenization, strip student prompt from decoded text and prepend teacher prompt.

Option B: pass original texts to loss by data collator

Modify the data collator to store original (un-prompted) text alongside tokenized features. The loss applies student/teacher prompts independently.

I think Option B might be more promising, since we can deal with future losses that use teacher/guide models.

Would be great if you could share your thoughts !

@tomaarsen
Copy link
Member

Hello!

I've been looking at this PR briefly while working on the hardness one & some others. I think overall, the whole re-tokenization approach that's required for 'online distillation/guidance' (e.g. this loss or GISTEmbedLoss) is just conceptually flawed. Especially given the extreme tokenization/preprocessing changes that will be introduced in #3554, it won't be possible to simply detokenize and retokenize.

Passing the raw inputs alongside the tokenized inputs would be a start, but it has its own issues. You've noticed these already with prompts, but there's more concerns. For example, the issue from #3674 (comment) would also apply for dense embedding models, meaning that you can't use iterable datasets anymore. And teacher models frequently have larger dimensionalities, which can be problematic for some distillation techniques.

I think a solid solution is to require offline distillation/guidance only. It's more restrictive, yes, but I'm not sure if there's a good case for online distillation/guidance apart from preprocessing simplicity.

So, users would have to compute teacher embeddings/similarities before training, and then they can use losses like MSELoss or MarginMSELoss to distill this into the student. Then the teacher embeddings only have to be computed once, and they can be reused for multiple training runs or shared as a Dataset to the community. It allows the model author to make sure that the teacher embeddings are correct, etc.

GISTEmbed is already implemented, so we can't really get rid of this one somehow, and 'online self-guidance/self-distillation' is totally fine as the data is already preprocessed correctly, but perhaps we should aim for offline only for new losses? Perhaps with some helper functions/methods for preprocessing datasets or something?

In relation to your open PRs, this might have some consequences:

  • [feat] Add EmbedDistillLoss #3665 (this PR): Perhaps instead of one new loss (EmbedDistillLoss), we can create a distillation loss superclass that gets subclassed for different types of distance metrics. The class only works with offline data (i.e. the embeddings must be fed in instead of computed on the fly), and the existing MSELoss can subclass this new superclass. Perhaps even the DistillKLDivLoss can subclass this new superclass? Not 100% sure. The superclass can support the trainable projection code, where the student embeddings are projected into the teacher's space.
  • [feat] Add hardness-weighted contrastive learning to losses #3667 (hardness PR): I have some local changes here that I'll push soon, but this should be totally fine. We can update MNRL, CMNRL, GIST, CGIST.
  • Add self-guide mode for CachedGISTEmbedLoss #3662 (self-guide): We can add this to GIST/CGIST "for free" because the scaffolding is already there, but maybe we can expand this and also add it to MNRL, CMNRL without any of the re-tokenization requirements. The primary downside that I see is that a lot of people finetune from a non-embedding model like bert-base-uncased or ModernBERT-base. Initially, the similarity scores will be quite arbitrary, and it would detect a ton of "false negatives" because the model is still dumb. We could either add it with an explicit comment to only use it for finetuning existing embedding models, or we need some warmup approach. Then the question of how to schedule this warmup is still pretty tricky, e.g.:
    • Disable false negative detection before step x (e.g. 0.2) (most likely, this is the only viable one as the others are rather complex, but interesting to benchmark perhaps)
    • Linearly schedule the margin, e.g. from 1.0 (effectively no false negative detection) to 0.05 (pretty strict)
    • Schedule the margin followed by standard false negative detection, e.g. from 1.0 to 0.1 after 30% of training, then 0.1 for the rest.

What do you think? I'll try and push the hardness code quickly as there's overlap in which files are edited across the hardness & self-guidance work, so there might be some conflicts there.

cc @NohTow

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

Thanks for all these detailed thoughts!
Since I've got a lot of things to reply, so I'll reply them sequentially.

1. About EmbedDistillLoss → Embedding Distillation Superclass

I fully agree that there are some problems with the re-tokenization approach, and offline-only might be the best fit for new distillation losses.

And I like the idea of an embedding distillation superclass. Since MSELoss is already doing embedding-level distillation (teacher embeddings as labels), and EmbedDistillLoss generalizes this with multiple distance metrics + projection support, it makes sense to unify them:

  • Superclass handles: offline teacher embeddings (which will be passed from parameter label), projection layer, multi-column support
  • Subclasses implement different distance metrics (MSE, cosine, KL divergence, L2, etc.)
  • Existing MSELoss can subclass it (as the MSE distance variant)

(MarginMSELoss and DistillKLDivLoss are score-level distillation rather than embedding-level, so they probably don't fit into this hierarchy.)

I'll remove the online mode (teacher_model parameter, re-tokenization logic) entirely and focus on making offline distillation clean and robust.

For usability, I think providing helper functions for pre-computing teacher embeddings would be valuable. Something like a dataset.map() recipe or a utility that handles batched encoding with proper prompts. Maybe something like

teacher_model = SentenceTransformer("Qwen/Qwen3-Embedding-4B")                                                        

# if it has its prompt in config_sentence_transformers.json                                                                                                                                                                                                                                                                                                                                                                                     
query_embeddings = teacher_model.encode(dataset["query"], prompt_name="query") 
doc_embeddings = teacher_model.encode(dataset["answer"], prompt_name="document")

dataset = dataset.add_column("label", precomputed_embeddings)

loss = EmbedDistillLoss(model=student_model, distance_metric="cosine")
trainer = SentenceTransformerTrainer(model=student_model, train_dataset=dataset, loss=loss)
trainer.train()

2. About hardness-weighted contrastive losses

Thanks for the fixes and focal weighting feature !!
I'll take a look about the changes you've made. We could have a conversation at that PR.

3. About self-guide mode for contrastive losses

Nice point here. I also first thought that this may fit to the MNRLs, since same model means no re-tokenization and no prompt mismatch. It'll be fine to add this feature to both MNRLs and GISTs if you're okay with it.

I think the warmup concern is also valid. Starting with the simplest option (disable before step x) is reasonable as a first step, and we can benchmark more sophisticated scheduling later if needed.

Also, since users might land on specific PRs directly, I'll leave a summary of our self-guide discussion as a comment on #3662 so the context is discoverable there as well.

Next Steps

I'll start reworking the EmbedDistillLoss PR toward the offline-only superclass design.
Happy to align on the exact API before I start if you have preferences on the class hierarchy.

@yjoonjang
Copy link
Contributor Author

@tomaarsen Before I start the implementation, it would be great if you could share your thoughts on a few design decisions:

  1. Naming & hierarchy: Should EmbedDistillLoss itself become the superclass (with a distance_metric parameter), or should it be a separate abstract base class (e.g., EmbeddingDistillationLoss) with each metric as a concrete subclass? I'm leaning toward a single class with distance_metric parameter since it's simpler and avoids class explosion, but open to your preference.

  2. MSELoss migration: Should the existing MSELoss subclass the new superclass to share the projection/multi-column logic? Or keep MSELoss as-is for backward compatibility and just have the new loss as a standalone replacement?

  3. Helper functions for pre-computing embeddings: Should this be a utility function (e.g., in sentence_transformers.util), a class method on the loss, or just a documented recipe in the examples/docstring?

  4. Self-guide expansion: Should we expand self-guide mode to MNRL/CMNRL as well? And for the warmup, would disabling false negative detection for the first 20% of training (i.e., warmup_ratio=0.2 as default) be a reasonable starting point? (Once these are decided, I'll update the summary comment on Add self-guide mode for CachedGISTEmbedLoss #3662 accordingly.)

@tomaarsen
Copy link
Member

tomaarsen commented Feb 27, 2026

1. About EmbedDistillLoss → Embedding Distillation Superclass

The code snippet doesn't make too much sense, but I agree with the rest.

3. About self-guide mode for contrastive losses

Nice point here. I also first thought that this may fit to the MNRLs, since same model means no re-tokenization and no prompt mismatch. It'll be fine to add this feature to both MNRLs and GISTs if you're okay with it.

I think that should work 🤗

  1. Naming & hierarchy: Should EmbedDistillLoss itself become the superclass (with a distance_metric parameter), or should it be a separate abstract base class (e.g., EmbeddingDistillationLoss) with each metric as a concrete subclass? I'm leaning toward a single class with distance_metric parameter since it's simpler and avoids class explosion, but open to your preference.

I don't mind the 'class explosion': it's just 3 classes probably (MSE, Cosine, L2), and one already exists. I think perhaps it should be an abstract base class with each metric as a subclass. I'm unsure about the model naming at the moment. Perhaps we'll want to prefix them with Distill, e.g.:

  • DistillMSELoss (where we softly deprecate MSELoss by having it subclass the new DistillMSELoss with just an __init__ that warns users about the upcoming deprecation)
  • DistillCosineLoss
  • DistillL2Loss

I'm not sure what's nicest here yet.

  1. MSELoss migration: Should the existing MSELoss subclass the new superclass to share the projection/multi-column logic? Or keep MSELoss as-is for backward compatibility and just have the new loss as a standalone replacement?

The existing MSELoss should subclass the new superclass and simply initialize the superclass with the MSE distance. Then the MSELoss should still be backwards compatible + it gets the new projection feature, right?

  1. Helper functions for pre-computing embeddings: Should this be a utility function (e.g., in sentence_transformers.util), a class method on the loss, or just a documented recipe in the examples/docstring?

Probably just a documented recipe in the examples/docstring for now.

  1. Self-guide expansion: Should we expand self-guide mode to MNRL/CMNRL as well? And for the warmup, would disabling false negative detection for the first 20% of training (i.e., warmup_ratio=0.2 as default) be a reasonable starting point? (Once these are decided, I'll update the summary comment on Add self-guide mode for CachedGISTEmbedLoss #3662 accordingly.)

Ideally added to MNRL/CMNRL, yes. The main problem is that the loss isn't informed at what step it is, so we'd need a callback like https://github.com/huggingface/sentence-transformers/blob/main/sentence_transformers/sparse_encoder/callbacks/splade_callbacks.py I think. Or perhaps we can add something to pass the trainer state to the loss, and then the loss can check the state to see if it's already at the warmup ratio.

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

Thanks for sharing !
Before I start, I wanted to raise two concerns on the naming & hierarchy:

1. Naming confusion with DistillKLDivLoss

If we name the new losses DistillMSELoss, DistillCosineLoss, etc., they share the Distill prefix with the existing DistillKLDivLoss. But they do fundamentally different things:

  • DistillKLDivLoss distills scores (teacher similarity scores as labels)
  • DistillMSELoss distills embeddings (teacher embedding vectors as labels)

I think this could be confusing for users. Would it be better to keep the EmbedDistill prefix (e.g., EmbedDistillLoss as the base, with MSELoss subclassing it), or use something like EmbedDistillMSELoss, EmbedDistillCosineLoss to make the distinction clear?

2. Single class vs. subclasses

I'm also wondering if a single EmbedDistillLoss class with a distance_metric parameter might be simpler than separate subclasses. The distance metrics are really just one function call difference, and metric-specific parameters like temperature (for KL) are already handled cleanly as optional args.

`MSELoss compatibility would still work either way:

class MSELoss(EmbedDistillLoss):
    def __init__(self, model, **kwargs):
        warnings.warn("MSELoss is deprecated, use EmbedDistillLoss(distance_metric='mse') instead")
        super().__init__(model, distance_metric="mse", **kwargs)

That said, I understand the library convention is separate classes, so happy to go either way.
What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants