33from __future__ import annotations
44
55import contextlib
6+ from inspect import signature
67from textwrap import indent
78from types import MappingProxyType
89from typing import TYPE_CHECKING , NamedTuple , TypedDict
1920from .._utils import NeighborsView , _doc_params , get_literal_vals
2021from . import _connectivity
2122from ._common import (
23+ _get_indices_distances_from_dense_matrix ,
2224 _get_indices_distances_from_sparse_matrix ,
2325 _get_sparse_matrix_from_indices_distances ,
2426)
27+ from ._connectivity import umap
2528from ._doc import doc_n_pcs , doc_use_rep
2629from ._types import _KnownTransformer , _Method
2730
2831if TYPE_CHECKING :
2932 from collections .abc import Callable , Mapping , MutableMapping
30- from typing import Any , Literal , NotRequired , TypeAlias
33+ from typing import Any , Literal , NotRequired , TypeAlias , Unpack
3134
3235 from anndata import AnnData
3336 from igraph import Graph
37+ from numpy .typing import NDArray
3438
3539 from .._utils .random import _LegacyRandom
3640 from ._types import KnnTransformerLike , _Metric , _MetricFn
@@ -56,11 +60,18 @@ class KwdsForTransformer(TypedDict):
5660 random_state : _LegacyRandom
5761
5862
63+ class NeighborsDict (TypedDict ): # noqa: D101
64+ connectivities_key : str
65+ distances_key : str
66+ params : NeighborsParams
67+ rp_forest : NotRequired [RPForestDict ]
68+
69+
5970class NeighborsParams (TypedDict ): # noqa: D101
6071 n_neighbors : int
6172 method : _Method
6273 random_state : _LegacyRandom
63- metric : _Metric | _MetricFn
74+ metric : _Metric | _MetricFn | None
6475 metric_kwds : NotRequired [Mapping [str , Any ]]
6576 use_rep : NotRequired [str ]
6677 n_pcs : NotRequired [int ]
@@ -72,11 +83,12 @@ def neighbors( # noqa: PLR0913
7283 n_neighbors : int = 15 ,
7384 n_pcs : int | None = None ,
7485 * ,
86+ distances : np .ndarray | SpBase | None = None ,
7587 use_rep : str | None = None ,
7688 knn : bool = True ,
7789 method : _Method = "umap" ,
7890 transformer : KnnTransformerLike | _KnownTransformer | None = None ,
79- metric : _Metric | _MetricFn = "euclidean" ,
91+ metric : _Metric | _MetricFn | None = None ,
8092 metric_kwds : Mapping [str , Any ] = MappingProxyType ({}),
8193 random_state : _LegacyRandom = 0 ,
8294 key_added : str | None = None ,
@@ -138,6 +150,8 @@ def neighbors( # noqa: PLR0913
138150 Use :func:`rapids_singlecell.pp.neighbors` instead.
139151 metric
140152 A known metric’s name or a callable that returns a distance.
153+ If `distances` is given, this parameter is simply stored in `.uns` (see below),
154+ otherwise defaults to `'euclidean'`.
141155
142156 *ignored if ``transformer`` is an instance.*
143157 metric_kwds
@@ -153,18 +167,18 @@ def neighbors( # noqa: PLR0913
153167 distances and connectivities are stored in `.obsp['distances']` and
154168 `.obsp['connectivities']` respectively.
155169 If specified, the neighbors data is added to .uns[key_added],
156- distances are stored in `.obsp[key_added+' _distances']` and
157- connectivities in `.obsp[key_added+' _connectivities']`.
170+ distances are stored in `.obsp[f'{{key_added}} _distances']` and
171+ connectivities in `.obsp[f'{{key_added}} _connectivities']`.
158172 copy
159173 Return a copy instead of writing to adata.
160174
161175 Returns
162176 -------
163177 Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields:
164178
165- `adata.obsp['distances' | key_added+' _distances']` : :class:`scipy.sparse.csr_matrix` (dtype `float`)
179+ `adata.obsp['distances' | f'{{key_added}} _distances']` : :class:`scipy.sparse.csr_matrix` (dtype `float`)
166180 Distance matrix of the nearest neighbors search. Each row (cell) has `n_neighbors`-1 non-zero entries. These are the distances to their `n_neighbors`-1 nearest neighbors (excluding the cell itself).
167- `adata.obsp['connectivities' | key_added+' _connectivities']` : :class:`scipy.sparse._csr.csr_matrix` (dtype `float`)
181+ `adata.obsp['connectivities' | f'{{key_added}} _connectivities']` : :class:`scipy.sparse._csr.csr_matrix` (dtype `float`)
168182 Weighted adjacency matrix of the neighborhood graph of data
169183 points. Weights should be interpreted as connectivities.
170184 `adata.uns['neighbors' | key_added]` : :class:`dict`
@@ -189,68 +203,108 @@ def neighbors( # noqa: PLR0913
189203 :doc:`/how-to/knn-transformers`
190204
191205 """
192- start = logg .info ("computing neighbors" )
193- adata = adata .copy () if copy else adata
194- if adata .is_view : # we shouldn't need this here...
195- adata ._init_as_actual (adata .copy ())
196- neighbors = Neighbors (adata )
197- neighbors .compute_neighbors (
198- n_neighbors ,
199- n_pcs = n_pcs ,
200- use_rep = use_rep ,
201- knn = knn ,
206+ if distances is None :
207+ if metric is None :
208+ metric = "euclidean"
209+ start = logg .info ("computing neighbors" )
210+ adata = adata .copy () if copy else adata
211+ if adata .is_view : # we shouldn't need this here...
212+ adata ._init_as_actual (adata .copy ())
213+ neighbors_ = Neighbors (adata )
214+ neighbors_ .compute_neighbors (
215+ n_neighbors ,
216+ n_pcs = n_pcs ,
217+ use_rep = use_rep ,
218+ knn = knn ,
219+ method = method ,
220+ transformer = transformer ,
221+ metric = metric ,
222+ metric_kwds = metric_kwds ,
223+ random_state = random_state ,
224+ )
225+ else :
226+ params = locals ()
227+ if ignored := {
228+ p .name
229+ for p in signature (neighbors ).parameters .values ()
230+ if p .name in {"use_rep" , "knn" , "n_pcs" , "metric_kwds" , "random_state" }
231+ if params [p .name ] != p .default
232+ }:
233+ warn (
234+ f"Parameter(s) ignored if `distances` is given: { ignored } " ,
235+ UserWarning ,
236+ )
237+ random_state = 0
238+ if callable (metric ):
239+ msg = "`metric` must be a string if `distances` is given."
240+ raise TypeError (msg )
241+ start = logg .info ("computing connectivities" )
242+ # if a precomputed distance matrix is provided, skip the PCA and distance computation
243+ if isinstance (distances , SpBase ):
244+ if TYPE_CHECKING :
245+ from scipy .sparse ._base import _spbase
246+
247+ assert isinstance (distances , _spbase )
248+ distances = distances .tocsr (copy = True )
249+ distances .setdiag (0 )
250+ distances .eliminate_zeros ()
251+ else :
252+ distances = np .asarray (distances )
253+ np .fill_diagonal (distances , 0 )
254+
255+ neighbors_ = Neighbors (adata )
256+ neighbors_ .n_neighbors = n_neighbors
257+ neighbors_ .knn = True
258+ neighbors_ ._distances = distances
259+ neighbors_ ._connectivities = neighbors_ ._compute_connectivites (method )
260+
261+ key_added , neighbors_dict = _get_metadata (
262+ key_added ,
263+ n_neighbors = neighbors_ .n_neighbors ,
202264 method = method ,
203- transformer = transformer ,
204- metric = metric ,
205- metric_kwds = metric_kwds ,
206265 random_state = random_state ,
266+ metric = metric ,
267+ ** ({} if not metric_kwds else dict (metric_kwds = metric_kwds )),
268+ ** ({} if use_rep is None else dict (use_rep = use_rep )),
269+ ** ({} if n_pcs is None else dict (n_pcs = n_pcs )),
207270 )
208271
209- if key_added is None :
210- key_added = "neighbors"
211- conns_key = "connectivities"
212- dists_key = "distances"
213- else :
214- conns_key = f"{ key_added } _connectivities"
215- dists_key = f"{ key_added } _distances"
216-
217- adata .uns [key_added ] = {}
218-
219- neighbors_dict = adata .uns [key_added ]
272+ if neighbors_ .rp_forest is not None :
273+ neighbors_dict ["rp_forest" ] = neighbors_ .rp_forest
220274
221- neighbors_dict ["connectivities_key" ] = conns_key
222- neighbors_dict ["distances_key" ] = dists_key
275+ adata .uns [key_added ] = neighbors_dict
276+ adata .obsp [neighbors_dict ["distances_key" ]] = neighbors_ .distances
277+ adata .obsp [neighbors_dict ["connectivities_key" ]] = neighbors_ .connectivities
223278
224- neighbors_dict ["params" ] = NeighborsParams (
225- n_neighbors = neighbors .n_neighbors ,
226- method = method ,
227- random_state = random_state ,
228- metric = metric ,
229- )
230- if metric_kwds :
231- neighbors_dict ["params" ]["metric_kwds" ] = metric_kwds
232- if use_rep is not None :
233- neighbors_dict ["params" ]["use_rep" ] = use_rep
234- if n_pcs is not None :
235- neighbors_dict ["params" ]["n_pcs" ] = n_pcs
236-
237- adata .obsp [dists_key ] = neighbors .distances
238- adata .obsp [conns_key ] = neighbors .connectivities
239-
240- if neighbors .rp_forest is not None :
241- neighbors_dict ["rp_forest" ] = neighbors .rp_forest
242279 logg .info (
243280 " finished" ,
244281 time = start ,
245282 deep = (
246283 f"added to `.uns[{ key_added !r} ]`\n "
247- f" `.obsp[{ dists_key !r} ]`, distances for each pair of neighbors\n "
248- f" `.obsp[{ conns_key !r} ]`, weighted adjacency matrix"
284+ f" `.obsp[{ neighbors_dict [ 'distances_key' ] !r} ]`, distances for each pair of neighbors\n "
285+ f" `.obsp[{ neighbors_dict [ 'connectivities_key' ] !r} ]`, weighted adjacency matrix"
249286 ),
250287 )
251288 return adata if copy else None
252289
253290
291+ def _get_metadata (
292+ key_added : str | None ,
293+ ** params : Unpack [NeighborsParams ],
294+ ) -> tuple [str , NeighborsDict ]:
295+ if key_added is None :
296+ return "neighbors" , NeighborsDict (
297+ connectivities_key = "connectivities" ,
298+ distances_key = "distances" ,
299+ params = params ,
300+ )
301+ return key_added , NeighborsDict (
302+ connectivities_key = f"{ key_added } _connectivities" ,
303+ distances_key = f"{ key_added } _distances" ,
304+ params = params ,
305+ )
306+
307+
254308class FlatTree (NamedTuple ): # noqa: D101
255309 hyperplanes : None
256310 offsets : None
@@ -358,7 +412,7 @@ class Neighbors:
358412 n_dcs
359413 Number of diffusion components to use.
360414 neighbors_key
361- Where to look in `.uns` and `.obsp` for neighbors data
415+ Where to look in `.uns` and `.obsp` for neighbors data.
362416
363417 """
364418
@@ -518,7 +572,7 @@ def to_igraph(self) -> Graph:
518572 return _utils .get_igraph_from_adjacency (self .connectivities )
519573
520574 @_doc_params (n_pcs = doc_n_pcs , use_rep = doc_use_rep )
521- def compute_neighbors ( # noqa: PLR0912
575+ def compute_neighbors (
522576 self ,
523577 n_neighbors : int = 30 ,
524578 n_pcs : int | None = None ,
@@ -601,26 +655,11 @@ def compute_neighbors( # noqa: PLR0912
601655 self ._rp_forest = _make_forest_dict (index )
602656 start_connect = logg .debug ("computed neighbors" , time = start_neighbors )
603657
604- if method == "umap" :
605- self ._connectivities = _connectivity .umap (
606- knn_indices ,
607- knn_distances ,
608- n_obs = self ._adata .shape [0 ],
609- n_neighbors = self .n_neighbors ,
610- )
611- elif method == "gauss" :
612- self ._connectivities = _connectivity .gauss (
613- self ._distances , self .n_neighbors , knn = self .knn
614- )
615- elif method == "jaccard" :
616- self ._connectivities = _connectivity .jaccard (
617- knn_indices ,
618- n_obs = self ._adata .shape [0 ],
619- n_neighbors = self .n_neighbors ,
620- )
621- elif method is not None :
622- msg = f"{ method !r} should have been coerced in _handle_transform_args"
623- raise AssertionError (msg )
658+ self ._connectivities = (
659+ None
660+ if method is None
661+ else self ._compute_connectivites (method , (knn_indices , knn_distances ))
662+ )
624663 self ._number_connected_components = 1
625664 if isinstance (self ._connectivities , CSBase ):
626665 from scipy .sparse .csgraph import connected_components
@@ -630,6 +669,43 @@ def compute_neighbors( # noqa: PLR0912
630669 if method is not None :
631670 logg .debug ("computed connectivities" , time = start_connect )
632671
672+ def _compute_connectivites (
673+ self ,
674+ method : _Method ,
675+ knn_ind_dist : (
676+ tuple [NDArray [np .int32 | np .int64 ], NDArray [np .float32 | np .float64 ]] | None
677+ ) = None ,
678+ ) -> CSRBase | NDArray [np .float32 | np .float64 ] | None :
679+ def get_knn ():
680+ if knn_ind_dist is not None :
681+ return knn_ind_dist
682+ if isinstance (self ._distances , CSBase ):
683+ return _get_indices_distances_from_sparse_matrix (
684+ self ._distances .tocsr (), self .n_neighbors
685+ )
686+ assert self ._distances is not None
687+ return _get_indices_distances_from_dense_matrix (
688+ self ._distances , self .n_neighbors
689+ )
690+
691+ if method == "umap" :
692+ knn_indices , knn_distances = get_knn ()
693+ return umap (
694+ knn_indices ,
695+ knn_distances ,
696+ n_obs = self ._adata .n_obs ,
697+ n_neighbors = self .n_neighbors ,
698+ )
699+ if method == "gauss" :
700+ return _connectivity .gauss (self ._distances , self .n_neighbors , knn = self .knn )
701+ if method == "jaccard" :
702+ knn_indices , _ = get_knn ()
703+ return _connectivity .jaccard (
704+ knn_indices , n_obs = self ._adata .n_obs , n_neighbors = self .n_neighbors
705+ )
706+ msg = f"Method { method } not implemented."
707+ raise NotImplementedError (msg )
708+
633709 def _handle_transformer (
634710 self ,
635711 method : _Method | Literal ["gauss" ] | None ,
0 commit comments