Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,23 @@
optimize_layout_inverse,
)

from pynndescent import NNDescent
from pynndescent.distances import named_distances as pynn_named_distances
from pynndescent.sparse import sparse_named_distances as pynn_sparse_named_distances
NNDescent = None
pynn_named_distances = None
pynn_sparse_named_distances = None


def _lazy_import_pynndescent():
global NNDescent, pynn_named_distances, pynn_sparse_named_distances
if NNDescent is None:
from pynndescent import NNDescent
if pynn_named_distances is None:
from pynndescent.distances import named_distances as pynn_named_distances
if pynn_sparse_named_distances is None:
from pynndescent.sparse import (
sparse_named_distances as pynn_sparse_named_distances,
)
return NNDescent, pynn_named_distances, pynn_sparse_named_distances


locale.setlocale(locale.LC_NUMERIC, "C")

Expand Down Expand Up @@ -327,6 +341,7 @@ def nearest_neighbors(
n_trees = min(64, 5 + int(round((X.shape[0]) ** 0.5 / 20.0)))
n_iters = max(5, int(round(np.log2(X.shape[0]))))

NNDescent, _, _ = _lazy_import_pynndescent()
knn_search_index = NNDescent(
X,
n_neighbors=n_neighbors,
Expand Down Expand Up @@ -1889,7 +1904,13 @@ def _dist_only(x, y, *kwds):
"inverse_transform will be unavailable".format(self.metric)
)
self._inverse_distance_func = None
elif self.metric in pynn_named_distances:
else:
_, pynn_named_distances, pynn_sparse_named_distances = (
_lazy_import_pynndescent()
)
if self.metric not in pynn_named_distances:
raise ValueError("metric is neither callable nor a recognised string")

if self._sparse_data:
if self.metric in pynn_sparse_named_distances:
self._input_distance_func = pynn_sparse_named_distances[self.metric]
Expand All @@ -1905,8 +1926,7 @@ def _dist_only(x, y, *kwds):
"inverse_transform will be unavailable".format(self.metric)
)
self._inverse_distance_func = None
else:
raise ValueError("metric is neither callable nor a recognised string")

# set output distance metric
if callable(self.output_metric):
out_returns_grad = self._check_custom_metric(
Expand Down Expand Up @@ -2017,6 +2037,7 @@ def _dist_only(x, y, *kwds):
" must be numpy arrays of the same size."
)
# #848: warn but proceed if no search index is present
NNDescent, _, _ = _lazy_import_pynndescent()
if not isinstance(self.knn_search_index, NNDescent):
warn(
"precomputed_knn[2] (knn_search_index) "
Expand Down Expand Up @@ -2621,6 +2642,9 @@ def fit(self, X, y=None, ensure_all_finite=True, **kwargs):
# Standard case
self._small_data = False
# Standard case
_, pynn_named_distances, pynn_sparse_named_distances = (
_lazy_import_pynndescent()
)
if self._sparse_data and self.metric in pynn_sparse_named_distances:
nn_metric = self.metric
elif not self._sparse_data and self.metric in pynn_named_distances:
Expand Down Expand Up @@ -3438,6 +3462,9 @@ def update(self, X, ensure_all_finite=True):
else:
# now large data
self._small_data = False
_, pynn_named_distances, pynn_sparse_named_distances = (
_lazy_import_pynndescent()
)
if self._sparse_data and self.metric in pynn_sparse_named_distances:
nn_metric = self.metric
elif not self._sparse_data and self.metric in pynn_named_distances:
Expand Down Expand Up @@ -3523,6 +3550,9 @@ def update(self, X, ensure_all_finite=True):
self._knn_dists,
) = self._knn_search_index.neighbor_graph

_, pynn_named_distances, pynn_sparse_named_distances = (
_lazy_import_pynndescent()
)
if self._sparse_data and self.metric in pynn_sparse_named_distances:
nn_metric = self.metric
elif not self._sparse_data and self.metric in pynn_named_distances:
Expand Down