Conversation
|
I also added KLDiv for the
|
067c8f2 to
6f09700
Compare
|
Hi, I changed The old I replaced with # Save after training
loss.save_projection("proj.pt")
# Reuse in next training
loss = EmbedDistillLoss(..., pretrained_projection_path="proj.pt")Would be great if you could take a look @tomaarsen ! |
|
Hi, I've encountered some problems while training with this However, Current Data FlowStudent prompts are handled by However, when # EmbedDistillLoss.forward()
decoded = [self.tokenizer.batch_decode(sf["input_ids"], skip_special_tokens=True) ...]
sentence_features = [self.teacher_model.tokenize(sentences) for sentences in decoded]Two issues:
Note This is not just an issue for Proposed SolutionsI think we have two options to solve this. Option A: decode-and-replaceAdd Option B: pass original texts to loss by data collatorModify the data collator to store original (un-prompted) text alongside tokenized features. The loss applies student/teacher prompts independently. I think Option B might be more promising, since we can deal with future losses that use teacher/guide models. Would be great if you could share your thoughts ! |
|
Hello! I've been looking at this PR briefly while working on the hardness one & some others. I think overall, the whole re-tokenization approach that's required for 'online distillation/guidance' (e.g. this loss or GISTEmbedLoss) is just conceptually flawed. Especially given the extreme tokenization/preprocessing changes that will be introduced in #3554, it won't be possible to simply detokenize and retokenize. Passing the raw inputs alongside the tokenized inputs would be a start, but it has its own issues. You've noticed these already with prompts, but there's more concerns. For example, the issue from #3674 (comment) would also apply for dense embedding models, meaning that you can't use iterable datasets anymore. And teacher models frequently have larger dimensionalities, which can be problematic for some distillation techniques. I think a solid solution is to require offline distillation/guidance only. It's more restrictive, yes, but I'm not sure if there's a good case for online distillation/guidance apart from preprocessing simplicity. So, users would have to compute teacher embeddings/similarities before training, and then they can use losses like MSELoss or MarginMSELoss to distill this into the student. Then the teacher embeddings only have to be computed once, and they can be reused for multiple training runs or shared as a Dataset to the community. It allows the model author to make sure that the teacher embeddings are correct, etc. GISTEmbed is already implemented, so we can't really get rid of this one somehow, and 'online self-guidance/self-distillation' is totally fine as the data is already preprocessed correctly, but perhaps we should aim for offline only for new losses? Perhaps with some helper functions/methods for preprocessing datasets or something? In relation to your open PRs, this might have some consequences:
What do you think? I'll try and push the hardness code quickly as there's overlap in which files are edited across the hardness & self-guidance work, so there might be some conflicts there. cc @NohTow
|
|
Thanks for all these detailed thoughts! 1. About EmbedDistillLoss → Embedding Distillation SuperclassI fully agree that there are some problems with the re-tokenization approach, and offline-only might be the best fit for new distillation losses. And I like the idea of an embedding distillation superclass. Since
( I'll remove the online mode ( For usability, I think providing helper functions for pre-computing teacher embeddings would be valuable. Something like a teacher_model = SentenceTransformer("Qwen/Qwen3-Embedding-4B")
# if it has its prompt in config_sentence_transformers.json
query_embeddings = teacher_model.encode(dataset["query"], prompt_name="query")
doc_embeddings = teacher_model.encode(dataset["answer"], prompt_name="document")
dataset = dataset.add_column("label", precomputed_embeddings)
loss = EmbedDistillLoss(model=student_model, distance_metric="cosine")
trainer = SentenceTransformerTrainer(model=student_model, train_dataset=dataset, loss=loss)
trainer.train()2. About hardness-weighted contrastive lossesThanks for the fixes and focal weighting feature !! 3. About self-guide mode for contrastive lossesNice point here. I also first thought that this may fit to the I think the warmup concern is also valid. Starting with the simplest option (disable before step x) is reasonable as a first step, and we can benchmark more sophisticated scheduling later if needed. Also, since users might land on specific PRs directly, I'll leave a summary of our self-guide discussion as a comment on #3662 so the context is discoverable there as well. Next StepsI'll start reworking the EmbedDistillLoss PR toward the offline-only superclass design. |
|
@tomaarsen Before I start the implementation, it would be great if you could share your thoughts on a few design decisions:
|
The code snippet doesn't make too much sense, but I agree with the rest.
I think that should work 🤗
I don't mind the 'class explosion': it's just 3 classes probably (MSE, Cosine, L2), and one already exists. I think perhaps it should be an abstract base class with each metric as a subclass. I'm unsure about the model naming at the moment. Perhaps we'll want to prefix them with
I'm not sure what's nicest here yet.
The existing MSELoss should subclass the new superclass and simply initialize the superclass with the MSE distance. Then the MSELoss should still be backwards compatible + it gets the new projection feature, right?
Probably just a documented recipe in the examples/docstring for now.
Ideally added to MNRL/CMNRL, yes. The main problem is that the loss isn't informed at what step it is, so we'd need a callback like https://github.com/huggingface/sentence-transformers/blob/main/sentence_transformers/sparse_encoder/callbacks/splade_callbacks.py I think. Or perhaps we can add something to pass the trainer state to the loss, and then the loss can check the state to see if it's already at the warmup ratio.
|
|
Thanks for sharing ! 1. Naming confusion with
|

Hello, @tomaarsen !
Pull Request overview
EmbedDistillLoss, an embedding-level knowledge distillation loss based on the EmbedDistill paper.Details
This PR adds a new loss function
EmbedDistillLossthat minimizes the distance between student and teacher model embeddings. Unlike score-based distillation (e.g.,MarginMSELoss,DistillKLDivLoss), this approach directly aligns the embedding spaces, which provides a stronger learning signal for the student model following the geometric distillation approach from Kim et al., 2023.Key features
Two operating modes:
Three distance metrics:
cosine(default) — cosine distance:1 - cos_sim(student, teacher)l2— Euclidean distance:||student - teacher||_2mse— mean squared errorLearnable projection layer:
ψ(z) = Wz + bmaps student embeddings to the teacher dimension (It is proved that the direction of the projection (student -> teacher) better by Jina Embeddings v5 Paper)save_projection()/load_projection()methods for persisting the learned projection weightsAutomatic retokenization:
Why cosine as default?
The original EmbedDistill paper uses L2 distance, but recent work on embedding distillation such as Jina Embeddings v5 reports strong results with cosine-based alignment. My experiments below confirm that cosine distance significantly outperforms L2 and MSE for this task.
Experiment setup
microsoft/mpnet-base(768d, with projection layer)BAAI/bge-m3(1024d)sentence-transformers/all-nli(triplet), 100K training samplesTripletEvaluatoron all-nli test set (cosine accuracy)Training test code details
Results
Key observations:
Usage example
TODO / Open questions
Add unit tests
Should we add a training example script under
examples/?Youngjoon Jang