From 0b830c7efff07228e613f782f0776fb36dff48fa Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Mon, 24 Oct 2022 14:02:35 +0200 Subject: [PATCH 1/5] updated trVAE code --- scib/integration.py | 53 +++++++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/scib/integration.py b/scib/integration.py index 4e75e3f7..19506522 100644 --- a/scib/integration.py +++ b/scib/integration.py @@ -55,39 +55,44 @@ def trvae(adata, batch, hvg=None): raise OptionalDependencyNotInstalled(e) utils.check_sanity(adata, batch, hvg) - n_batches = len(adata.obs[batch].cat.categories) - train_adata, valid_adata = trvae.utils.train_test_split(adata, train_frac=0.80) + batches = adata.obs[batch].unique().tolist() - condition_encoder = trvae.utils.create_dictionary( - adata.obs[batch].cat.categories.tolist(), [] - ) - - network = trvae.archs.trVAEMulti( - x_dimension=train_adata.shape[1], - n_conditions=n_batches, - output_activation="relu", + network = trvae.models.trVAE( + x_dimension=adata.shape[1], + architecture=[256, 64], + z_dimension=10, + gene_names=adata.var_names.tolist(), + conditions=batches, + model_path="/localscratch/", + alpha=0.0001, + beta=50, + eta=100, + loss_fn="sse", + output_activation="linear", ) network.train( - train_adata, - valid_adata, - condition_key=batch, - condition_encoder=condition_encoder, - verbose=0, - ) - - labels, _ = trvae.tl.label_encoder( adata, - condition_key=batch, - label_encoder=condition_encoder, + batch, + train_size=0.8, + n_epochs=50, + batch_size=512, + early_stop_limit=10, + lr_reducer=20, + verbose=5, + save=False, ) - network.get_corrected(adata, labels, return_z=False) + latent_adata = network.get_latent(adata, batch) + + target_batch = adata.obs[batch].value_counts().index[0] + + corrected_data = network.predict(adata, batch, target_condition=target_batch) - adata.obsm["X_emb"] = adata.obsm["mmd_latent"] - del adata.obsm["mmd_latent"] - adata.X = adata.obsm["reconstructed"] + # Assign trVAE outputs + adata.obsm["X_emb"] = latent_adata + adata.X = corrected_data return adata From 2eca5339aa427b98b581c714ece1682376174bbf Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Mon, 24 Oct 2022 14:49:56 +0200 Subject: [PATCH 2/5] add trvae test --- .github/workflows/test.yml | 2 +- setup.cfg | 2 +- tests/integration/test_trvae.py | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 tests/integration/test_trvae.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index abec31e1..e534af1d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -93,7 +93,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install '.[test,scanorama,scvi]' + pip install '.[test,scanorama,scvi,trvae]' - name: Test with pytest run: | diff --git a/setup.cfg b/setup.cfg index f90fc0e4..14dcf8be 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,7 +80,7 @@ scanorama = scanorama ==1.7.0 mnn = mnnpy ==0.1.9.5 scgen = scgen >=2.1.0 scvi = scvi-tools >=0.16.1 -trvae = trvae ==1.1.2 +trvae = trvae @ git+https://github.com/theislab/trVAE.git@6fd87fcd1fc47a6b93579dcb7caac0d1e85ed10e trvaep = trvaep ==0.1.0 desc = desc ==2.0.3 diff --git a/tests/integration/test_trvae.py b/tests/integration/test_trvae.py new file mode 100644 index 00000000..c900dcc0 --- /dev/null +++ b/tests/integration/test_trvae.py @@ -0,0 +1,19 @@ +import scib +from tests.common import LOGGER + + +def test_trvae(adata_paul15_template): + adata = scib.ig.trvae(adata_paul15_template, batch="batch") + + # check full feature output + scib.pp.reduce_data( + adata, pca=True, n_top_genes=200, neighbors=True, use_rep="X_pca", umap=False + ) + score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"\nscore: {score}") + + # check embedding output + scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) + score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"\nscore: {score}") + # assert_near_exact(score, 0.9834078129657216, 1e-2) From 55d5313598bc442073c0b58e3170e38693d6465e Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Mon, 24 Oct 2022 14:55:50 +0200 Subject: [PATCH 3/5] update processing for integration tests --- tests/integration/test_scanorama.py | 10 +++++++++- tests/integration/test_scanvi.py | 4 +--- tests/integration/test_scvi.py | 4 +--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_scanorama.py b/tests/integration/test_scanorama.py index ebf45279..471ccc8c 100644 --- a/tests/integration/test_scanorama.py +++ b/tests/integration/test_scanorama.py @@ -7,9 +7,17 @@ def test_scanorama(adata_paul15_template): adata = scib.ig.scanorama(adata_paul15_template, batch="batch") + # check full feature output scib.pp.reduce_data( - adata, n_top_genes=200, neighbors=True, use_rep="X_emb", pca=True, umap=False + adata, pca=True, n_top_genes=200, neighbors=True, use_rep="X_pca", umap=False ) + score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"\nscore: {score}") + + # check embedding output + scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) + score = scib.me.graph_connectivity(adata, label_key="celltype") + LOGGER.info(f"\nscore: {score}") # check NMI after clustering res_max, score_max, _ = scib.cl.opt_louvain( diff --git a/tests/integration/test_scanvi.py b/tests/integration/test_scanvi.py index b61b5358..3e7b7b33 100644 --- a/tests/integration/test_scanvi.py +++ b/tests/integration/test_scanvi.py @@ -7,9 +7,7 @@ def test_scanvi(adata_paul15_template): adata_paul15_template, batch="batch", labels="celltype", max_epochs=20 ) - scib.pp.reduce_data( - adata, n_top_genes=200, neighbors=True, use_rep="X_emb", pca=True, umap=False - ) + scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) score = scib.me.graph_connectivity(adata, label_key="celltype") assert_near_exact(score, 0.9834078129657216, 1e-2) diff --git a/tests/integration/test_scvi.py b/tests/integration/test_scvi.py index 98989d4e..f56c0b9a 100644 --- a/tests/integration/test_scvi.py +++ b/tests/integration/test_scvi.py @@ -5,9 +5,7 @@ def test_scvi(adata_paul15_template): adata = scib.ig.scvi(adata_paul15_template, batch="batch", max_epochs=20) - scib.pp.reduce_data( - adata, n_top_genes=200, neighbors=True, use_rep="X_emb", pca=True, umap=False - ) + scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) score = scib.me.graph_connectivity(adata, label_key="celltype") assert_near_exact(score, 0.9684638088694193, 1e-2) From 062783cf3f640e8f1db9ceaa81b02b4ac3695510 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Mon, 24 Oct 2022 15:10:03 +0200 Subject: [PATCH 4/5] split integration tests to different matrix jobs --- .github/workflows/test.yml | 5 +++-- setup.cfg | 2 +- .../{test_scanvi.py => test_scvi-tools.py} | 9 +++++++++ tests/integration/test_scvi.py | 11 ----------- 4 files changed, 13 insertions(+), 14 deletions(-) rename tests/integration/{test_scanvi.py => test_scvi-tools.py} (56%) delete mode 100644 tests/integration/test_scvi.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e534af1d..40798b25 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -83,6 +83,7 @@ jobs: matrix: python: [3.7, 3.9] os: [ubuntu-latest] + method: [scanorama, scvi-tools, trvae] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python }} @@ -93,11 +94,11 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install '.[test,scanorama,scvi,trvae]' + pip install '.[test,${{ matrix.method }}]' - name: Test with pytest run: | - python -m pytest --cov --cov-append --cov-report=term-missing -k integration -vv + python -m pytest --cov --cov-append --cov-report=term-missing -k ${{ matrix.method }} -vv - name: Upload coverage env: CODECOV_NAME: ${{ matrix.os }}-${{ matrix.python }} diff --git a/setup.cfg b/setup.cfg index 14dcf8be..d25d2112 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,7 +79,7 @@ bbknn = bbknn ==1.3.9 scanorama = scanorama ==1.7.0 mnn = mnnpy ==0.1.9.5 scgen = scgen >=2.1.0 -scvi = scvi-tools >=0.16.1 +scvi-tools = scvi-tools >=0.16.1 trvae = trvae @ git+https://github.com/theislab/trVAE.git@6fd87fcd1fc47a6b93579dcb7caac0d1e85ed10e trvaep = trvaep ==0.1.0 desc = desc ==2.0.3 diff --git a/tests/integration/test_scanvi.py b/tests/integration/test_scvi-tools.py similarity index 56% rename from tests/integration/test_scanvi.py rename to tests/integration/test_scvi-tools.py index 3e7b7b33..b57dea3c 100644 --- a/tests/integration/test_scanvi.py +++ b/tests/integration/test_scvi-tools.py @@ -2,6 +2,15 @@ from tests.common import assert_near_exact +def test_scvi(adata_paul15_template): + adata = scib.ig.scvi(adata_paul15_template, batch="batch", max_epochs=20) + + scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) + + score = scib.me.graph_connectivity(adata, label_key="celltype") + assert_near_exact(score, 0.9684638088694193, 1e-2) + + def test_scanvi(adata_paul15_template): adata = scib.ig.scanvi( adata_paul15_template, batch="batch", labels="celltype", max_epochs=20 diff --git a/tests/integration/test_scvi.py b/tests/integration/test_scvi.py deleted file mode 100644 index f56c0b9a..00000000 --- a/tests/integration/test_scvi.py +++ /dev/null @@ -1,11 +0,0 @@ -import scib -from tests.common import assert_near_exact - - -def test_scvi(adata_paul15_template): - adata = scib.ig.scvi(adata_paul15_template, batch="batch", max_epochs=20) - - scib.pp.reduce_data(adata, pca=False, neighbors=True, use_rep="X_emb", umap=False) - - score = scib.me.graph_connectivity(adata, label_key="celltype") - assert_near_exact(score, 0.9684638088694193, 1e-2) From 07b227e40d55da77c100932d92f833bac37d58ef Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Mon, 24 Oct 2022 15:19:45 +0200 Subject: [PATCH 5/5] fix tensorflow version for trvae --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 14dcf8be..b50c24c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,7 +80,7 @@ scanorama = scanorama ==1.7.0 mnn = mnnpy ==0.1.9.5 scgen = scgen >=2.1.0 scvi = scvi-tools >=0.16.1 -trvae = trvae @ git+https://github.com/theislab/trVAE.git@6fd87fcd1fc47a6b93579dcb7caac0d1e85ed10e +trvae = tensorflow ==2.5.3; trvae @ git+https://github.com/theislab/trVAE.git@6fd87fcd1fc47a6b93579dcb7caac0d1e85ed10e trvaep = trvaep ==0.1.0 desc = desc ==2.0.3