Skip to content

Commit 040b8b7

Browse files
committed
unify metadata assembly
1 parent 4730667 commit 040b8b7

File tree

1 file changed

+65
-72
lines changed

1 file changed

+65
-72
lines changed

src/scanpy/neighbors/__init__.py

Lines changed: 65 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
if TYPE_CHECKING:
3333
from collections.abc import Callable, MutableMapping
34-
from typing import Any, Literal, NotRequired
34+
from typing import Any, Literal, NotRequired, Unpack
3535

3636
from anndata import AnnData
3737
from igraph import Graph
@@ -60,6 +60,13 @@ class KwdsForTransformer(TypedDict):
6060
random_state: _LegacyRandom
6161

6262

63+
class NeighborsDict(TypedDict): # noqa: D101
64+
connectivities_key: str
65+
distances_key: str
66+
params: NeighborsParams
67+
rp_forest: NotRequired[RPForestDict]
68+
69+
6370
class NeighborsParams(TypedDict): # noqa: D101
6471
n_neighbors: int
6572
method: _Method
@@ -138,6 +145,7 @@ def neighbors( # noqa: PLR0913
138145
Use :func:`rapids_singlecell.pp.neighbors` instead.
139146
metric
140147
A known metric’s name or a callable that returns a distance.
148+
If `distances` is given, this parameter is simply stored in `.uns` (see below).
141149
142150
*ignored if ``transformer`` is an instance.*
143151
metric_kwds
@@ -190,12 +198,15 @@ def neighbors( # noqa: PLR0913
190198
191199
"""
192200
if distances is not None:
193-
# Added this to support the new distance matrix function
201+
if callable(metric):
202+
msg = "`metric` must be a string if `distances` is given."
203+
raise TypeError(msg)
194204
# if a precomputed distance matrix is provided, skip the PCA and distance computation
195205
return neighbors_from_distance(
196206
adata,
197207
distances,
198208
n_neighbors=n_neighbors,
209+
metric=metric,
199210
method=method,
200211
)
201212
start = logg.info("computing neighbors")
@@ -215,46 +226,31 @@ def neighbors( # noqa: PLR0913
215226
random_state=random_state,
216227
)
217228

218-
if key_added is None:
219-
key_added = "neighbors"
220-
conns_key = "connectivities"
221-
dists_key = "distances"
222-
else:
223-
conns_key = key_added + "_connectivities"
224-
dists_key = key_added + "_distances"
225-
226-
adata.uns[key_added] = {}
227-
228-
neighbors_dict = adata.uns[key_added]
229-
230-
neighbors_dict["connectivities_key"] = conns_key
231-
neighbors_dict["distances_key"] = dists_key
232-
233-
neighbors_dict["params"] = NeighborsParams(
229+
key_added, neighbors_dict = _get_metadata(
230+
key_added,
234231
n_neighbors=neighbors.n_neighbors,
235232
method=method,
236233
random_state=random_state,
237234
metric=metric,
235+
**({} if not metric_kwds else dict(metric_kwds=metric_kwds)),
236+
**({} if use_rep is None else dict(use_rep=use_rep)),
237+
**({} if n_pcs is None else dict(n_pcs=n_pcs)),
238238
)
239-
if metric_kwds:
240-
neighbors_dict["params"]["metric_kwds"] = metric_kwds
241-
if use_rep is not None:
242-
neighbors_dict["params"]["use_rep"] = use_rep
243-
if n_pcs is not None:
244-
neighbors_dict["params"]["n_pcs"] = n_pcs
245-
246-
adata.obsp[dists_key] = neighbors.distances
247-
adata.obsp[conns_key] = neighbors.connectivities
248239

249240
if neighbors.rp_forest is not None:
250241
neighbors_dict["rp_forest"] = neighbors.rp_forest
242+
243+
adata.uns[key_added] = neighbors_dict
244+
adata.obsp[neighbors_dict["distances_key"]] = neighbors.distances
245+
adata.obsp[neighbors_dict["connectivities_key"]] = neighbors.connectivities
246+
251247
logg.info(
252248
" finished",
253249
time=start,
254250
deep=(
255251
f"added to `.uns[{key_added!r}]`\n"
256-
f" `.obsp[{dists_key!r}]`, distances for each pair of neighbors\n"
257-
f" `.obsp[{conns_key!r}]`, weighted adjacency matrix"
252+
f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n"
253+
f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix"
258254
),
259255
)
260256
return adata if copy else None
@@ -265,6 +261,7 @@ def neighbors_from_distance(
265261
distances: np.ndarray | SpBase,
266262
*,
267263
n_neighbors: int = 15,
264+
metric: _Metric = "euclidean",
268265
method: _Method = "umap", # default to umap
269266
key_added: str | None = None,
270267
) -> AnnData:
@@ -298,63 +295,59 @@ def neighbors_from_distance(
298295
distances = sparse.csr_matrix(distances) # noqa: TID251
299296
distances.setdiag(0)
300297
distances.eliminate_zeros()
301-
# extracting for each observation the indices and distances of the n_neighbors
302-
# being then used by umap or gauss
303-
knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix(
304-
distances, n_neighbors
305-
)
306298
else:
307-
# if it is dense, converting it to ndarray
308-
# and setting the diagonal to 0
309-
# extracting knn indices and distances
310299
distances = np.asarray(distances)
311300
np.fill_diagonal(distances, 0)
312-
knn_indices, knn_distances = _get_indices_distances_from_dense_matrix(
313-
distances, n_neighbors
314-
)
315301

316302
if method == "umap":
317-
# using umap to build connectivities from distances
303+
if isinstance(distances, CSRBase):
304+
knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix(
305+
distances, n_neighbors
306+
)
307+
else:
308+
knn_indices, knn_distances = _get_indices_distances_from_dense_matrix(
309+
distances, n_neighbors
310+
)
318311
connectivities = umap(
319-
knn_indices,
320-
knn_distances,
321-
n_obs=adata.n_obs,
322-
n_neighbors=n_neighbors,
312+
knn_indices, knn_distances, n_obs=adata.n_obs, n_neighbors=n_neighbors
323313
)
324314
elif method == "gauss":
325-
# using gauss to build connectivities from distances
326-
# requires sparse matrix for efficiency
327-
connectivities = _connectivity.gauss(
328-
sparse.csr_matrix(distances), # noqa: TID251
329-
n_neighbors,
330-
knn=True,
331-
)
315+
distances = sparse.csr_matrix(distances) # noqa: TID251
316+
connectivities = _connectivity.gauss(distances, n_neighbors, knn=True)
332317
else:
333318
msg = f"Method {method} not implemented."
334319
raise NotImplementedError(msg)
335-
# defining where to store graph info
336-
key = "neighbors" if key_added is None else key_added
337-
dists_key = "distances" if key_added is None else key_added + "_distances"
338-
conns_key = "connectivities" if key_added is None else key_added + "_connectivities"
339-
# storing the actual distance and connectivitiy matrices as obsp
340-
adata.obsp[dists_key] = sparse.csr_matrix(distances) # noqa: TID251
341-
adata.obsp[conns_key] = connectivities
342-
# populating with metadata describing how neighbors were computed
343-
# I think might be important as many functions downstream rely
344-
# on .uns['neighbors'] to find correct .obsp key
345-
adata.uns[key] = {
346-
"connectivities_key": "connectivities",
347-
"distances_key": "distances",
348-
"params": {
349-
"n_neighbors": n_neighbors,
350-
"method": method,
351-
"random_state": 0,
352-
"metric": "euclidean",
353-
},
354-
}
320+
321+
key_added, neighbors_dict = _get_metadata(
322+
key_added,
323+
n_neighbors=n_neighbors,
324+
method=method,
325+
random_state=0,
326+
metric=metric,
327+
)
328+
adata.uns[key_added] = neighbors_dict
329+
adata.obsp[neighbors_dict["distances_key"]] = distances
330+
adata.obsp[neighbors_dict["connectivities_key"]] = connectivities
355331
return adata
356332

357333

334+
def _get_metadata(
335+
key_added: str | None,
336+
**params: Unpack[NeighborsParams],
337+
) -> tuple[str, NeighborsDict]:
338+
if key_added is None:
339+
return "neighbors", NeighborsDict(
340+
connectivities_key="connectivities",
341+
distances_key="distances",
342+
params=params,
343+
)
344+
return key_added, NeighborsDict(
345+
connectivities_key=f"{key_added}_connectivities",
346+
distances_key=f"{key_added}_distances",
347+
params=params,
348+
)
349+
350+
358351
class FlatTree(NamedTuple): # noqa: D101
359352
hyperplanes: None
360353
offsets: None

0 commit comments

Comments
 (0)