31
31
32
32
if TYPE_CHECKING :
33
33
from collections .abc import Callable , MutableMapping
34
- from typing import Any , Literal , NotRequired
34
+ from typing import Any , Literal , NotRequired , Unpack
35
35
36
36
from anndata import AnnData
37
37
from igraph import Graph
@@ -60,6 +60,13 @@ class KwdsForTransformer(TypedDict):
60
60
random_state : _LegacyRandom
61
61
62
62
63
+ class NeighborsDict (TypedDict ): # noqa: D101
64
+ connectivities_key : str
65
+ distances_key : str
66
+ params : NeighborsParams
67
+ rp_forest : NotRequired [RPForestDict ]
68
+
69
+
63
70
class NeighborsParams (TypedDict ): # noqa: D101
64
71
n_neighbors : int
65
72
method : _Method
@@ -138,6 +145,7 @@ def neighbors( # noqa: PLR0913
138
145
Use :func:`rapids_singlecell.pp.neighbors` instead.
139
146
metric
140
147
A known metric’s name or a callable that returns a distance.
148
+ If `distances` is given, this parameter is simply stored in `.uns` (see below).
141
149
142
150
*ignored if ``transformer`` is an instance.*
143
151
metric_kwds
@@ -190,12 +198,15 @@ def neighbors( # noqa: PLR0913
190
198
191
199
"""
192
200
if distances is not None :
193
- # Added this to support the new distance matrix function
201
+ if callable (metric ):
202
+ msg = "`metric` must be a string if `distances` is given."
203
+ raise TypeError (msg )
194
204
# if a precomputed distance matrix is provided, skip the PCA and distance computation
195
205
return neighbors_from_distance (
196
206
adata ,
197
207
distances ,
198
208
n_neighbors = n_neighbors ,
209
+ metric = metric ,
199
210
method = method ,
200
211
)
201
212
start = logg .info ("computing neighbors" )
@@ -215,46 +226,31 @@ def neighbors( # noqa: PLR0913
215
226
random_state = random_state ,
216
227
)
217
228
218
- if key_added is None :
219
- key_added = "neighbors"
220
- conns_key = "connectivities"
221
- dists_key = "distances"
222
- else :
223
- conns_key = key_added + "_connectivities"
224
- dists_key = key_added + "_distances"
225
-
226
- adata .uns [key_added ] = {}
227
-
228
- neighbors_dict = adata .uns [key_added ]
229
-
230
- neighbors_dict ["connectivities_key" ] = conns_key
231
- neighbors_dict ["distances_key" ] = dists_key
232
-
233
- neighbors_dict ["params" ] = NeighborsParams (
229
+ key_added , neighbors_dict = _get_metadata (
230
+ key_added ,
234
231
n_neighbors = neighbors .n_neighbors ,
235
232
method = method ,
236
233
random_state = random_state ,
237
234
metric = metric ,
235
+ ** ({} if not metric_kwds else dict (metric_kwds = metric_kwds )),
236
+ ** ({} if use_rep is None else dict (use_rep = use_rep )),
237
+ ** ({} if n_pcs is None else dict (n_pcs = n_pcs )),
238
238
)
239
- if metric_kwds :
240
- neighbors_dict ["params" ]["metric_kwds" ] = metric_kwds
241
- if use_rep is not None :
242
- neighbors_dict ["params" ]["use_rep" ] = use_rep
243
- if n_pcs is not None :
244
- neighbors_dict ["params" ]["n_pcs" ] = n_pcs
245
-
246
- adata .obsp [dists_key ] = neighbors .distances
247
- adata .obsp [conns_key ] = neighbors .connectivities
248
239
249
240
if neighbors .rp_forest is not None :
250
241
neighbors_dict ["rp_forest" ] = neighbors .rp_forest
242
+
243
+ adata .uns [key_added ] = neighbors_dict
244
+ adata .obsp [neighbors_dict ["distances_key" ]] = neighbors .distances
245
+ adata .obsp [neighbors_dict ["connectivities_key" ]] = neighbors .connectivities
246
+
251
247
logg .info (
252
248
" finished" ,
253
249
time = start ,
254
250
deep = (
255
251
f"added to `.uns[{ key_added !r} ]`\n "
256
- f" `.obsp[{ dists_key !r} ]`, distances for each pair of neighbors\n "
257
- f" `.obsp[{ conns_key !r} ]`, weighted adjacency matrix"
252
+ f" `.obsp[{ neighbors_dict [ 'distances_key' ] !r} ]`, distances for each pair of neighbors\n "
253
+ f" `.obsp[{ neighbors_dict [ 'connectivities_key' ] !r} ]`, weighted adjacency matrix"
258
254
),
259
255
)
260
256
return adata if copy else None
@@ -265,6 +261,7 @@ def neighbors_from_distance(
265
261
distances : np .ndarray | SpBase ,
266
262
* ,
267
263
n_neighbors : int = 15 ,
264
+ metric : _Metric = "euclidean" ,
268
265
method : _Method = "umap" , # default to umap
269
266
key_added : str | None = None ,
270
267
) -> AnnData :
@@ -298,63 +295,59 @@ def neighbors_from_distance(
298
295
distances = sparse .csr_matrix (distances ) # noqa: TID251
299
296
distances .setdiag (0 )
300
297
distances .eliminate_zeros ()
301
- # extracting for each observation the indices and distances of the n_neighbors
302
- # being then used by umap or gauss
303
- knn_indices , knn_distances = _get_indices_distances_from_sparse_matrix (
304
- distances , n_neighbors
305
- )
306
298
else :
307
- # if it is dense, converting it to ndarray
308
- # and setting the diagonal to 0
309
- # extracting knn indices and distances
310
299
distances = np .asarray (distances )
311
300
np .fill_diagonal (distances , 0 )
312
- knn_indices , knn_distances = _get_indices_distances_from_dense_matrix (
313
- distances , n_neighbors
314
- )
315
301
316
302
if method == "umap" :
317
- # using umap to build connectivities from distances
303
+ if isinstance (distances , CSRBase ):
304
+ knn_indices , knn_distances = _get_indices_distances_from_sparse_matrix (
305
+ distances , n_neighbors
306
+ )
307
+ else :
308
+ knn_indices , knn_distances = _get_indices_distances_from_dense_matrix (
309
+ distances , n_neighbors
310
+ )
318
311
connectivities = umap (
319
- knn_indices ,
320
- knn_distances ,
321
- n_obs = adata .n_obs ,
322
- n_neighbors = n_neighbors ,
312
+ knn_indices , knn_distances , n_obs = adata .n_obs , n_neighbors = n_neighbors
323
313
)
324
314
elif method == "gauss" :
325
- # using gauss to build connectivities from distances
326
- # requires sparse matrix for efficiency
327
- connectivities = _connectivity .gauss (
328
- sparse .csr_matrix (distances ), # noqa: TID251
329
- n_neighbors ,
330
- knn = True ,
331
- )
315
+ distances = sparse .csr_matrix (distances ) # noqa: TID251
316
+ connectivities = _connectivity .gauss (distances , n_neighbors , knn = True )
332
317
else :
333
318
msg = f"Method { method } not implemented."
334
319
raise NotImplementedError (msg )
335
- # defining where to store graph info
336
- key = "neighbors" if key_added is None else key_added
337
- dists_key = "distances" if key_added is None else key_added + "_distances"
338
- conns_key = "connectivities" if key_added is None else key_added + "_connectivities"
339
- # storing the actual distance and connectivitiy matrices as obsp
340
- adata .obsp [dists_key ] = sparse .csr_matrix (distances ) # noqa: TID251
341
- adata .obsp [conns_key ] = connectivities
342
- # populating with metadata describing how neighbors were computed
343
- # I think might be important as many functions downstream rely
344
- # on .uns['neighbors'] to find correct .obsp key
345
- adata .uns [key ] = {
346
- "connectivities_key" : "connectivities" ,
347
- "distances_key" : "distances" ,
348
- "params" : {
349
- "n_neighbors" : n_neighbors ,
350
- "method" : method ,
351
- "random_state" : 0 ,
352
- "metric" : "euclidean" ,
353
- },
354
- }
320
+
321
+ key_added , neighbors_dict = _get_metadata (
322
+ key_added ,
323
+ n_neighbors = n_neighbors ,
324
+ method = method ,
325
+ random_state = 0 ,
326
+ metric = metric ,
327
+ )
328
+ adata .uns [key_added ] = neighbors_dict
329
+ adata .obsp [neighbors_dict ["distances_key" ]] = distances
330
+ adata .obsp [neighbors_dict ["connectivities_key" ]] = connectivities
355
331
return adata
356
332
357
333
334
+ def _get_metadata (
335
+ key_added : str | None ,
336
+ ** params : Unpack [NeighborsParams ],
337
+ ) -> tuple [str , NeighborsDict ]:
338
+ if key_added is None :
339
+ return "neighbors" , NeighborsDict (
340
+ connectivities_key = "connectivities" ,
341
+ distances_key = "distances" ,
342
+ params = params ,
343
+ )
344
+ return key_added , NeighborsDict (
345
+ connectivities_key = f"{ key_added } _connectivities" ,
346
+ distances_key = f"{ key_added } _distances" ,
347
+ params = params ,
348
+ )
349
+
350
+
358
351
class FlatTree (NamedTuple ): # noqa: D101
359
352
hyperplanes : None
360
353
offsets : None
0 commit comments