Skip to content

Commit edffd4a

Browse files
authored
Merge pull request #1359 from bact/fix-keybert-type
Fix type hints in keybert
2 parents 08657f0 + 2557659 commit edffd4a

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,10 @@ python_version = "3.9"
429429
show_column_numbers = true
430430
show_error_code_links = true
431431
show_error_context = true
432-
strict_optional = true
432+
strict_bytes = true
433433
strict_equality = true
434+
strict_equality_for_none = true
435+
strict_optional = true
434436
warn_no_return = true
435437
warn_redundant_casts = true
436438
warn_return_any = true

pythainlp/summarize/keybert.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,25 +225,31 @@ def _rank_keywords(
225225
) -> list[tuple[str, float]]:
226226
import numpy as np
227227

228-
def l2_norm(v: np.ndarray) -> np.ndarray:
228+
def l2_norm(v: "NDArray[np.float32]") -> "NDArray[np.float32]":
229229
vec_size = v.shape[1]
230230
result = np.divide(
231231
v,
232232
np.linalg.norm(v, axis=1).reshape(-1, 1).repeat(vec_size, axis=1),
233+
dtype=np.float32,
233234
)
234235
if not np.isclose(np.linalg.norm(result, axis=1), 1).all():
235236
raise ValueError("Cannot normalize a vector to unit vector.")
236-
return result
237+
return cast("NDArray[np.float32]", result)
237238

238-
def cosine_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray:
239-
return (np.matmul(a, b.T).T).sum(axis=1)
239+
def cosine_sim(
240+
a: "NDArray[np.float32]", b: "NDArray[np.float32]"
241+
) -> "NDArray[np.float32]":
242+
# `a` has one row (document embedding), so flatten to get 1-D scores.
243+
scores = np.matmul(a, b.T).reshape(-1)
244+
return cast("NDArray[np.float32]", scores.astype(np.float32, copy=False))
240245

241246
doc_vector = l2_norm(doc_vector)
242247
word_vectors = l2_norm(word_vectors)
243248
cosine_sims = cosine_sim(doc_vector, word_vectors)
244249
ranking_desc = np.argsort(-cosine_sims)
245250

251+
top_indices = cast("list[int]", ranking_desc[:max_keywords].tolist())
246252
final_ranks = [
247-
(keywords[r], cosine_sims[r]) for r in ranking_desc[:max_keywords]
253+
(keywords[idx], float(cosine_sims[idx])) for idx in top_indices
248254
]
249255
return final_ranks

0 commit comments

Comments
 (0)