diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index abec31e1..40798b25 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -83,6 +83,7 @@ jobs: matrix: python: [3.7, 3.9] os: [ubuntu-latest] + method: [scanorama, scvi-tools, trvae] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python }} @@ -93,11 +94,11 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install '.[test,scanorama,scvi]' + pip install '.[test,${{ matrix.method }}]' - name: Test with pytest run: | - python -m pytest --cov --cov-append --cov-report=term-missing -k integration -vv + python -m pytest --cov --cov-append --cov-report=term-missing -k ${{ matrix.method }} -vv - name: Upload coverage env: CODECOV_NAME: ${{ matrix.os }}-${{ matrix.python }} diff --git a/scib/integration.py b/scib/integration.py index 4e75e3f7..19506522 100644 --- a/scib/integration.py +++ b/scib/integration.py @@ -55,39 +55,44 @@ def trvae(adata, batch, hvg=None): raise OptionalDependencyNotInstalled(e) utils.check_sanity(adata, batch, hvg) - n_batches = len(adata.obs[batch].cat.categories) - train_adata, valid_adata = trvae.utils.train_test_split(adata, train_frac=0.80) + batches = adata.obs[batch].unique().tolist() - condition_encoder = trvae.utils.create_dictionary( - adata.obs[batch].cat.categories.tolist(), [] - ) - - network = trvae.archs.trVAEMulti( - x_dimension=train_adata.shape[1], - n_conditions=n_batches, - output_activation="relu", + network = trvae.models.trVAE( + x_dimension=adata.shape[1], + architecture=[256, 64], + z_dimension=10, + gene_names=adata.var_names.tolist(), + conditions=batches, + model_path="/localscratch/", + alpha=0.0001, + beta=50, + eta=100, + loss_fn="sse", + output_activation="linear", ) network.train( - train_adata, - valid_adata, - condition_key=batch, - condition_encoder=condition_encoder, - verbose=0, - ) - - labels, _ = trvae.tl.label_encoder( adata, - condition_key=batch, - label_encoder=condition_encoder, + batch, + train_size=0.8, + n_epochs=50, + batch_size=512, + early_stop_limit=10, + lr_reducer=20, + verbose=5, + save=False, ) - network.get_corrected(adata, labels, return_z=False) + latent_adata = network.get_latent(adata, batch) + + target_batch = adata.obs[batch].value_counts().index[0] + + corrected_data = network.predict(adata, batch, target_condition=target_batch) - adata.obsm["X_emb"] = adata.obsm["mmd_latent"] - del adata.obsm["mmd_latent"] - adata.X = adata.obsm["reconstructed"] + # Assign trVAE outputs + adata.obsm["X_emb"] = latent_adata + adata.X = corrected_data return adata diff --git a/setup.cfg b/setup.cfg index f90fc0e4..a0314d21 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,8 +79,8 @@ bbknn = bbknn ==1.3.9 scanorama = scanorama ==1.7.0 mnn = mnnpy ==0.1.9.5 scgen = scgen >=2.1.0 -scvi = scvi-tools >=0.16.1 -trvae = trvae ==1.1.2 +scvi-tools = scvi-tools >=0.16.1 +trvae = tensorflow ==2.5.3; trvae @ git+https://github.com/theislab/trVAE.git@6fd87fcd1fc47a6b93579dcb7caac0d1e85ed10e trvaep = trvaep ==0.1.0 desc = desc ==2.0.3 diff --git a/tests/integration/test_scanorama.py b/tests/integration/test_scanorama.py index ebf45279..471ccc8c 100644 --- a/tests/integration/test_scanorama.py +++ b/tests/integration/test_scanorama.py @@ -7,9 +7,17 @@ def test_scanorama(adata_paul15_template): adata = scib.ig.scanorama(adata_paul15_template, batch="batch") + # check full feature output scib.pp.reduce_data( - adata, n_top_genes=200, neighbors=True, use_rep="X_emb", pca=True, umap=False + adata, pca=True, n_top_genes=200, neighbors=True, use_rep="X_pca", umap=False ) + score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"\nscore: {score}") + + # check embedding output + scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) + score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"\nscore: {score}") # check NMI after clustering res_max, score_max, _ = scib.cl.opt_louvain( diff --git a/tests/integration/test_scanvi.py b/tests/integration/test_scanvi.py deleted file mode 100644 index b61b5358..00000000 --- a/tests/integration/test_scanvi.py +++ /dev/null @@ -1,15 +0,0 @@ -import scib -from tests.common import assert_near_exact - - -def test_scanvi(adata_paul15_template): - adata = scib.ig.scanvi( - adata_paul15_template, batch="batch", labels="celltype", max_epochs=20 - ) - - scib.pp.reduce_data( - adata, n_top_genes=200, neighbors=True, use_rep="X_emb", pca=True, umap=False - ) - - score = scib.me.graph_connectivity(adata, label_key="celltype") - assert_near_exact(score, 0.9834078129657216, 1e-2) diff --git a/tests/integration/test_scvi-tools.py b/tests/integration/test_scvi-tools.py new file mode 100644 index 00000000..b57dea3c --- /dev/null +++ b/tests/integration/test_scvi-tools.py @@ -0,0 +1,22 @@ +import scib +from tests.common import assert_near_exact + + +def test_scvi(adata_paul15_template): + adata = scib.ig.scvi(adata_paul15_template, batch="batch", max_epochs=20) + + scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) + + score = scib.me.graph_connectivity(adata, label_key="celltype") + assert_near_exact(score, 0.9684638088694193, 1e-2) + + +def test_scanvi(adata_paul15_template): + adata = scib.ig.scanvi( + adata_paul15_template, batch="batch", labels="celltype", max_epochs=20 + ) + + scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) + + score = scib.me.graph_connectivity(adata, label_key="celltype") + assert_near_exact(score, 0.9834078129657216, 1e-2) diff --git a/tests/integration/test_scvi.py b/tests/integration/test_scvi.py deleted file mode 100644 index 98989d4e..00000000 --- a/tests/integration/test_scvi.py +++ /dev/null @@ -1,13 +0,0 @@ -import scib -from tests.common import assert_near_exact - - -def test_scvi(adata_paul15_template): - adata = scib.ig.scvi(adata_paul15_template, batch="batch", max_epochs=20) - - scib.pp.reduce_data( - adata, n_top_genes=200, neighbors=True, use_rep="X_emb", pca=True, umap=False - ) - - score = scib.me.graph_connectivity(adata, label_key="celltype") - assert_near_exact(score, 0.9684638088694193, 1e-2) diff --git a/tests/integration/test_trvae.py b/tests/integration/test_trvae.py new file mode 100644 index 00000000..c900dcc0 --- /dev/null +++ b/tests/integration/test_trvae.py @@ -0,0 +1,19 @@ +import scib +from tests.common import LOGGER + + +def test_trvae(adata_paul15_template): + adata = scib.ig.trvae(adata_paul15_template, batch="batch") + + # check full feature output + scib.pp.reduce_data( + adata, pca=True, n_top_genes=200, neighbors=True, use_rep="X_pca", umap=False + ) + score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"\nscore: {score}") + + # check embedding output + scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) + score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"\nscore: {score}") + # assert_near_exact(score, 0.9834078129657216, 1e-2)