Skip to content

fix: use caller's class name as default model_type for subclasses#3692

Open
AtharvaJaiswal005 wants to merge 1 commit intohuggingface:mainfrom
AtharvaJaiswal005:fix/subclass-model-type-default
Open

fix: use caller's class name as default model_type for subclasses#3692
AtharvaJaiswal005 wants to merge 1 commit intohuggingface:mainfrom
AtharvaJaiswal005:fix/subclass-model-type-default

Conversation

@AtharvaJaiswal005
Copy link

Summary

  • Use self.__class__.__name__ instead of hardcoded "SentenceTransformer" as the default model_type in _get_model_type()
  • Fixes subclassed SentenceTransformer loading older models (without model_type in config) with wrong pooling defaults

Root Cause

_get_model_type() hardcoded the fallback to "SentenceTransformer" in two places:

  1. When no config_sentence_transformers.json exists
  2. When the config exists but has no model_type key

Meanwhile, _model_config["model_type"] is set to self.__class__.__name__ (e.g. "MySentenceTransformer"). This mismatch caused the comparison at line 325 to fail, falling through to _load_auto_model instead of _load_sbert_model, resulting in wrong pooling defaults (mean pooling instead of CLS token).

Reproduction

from sentence_transformers import SentenceTransformer

class MySentenceTransformer(SentenceTransformer): ...

model = MySentenceTransformer("BAAI/bge-small-en-v1.5")
# Before fix: "No sentence-transformers model found... Creating a new one with mean pooling."
# After fix: loads correctly with CLS token pooling

Changes

  • sentence_transformers/SentenceTransformer.py: Changed default from "SentenceTransformer" to self.__class__.__name__ in both fallback paths of _get_model_type()

Fixes #3536

When subclassing SentenceTransformer, _get_model_type() hardcoded the
default to "SentenceTransformer" which mismatched with _model_config
model_type (set to self.__class__.__name__). This caused subclassed
models loading older models without model_type in their config to
incorrectly use _load_auto_model instead of _load_sbert_model, resulting
in wrong pooling defaults (mean pooling instead of CLS token).

Fixes huggingface#3536
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wrong defaults used when loading older non-mean-pooled models via subclass

1 participant