diff --git a/docs/release-notes/3627.feature.md b/docs/release-notes/3627.feature.md new file mode 100644 index 0000000000..3443aa5cea --- /dev/null +++ b/docs/release-notes/3627.feature.md @@ -0,0 +1,3 @@ +Add your info here + +Added `neighbors_from_distance`, function for computing graphs from a precoputing distance matrix using UMAP or Gaussian methods. `A. Karesh` diff --git a/src/scanpy/metrics/_metrics.py b/src/scanpy/metrics/_metrics.py index c24cfe9fd1..83957fcd4a 100644 --- a/src/scanpy/metrics/_metrics.py +++ b/src/scanpy/metrics/_metrics.py @@ -60,7 +60,7 @@ def confusion_matrix( orig, new = pd.Series(orig), pd.Series(new) assert len(orig) == len(new) - unique_labels = pd.unique(np.concatenate((orig.values, new.values))) + unique_labels = pd.unique(np.concatenate((orig.to_numpy(), new.to_numpy()))) # Compute mtx = _confusion_matrix(orig, new, labels=unique_labels) diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 4a42111b74..574c036f63 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -21,15 +21,17 @@ from .._utils import NeighborsView, _doc_params, get_literal_vals from . import _connectivity from ._common import ( + _get_indices_distances_from_dense_matrix, _get_indices_distances_from_sparse_matrix, _get_sparse_matrix_from_indices_distances, ) +from ._connectivity import umap from ._doc import doc_n_pcs, doc_use_rep from ._types import _KnownTransformer, _Method if TYPE_CHECKING: from collections.abc import Callable, MutableMapping - from typing import Any, Literal, NotRequired + from typing import Any, Literal, NotRequired, Unpack from anndata import AnnData from igraph import Graph @@ -58,6 +60,13 @@ class KwdsForTransformer(TypedDict): random_state: _LegacyRandom +class NeighborsDict(TypedDict): # noqa: D101 + connectivities_key: str + distances_key: str + params: NeighborsParams + rp_forest: NotRequired[RPForestDict] + + class NeighborsParams(TypedDict): # noqa: D101 n_neighbors: int method: _Method @@ -74,6 +83,7 @@ def neighbors( # noqa: PLR0913 n_neighbors: int = 15, n_pcs: int | None = None, *, + distances: np.ndarray | SpBase | None = None, use_rep: str | None = None, knn: bool = True, method: _Method = "umap", @@ -135,6 +145,7 @@ def neighbors( # noqa: PLR0913 Use :func:`rapids_singlecell.pp.neighbors` instead. metric A known metric’s name or a callable that returns a distance. + If `distances` is given, this parameter is simply stored in `.uns` (see below). *ignored if ``transformer`` is an instance.* metric_kwds @@ -186,6 +197,18 @@ def neighbors( # noqa: PLR0913 :doc:`/how-to/knn-transformers` """ + if distances is not None: + if callable(metric): + msg = "`metric` must be a string if `distances` is given." + raise TypeError(msg) + # if a precomputed distance matrix is provided, skip the PCA and distance computation + return neighbors_from_distance( + adata, + distances, + n_neighbors=n_neighbors, + metric=metric, + method=method, + ) start = logg.info("computing neighbors") adata = adata.copy() if copy else adata if adata.is_view: # we shouldn't need this here... @@ -203,51 +226,122 @@ def neighbors( # noqa: PLR0913 random_state=random_state, ) - if key_added is None: - key_added = "neighbors" - conns_key = "connectivities" - dists_key = "distances" - else: - conns_key = key_added + "_connectivities" - dists_key = key_added + "_distances" - - adata.uns[key_added] = {} - - neighbors_dict = adata.uns[key_added] - - neighbors_dict["connectivities_key"] = conns_key - neighbors_dict["distances_key"] = dists_key - - neighbors_dict["params"] = NeighborsParams( + key_added, neighbors_dict = _get_metadata( + key_added, n_neighbors=neighbors.n_neighbors, method=method, random_state=random_state, metric=metric, + **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), + **({} if use_rep is None else dict(use_rep=use_rep)), + **({} if n_pcs is None else dict(n_pcs=n_pcs)), ) - if metric_kwds: - neighbors_dict["params"]["metric_kwds"] = metric_kwds - if use_rep is not None: - neighbors_dict["params"]["use_rep"] = use_rep - if n_pcs is not None: - neighbors_dict["params"]["n_pcs"] = n_pcs - - adata.obsp[dists_key] = neighbors.distances - adata.obsp[conns_key] = neighbors.connectivities if neighbors.rp_forest is not None: neighbors_dict["rp_forest"] = neighbors.rp_forest + + adata.uns[key_added] = neighbors_dict + adata.obsp[neighbors_dict["distances_key"]] = neighbors.distances + adata.obsp[neighbors_dict["connectivities_key"]] = neighbors.connectivities + logg.info( " finished", time=start, deep=( f"added to `.uns[{key_added!r}]`\n" - f" `.obsp[{dists_key!r}]`, distances for each pair of neighbors\n" - f" `.obsp[{conns_key!r}]`, weighted adjacency matrix" + f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n" + f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix" ), ) return adata if copy else None +def neighbors_from_distance( + adata: AnnData, + distances: np.ndarray | SpBase, + *, + n_neighbors: int = 15, + metric: _Metric = "euclidean", + method: _Method = "umap", # default to umap + key_added: str | None = None, +) -> AnnData: + """Compute neighbors from a precomputer distance matrix. + + Parameters + ---------- + adata + Annotated data matrix. + distances + Precomputed dense or sparse distance matrix. + n_neighbors + Number of nearest neighbors to use in the graph. + method + Method to use for computing the graph. Currently only 'umap' is supported. + key_added + Optional key under which to store the results. Default is 'neighbors'. + + Returns + ------- + adata + Annotated data with computed distances and connectivities. + """ + if isinstance(distances, SpBase): + distances = sparse.csr_matrix(distances) # noqa: TID251 + distances.setdiag(0) + distances.eliminate_zeros() + else: + distances = np.asarray(distances) + np.fill_diagonal(distances, 0) + + if method == "umap": + if isinstance(distances, CSRBase): + knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix( + distances, n_neighbors + ) + else: + knn_indices, knn_distances = _get_indices_distances_from_dense_matrix( + distances, n_neighbors + ) + connectivities = umap( + knn_indices, knn_distances, n_obs=adata.n_obs, n_neighbors=n_neighbors + ) + elif method == "gauss": + distances = sparse.csr_matrix(distances) # noqa: TID251 + connectivities = _connectivity.gauss(distances, n_neighbors, knn=True) + else: + msg = f"Method {method} not implemented." + raise NotImplementedError(msg) + + key_added, neighbors_dict = _get_metadata( + key_added, + n_neighbors=n_neighbors, + method=method, + random_state=0, + metric=metric, + ) + adata.uns[key_added] = neighbors_dict + adata.obsp[neighbors_dict["distances_key"]] = distances + adata.obsp[neighbors_dict["connectivities_key"]] = connectivities + return adata + + +def _get_metadata( + key_added: str | None, + **params: Unpack[NeighborsParams], +) -> tuple[str, NeighborsDict]: + if key_added is None: + return "neighbors", NeighborsDict( + connectivities_key="connectivities", + distances_key="distances", + params=params, + ) + return key_added, NeighborsDict( + connectivities_key=f"{key_added}_connectivities", + distances_key=f"{key_added}_distances", + params=params, + ) + + class FlatTree(NamedTuple): # noqa: D101 hyperplanes: None offsets: None diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 904c47f813..7562394548 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -13,6 +13,7 @@ from scanpy import Neighbors from scanpy._compat import CSBase from testing.scanpy._helpers import anndata_v0_8_constructor_compat +from testing.scanpy._helpers.data import pbmc68k_reduced if TYPE_CHECKING: from typing import Literal @@ -241,3 +242,22 @@ def test_restore_n_neighbors(neigh, conv): ad.uns["neighbors"] = dict(connectivities=conv(neigh.connectivities)) neigh_restored = Neighbors(ad) assert neigh_restored.n_neighbors == 1 + + +def test_neighbors_distance_equivalence(): + adata = pbmc68k_reduced() + adata_d = adata.copy() + + sc.pp.neighbors(adata) + # reusing the same distances + sc.pp.neighbors(adata_d, distances=adata.obsp["distances"]) + np.testing.assert_allclose( + adata.obsp["connectivities"].toarray(), + adata_d.obsp["connectivities"].toarray(), + rtol=1e-5, + ) + np.testing.assert_allclose( + adata.obsp["distances"].toarray(), + adata_d.obsp["distances"].toarray(), + rtol=1e-5, + )