From 9c5a936b0ecf21ba09b1bc803712a42c2a942fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michaela=20M=C3=BCller?= <51025211+mumichae@users.noreply.github.com> Date: Mon, 22 Apr 2024 19:59:15 +0200 Subject: [PATCH] Use precomputed clustering (#406) * cluster only when clustering key does not exist or if forced * add tests for precomputed clusters * allow passing arguments to cluster_optimal_resolution from isolated label score * fix docstring issues * add test for isolated labe score with precomputed clusters * fix neighbors check * fix resolution getter function * update explicit get_resolutions functions * return no isolated labels when minimum number of batches per label is the same as total batches * include update docstrings for clustering * add use_rep to optimal clustering * use graph connectivity for scanorama --- docs/source/api.rst | 17 ++++ scib/metrics/clustering.py | 106 ++++++++++++++++--------- scib/metrics/isolated_labels.py | 60 ++++++++++++-- setup.cfg | 18 ++--- tests/integration/test_harmony.py | 6 +- tests/integration/test_scanorama.py | 10 +-- tests/integration/test_scanvi.py | 3 +- tests/integration/test_scvi.py | 3 +- tests/metrics/test_isolated_label.py | 19 ++++- tests/preprocessing/test_clustering.py | 37 ++++++++- 10 files changed, 214 insertions(+), 65 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 60c9efa2..4da63fd1 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -91,6 +91,22 @@ For these, you need to additionally provide the corresponding label column of `` :skip: runTrVaep :skip: issparse + +Clustering +---------- +.. currentmodule:: scib.metrics + +After integration, one of the first ways to determine the quality of the integration is to cluster the integrated data and compare the clusters to the original annotations. +This is exactly what some of the metrics do. + +.. autosummary:: + :toctree: api/ + + cluster_optimal_resolution + get_resolutions + opt_louvain + + Metrics ------- @@ -184,6 +200,7 @@ Some parts of metrics can be used individually, these are listed below. :toctree: api/ cluster_optimal_resolution + get_resolutions lisi_graph pcr pc_regression diff --git a/scib/metrics/clustering.py b/scib/metrics/clustering.py index 61685370..d9a9b7df 100644 --- a/scib/metrics/clustering.py +++ b/scib/metrics/clustering.py @@ -1,5 +1,6 @@ +import warnings + import matplotlib.pyplot as plt -import numpy as np import pandas as pd import scanpy as sc import seaborn as sns @@ -8,11 +9,27 @@ from .nmi import nmi -def get_resolutions(n=20, min=0.1, max=2): - min = np.max([1, int(min * 10)]) - max = np.max([min, max * 10]) - frac = n / 10 - return [frac * x / n for x in range(min, max + 1)] +def get_resolutions(n=20, min=0, max=2): + """ + Get equally spaced resolutions for optimised clustering + + :param n: number of resolutions + :param min: minimum resolution + :param max: maximum resolution + + .. code-block:: python + + scib.cl.get_resolutions(n=10) + + Output: + + .. code-block:: + + [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0] + + """ + res_range = max - min + return [res_range * (x + 1) / n for x in range(n)] def cluster_optimal_resolution( @@ -23,7 +40,7 @@ def cluster_optimal_resolution( metric=None, resolutions=None, use_rep=None, - force=True, + force=False, verbose=True, return_all=False, metric_kwargs=None, @@ -36,14 +53,15 @@ def cluster_optimal_resolution( :param adata: anndata object :param label_key: name of column in adata.obs containing biological labels to be optimised against - :param cluster_key: name of column to be added to adata.obs during clustering. - Will be overwritten if exists and ``force=True`` + :param cluster_key: name and prefix of columns to be added to adata.obs during clustering. + Each resolution will be saved under "{cluster_key}_{resolution}", while the optimal clustering will be under ``cluster_key``. + If ``force=True`` and one of the keys already exists, it will be overwritten. :param cluster_function: a clustering function that takes an anndata.Anndata object. Default: Leiden clustering :param metric: function that computes the cost to be optimised over. Must take as arguments ``(adata, label_key, cluster_key, **metric_kwargs)`` and returns a number for maximising Default is :func:`~scib.metrics.nmi()` :param resolutions: list of resolutions to be optimised over. If ``resolutions=None``, - default resolutions of 10 values ranging between 0.1 and 2 will be used + by default 10 equally spaced resolutions ranging between 0 and 2 will be used (see :func:`~scib.metrics.get_resolutions`) :param use_rep: key of embedding to use only if ``adata.uns['neighbors']`` is not defined, otherwise will be ignored :param force: whether to overwrite the cluster assignments in the ``.obs[cluster_key]`` @@ -56,22 +74,33 @@ def cluster_optimal_resolution( ``res_max``: resolution of maximum score; ``score_max``: maximum score; ``score_all``: ``pd.DataFrame`` containing all scores at resolutions. Can be used to plot the score profile. + + If you specify an embedding that was not used for the kNN graph (i.e. ``adata.uns["neighbors"]["params"]["use_rep"]`` is not the same as ``use_rep``), + the neighbors will be recomputed in-place. """ - if cluster_key in adata.obs.columns: - if force: - print( - f"WARNING: cluster key {cluster_key} already exists in adata.obs and will be overwritten because " - "force=True " - ) - else: - raise ValueError( - f"cluster key {cluster_key} already exists in adata, please remove the key or choose a different " - "name. If you want to force overwriting the key, specify `force=True` " + + def call_cluster_function(adata, res, resolution_key, cluster_function, **kwargs): + if resolution_key in adata.obs.columns: + warnings.warn( + f"Overwriting existing key {resolution_key} in adata.obs", stacklevel=2 ) + # check or recompute neighbours + knn_rep = adata.uns.get("neighbors", {}).get("params", {}).get("use_rep") + if use_rep is not None and use_rep != knn_rep: + print(f"Recompute neighbors on rep {use_rep} instead of {knn_rep}") + sc.pp.neighbors(adata, use_rep=use_rep) + + # call clustering function + print(f"Cluster for {resolution_key} with {cluster_function.__name__}") + cluster_function(adata, resolution=res, key_added=resolution_key, **kwargs) + if cluster_function is None: cluster_function = sc.tl.leiden + if cluster_key is None: + cluster_key = cluster_function.__name__ + if metric is None: metric = nmi @@ -86,30 +115,27 @@ def cluster_optimal_resolution( clustering = None score_all = [] - if use_rep is None: - try: - adata.uns["neighbors"] - except KeyError: - raise RuntimeError( - "Neighbours must be computed when setting use_rep to None" + for res in resolutions: + resolution_key = f"{cluster_key}_{res}" + + # check if clustering exists + if resolution_key not in adata.obs.columns or force: + call_cluster_function( + adata, res, resolution_key, cluster_function, **kwargs ) - else: - print(f"Compute neighbors on rep {use_rep}") - sc.pp.neighbors(adata, use_rep=use_rep) - for res in resolutions: - cluster_function(adata, resolution=res, key_added=cluster_key, **kwargs) - score = metric(adata, label_key, cluster_key, **metric_kwargs) - if verbose: - print(f"resolution: {res}, {metric.__name__}: {score}") + # score cluster resolution + score = metric(adata, label_key, resolution_key, **metric_kwargs) score_all.append(score) + if verbose: + print(f"resolution: {res}, {metric.__name__}: {score}", flush=True) + # optimise score if score_max < score: score_max = score res_max = res - clustering = adata.obs[cluster_key] - del adata.obs[cluster_key] + clustering = adata.obs[resolution_key] if verbose: print(f"optimised clustering against {label_key}") @@ -120,10 +146,16 @@ def cluster_optimal_resolution( zip(resolutions, score_all), columns=["resolution", "score"] ) + # save optimal clustering in adata.obs + if cluster_key in adata.obs.columns: + warnings.warn( + f"Overwriting existing key {cluster_key} in adata.obs", stacklevel=2 + ) adata.obs[cluster_key] = clustering if return_all: return res_max, score_max, score_all + return res_max, score_max @deprecated @@ -142,6 +174,8 @@ def opt_louvain( ): """Optimised Louvain clustering + DEPRECATED: Use :func:`~scib.metrics.cluster_optimal_resolution` instead + Louvain clustering with resolution optimised against a metric :param adata: anndata object diff --git a/scib/metrics/isolated_labels.py b/scib/metrics/isolated_labels.py index 2e8258ab..22b97bb7 100644 --- a/scib/metrics/isolated_labels.py +++ b/scib/metrics/isolated_labels.py @@ -1,3 +1,5 @@ +import warnings + import pandas as pd from sklearn.metrics import f1_score, silhouette_samples @@ -9,8 +11,11 @@ def isolated_labels_f1( label_key, batch_key, embed, + cluster_key="iso_label", + resolutions=None, iso_threshold=None, verbose=True, + **kwargs, ): """Isolated label score F1 @@ -25,11 +30,17 @@ def isolated_labels_f1( :param iso_threshold: max number of batches per label for label to be considered as isolated, if iso_threshold is integer. If ``iso_threshold=None``, consider minimum number of batches that labels are present in + :param cluster_key: clustering key prefix to look or recompute for each resolution in resolutions. + Is passed to :func:`~scib.metrics.cluster_optimal_resolution` + :param resolutions: list of resolutions to be passed to :func:`~scib.metrics.cluster_optimal_resolution` :param verbose: + :params \\**kwargs: additional arguments to be passed to :func:`~scib.metrics.cluster_optimal_resolution` :return: Mean of F1 scores over all isolated labels This function performs clustering on a kNN graph and can be applied to all integration output types. - For this metric the ``adata`` needs a kNN graph. + For this metric the ``adata`` needs a kNN graph and can optionally make use of precomputed clustering (see example below). + The precomputed clusters must be saved under ``adata.obs[cluster_key]`` as well as ``adata.obs[f"{cluster_key}_{resolution}"]`` for all resolutions. + See :ref:`preprocessing` for more information on preprocessing. **Examples** @@ -49,6 +60,13 @@ def isolated_labels_f1( # knn output scib.me.isolated_labels_f1(adata, batch_key="batch", label_key="celltype") + # use precomputed clustering + scib.cl.cluster_optimal_resolution(adata, cluster_key="iso_label", label_key="celltype") + scib.me.isolated_labels_f1(adata, batch_key="batch", label_key="celltype") + + # overwrite existing clustering + scib.me.isolated_labels_f1(adata, batch_key="batch", label_key="celltype", force=True) + """ return isolated_labels( adata, @@ -56,8 +74,11 @@ def isolated_labels_f1( batch_key=batch_key, embed=embed, cluster=True, + cluster_key=cluster_key, + resolutions=resolutions, iso_threshold=iso_threshold, verbose=verbose, + **kwargs, ) @@ -84,6 +105,7 @@ def isolated_labels_asw( If ``iso_threshold=None``, consider minimum number of batches that labels are present in :param scale: Whether to scale the score between 0 and 1. Only relevant for ASW scores. :param verbose: + :params \\**kwargs: additional arguments to be passed to :func:`~scib.metrics.cluster_optimal_resolution` :return: Mean of ASW over all isolated labels The function requires an embedding to be stored in ``adata.obsm`` and can only be applied to feature and embedding @@ -125,10 +147,13 @@ def isolated_labels( batch_key, embed, cluster=True, + cluster_key="iso_label", + resolutions=None, iso_threshold=None, scale=True, return_all=False, verbose=True, + **kwargs, ): """Isolated label score @@ -146,9 +171,12 @@ def isolated_labels( :param iso_threshold: max number of batches per label for label to be considered as isolated, if iso_threshold is integer. If iso_threshold=None, consider minimum number of batches that labels are present in + :param cluster_key: name of key to be passed to :func:`~scib.metrics.cluster_optimal_resolution` + :param resolutions: list of resolutions to be passed to :func:`~scib.metrics.cluster_optimal_resolution` :param scale: Whether to scale the score between 0 and 1. Only relevant for ASW scores. :param return_all: return scores for all isolated labels instead of aggregated mean :param verbose: + :params \\**kwargs: additional arguments to be passed to :func:`~scib.metrics.cluster_optimal_resolution` :return: Mean of scores for each isolated label or dictionary of scores for each label if `return_all=True` @@ -158,6 +186,8 @@ def isolated_labels( isolated_labels = get_isolated_labels( adata, label_key, batch_key, iso_threshold, verbose ) + if verbose: + print(f"isolated labels: {isolated_labels}") # 2. compute isolated label score for each isolated label scores = {} @@ -171,9 +201,12 @@ def isolated_labels( label_key, label, embed, - cluster, + cluster_key=cluster_key, + cluster=cluster, scale=scale, verbose=verbose, + resolutions=resolutions, + **kwargs, ) scores[label] = score scores = pd.Series(scores) @@ -189,10 +222,12 @@ def score_isolated_label( label_key, isolated_label, embed, + cluster_key, cluster=True, - iso_label_key="iso_label", + resolutions=None, scale=True, verbose=False, + **kwargs, ): """ Compute label score for a single label @@ -203,10 +238,12 @@ def score_isolated_label( :param embed: embedding to be passed to opt_louvain, if adata.uns['neighbors'] is missing :param cluster: if True, compute clustering-based F1 score, otherwise compute silhouette score on grouping of isolated label vs all other remaining labels - :param iso_label_key: name of key to use for cluster assignment for F1 score or + :param cluster_key: name of key to use for cluster assignment for F1 score or isolated-vs-rest assignment for silhouette score + :param resolutions: list of resolutions to be passed to :func:`~scib.metrics.cluster_optimal_resolution` :param scale: Whether to scale the score between 0 and 1. Only relevant for ASW scores. :param verbose: + :params \\**kwargs: additional arguments to be passed to :func:`~scib.metrics.cluster_optimal_resolution` :return: Isolated label score """ @@ -233,13 +270,15 @@ def max_f1(adata, label_key, cluster_key, label, argmax=False): cluster_optimal_resolution( adata, label_key, - cluster_key=iso_label_key, + cluster_key=cluster_key, use_rep=embed, metric=max_f1, metric_kwargs={"label": isolated_label}, - verbose=False, + resolutions=resolutions, + force=False, + verbose=verbose, ) - score = max_f1(adata, label_key, iso_label_key, isolated_label, argmax=False) + score = max_f1(adata, label_key, cluster_key, isolated_label, argmax=False) else: # AWS score between isolated label vs rest if "silhouette_temp" not in adata.obs: @@ -275,6 +314,13 @@ def get_isolated_labels(adata, label_key, batch_key, iso_threshold, verbose): if iso_threshold is None: iso_threshold = batch_per_lab.min().tolist()[0] + if iso_threshold == adata.obs[batch_key].nunique(): + warnings.warn( + "iso_threshold is equal to number of batches in data, no isolated labels will be found", + stacklevel=2, + ) + return [] + if verbose: print(f"isolated labels: no more than {iso_threshold} batches per label") diff --git a/setup.cfg b/setup.cfg index e82d0c17..8c18984d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,15 +17,15 @@ author = Malte D. Luecken, Maren Buettner, Daniel C. Strobl, Michaela F. Mueller author_email = malte.luecken@helmholtz-muenchen.de, michaela.mueller@helmholtz-muenchen.de license = MIT url = https://github.com/theislab/scib -project_urls = +project_urls = Pipeline = https://github.com/theislab/scib-pipeline Reproducibility = https://theislab.github.io/scib-reproducibility Bug Tracker = https://github.com/theislab/scib/issues -keywords = +keywords = benchmarking single cell data integration -classifiers = +classifiers = Development Status :: 3 - Alpha Intended Audience :: Developers Intended Audience :: Science/Research @@ -40,11 +40,11 @@ classifiers = build_number = 1 [options] -packages = +packages = scib scib.metrics python_requires = >=3.8 -install_requires = +install_requires = numpy pandas>=2 seaborn @@ -65,7 +65,7 @@ install_requires = zip_safe = False [options.package_data] -scib = +scib = resources/*.txt knn_graph/* @@ -95,7 +95,7 @@ skip_glob = docs/* line-length = 120 target-version = py38 include = \.pyi?$ -exclude = +exclude = .eggs .git .venv @@ -104,7 +104,7 @@ exclude = [flake8] max-line-length = 88 -ignore = +ignore = W503 W504 E501 @@ -126,7 +126,7 @@ ignore = RST304 C408 exclude = .git,__pycache__,build,docs/_build,dist -per-file-ignores = +per-file-ignores = scib/*: D tests/*: D */__init__.py: F401 diff --git a/tests/integration/test_harmony.py b/tests/integration/test_harmony.py index 29abf800..6b7b9f62 100644 --- a/tests/integration/test_harmony.py +++ b/tests/integration/test_harmony.py @@ -11,7 +11,11 @@ def test_harmony(adata_paul15_template): # check NMI after clustering res_max, score_max, _ = scib.cl.cluster_optimal_resolution( - adata, label_key="celltype", cluster_key="cluster", return_all=True + adata, + label_key="celltype", + cluster_key="cluster", + use_rep="X_emb", + return_all=True, ) LOGGER.info(f"max resolution: {res_max}, max score: {score_max}") diff --git a/tests/integration/test_scanorama.py b/tests/integration/test_scanorama.py index 2f2f9b3c..23de802c 100644 --- a/tests/integration/test_scanorama.py +++ b/tests/integration/test_scanorama.py @@ -11,13 +11,9 @@ def test_scanorama(adata_paul15_template): adata, n_top_genes=200, neighbors=True, use_rep="X_emb", pca=True, umap=False ) - # check NMI after clustering - res_max, score_max, _ = scib.cl.cluster_optimal_resolution( - adata, label_key="celltype", cluster_key="cluster", return_all=True - ) - LOGGER.info(f"max resolution: {res_max}, max score: {score_max}") - - assert_near_exact(score_max, 0.6610082444492823, 1e-2) + score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"score: {score}") + assert_near_exact(score, 0.9922324725135062, 1e-2) def test_scanorama_batch_cols(adata_paul15_template): diff --git a/tests/integration/test_scanvi.py b/tests/integration/test_scanvi.py index 1aa71791..3c45f5b3 100644 --- a/tests/integration/test_scanvi.py +++ b/tests/integration/test_scanvi.py @@ -1,5 +1,5 @@ import scib -from tests.common import assert_near_exact +from tests.common import LOGGER, assert_near_exact def test_scanvi(adata_paul15_template): @@ -12,4 +12,5 @@ def test_scanvi(adata_paul15_template): ) score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"score: {score}") assert_near_exact(score, 1.0, 1e-1) diff --git a/tests/integration/test_scvi.py b/tests/integration/test_scvi.py index 6e41ac80..e0c33633 100644 --- a/tests/integration/test_scvi.py +++ b/tests/integration/test_scvi.py @@ -1,5 +1,5 @@ import scib -from tests.common import assert_near_exact +from tests.common import LOGGER, assert_near_exact def test_scvi(adata_paul15_template): @@ -10,4 +10,5 @@ def test_scvi(adata_paul15_template): ) score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"score: {score}") assert_near_exact(score, 0.96, 1e-1) diff --git a/tests/metrics/test_isolated_label.py b/tests/metrics/test_isolated_label.py index 43a209ec..c8117515 100644 --- a/tests/metrics/test_isolated_label.py +++ b/tests/metrics/test_isolated_label.py @@ -22,7 +22,7 @@ def _random_embedding(partition): return embedding -def test_isolated_labels_F1(adata_neighbors): +def test_isolated_labels_f1(adata_neighbors): score = scib.me.isolated_labels_f1( adata_neighbors, label_key="celltype", @@ -34,7 +34,22 @@ def test_isolated_labels_F1(adata_neighbors): assert_near_exact(score, 0.5581395348837209, diff=1e-12) -def test_isolated_labels_ASW(adata_pca): +def test_isolated_labels_f1_precomputed(adata_clustered): + score = scib.me.isolated_labels_f1( + adata_clustered, + label_key="celltype", + batch_key="batch", + embed="X_pca", + cluster_key="cluster", + verbose=True, + ) + assert "iso_label" not in adata_clustered.obs.columns + + LOGGER.info(f"score: {score}") + assert_near_exact(score, 0.5581395348837209, diff=1e-12) + + +def test_isolated_labels_asw(adata_pca): score = scib.me.isolated_labels_asw( adata_pca, label_key="celltype", diff --git a/tests/preprocessing/test_clustering.py b/tests/preprocessing/test_clustering.py index 004a4743..1921356b 100644 --- a/tests/preprocessing/test_clustering.py +++ b/tests/preprocessing/test_clustering.py @@ -27,10 +27,11 @@ def test_cluster_optimal_resolution_louvain(adata_neighbors): label_key="celltype", cluster_key="cluster", cluster_function=sc.tl.louvain, - resolutions=scib.cl.get_resolutions(n=20, min=0.1, max=2), + resolutions=scib.cl.get_resolutions(n=10, min=0, max=1), return_all=True, ) assert isinstance(score_all, pd.DataFrame) + assert "cluster" in adata_neighbors.obs.columns LOGGER.info(f"max resolution: {res_max}, max score: {score_max}") assert res_max == 0.7 @@ -43,10 +44,44 @@ def test_cluster_optimal_resolution_leiden(adata_neighbors): label_key="celltype", cluster_key="cluster", cluster_function=sc.tl.leiden, + resolutions=scib.cl.get_resolutions(n=10, min=0, max=1), return_all=True, ) assert isinstance(score_all, pd.DataFrame) + assert "cluster" in adata_neighbors.obs.columns LOGGER.info(f"max resolution: {res_max}, max score: {score_max}") assert res_max == 0.5 assert_near_exact(score_max, 0.7424614219722735, diff=1e-3) + + +def test_precomputed_cluster(adata): + resolutions = scib.cl.get_resolutions(n=10, min=0, max=1) + for res in resolutions: + adata.obs[f"cluster_{res}"] = adata.obs["celltype"] + + res_max, score_max = scib.cl.cluster_optimal_resolution( + adata, + cluster_key="cluster", + label_key="celltype", + force=False, + resolutions=resolutions, + ) + assert res_max == 0.1 + assert_near_exact(score_max, 1, diff=0) + + +def test_precomputed_cluster_force(adata_neighbors): + resolutions = scib.cl.get_resolutions(n=10, min=0, max=1) + for res in resolutions: + adata_neighbors.obs[f"cluster_{res}"] = adata_neighbors.obs["celltype"] + + res_max, score_max = scib.cl.cluster_optimal_resolution( + adata_neighbors, + cluster_key="cluster", + label_key="celltype", + resolutions=resolutions, + force=True, + ) + assert res_max == 0.5 + assert_near_exact(score_max, 0.7424614219722736, diff=1e-5)