Skip to content

Add self-guide mode for CachedGISTEmbedLoss#3662

Open
yjoonjang wants to merge 6 commits intohuggingface:mainfrom
yjoonjang:feat/self-guide
Open

Add self-guide mode for CachedGISTEmbedLoss#3662
yjoonjang wants to merge 6 commits intohuggingface:mainfrom
yjoonjang:feat/self-guide

Conversation

@yjoonjang
Copy link
Contributor

Summary

Hello ! This PR optimizes CachedGISTEmbedLoss for self-guided training scenarios where the student model serves as its own guide.

Motivation

Recent embedding papers (e.g., Qwen3-Embedding, Diffusion-Pretrained Dense and Contextual Embeddings (a.k.a pplx-embed)) have shown that using the model's own similarity scores with a margin (e.g., margin=-1.0) provides effective self-guide without requiring a separate guide model.

[Qwen3-Embedding]

$$ m_{ij} = \begin{cases} 0 & \text{if } s_{ij} > s(q_i, d_i^+) + 0.1 \text{ or } d_j == d_i^+, \\ 1 & \text{otherwise,} \end{cases} $$

[pplx-embed]

$$ m_i(\mathbf{x}) = \mathbb{1}_{{s(\mathbf{q}_i, \mathbf{x}) \le s(\mathbf{q}_i, \mathbf{d}_i) + 0.1}} $$

This approach:

  • Filters negatives based on student_score > positive_score - margin (0.1)
  • Eliminates the need for a separate teacher model

However, when using the model as its own guide, the previous implementation still required passing the same instance twice and performed two forward passes - one for student embeddings and one for guide embeddings. This is computationally wasteful since both would produce identical results.

Changes

  1. guide parameter is now optional: When guide=None (default), the model uses itself as the guide (self-guided mode)
  2. Skip redundant forward pass: When self-guided, reuse reps.detach() as guide_reps instead of calling the guide model again
  3. Skip retokenization check: When self-guided, must_retokenize is always False (same tokenizer)

Code Changes

# In __init__
if guide is None:
    self.guide = model
    self.is_self_guided = True
else:
    self.guide = guide
    self.is_self_guided = model is guide

# In embed_minibatch
if self.is_self_guided:
    guide_reps = reps.detach()  # Reuse student embeddings
else:
    guide_reps = self.guide(...)  # Separate forward pass

Usage Example

from sentence_transformers import SentenceTransformer, losses

model = SentenceTransformer("BAAI/bge-base-en-v1.5")

# Self-guided mode (recommended): simply omit the guide parameter
loss = losses.CachedGISTEmbedLoss(
    model,
    margin=-1.0,  # Recommended for self-guided training
)

# Or with explicit separate guide model
guide = SentenceTransformer("all-MiniLM-L6-v2")
loss = losses.CachedGISTEmbedLoss(
    model,
    guide=guide,
    margin=0.1,
)

Benefits

Scenario Before After
Self-guided (guide=None) N/A (required guide) 1 forward pass
Self-guided (guide=model) 2 forward passes 1 forward pass
Separate guide 2 forward passes 2 forward passes (unchanged)

@yjoonjang
Copy link
Contributor Author

I also thought about making a self_guided: bool parameter for initializing CachedGISTEmbedLoss, but I thought this way would be a bit simpler. Would be great if you could share some thought about this @tomaarsen.

Thank you !

@tomaarsen
Copy link
Member

Hello!

This is very cool! Have you been able to run some training tests with this? I imagine it might work pretty nicely compared to e.g. MultipleNegativesRankingLoss.

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

yjoonjang commented Feb 15, 2026

Hi @tomaarsen, thank you for your comment !

I ran some experiments comparing -
MNRL vs. GIST (self-guided, original) vs. GIST (self-guided, current implementation) .

You can find the details in my colab.
To ensure a fair comparison with MNRL, I set the contrast_anchors and contrast_positives parameters to False for two GIST expriments.

While the performance gain is modest, the results look promising
(more accurate than MNRL and faster than original GIST):

  MNRL GIST (original) GIST (current)
time 1079.7s 1923.4s 1151.3s
test acc. 0.8930 0.8982 0.8985

Wonder why the test results aren't exactly same for the two GIST experiments.
Since they use same guide model, they should be identical.

@tomaarsen
Copy link
Member

Thanks for running and sharing your results, very nice! I imagine a difference between your experiment 2 (self-guided) and experiment 3 (guided with mpnet-base) is that in experiment 2 the guide model is continuously being updated. There is a bit of a risk with the self-guiding that the model learns to consider a lot of samples as false negatives so those can be ignored.

I'm glad to see that the self-guided saves a lot of training time though, that's very helpful.

What was the margin of -1.0 based on? I don't fully remember what margin did, perhaps the docs is a bit lacking there. If the loss uses cosine similarity, then the scores are between -1 and 1, but often between 0 and 1, and then a margin of -1.0 seems like a ton.

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

yjoonjang commented Feb 17, 2026

Ah thanks for the clarification. I forgot that the guide model is being updated 🤣

Answer to your question - how does the margin work?

The filtering condition in the code is:

# absolute strategy
mask = guided_sim_mat > (guided_sim - margin)

Where guided_sim is $s(q, d^+)$ (positive pair similarity from the guide).
So a negative $d^-$ is masked out (treated as false negative) when:

$$s(q, d^-) > s(q, d^+) - \text{margin}$$

Why negative margin for self-guided?

In the original GISTEmbed, the guide is a strong, frozen teacher. Its similarity scores are reliable, so we trust it to identify false negatives (margin=0.0~0.1).

In self-guided mode, the guide is the model itself, which is still being trained and initially unreliable. If we use margin=0.0, the model's noisy similarity estimates could aggressively filter out valid negatives as "false negatives."
As you pointed out, the model might learn to consider a lot of samples as false negatives, so we flip the margin direction to be more permissive.

Corrected experiment

I realized I made a mistake in my previous experiment. I was using margin=-1.0, but this effectively disables filtering entirely: the threshold becomes $s(q, d^+) + 1.0$, and since cosine similarity between text embeddings is bounded to $[0, 1]$, this threshold is unreachable (range $[1.0, 2.0]$). In other words, no negatives were being filtered at all.

The performance gain I observed in the previous experiment was likely due to the difference in default temperature between the two losses (MNRL: 0.05, GISTEmbed: 0.01), not from the self-guided filtering itself.

I re-ran the experiment with margin=-0.1, which actually should perform meaningful filtering. The results are available in this Colab notebook:

MNRL GIST (original) GIST (self-guided, margin=-0.1)
Training time 1146.6s 2118.0s 1158.4s
Test accuracy 0.8930 0.8979 0.8979

The self-guided mode matches the original GISTEmbed accuracy while being ~1.8x faster, nearly matching MNRL's training speed. This speedup is expected since self-guided mode eliminates the second forward pass through the guide model.

However, I should note that the accuracy improvement over MNRL is likely still driven by the temperature difference rather than the self-guided filtering itself. The results are very similar to the previous experiment with margin=-1.0 (where no filtering occurred), which suggests that at this stage, the student model is too weak to produce meaningful similarity estimates for effective false negative filtering. Since the model is still early in training, its similarity scores are essentially noise. So even with margin=-0.1, the filtering is not yet contributing a useful signal.

We expect self-guided filtering to become more effective when starting from a stronger base model or later in training, where the model's similarity estimates are more reliable. A more comprehensive evaluation at larger scale with a stronger student would be needed to isolate the true benefit of self-guided filtering from the temperature effect.

Precedent from recent work

The idea of using the model's own similarity scores for in-batch false negative filtering with a tolerance margin is well-established in recent embedding papers:

  • Qwen3 Embedding (arXiv:2506.05176): Uses a self-similarity mask $m_{ij} = 0 \ \text{if} \ s_{ij} > s(q_i, d_i^+) + 0.1$, where the model's own similarity scores are used for filtering with a $+0.1$ tolerance margin.
  • pplx-embed (arXiv:2602.11151): Uses $m_i(x) = \mathbb{1}[s(q_i, x) \leq s(q_i, d_i) + 0.1]$, again self-similarity based filtering with $+0.1$ tolerance.
  • Contextual Document Embeddings (CDE) (arXiv:2410.02525): Uses epsilon-based false negative filtering — $S(q,d) = {d' \in D \mid f(q,d') \geq f(q,d) + \epsilon}$.

All three confirm that when using the model's own scores for filtering, a positive tolerance (relaxed threshold) is necessary to avoid over-filtering. Our self-guided mode with margin=-0.1 follows this same principle, translating to a $+0.1$ tolerance threshold, consistent with the values used in these works.

@yjoonjang
Copy link
Contributor Author

For context, we had a related discussion in #3665 about the broader direction of online vs. offline distillation/guidance.

Here's a quick summary of the relevant points for this PR:

  • The re-tokenization approach used by online guidance (decode student → retokenize for guide) is considered fundamentally flawed, especially with upcoming tokenization changes in [feat] Introduce cross-modality and multi-modality support; modularize CrossEncoder class #3554 and prompt requirements of modern
    models.
  • Self-guide mode is unaffected by this concern, since guide = model means no re-tokenization and no prompt mismatch.
  • @tomaarsen suggested expanding self-guide beyond GIST/CGIST to also cover MNRL and CMNRL.
  • The main open question is warmup: early in training, the model's similarity scores are arbitrary, so self-guidance could produce many false negatives. The simplest approach would be to disable false negative detection
    before a certain step (e.g., 20% of training), with more sophisticated scheduling as a future improvement.

Linking here so the context is easy to find for anyone following this PR.

@yjoonjang
Copy link
Contributor Author

I'll run some test with the new implementation.

@tomaarsen
Copy link
Member

Thanks! My guess is that it doesn't work great with a base model (bert-base-uncased, mpnet-base, ModernBERT-base), but does work nicely with a model that's already an embedding model.

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

yjoonjang commented Mar 3, 2026

I've ran experiments some experiments.

Experiments

Setup

  • Models:
    • Base MLM: microsoft/mpnet-base, answerdotai/ModernBERT-base
    • Pretained: Alibaba-NLP/gte-base-en-v1.5, intfloat/e5-small (We ensure that I added prefix for e5)
  • Dataset: tomaarsen/natural-questions-hard-negatives (triplet-5)
  • Loss: CachedMultipleNegativesRankingLoss with gather_across_devices=True
  • Training: 1 epoch, 4 GPUs (DDP), batch_size=256/GPU, lr=2e-5
  • Evaluation: Full NanoBEIR (13 subsets), aggregated cosine metrics
  • Self-guide config: margin=-0.1, margin_strategy=absolute
  • Comparison: No self-guide(CMNRL) vs. self-guide(warmup=0) vs. self-guide(warmup=0.2)

Results (NanoBEIR-full)

wandb: [link]

Model Config NDCG@10 MRR@10 MAP@100
mpnet-base No self-guide (baseline) 0.4971 0.5461 0.4184
Self-guide (warmup=0.2) 0.4962 0.5455 0.4168
Self-guide (warmup=0) 0.4955 0.5463 0.4175
ModernBERT-base No self-guide (baseline) 0.4380 0.4867 0.3661
Self-guide (warmup=0.2) 0.4305 0.4768 0.3616
Self-guide (warmup=0) 0.4364 0.4847 0.3646
gte-small No self-guide (baseline) 0.6209 0.6922 0.5367
Self-guide (warmup=0.2) 0.6198 0.6912 0.5366
Self-guide (warmup=0) 0.6197 0.6912 0.5365
e5-small No self-guide (baseline) 0.5517 0.6184 0.4682
Self-guide (warmup=0.2) 0.5513 0.6183 0.4680
Self-guide (warmup=0) 0.5510 0.6176 0.4675

I couldn't find a solid trend here, but it looks
No self-guide > self-guide (warmup=0.2) > self-guide (warmup=0.0)

Warmup Mechanism Verification

To verify the warmup step, I've added some logs (just for this experiment) to the callback.

With warmup_ratio=0.2 (total 94 steps, warmup for first 18 steps):

[SelfGuide] Warmup: filtering disabled for the first 18/94 steps (warmup_ratio=0.2)
[SelfGuide] step=0  | active=0 | no filtering stats
[SelfGuide] step=1  | active=0 | no filtering stats
[SelfGuide] step=10 | active=0 | no filtering stats
[SelfGuide] Filtering ACTIVATED at step 18/94 (warmup_steps=18)
[SelfGuide] step=19 | active=1 | masked_per_query=2.92 | pos_sim=0.6475 | fn_sim=0.5188
[SelfGuide] step=20 | active=1 | masked_per_query=4.23 | pos_sim=0.6465 | fn_sim=0.4565
...
[SelfGuide] step=90 | active=1 | masked_per_query=0.12 | pos_sim=0.6647 | fn_sim=0.5776

  • Steps 0-17: active=0, no filtering applied, no filtering stats reported.
  • Step 18: Callback activates filtering (int(94 * 0.2) = 18).
  • Steps 18+: active=1, filtering stats (masked_per_query , pos_sim, fn_sim) were properly reported.

With warmup_ratio=0 (no warmup):

[SelfGuide] Filtering active from step 0 (no warmup)
[SelfGuide] step=0 | active=1 | no filtering stats
[SelfGuide] step=1 | active=1 | masked_per_query=84.63 | pos_sim=0.8572 | fn_sim=0.795
  • Filtering is active from the very first step. Stats were reported starting from step 1

In conclusion,

  • The warmup mechanism works properly as configured: warmup_ratio=0.2 disables filtering for the first 20% of training, while warmup_ratio=0 enables filtering immediately.
  • On NanoBEIR-full (13 subsets), self-guide shows minimal to slightly negative impact across all models.
  • The filtering statistics reveal that with margin=-0.1, very few candidates are actually filtered. This is expected since the threshold becomes pos_sim + 0.1, which most negatives won't exceed.
  • We could try More aggressive margins (e.g., margin=0.0 or positive values) or datasets with higher false-negative rates may show a stronger effect, but the papers (Qwen3-Embedding, pplx-embed) uses -0.1.

@yjoonjang
Copy link
Contributor Author

Should I run some extra experiments (e.g. margin=-0.05) @tomaarsen ?
The self-guide mode doesn't seem promising right now

@tomaarsen
Copy link
Member

I think that would be useful. GIST works pretty nicely, and this is similar, so I imagine there's some settings for which it works.

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

Happy to share some good news.
I ran an extended sweep over margin values with margin_strategy=absolute on mpnet-base.
The results are:

Self-Guide Margin warmup_ratio NDCG@10
X (CMNRL baseline) - 0.4971
-0.05 0 0.4945
-0.05 0.2 0.4942
-0.1 0 0.4955
-0.1 0.2 0.4962
-0.2 0 0.4968
-0.2 0.2 0.4957
-0.25 0 0.4991
-0.25 0.2 0.4971
-0.3 0 0.496
-0.3 0.2 0.4962
  • It's kind of hyperparameter tuning, but margin=-0.25, warmup_ratio=0 achieves 0.4991 NDCG@10, outperforming the baseline (0.4971).
  • I still don't see a big trend across the warmup_ratio (0.0 vs. 0.2)

@yjoonjang
Copy link
Contributor Author

Hi @tomaarsen, could you please take a look at the results when you have a moment ?
Thank you !

@tomaarsen
Copy link
Member

Thanks for the detailed benchmarks. It's nice to see that there's more gains when the margin goes further into the negatives, although I'm not sure how many samples are really being filtered away at some point. Beyond that, -0.25 margin being the strongest, while -0.2 and -0.3 are both worse than baseline, make me think that perhaps it's a bit arbitrary.
Apologies for asking, but could you also try to run the same experiment on an already trained embedding model? I think a margin like -0.1 or -0.2 on an already trained model could be a solid combination. After all, I trust those scores more, even at the start of training.

  • Tom Aarsen

@yjoonjang
Copy link
Contributor Author

Following your recommendation, I ran a margin sweep on gte-multilingual-base fine-tuned on English data.

Setup

  • Model: Alibaba-NLP/gte-multilingual-base (pretrained multilingual retrieval model)
  • Dataset: tomaarsen/natural-questions-hard-negatives (triplet-5)
  • Loss: CachedMultipleNegativesRankingLoss with gather_across_devices=False
  • Training: 1 epoch, 4 GPUs (DDP), batch_size=512/GPU, lr=2e-5
  • Final Evaluation: Full NanoBEIR (13 subsets)
  • Self-guide config: margin_strategy=absolute, varying margin and warmup_ratio
Training script
#!/bin/bash
GPUS="${GPUS:-4,5,6,7}"
NUM_GPUS=$(echo "$GPUS" | tr ',' '\n' | wc -l)

MODEL_NAME="Alibaba-NLP/gte-multilingual-base"
SHORT_MODEL_NAME="gte-multilingual-base"

mkdir -p $SHORT_MODEL_NAME-logs

# 1. No self-guide (baseline CachedMNRL)
echo "=== [1/9] No self-guide (baseline) ==="
CUDA_VISIBLE_DEVICES=$GPUS torchrun \
    --nproc_per_node=$NUM_GPUS --master_port 29502 \
    train.py \
    --model_name $MODEL_NAME \
    --query_prefix "" \
    --doc_prefix "" \
    --per_device_train_batch_size 512 \
    --mini_batch_size 64 \
    > $SHORT_MODEL_NAME-logs/no_sg.log 2>&1

# 2. Self-guide
RUN_IDX=2
for MARGIN in -0.1 -0.2 -0.25 -0.3; do
    for WARMUP in 0.2 0; do
        if [ "$WARMUP" == "0.2" ]; then
            WARMUP_LABEL="w/ warmup"
        else
            WARMUP_LABEL="w/o warmup"
        fi
        echo "=== [$RUN_IDX/9] margin=${MARGIN}, ${WARMUP_LABEL} ==="

        CUDA_VISIBLE_DEVICES=$GPUS torchrun \
            --nproc_per_node=$NUM_GPUS --master_port 29502 \
            train.py \
            --model_name $MODEL_NAME \
            --self_guide \
            --self_guide_warmup_ratio $WARMUP \
            --self_guide_margin=$MARGIN \
            --self_guide_margin_strategy absolute \
            --query_prefix "" \
            --doc_prefix "" \
            --per_device_train_batch_size 512 \
            --mini_batch_size 64 \
            > $SHORT_MODEL_NAME-logs/sg_warmup_${WARMUP}_margin_${MARGIN}_absolute.log 2>&1

        RUN_IDX=$((RUN_IDX + 1))
    done
done

echo "=== All 9 runs completed ==="
Training Python file
import argparse
import logging
import os

from datasets import load_dataset

from sentence_transformers import (
	SentenceTransformer,
	SentenceTransformerTrainer,
	SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)


def parse_args():
	parser = argparse.ArgumentParser()
	parser.add_argument("--model_name", type=str, default="microsoft/mpnet-base")

	# Self-guide arguments
	parser.add_argument("--self_guide", action="store_true", help="Enable self-guided false-negative filtering")
	parser.add_argument("--self_guide_margin", type=float, default=0.0, help="Margin for false-negative detection threshold")
	parser.add_argument("--self_guide_margin_strategy", type=str, default="absolute", choices=["absolute", "relative"], help="Strategy for applying the margin")
	parser.add_argument("--self_guide_warmup_ratio", type=float, default=0.0, help="Fraction of training steps to disable filtering (warmup)")

	# Hardness weighting arguments (can be combined with self-guide)
	parser.add_argument("--hardness_mode", type=str, default="none", choices=["none", "in_batch_negatives", "hard_negatives", "all_negatives"])
	parser.add_argument("--hardness_strength", type=float, default=0.0)

	# Prefix arguments (e.g., for e5 models: --query_prefix "query: " --doc_prefix "passage: ")
	parser.add_argument("--query_prefix", type=str, default="", help="Prefix prepended to anchor/query texts")
	parser.add_argument("--doc_prefix", type=str, default="", help="Prefix prepended to positive/negative texts")

	# Training arguments
	parser.add_argument("--per_device_train_batch_size", type=int, default=256)
	parser.add_argument("--mini_batch_size", type=int, default=32)
	parser.add_argument("--num_train_epochs", type=int, default=1)
	parser.add_argument("--learning_rate", type=float, default=2e-5)
	parser.add_argument("--output_dir", type=str, default="models/self_guide")
	parser.add_argument("--wandb_project", type=str, default="Self_Guide")
	return parser.parse_args()


def main():
	args = parse_args()

	hardness_mode = None if args.hardness_mode == "none" else args.hardness_mode
	short_model_name = args.model_name.split("/")[-1]

	# Build run name
	parts = [short_model_name]
	if args.self_guide:
		sg_name = "sg"
		if args.self_guide_margin != 0.0:
			sg_name += f"_m{args.self_guide_margin}_{args.self_guide_margin_strategy}"
		sg_name += f"_w{args.self_guide_warmup_ratio}"
		parts.append(sg_name)
	if hardness_mode is not None:
		parts.append(f"{args.hardness_mode}_s{args.hardness_strength}")
	if not parts:
		parts.append("baseline")
	run_name = "_".join(parts)

	logger.info(
		f"Run: {run_name} | self_guide={args.self_guide}, margin={args.self_guide_margin}, "
		f"margin_strategy={args.self_guide_margin_strategy}, warmup_ratio={args.self_guide_warmup_ratio}, "
		f"hardness_mode={hardness_mode}, hardness_strength={args.hardness_strength}"
	)

	# 1. Load model
	model = SentenceTransformer(args.model_name, trust_remote_code=True)
	model.max_seq_length = 512

	# 2. Load dataset
	dataset = load_dataset("tomaarsen/natural-questions-hard-negatives", "triplet-5", split="train")
	dataset = dataset.rename_columns({"query": "anchor", "answer": "positive"})

	# Apply query/doc prefixes if specified
	if args.query_prefix or args.doc_prefix:
		def add_prefixes(example):
			for key in example:
				if key == "anchor":
					example[key] = args.query_prefix + example[key]
				else:
					example[key] = args.doc_prefix + example[key]
			return example
		dataset = dataset.map(add_prefixes)
		logger.info(f"Applied prefixes: query='{args.query_prefix}', doc='{args.doc_prefix}'")

	logger.info(f"Dataset size: {len(dataset)}")

	# 3. Define loss
	loss = CachedMultipleNegativesRankingLoss(
		model,
		mini_batch_size=args.mini_batch_size,
		gather_across_devices=False,
		hardness_mode=hardness_mode,
		hardness_strength=args.hardness_strength,
		self_guide=args.self_guide,
		self_guide_margin=args.self_guide_margin,
		self_guide_margin_strategy=args.self_guide_margin_strategy,
		self_guide_warmup_ratio=args.self_guide_warmup_ratio,
	)

	# 4. Evaluator (subset of NanoBEIR for mid-training eval)
	evaluator_kwargs = {}
	if args.query_prefix:
		evaluator_kwargs["query_prompts"] = args.query_prefix
	if args.doc_prefix:
		evaluator_kwargs["corpus_prompts"] = args.doc_prefix
	dev_evaluator = NanoBEIREvaluator(
		dataset_names=["msmarco", "nq", "nfcorpus", "quoraretrieval"],
		show_progress_bar=True,
		batch_size=64,
		**evaluator_kwargs,
	)

	# 5. Training arguments
	os.environ["WANDB_PROJECT"] = args.wandb_project
	training_args = SentenceTransformerTrainingArguments(
		output_dir=f"{args.output_dir}/{run_name}",
		num_train_epochs=args.num_train_epochs,
		per_device_train_batch_size=args.per_device_train_batch_size,
		learning_rate=args.learning_rate,
		warmup_ratio=0.1,
		bf16=True,
		eval_on_start=True,
		eval_strategy="steps",
		eval_steps=0.2,
		logging_steps=10,
		logging_first_step=True,
		save_strategy="no",
		report_to="wandb",
		run_name=run_name,
		ddp_find_unused_parameters=True,
	)

	# 6. Train
	trainer = SentenceTransformerTrainer(
		model=model,
		args=training_args,
		train_dataset=dataset,
		loss=loss,
		evaluator=dev_evaluator,
	)
	trainer.train()

	# 7. Full NanoBEIR evaluation (all 13 subsets)
	logger.info("Running full NanoBEIR evaluation on all subsets...")
	full_evaluator = NanoBEIREvaluator(
		show_progress_bar=True,
		batch_size=64,
		**evaluator_kwargs,
	)
	full_evaluator(model)

	# 8. Save
	model.save_pretrained(f"{args.output_dir}/{run_name}/final")
	logger.info(f"Done: {run_name}")


if __name__ == "__main__":
	main()

Results

Self-Guide Margin warmup_ratio NDCG@10 Recall@10 MRR@10
X (CMNRL baseline) - 0.6111 0.6398 0.6912
-0.1 0 0.6116 0.6394 0.6928
-0.1 0.2 0.6113 0.6400 0.6923
-0.2 0 0.6113 0.6395 0.6911
-0.2 0.2 0.6111 0.6399 0.6917
-0.25 0 0.6112 0.6399 0.6916
-0.25 0.2 0.6122 0.6403 0.6928
-0.3 0 0.6116 0.6407 0.6914
-0.3 0.2 0.6118 0.6405 0.6932

The best is margin=-0.25, warmup=0.2 at 0.6122, showing better performance than baseline (but not a big gap also).

Would be happy if these results help.

@yjoonjang
Copy link
Contributor Author

Hi @tomaarsen, apologies for tagging you bunch of times.
Could you please take a look at the results ?
Thank you !

@tomaarsen
Copy link
Member

Thank you! I'm surprised at how small the gaps are. Very interesting. I think perhaps I'll have to do more research on this. However, I want to try and push #3554 first as there's a lot of demand there.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants