Skip to content

Commit d5901ab

Browse files
feat: add distances parameter to sc.pp.neighbors (#3627)
Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent 2ff2a9b commit d5901ab

File tree

5 files changed

+179
-77
lines changed

5 files changed

+179
-77
lines changed

docs/release-notes/3627.feat.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added `distances` parameter to {func}`scanpy.pp.neighbors`, allowing to compute graphs from a precomputed distance matrix {smaller}`A. Karesh`

src/scanpy/metrics/_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def confusion_matrix(
6060
orig, new = pd.Series(orig), pd.Series(new)
6161
assert len(orig) == len(new)
6262

63-
unique_labels = pd.unique(np.concatenate((orig.values, new.values)))
63+
unique_labels = pd.unique(np.concatenate((orig.to_numpy(), new.to_numpy())))
6464

6565
# Compute
6666
mtx = _confusion_matrix(orig, new, labels=unique_labels)

src/scanpy/neighbors/__init__.py

Lines changed: 151 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import contextlib
6+
from inspect import signature
67
from textwrap import indent
78
from types import MappingProxyType
89
from typing import TYPE_CHECKING, NamedTuple, TypedDict
@@ -19,18 +20,21 @@
1920
from .._utils import NeighborsView, _doc_params, get_literal_vals
2021
from . import _connectivity
2122
from ._common import (
23+
_get_indices_distances_from_dense_matrix,
2224
_get_indices_distances_from_sparse_matrix,
2325
_get_sparse_matrix_from_indices_distances,
2426
)
27+
from ._connectivity import umap
2528
from ._doc import doc_n_pcs, doc_use_rep
2629
from ._types import _KnownTransformer, _Method
2730

2831
if TYPE_CHECKING:
2932
from collections.abc import Callable, Mapping, MutableMapping
30-
from typing import Any, Literal, NotRequired, TypeAlias
33+
from typing import Any, Literal, NotRequired, TypeAlias, Unpack
3134

3235
from anndata import AnnData
3336
from igraph import Graph
37+
from numpy.typing import NDArray
3438

3539
from .._utils.random import _LegacyRandom
3640
from ._types import KnnTransformerLike, _Metric, _MetricFn
@@ -56,11 +60,18 @@ class KwdsForTransformer(TypedDict):
5660
random_state: _LegacyRandom
5761

5862

63+
class NeighborsDict(TypedDict): # noqa: D101
64+
connectivities_key: str
65+
distances_key: str
66+
params: NeighborsParams
67+
rp_forest: NotRequired[RPForestDict]
68+
69+
5970
class NeighborsParams(TypedDict): # noqa: D101
6071
n_neighbors: int
6172
method: _Method
6273
random_state: _LegacyRandom
63-
metric: _Metric | _MetricFn
74+
metric: _Metric | _MetricFn | None
6475
metric_kwds: NotRequired[Mapping[str, Any]]
6576
use_rep: NotRequired[str]
6677
n_pcs: NotRequired[int]
@@ -72,11 +83,12 @@ def neighbors( # noqa: PLR0913
7283
n_neighbors: int = 15,
7384
n_pcs: int | None = None,
7485
*,
86+
distances: np.ndarray | SpBase | None = None,
7587
use_rep: str | None = None,
7688
knn: bool = True,
7789
method: _Method = "umap",
7890
transformer: KnnTransformerLike | _KnownTransformer | None = None,
79-
metric: _Metric | _MetricFn = "euclidean",
91+
metric: _Metric | _MetricFn | None = None,
8092
metric_kwds: Mapping[str, Any] = MappingProxyType({}),
8193
random_state: _LegacyRandom = 0,
8294
key_added: str | None = None,
@@ -138,6 +150,8 @@ def neighbors( # noqa: PLR0913
138150
Use :func:`rapids_singlecell.pp.neighbors` instead.
139151
metric
140152
A known metric’s name or a callable that returns a distance.
153+
If `distances` is given, this parameter is simply stored in `.uns` (see below),
154+
otherwise defaults to `'euclidean'`.
141155
142156
*ignored if ``transformer`` is an instance.*
143157
metric_kwds
@@ -153,18 +167,18 @@ def neighbors( # noqa: PLR0913
153167
distances and connectivities are stored in `.obsp['distances']` and
154168
`.obsp['connectivities']` respectively.
155169
If specified, the neighbors data is added to .uns[key_added],
156-
distances are stored in `.obsp[key_added+'_distances']` and
157-
connectivities in `.obsp[key_added+'_connectivities']`.
170+
distances are stored in `.obsp[f'{{key_added}}_distances']` and
171+
connectivities in `.obsp[f'{{key_added}}_connectivities']`.
158172
copy
159173
Return a copy instead of writing to adata.
160174
161175
Returns
162176
-------
163177
Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields:
164178
165-
`adata.obsp['distances' | key_added+'_distances']` : :class:`scipy.sparse.csr_matrix` (dtype `float`)
179+
`adata.obsp['distances' | f'{{key_added}}_distances']` : :class:`scipy.sparse.csr_matrix` (dtype `float`)
166180
Distance matrix of the nearest neighbors search. Each row (cell) has `n_neighbors`-1 non-zero entries. These are the distances to their `n_neighbors`-1 nearest neighbors (excluding the cell itself).
167-
`adata.obsp['connectivities' | key_added+'_connectivities']` : :class:`scipy.sparse._csr.csr_matrix` (dtype `float`)
181+
`adata.obsp['connectivities' | f'{{key_added}}_connectivities']` : :class:`scipy.sparse._csr.csr_matrix` (dtype `float`)
168182
Weighted adjacency matrix of the neighborhood graph of data
169183
points. Weights should be interpreted as connectivities.
170184
`adata.uns['neighbors' | key_added]` : :class:`dict`
@@ -189,68 +203,108 @@ def neighbors( # noqa: PLR0913
189203
:doc:`/how-to/knn-transformers`
190204
191205
"""
192-
start = logg.info("computing neighbors")
193-
adata = adata.copy() if copy else adata
194-
if adata.is_view: # we shouldn't need this here...
195-
adata._init_as_actual(adata.copy())
196-
neighbors = Neighbors(adata)
197-
neighbors.compute_neighbors(
198-
n_neighbors,
199-
n_pcs=n_pcs,
200-
use_rep=use_rep,
201-
knn=knn,
206+
if distances is None:
207+
if metric is None:
208+
metric = "euclidean"
209+
start = logg.info("computing neighbors")
210+
adata = adata.copy() if copy else adata
211+
if adata.is_view: # we shouldn't need this here...
212+
adata._init_as_actual(adata.copy())
213+
neighbors_ = Neighbors(adata)
214+
neighbors_.compute_neighbors(
215+
n_neighbors,
216+
n_pcs=n_pcs,
217+
use_rep=use_rep,
218+
knn=knn,
219+
method=method,
220+
transformer=transformer,
221+
metric=metric,
222+
metric_kwds=metric_kwds,
223+
random_state=random_state,
224+
)
225+
else:
226+
params = locals()
227+
if ignored := {
228+
p.name
229+
for p in signature(neighbors).parameters.values()
230+
if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds", "random_state"}
231+
if params[p.name] != p.default
232+
}:
233+
warn(
234+
f"Parameter(s) ignored if `distances` is given: {ignored}",
235+
UserWarning,
236+
)
237+
random_state = 0
238+
if callable(metric):
239+
msg = "`metric` must be a string if `distances` is given."
240+
raise TypeError(msg)
241+
start = logg.info("computing connectivities")
242+
# if a precomputed distance matrix is provided, skip the PCA and distance computation
243+
if isinstance(distances, SpBase):
244+
if TYPE_CHECKING:
245+
from scipy.sparse._base import _spbase
246+
247+
assert isinstance(distances, _spbase)
248+
distances = distances.tocsr(copy=True)
249+
distances.setdiag(0)
250+
distances.eliminate_zeros()
251+
else:
252+
distances = np.asarray(distances)
253+
np.fill_diagonal(distances, 0)
254+
255+
neighbors_ = Neighbors(adata)
256+
neighbors_.n_neighbors = n_neighbors
257+
neighbors_.knn = True
258+
neighbors_._distances = distances
259+
neighbors_._connectivities = neighbors_._compute_connectivites(method)
260+
261+
key_added, neighbors_dict = _get_metadata(
262+
key_added,
263+
n_neighbors=neighbors_.n_neighbors,
202264
method=method,
203-
transformer=transformer,
204-
metric=metric,
205-
metric_kwds=metric_kwds,
206265
random_state=random_state,
266+
metric=metric,
267+
**({} if not metric_kwds else dict(metric_kwds=metric_kwds)),
268+
**({} if use_rep is None else dict(use_rep=use_rep)),
269+
**({} if n_pcs is None else dict(n_pcs=n_pcs)),
207270
)
208271

209-
if key_added is None:
210-
key_added = "neighbors"
211-
conns_key = "connectivities"
212-
dists_key = "distances"
213-
else:
214-
conns_key = f"{key_added}_connectivities"
215-
dists_key = f"{key_added}_distances"
216-
217-
adata.uns[key_added] = {}
218-
219-
neighbors_dict = adata.uns[key_added]
272+
if neighbors_.rp_forest is not None:
273+
neighbors_dict["rp_forest"] = neighbors_.rp_forest
220274

221-
neighbors_dict["connectivities_key"] = conns_key
222-
neighbors_dict["distances_key"] = dists_key
275+
adata.uns[key_added] = neighbors_dict
276+
adata.obsp[neighbors_dict["distances_key"]] = neighbors_.distances
277+
adata.obsp[neighbors_dict["connectivities_key"]] = neighbors_.connectivities
223278

224-
neighbors_dict["params"] = NeighborsParams(
225-
n_neighbors=neighbors.n_neighbors,
226-
method=method,
227-
random_state=random_state,
228-
metric=metric,
229-
)
230-
if metric_kwds:
231-
neighbors_dict["params"]["metric_kwds"] = metric_kwds
232-
if use_rep is not None:
233-
neighbors_dict["params"]["use_rep"] = use_rep
234-
if n_pcs is not None:
235-
neighbors_dict["params"]["n_pcs"] = n_pcs
236-
237-
adata.obsp[dists_key] = neighbors.distances
238-
adata.obsp[conns_key] = neighbors.connectivities
239-
240-
if neighbors.rp_forest is not None:
241-
neighbors_dict["rp_forest"] = neighbors.rp_forest
242279
logg.info(
243280
" finished",
244281
time=start,
245282
deep=(
246283
f"added to `.uns[{key_added!r}]`\n"
247-
f" `.obsp[{dists_key!r}]`, distances for each pair of neighbors\n"
248-
f" `.obsp[{conns_key!r}]`, weighted adjacency matrix"
284+
f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n"
285+
f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix"
249286
),
250287
)
251288
return adata if copy else None
252289

253290

291+
def _get_metadata(
292+
key_added: str | None,
293+
**params: Unpack[NeighborsParams],
294+
) -> tuple[str, NeighborsDict]:
295+
if key_added is None:
296+
return "neighbors", NeighborsDict(
297+
connectivities_key="connectivities",
298+
distances_key="distances",
299+
params=params,
300+
)
301+
return key_added, NeighborsDict(
302+
connectivities_key=f"{key_added}_connectivities",
303+
distances_key=f"{key_added}_distances",
304+
params=params,
305+
)
306+
307+
254308
class FlatTree(NamedTuple): # noqa: D101
255309
hyperplanes: None
256310
offsets: None
@@ -358,7 +412,7 @@ class Neighbors:
358412
n_dcs
359413
Number of diffusion components to use.
360414
neighbors_key
361-
Where to look in `.uns` and `.obsp` for neighbors data
415+
Where to look in `.uns` and `.obsp` for neighbors data.
362416
363417
"""
364418

@@ -518,7 +572,7 @@ def to_igraph(self) -> Graph:
518572
return _utils.get_igraph_from_adjacency(self.connectivities)
519573

520574
@_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep)
521-
def compute_neighbors( # noqa: PLR0912
575+
def compute_neighbors(
522576
self,
523577
n_neighbors: int = 30,
524578
n_pcs: int | None = None,
@@ -601,26 +655,11 @@ def compute_neighbors( # noqa: PLR0912
601655
self._rp_forest = _make_forest_dict(index)
602656
start_connect = logg.debug("computed neighbors", time=start_neighbors)
603657

604-
if method == "umap":
605-
self._connectivities = _connectivity.umap(
606-
knn_indices,
607-
knn_distances,
608-
n_obs=self._adata.shape[0],
609-
n_neighbors=self.n_neighbors,
610-
)
611-
elif method == "gauss":
612-
self._connectivities = _connectivity.gauss(
613-
self._distances, self.n_neighbors, knn=self.knn
614-
)
615-
elif method == "jaccard":
616-
self._connectivities = _connectivity.jaccard(
617-
knn_indices,
618-
n_obs=self._adata.shape[0],
619-
n_neighbors=self.n_neighbors,
620-
)
621-
elif method is not None:
622-
msg = f"{method!r} should have been coerced in _handle_transform_args"
623-
raise AssertionError(msg)
658+
self._connectivities = (
659+
None
660+
if method is None
661+
else self._compute_connectivites(method, (knn_indices, knn_distances))
662+
)
624663
self._number_connected_components = 1
625664
if isinstance(self._connectivities, CSBase):
626665
from scipy.sparse.csgraph import connected_components
@@ -630,6 +669,43 @@ def compute_neighbors( # noqa: PLR0912
630669
if method is not None:
631670
logg.debug("computed connectivities", time=start_connect)
632671

672+
def _compute_connectivites(
673+
self,
674+
method: _Method,
675+
knn_ind_dist: (
676+
tuple[NDArray[np.int32 | np.int64], NDArray[np.float32 | np.float64]] | None
677+
) = None,
678+
) -> CSRBase | NDArray[np.float32 | np.float64] | None:
679+
def get_knn():
680+
if knn_ind_dist is not None:
681+
return knn_ind_dist
682+
if isinstance(self._distances, CSBase):
683+
return _get_indices_distances_from_sparse_matrix(
684+
self._distances.tocsr(), self.n_neighbors
685+
)
686+
assert self._distances is not None
687+
return _get_indices_distances_from_dense_matrix(
688+
self._distances, self.n_neighbors
689+
)
690+
691+
if method == "umap":
692+
knn_indices, knn_distances = get_knn()
693+
return umap(
694+
knn_indices,
695+
knn_distances,
696+
n_obs=self._adata.n_obs,
697+
n_neighbors=self.n_neighbors,
698+
)
699+
if method == "gauss":
700+
return _connectivity.gauss(self._distances, self.n_neighbors, knn=self.knn)
701+
if method == "jaccard":
702+
knn_indices, _ = get_knn()
703+
return _connectivity.jaccard(
704+
knn_indices, n_obs=self._adata.n_obs, n_neighbors=self.n_neighbors
705+
)
706+
msg = f"Method {method} not implemented."
707+
raise NotImplementedError(msg)
708+
633709
def _handle_transformer(
634710
self,
635711
method: _Method | Literal["gauss"] | None,

src/scanpy/neighbors/_connectivity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515

1616

17-
def gauss[D: (NDArray[np.float32], CSRBase)]( # noqa: PLR0912
17+
def gauss[D: (NDArray[np.float32 | np.float64], CSRBase)]( # noqa: PLR0912
1818
distances: D, n_neighbors: int, *, knn: bool
1919
) -> D:
2020
"""Derive gaussian connectivities between data points from their distances.

0 commit comments

Comments
 (0)