Skip to content

Commit 80a5683

Browse files
committed
FIX: ensure compatibility between sklearn and spark tfidf vectors for skl>=1.5
Change for TfidfTransformer of sklearn v1.5 in order to ensure compatibility between the pandas and spark version of emm. In sklearn v1.5+ TfidfTransformer no longer has the _idf_diag attribute, needed for setting the compatibility.
1 parent 338b410 commit 80a5683

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

emm/indexing/pandas_normalized_tfidf.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,25 @@ def fit(self, X: pd.Series | pd.DataFrame) -> TfidfVectorizer:
7676
super().fit(X)
7777

7878
timer.label("normalize")
79-
idf_diag = self._tfidf._idf_diag
79+
n_features = self.idf_.shape[0]
80+
81+
# 1. this max_idf_square value is used in normalization step for simulating out-of-vocabulary tokens
82+
idf_diag = scipy.sparse.diags(
83+
self.idf_, offsets=0, shape=(n_features, n_features), format="csr", dtype=self.dtype
84+
)
8085
idf_diag = idf_diag - scipy.sparse.diags(np.ones(idf_diag.shape[0]), shape=idf_diag.shape, dtype=self.dtype)
81-
self._tfidf._idf_diag = idf_diag
82-
assert self._tfidf._idf_diag.dtype == self.dtype
83-
# this value is used in normalization step for simulating out-of-vocabulary tokens
8486
self.max_idf_square = idf_diag.max() ** 2
8587

88+
# 2. ensure compatibility between sklearn and spark tfidf vectors
89+
if hasattr(self._tfidf, "_idf_diag"):
90+
# sklearn < 1.5
91+
self._tfidf._idf_diag = idf_diag
92+
assert self._tfidf._idf_diag.dtype == self.dtype
93+
else:
94+
# sklearn >= 1.5
95+
self.idf_ = self.idf_ - np.ones(n_features, dtype=self.dtype)
96+
assert self.idf_.dtype == self.dtype
97+
8698
timer.log_params({"n": len(X), "n_features": idf_diag.shape[0]})
8799

88100
return self

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
# Fix for error ValueError: numpy.ndarray size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject.
1919
"numpy>=1.20.1",
2020
"scipy",
21-
"scikit-learn<1.5.0",
21+
"scikit-learn>=1.0.0",
2222
"pandas>=1.1.0,!=1.5.0",
2323
"jinja2", # for pandas https://pandas.pydata.org/docs/getting_started/install.html#visualization
2424
"rapidfuzz<3.0.0",

0 commit comments

Comments
 (0)