diff --git a/implicit/cpu/matrix_factorization_base.py b/implicit/cpu/matrix_factorization_base.py index dcd37fa..5579d7e 100644 --- a/implicit/cpu/matrix_factorization_base.py +++ b/implicit/cpu/matrix_factorization_base.py @@ -4,7 +4,7 @@ import numpy as np from scipy.sparse import csr_matrix, lil_matrix -from ..recommender_base import ModelFitError, RecommenderBase +from ..recommender_base import RecommenderBase from .topk import topk @@ -246,10 +246,7 @@ def item_norms(self): return self._item_norms def _check_fit_errors(self): - is_nan = np.any(np.isnan(self.user_factors), axis=None) - is_nan |= np.any(np.isnan(self.item_factors), axis=None) - if is_nan: - raise ModelFitError("NaN encountered in factors") + self._check_factors(self.user_factors, self.item_factors) def _filter_items_from_sparse_matrix(items, query_items): diff --git a/implicit/gpu/bpr.py b/implicit/gpu/bpr.py index cffc471..cf33176 100644 --- a/implicit/gpu/bpr.py +++ b/implicit/gpu/bpr.py @@ -157,6 +157,8 @@ def fit(self, user_items, show_progress=True, callback=None): if callback: callback(_epoch, time.time() - s, correct, skipped) + self._check_fit_errors() + def to_cpu(self) -> implicit.cpu.bpr.BayesianPersonalizedRanking: """Converts this model to an equivalent version running on the cpu""" ret = implicit.cpu.bpr.BayesianPersonalizedRanking( diff --git a/implicit/gpu/matrix_factorization_base.py b/implicit/gpu/matrix_factorization_base.py index 4b5933c..6b1d5a5 100644 --- a/implicit/gpu/matrix_factorization_base.py +++ b/implicit/gpu/matrix_factorization_base.py @@ -201,6 +201,9 @@ def similar_items( similar_items.__doc__ = RecommenderBase.similar_items.__doc__ + def _check_fit_errors(self): + self._check_factors(self.user_factors.to_numpy(), self.item_factors.to_numpy()) + def recalculate_user(self, userid, user_items): raise NotImplementedError("recalculate_user is not supported with this model") diff --git a/implicit/recommender_base.py b/implicit/recommender_base.py index 25c7c0a..76a2d5b 100644 --- a/implicit/recommender_base.py +++ b/implicit/recommender_base.py @@ -213,3 +213,10 @@ def rank_items(self, userid, user_items, selected_items, recalculate_user=False) items=selected_items, filter_already_liked_items=False, ) + + @staticmethod + def _check_factors(user_factors, item_factors): + is_nan = np.any(np.isnan(user_factors), axis=None) + is_nan |= np.any(np.isnan(item_factors), axis=None) + if is_nan: + raise ModelFitError("NaN encountered in factors")