[feat] Implement Quantization Aware Training (QAT) Loss draft#3655
[feat] Implement Quantization Aware Training (QAT) Loss draft#3655tomaarsen wants to merge 4 commits intohuggingface:mainfrom
feat] Implement Quantization Aware Training (QAT) Loss draft#3655Conversation
|
|
||
| Args: | ||
| embeddings: Input tensor to quantize | ||
| precision: Quantization precision ("float32", "int8", "uint8", "binary", "ubinary") |
There was a problem hiding this comment.
Great work, I wanted for a long time to add this loss too! I think you can also add (b)float16/8.
There was a problem hiding this comment.
Yess, right now it's mirroring this function:
sentence-transformers/sentence_transformers/quantization.py
Lines 371 to 376 in 8c4be2c
But I'd like to extend that one as well. The tricky part is that there's multiple ways to convert embeddings to specific precisions, and not all e.g. evaluators work nicely with both. I can just give those different names though, like "better_binary" or something.
My new plan is to perhaps implement 3 functions:
- quantize_embeddings
- dequantize_embeddings
- quantize_embeddings_grad
The model.encode method would use the first one to quantize for the end users. The evaluators each then call dequantize_embeddings so that they can still evaluate with fp32 embeddings and there's no complications with the similarity functions, and quantize_embeddings_grad is used in the QAT loss to allow gradients to flow through the quantization.
Regarding the dequantize_embeddings for the evaluators: do you think that'll work nicely? I want to make sure that the simple evaluation in ST works correctly.
There was a problem hiding this comment.
evaluators work nicely with both. I can just give those different names though, like "better_binary" or something.
Yeah, there can be a lot of different combinations. I think you could create configs similar to https://huggingface.co/docs/transformers/en/main_classes/quantization#transformers.QuantoConfig to have more options to expand later. But it might be better to do this in (or after) #3554 and keep things simple for now.
Regarding the dequantize_embeddings for the evaluators: do you think that'll work nicely? I want to make sure that the simple evaluation in ST works correctly.
I think this should work and would allow comparing embeddings with different quantizations. However, it might introduce memory spikes or minor inaccuracies during conversion, so it would be good to have options to compare embeddings as-is as well
There was a problem hiding this comment.
Yeah, I think I should do it "properly" with configs, etc. A recurring issue for now is also the int8 calibration to "set the buckets". I think I'll have to look at it after the multi-modality refactor, as that's more important I think.
|
Note to self, read more up on https://jina.ai/news/quantization-aware-training-of-jina-embeddings-v4/.
|
| self.cache.append(output) | ||
| # Using cache (subsequent passes with quantization): | ||
| else: | ||
| output = self.cache[self.idx] |
There was a problem hiding this comment.
- output = self.cache[self.idx]
+ # Avoid mutating float32 cache in subsequent quantized passes.
+ output = self.cache[self.idx].copy()Without .copy(), quantizing output mutates the cached float32 entry, so later precision passes may use already-quantized embeddings instead of the original float32 cache.
|
Hello! I found this QAT quantization implementation very interesting, and I also reviewed and validated the implementation/behavior on my side while using this PR as a reference. In addition, I ran extra experiments in my environment.
Note: These evaluation results were obtained while I was iteratively making small adjustments in my local environment. As such, there may be minor inconsistencies across runs, so please consider them as indicative rather than strictly controlled benchmarks. Notes (Linear staged warmup):
Note for the 90k setting:
At this point, (not only warmup but in general) the behavior appears sensitive to dataset size and batch size, and I have not found a fully stable quantization recipe yet. |
|
Thanks a bunch for the fix, that's very well spotted. I'm also glad that it helps the performance, although it's not clear that any QAT variant is clearly better than MNRL. This is also what I experienced originally. Is the linear warmup based on any paper? I've not read many works on QAT, but I do feel like this implementation is still missing something. Or perhaps MNRL (a.k.a. InfoNCE) is quite robust with QAT out of the box.
|
|
Thanks for the thoughtful feedback. To directly answer your question: the linear warmup I used is not based on a specific QAT paper. It is the same kind of linear-increase implementation as in the Transformers library, and was a heuristic I wanted to try. I also have not done a deep survey of QAT papers yet. As a side investigation (not exactly QAT), I looked into how a production vector engine (Qdrant) implements quantization in real search workloads: The most interesting point for me was asymmetric query quantization when documents are stored as 1-bit binary vectors. In Qdrant, documents stay 1-bit binary, but queries can be encoded as scalar4bits/scalar8bits instead of pure binary, which can improve ranking quality/recall at extra compute cost. This asymmetric query-quantization path also appears to reference an implementation idea from https://arxiv.org/abs/2405.12497. Relevant implementation links (SHA-pinned)
This is not exactly a QAT direction, but rather an inference-time optimization, so it may be better suited for a follow-up than this PR. Still, I found asymmetric quantization to be a very interesting approach. The main caveat is implementation: achieving production-grade performance in Python would likely require careful low-level optimization to avoid significant latency overhead. |
|
I've seen the https://arxiv.org/abs/2405.12497 paper linked a lot, it seems like a promising direction. And the asymmetric quantization is very solid in my experience. For context, this demo uses binary search for an initial retrieval (e.g. top 40) and then loads int8 document embeddings for those 40 documents. It then rescores using the fp32 query and the 40 int8 documents to determine a new ranking. It's only a tad slower than full binary search, while being quite a lot better. So yes: I'm fond of asymmetric quantization for sure.
|
|
It's great to see that you had already implemented retrieval with asymmetric quantization. In real-world applications, your approach of retrieving a slightly larger candidate set and then reranking with asymmetric quantization to improve accuracy seems very effective.Thank you for sharing this. Your implementation will be a very good reference for my work. |
Hello!
Pull Request overview
Details
Dense embedding models typically produce embeddings with a fixed precision (float32), requiring significant storage and compute resources. Quantization-Aware Training trains models to maintain high performance even when embeddings are quantized to lower precision formats (int8, binary), enabling dramatically reduced storage costs and faster similarity computations.
Training
Training using Quantization-Aware Training (QAT) is straightforward: rather than applying a loss function only on full-precision embeddings, we also apply that same loss function on quantized versions of the embeddings. For example, a model can be trained on float32, int8, and binary simultaneously. Each of these losses will be added together, optionally with some weight:
Inference
When a model has been trained using Quantization-Aware Training, you can run inference with quantized embeddings using the
precisionparameter:You can also evaluate with quantization by passing the
precisionparameter to evaluators:Implementation
The implementation uses a Straight-Through Estimator (STE) for differentiable quantization:
Two key components:
The loss always computes with original float32 embeddings first (for caching efficiency), then optionally computes losses for specified quantization precisions. Training with QAT adds minimal overhead. The embeddings are computed once and cached, then quantized versions are derived from the cache for each specified precision.
Note
Currently, I've not yet been able to train a model with this that actually clearly outperformed int8 or binary performance versus "regular" MNRL/InfoNCE training:
In all of these: float32 is a bit worse with QAT, int8 is a bit worse with QAT, and binary is slightly better with QAT. I suspect that I might be missing something.