Skip to content

[feat] Implement Quantization Aware Training (QAT) Loss draft#3655

Draft
tomaarsen wants to merge 4 commits intohuggingface:mainfrom
tomaarsen:feat/qat_loss
Draft

[feat] Implement Quantization Aware Training (QAT) Loss draft#3655
tomaarsen wants to merge 4 commits intohuggingface:mainfrom
tomaarsen:feat/qat_loss

Conversation

@tomaarsen
Copy link
Member

Hello!

Pull Request overview

Details

Dense embedding models typically produce embeddings with a fixed precision (float32), requiring significant storage and compute resources. Quantization-Aware Training trains models to maintain high performance even when embeddings are quantized to lower precision formats (int8, binary), enabling dramatically reduced storage costs and faster similarity computations.

Training

Training using Quantization-Aware Training (QAT) is straightforward: rather than applying a loss function only on full-precision embeddings, we also apply that same loss function on quantized versions of the embeddings. For example, a model can be trained on float32, int8, and binary simultaneously. Each of these losses will be added together, optionally with some weight:

from sentence_transformers import SentenceTransformer, losses

model = SentenceTransformer("microsoft/mpnet-base")

base_loss = losses.MultipleNegativesRankingLoss(model=model)
loss = losses.QuantizationAwareLoss(
    model=model, 
    loss=base_loss, 
    quantization_precisions=["int8", "binary"],
    quantization_weights=[1, 1]
)

Inference

When a model has been trained using Quantization-Aware Training, you can run inference with quantized embeddings using the precision parameter:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("path/to/your/qat-model")

# Encode with int8 quantization
embeddings_int8 = model.encode(
    sentences,
    precision="int8",
    normalize_embeddings=True,  # Recommended for quantized embeddings
)

# Or with binary quantization for maximum compression
embeddings_binary = model.encode(
    sentences,
    precision="binary",
    normalize_embeddings=True,
)

You can also evaluate with quantization by passing the precision parameter to evaluators:

from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

evaluator = EmbeddingSimilarityEvaluator(
    sentences1=sentences1,
    sentences2=sentences2,
    scores=scores,
    precision="int8",  # or "binary", "uint8", etc.
)
results = evaluator(model)

Implementation

The implementation uses a Straight-Through Estimator (STE) for differentiable quantization:

  • Forward pass: Apply actual quantization (rounding to int8, thresholding to binary)
  • Backward pass: Gradients flow through as if quantization is the identity function

Two key components:

  • ForwardDecorator: Caches embeddings on the first pass (float32), then reuses them for subsequent quantized passes
  • CachedLossDecorator: Special handling for Cached... losses that pre-compute embeddings

The loss always computes with original float32 embeddings first (for caching efficiency), then optionally computes losses for specified quantization precisions. Training with QAT adds minimal overhead. The embeddings are computed once and cached, then quantized versions are derived from the cache for each specified precision.

Note

Currently, I've not yet been able to train a model with this that actually clearly outperformed int8 or binary performance versus "regular" MNRL/InfoNCE training:

In all of these: float32 is a bit worse with QAT, int8 is a bit worse with QAT, and binary is slightly better with QAT. I suspect that I might be missing something.

  • Tom Aarsen


Args:
embeddings: Input tensor to quantize
precision: Quantization precision ("float32", "int8", "uint8", "binary", "ubinary")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, I wanted for a long time to add this loss too! I think you can also add (b)float16/8.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yess, right now it's mirroring this function:

def quantize_embeddings(
embeddings: Tensor | np.ndarray,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"],
ranges: np.ndarray | None = None,
calibration_embeddings: np.ndarray | None = None,
) -> np.ndarray:

But I'd like to extend that one as well. The tricky part is that there's multiple ways to convert embeddings to specific precisions, and not all e.g. evaluators work nicely with both. I can just give those different names though, like "better_binary" or something.

My new plan is to perhaps implement 3 functions:

  • quantize_embeddings
  • dequantize_embeddings
  • quantize_embeddings_grad

The model.encode method would use the first one to quantize for the end users. The evaluators each then call dequantize_embeddings so that they can still evaluate with fp32 embeddings and there's no complications with the similarity functions, and quantize_embeddings_grad is used in the QAT loss to allow gradients to flow through the quantization.

Regarding the dequantize_embeddings for the evaluators: do you think that'll work nicely? I want to make sure that the simple evaluation in ST works correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

evaluators work nicely with both. I can just give those different names though, like "better_binary" or something.

Yeah, there can be a lot of different combinations. I think you could create configs similar to https://huggingface.co/docs/transformers/en/main_classes/quantization#transformers.QuantoConfig to have more options to expand later. But it might be better to do this in (or after) #3554 and keep things simple for now.

Regarding the dequantize_embeddings for the evaluators: do you think that'll work nicely? I want to make sure that the simple evaluation in ST works correctly.

I think this should work and would allow comparing embeddings with different quantizations. However, it might introduce memory spikes or minor inaccuracies during conversion, so it would be good to have options to compare embeddings as-is as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think I should do it "properly" with configs, etc. A recurring issue for now is also the int8 calibration to "set the buckets". I think I'll have to look at it after the multi-modality refactor, as that's more important I think.

@tomaarsen
Copy link
Member Author

Note to self, read more up on https://jina.ai/news/quantization-aware-training-of-jina-embeddings-v4/.

  • Tom Aarsen

self.cache.append(output)
# Using cache (subsequent passes with quantization):
else:
output = self.cache[self.idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-            output = self.cache[self.idx]
+            # Avoid mutating float32 cache in subsequent quantized passes.
+            output = self.cache[self.idx].copy()

Without .copy(), quantizing output mutates the cached float32 entry, so later precision passes may use already-quantized embeddings instead of the original float32 cache.

@hotchpotch
Copy link
Contributor

hotchpotch commented Feb 15, 2026

Hello!

I found this QAT quantization implementation very interesting, and I also reviewed and validated the implementation/behavior on my side while using this PR as a reference.
I left a separate comment about one potentially unstable behavior (a cache-related issue). The "cache-fixed" variant below is the version with that fix applied.

In addition, I ran extra experiments in my environment.
I used sentence-transformers/gooaq for training, and evaluated with NanoBEIR (msmarco, nq) using nDCG@10 (float32 / int8 / binary).

Variant samples steps bs float32 int8 binary
MNRL baseline 90,000 1,407 64 0.5321 0.5115 0.4844
PR implementation 90,000 1,407 64 0.4890 0.4733 0.4435
PR implementation (cache-fixed) 90,000 1,407 64 0.5229 0.5046 0.4882
Linear staged warmup 90,000 1,407 64 0.5233 0.5098 0.4987
MNRL baseline 90,000 704 128 0.5094 0.5268 0.5226
PR implementation 90,000 704 128 0.5121 0.4937 0.4749
PR implementation (cache-fixed) 90,000 704 128 0.5156 0.5289 0.4912
Linear staged warmup 90,000 704 128 0.5091 0.5267 0.5180
MNRL baseline 990,000 7,735 128 0.5723 0.5641 0.5383
PR implementation 990,000 7,735 128 0.5642 0.5627 0.5533
PR implementation (cache-fixed) 990,000 7,735 128 0.5646 0.5731 0.5621
Linear staged warmup 990,000 7,735 128 0.5703 0.5640 0.5457
MNRL baseline 2,990,000 46,719 64 0.5379 0.5671 0.5335
PR implementation 2,990,000 46,719 64 0.5193 0.5182 0.4757
PR implementation (cache-fixed) 2,990,000 46,719 64 0.5502 0.5667 0.5022
Linear staged warmup 2,990,000 46,719 64 0.5722 0.5636 0.5297

Note: These evaluation results were obtained while I was iteratively making small adjustments in my local environment. As such, there may be minor inconsistencies across runs, so please consider them as indicative rather than strictly controlled benchmarks.

Notes (Linear staged warmup):

  • The loss is QuantizationAwareLoss (weights=[1.0, 1.0, 0.5]), and each precision weight is increased linearly by training step.
  • I implemented this linear step-based warmup on my side.
  • Warmup steps are float32=0, int8=200, binary=1800.
  • So float32 starts at full weight 1.0, int8 reaches 1.0 at step 200, and binary reaches 0.5 at step 1800.

Note for the 90k setting:

  • In 90k runs, total steps are 1,407 (bs=64) or 704 (bs=128), so binary warmup=1800 does not finish before training ends.
  • Therefore, in 90k runs, the effective binary weight does not reach its target 0.5 by the final step (about 78% for bs=64, about 39% for bs=128).

At this point, (not only warmup but in general) the behavior appears sensitive to dataset size and batch size, and I have not found a fully stable quantization recipe yet.
I also want to apply this QAT implementation to the model I am currently developing, so I will keep exploring better implementations and ideas as well.

@tomaarsen
Copy link
Member Author

Thanks a bunch for the fix, that's very well spotted. I'm also glad that it helps the performance, although it's not clear that any QAT variant is clearly better than MNRL. This is also what I experienced originally.

Is the linear warmup based on any paper? I've not read many works on QAT, but I do feel like this implementation is still missing something. Or perhaps MNRL (a.k.a. InfoNCE) is quite robust with QAT out of the box.

  • Tom Aarsen

@hotchpotch
Copy link
Contributor

hotchpotch commented Feb 17, 2026

Thanks for the thoughtful feedback.

To directly answer your question: the linear warmup I used is not based on a specific QAT paper. It is the same kind of linear-increase implementation as in the Transformers library, and was a heuristic I wanted to try.

I also have not done a deep survey of QAT papers yet. As a side investigation (not exactly QAT), I looked into how a production vector engine (Qdrant) implements quantization in real search workloads:

The most interesting point for me was asymmetric query quantization when documents are stored as 1-bit binary vectors. In Qdrant, documents stay 1-bit binary, but queries can be encoded as scalar4bits/scalar8bits instead of pure binary, which can improve ranking quality/recall at extra compute cost.

This asymmetric query-quantization path also appears to reference an implementation idea from https://arxiv.org/abs/2405.12497.

Relevant implementation links (SHA-pinned)

This is not exactly a QAT direction, but rather an inference-time optimization, so it may be better suited for a follow-up than this PR. Still, I found asymmetric quantization to be a very interesting approach. The main caveat is implementation: achieving production-grade performance in Python would likely require careful low-level optimization to avoid significant latency overhead.

@tomaarsen
Copy link
Member Author

I've seen the https://arxiv.org/abs/2405.12497 paper linked a lot, it seems like a promising direction. And the asymmetric quantization is very solid in my experience. For context, this demo uses binary search for an initial retrieval (e.g. top 40) and then loads int8 document embeddings for those 40 documents. It then rescores using the fp32 query and the 40 int8 documents to determine a new ranking.

It's only a tad slower than full binary search, while being quite a lot better. So yes: I'm fond of asymmetric quantization for sure.

  • Tom Aarsen

@hotchpotch
Copy link
Contributor

It's great to see that you had already implemented retrieval with asymmetric quantization.

In real-world applications, your approach of retrieving a slightly larger candidate set and then reranking with asymmetric quantization to improve accuracy seems very effective.Thank you for sharing this.

Your implementation will be a very good reference for my work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants