-
Notifications
You must be signed in to change notification settings - Fork 69
Description
version info:
PyTorch: 2.6.0+cu124
Pyro: 1.9.1
Cell2location: 0.1.4
CUDA available: True
CUDA version: 12.4
GPU name: Tesla V100-PCIE-32GB
Lightning: 2.5.3
Pyro: 1.9.1
ref_adata
AnnData object with n_obs × n_vars = 29851 × 26255
obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'sample', 'percent.ribo', 'RNA_snn_res.0.6', 'seurat_clusters', 'condition', 'sample2', 'cell_type', 'cell_type2', 'cell_type3', 'cell_type4'
var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'
obsm: 'X_pca', 'X_tsne', 'X_umap'
ref_adata.X = ref_adata.X.astype("int32")
from cell2location.utils.filtering import filter_genes
selected = filter_genes(ref_adata, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
ref_adata = ref_adata[:, selected].copy()
cell2location.models.RegressionModel.setup_anndata(adata=ref_adata,
batch_key='sample',
labels_key='cell_type'
)
from cell2location.models import RegressionModel
mod = RegressionModel(ref_adata)
mod.view_anndata_setup()
mod.train(
max_epochs=250,
batch_size=128,
accelerator="gpu"
)
# error:
......
File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py:80](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py#line=79), in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
78 for site in model_trace.nodes.values():
79 if site["type"] == "sample":
---> 80 check_site_shape(site, max_plate_nesting)
81 for site in guide_trace.nodes.values():
82 if site["type"] == "sample":
File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py:437](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py#line=436), in check_site_shape(site, max_plate_nesting)
433 for actual_size, expected_size in zip_longest(
434 reversed(actual_shape), reversed(expected_shape), fillvalue=1
435 ):
436 if expected_size != -1 and expected_size != actual_size:
--> 437 raise ValueError(
438 "\n ".join(
439 [
440 'at site "{}", invalid log_prob shape'.format(site["name"]),
441 "Expected {}, actual {}".format(expected_shape, actual_shape),
442 "Try one of the following fixes:",
443 "- enclose the batched tensor in a with pyro.plate(...): context",
444 "- .to_event(...) the distribution being sampled",
445 "- .permute() data dimensions",
446 ]
447 )
448 )
450 # Check parallel dimensions on the left of max_plate_nesting.
451 enum_dim = site["infer"].get("_enumerate_dim")
ValueError: at site "data_target", invalid log_prob shape
Expected [128, -1], actual [128, 128, 15125]
Try one of the following fixes:
- enclose the batched tensor in a with pyro.plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensionsi'm sure i used the count matrix, and I have set ref_adata.X.astype("int32") or ref_adata.X.astype("int") ,the problem is still exist. moreover , change accelerator="cpu", the problem is still exist.
when run the lymph node tutorial ,the error is still:
adata_ref = sc.read(
f'./data/sc.h5ad',
backup_url='https://cell2location.cog.sanger.ac.uk/paper/integrated_lymphoid_organ_scrna/RegressionNBV4Torch_57covariates_73260cells_10237genes/sc.h5ad'
)
adata_ref.var['SYMBOL'] = adata_ref.var.index
adata_ref.var.set_index('GeneID-2', drop=True, inplace=True)
del adata_ref.raw
from cell2location.utils.filtering import filter_genes
selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
adata_ref = adata_ref[:, selected].copy()
cell2location.models.RegressionModel.setup_anndata(adata=adata_ref,
# 10X reaction / sample / batch
batch_key='Sample',
# cell type, covariate used for constructing signatures
labels_key='Subset',
# multiplicative technical effects (platform, 3' vs 5', donor effect)
categorical_covariate_keys=['Method']
)
from cell2location.models import RegressionModel
mod = RegressionModel(adata_ref)
mod.view_anndata_setup()
mod.train(max_epochs=250, batch_size=32,accelerator='gpu')
......
File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/trace_elbo.py:57](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/st_AGA3/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/trace_elbo.py#line=56), in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 """
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)
File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py:80](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/st_AGA3/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py#line=79), in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
78 for site in model_trace.nodes.values():
79 if site["type"] == "sample":
---> 80 check_site_shape(site, max_plate_nesting)
81 for site in guide_trace.nodes.values():
82 if site["type"] == "sample":
File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py:437](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/st_AGA3/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py#line=436), in check_site_shape(site, max_plate_nesting)
433 for actual_size, expected_size in zip_longest(
434 reversed(actual_shape), reversed(expected_shape), fillvalue=1
435 ):
436 if expected_size != -1 and expected_size != actual_size:
--> 437 raise ValueError(
438 "\n ".join(
439 [
440 'at site "{}", invalid log_prob shape'.format(site["name"]),
441 "Expected {}, actual {}".format(expected_shape, actual_shape),
442 "Try one of the following fixes:",
443 "- enclose the batched tensor in a with pyro.plate(...): context",
444 "- .to_event(...) the distribution being sampled",
445 "- .permute() data dimensions",
446 ]
447 )
448 )
450 # Check parallel dimensions on the left of max_plate_nesting.
451 enum_dim = site["infer"].get("_enumerate_dim")
ValueError: at site "data_target", invalid log_prob shape
Expected [32, -1], actual [32, 32, 10237]
Try one of the following fixes:
- enclose the batched tensor in a with pyro.plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions