Skip to content

Add neighbors_from_distance for computing neighborhood graphs from precomputed distance matrices #3627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/release-notes/3627.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Add your info here

Added `neighbors_from_distance`, function for computing graphs from a precoputing distance matrix using UMAP or Gaussian methods. `A. Karesh`
2 changes: 1 addition & 1 deletion src/scanpy/metrics/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def confusion_matrix(
orig, new = pd.Series(orig), pd.Series(new)
assert len(orig) == len(new)

unique_labels = pd.unique(np.concatenate((orig.values, new.values)))
unique_labels = pd.unique(np.concatenate((orig.to_numpy(), new.to_numpy())))

# Compute
mtx = _confusion_matrix(orig, new, labels=unique_labels)
Expand Down
150 changes: 122 additions & 28 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@
from .._utils import NeighborsView, _doc_params, get_literal_vals
from . import _connectivity
from ._common import (
_get_indices_distances_from_dense_matrix,
_get_indices_distances_from_sparse_matrix,
_get_sparse_matrix_from_indices_distances,
)
from ._connectivity import umap
from ._doc import doc_n_pcs, doc_use_rep
from ._types import _KnownTransformer, _Method

if TYPE_CHECKING:
from collections.abc import Callable, MutableMapping
from typing import Any, Literal, NotRequired
from typing import Any, Literal, NotRequired, Unpack

from anndata import AnnData
from igraph import Graph
Expand Down Expand Up @@ -58,6 +60,13 @@ class KwdsForTransformer(TypedDict):
random_state: _LegacyRandom


class NeighborsDict(TypedDict): # noqa: D101
connectivities_key: str
distances_key: str
params: NeighborsParams
rp_forest: NotRequired[RPForestDict]


class NeighborsParams(TypedDict): # noqa: D101
n_neighbors: int
method: _Method
Expand All @@ -74,6 +83,7 @@ def neighbors( # noqa: PLR0913
n_neighbors: int = 15,
n_pcs: int | None = None,
*,
distances: np.ndarray | SpBase | None = None,
use_rep: str | None = None,
knn: bool = True,
method: _Method = "umap",
Expand Down Expand Up @@ -135,6 +145,7 @@ def neighbors( # noqa: PLR0913
Use :func:`rapids_singlecell.pp.neighbors` instead.
metric
A known metric’s name or a callable that returns a distance.
If `distances` is given, this parameter is simply stored in `.uns` (see below).

*ignored if ``transformer`` is an instance.*
metric_kwds
Expand Down Expand Up @@ -186,6 +197,18 @@ def neighbors( # noqa: PLR0913
:doc:`/how-to/knn-transformers`

"""
if distances is not None:
if callable(metric):
msg = "`metric` must be a string if `distances` is given."
raise TypeError(msg)
# if a precomputed distance matrix is provided, skip the PCA and distance computation
return neighbors_from_distance(
adata,
distances,
n_neighbors=n_neighbors,
metric=metric,
method=method,
)
start = logg.info("computing neighbors")
adata = adata.copy() if copy else adata
if adata.is_view: # we shouldn't need this here...
Expand All @@ -203,51 +226,122 @@ def neighbors( # noqa: PLR0913
random_state=random_state,
)

if key_added is None:
key_added = "neighbors"
conns_key = "connectivities"
dists_key = "distances"
else:
conns_key = key_added + "_connectivities"
dists_key = key_added + "_distances"

adata.uns[key_added] = {}

neighbors_dict = adata.uns[key_added]

neighbors_dict["connectivities_key"] = conns_key
neighbors_dict["distances_key"] = dists_key

neighbors_dict["params"] = NeighborsParams(
key_added, neighbors_dict = _get_metadata(
key_added,
n_neighbors=neighbors.n_neighbors,
method=method,
random_state=random_state,
metric=metric,
**({} if not metric_kwds else dict(metric_kwds=metric_kwds)),
**({} if use_rep is None else dict(use_rep=use_rep)),
**({} if n_pcs is None else dict(n_pcs=n_pcs)),
)
if metric_kwds:
neighbors_dict["params"]["metric_kwds"] = metric_kwds
if use_rep is not None:
neighbors_dict["params"]["use_rep"] = use_rep
if n_pcs is not None:
neighbors_dict["params"]["n_pcs"] = n_pcs

adata.obsp[dists_key] = neighbors.distances
adata.obsp[conns_key] = neighbors.connectivities

if neighbors.rp_forest is not None:
neighbors_dict["rp_forest"] = neighbors.rp_forest

adata.uns[key_added] = neighbors_dict
adata.obsp[neighbors_dict["distances_key"]] = neighbors.distances
adata.obsp[neighbors_dict["connectivities_key"]] = neighbors.connectivities

logg.info(
" finished",
time=start,
deep=(
f"added to `.uns[{key_added!r}]`\n"
f" `.obsp[{dists_key!r}]`, distances for each pair of neighbors\n"
f" `.obsp[{conns_key!r}]`, weighted adjacency matrix"
f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n"
f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix"
),
)
return adata if copy else None


def neighbors_from_distance(
adata: AnnData,
distances: np.ndarray | SpBase,
*,
n_neighbors: int = 15,
metric: _Metric = "euclidean",
method: _Method = "umap", # default to umap
key_added: str | None = None,
) -> AnnData:
"""Compute neighbors from a precomputer distance matrix.

Parameters
----------
adata
Annotated data matrix.
distances
Precomputed dense or sparse distance matrix.
n_neighbors
Number of nearest neighbors to use in the graph.
method
Method to use for computing the graph. Currently only 'umap' is supported.
key_added
Optional key under which to store the results. Default is 'neighbors'.

Returns
-------
adata
Annotated data with computed distances and connectivities.
"""
if isinstance(distances, SpBase):
distances = sparse.csr_matrix(distances) # noqa: TID251
distances.setdiag(0)
distances.eliminate_zeros()
else:
distances = np.asarray(distances)
np.fill_diagonal(distances, 0)

if method == "umap":
if isinstance(distances, CSRBase):
knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix(
distances, n_neighbors
)
else:
knn_indices, knn_distances = _get_indices_distances_from_dense_matrix(
distances, n_neighbors
)
connectivities = umap(
knn_indices, knn_distances, n_obs=adata.n_obs, n_neighbors=n_neighbors
)
elif method == "gauss":
distances = sparse.csr_matrix(distances) # noqa: TID251
connectivities = _connectivity.gauss(distances, n_neighbors, knn=True)
else:
msg = f"Method {method} not implemented."
raise NotImplementedError(msg)

key_added, neighbors_dict = _get_metadata(
key_added,
n_neighbors=n_neighbors,
method=method,
random_state=0,
metric=metric,
)
adata.uns[key_added] = neighbors_dict
adata.obsp[neighbors_dict["distances_key"]] = distances
adata.obsp[neighbors_dict["connectivities_key"]] = connectivities
return adata


def _get_metadata(
key_added: str | None,
**params: Unpack[NeighborsParams],
) -> tuple[str, NeighborsDict]:
if key_added is None:
return "neighbors", NeighborsDict(
connectivities_key="connectivities",
distances_key="distances",
params=params,
)
return key_added, NeighborsDict(
connectivities_key=f"{key_added}_connectivities",
distances_key=f"{key_added}_distances",
params=params,
)


class FlatTree(NamedTuple): # noqa: D101
hyperplanes: None
offsets: None
Expand Down
20 changes: 20 additions & 0 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scanpy import Neighbors
from scanpy._compat import CSBase
from testing.scanpy._helpers import anndata_v0_8_constructor_compat
from testing.scanpy._helpers.data import pbmc68k_reduced

if TYPE_CHECKING:
from typing import Literal
Expand Down Expand Up @@ -241,3 +242,22 @@ def test_restore_n_neighbors(neigh, conv):
ad.uns["neighbors"] = dict(connectivities=conv(neigh.connectivities))
neigh_restored = Neighbors(ad)
assert neigh_restored.n_neighbors == 1


def test_neighbors_distance_equivalence():
adata = pbmc68k_reduced()
adata_d = adata.copy()

sc.pp.neighbors(adata)
# reusing the same distances
sc.pp.neighbors(adata_d, distances=adata.obsp["distances"])
np.testing.assert_allclose(
adata.obsp["connectivities"].toarray(),
adata_d.obsp["connectivities"].toarray(),
rtol=1e-5,
)
np.testing.assert_allclose(
adata.obsp["distances"].toarray(),
adata_d.obsp["distances"].toarray(),
rtol=1e-5,
)
Loading