Skip to content

ValueError: at site "data_target", invalid log_prob shape #416

@nextgenius-ai

Description

@nextgenius-ai

version info:
PyTorch: 2.6.0+cu124
Pyro: 1.9.1
Cell2location: 0.1.4
CUDA available: True
CUDA version: 12.4
GPU name: Tesla V100-PCIE-32GB
Lightning: 2.5.3
Pyro: 1.9.1

ref_adata

AnnData object with n_obs × n_vars = 29851 × 26255
obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'percent.mito', 'sample', 'percent.ribo', 'RNA_snn_res.0.6', 'seurat_clusters', 'condition', 'sample2', 'cell_type', 'cell_type2', 'cell_type3', 'cell_type4'
var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'
obsm: 'X_pca', 'X_tsne', 'X_umap'

ref_adata.X = ref_adata.X.astype("int32")

from cell2location.utils.filtering import filter_genes
selected = filter_genes(ref_adata, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
ref_adata = ref_adata[:, selected].copy()

cell2location.models.RegressionModel.setup_anndata(adata=ref_adata,
                        batch_key='sample',
                        labels_key='cell_type'
                       )

from cell2location.models import RegressionModel
mod = RegressionModel(ref_adata)
mod.view_anndata_setup()

mod.train(
    max_epochs=250,
    batch_size=128,
    accelerator="gpu"
)

# error:
......
File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py:80](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py#line=79), in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     78 for site in model_trace.nodes.values():
     79     if site["type"] == "sample":
---> 80         check_site_shape(site, max_plate_nesting)
     81 for site in guide_trace.nodes.values():
     82     if site["type"] == "sample":

File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py:437](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py#line=436), in check_site_shape(site, max_plate_nesting)
    433 for actual_size, expected_size in zip_longest(
    434     reversed(actual_shape), reversed(expected_shape), fillvalue=1
    435 ):
    436     if expected_size != -1 and expected_size != actual_size:
--> 437         raise ValueError(
    438             "\n  ".join(
    439                 [
    440                     'at site "{}", invalid log_prob shape'.format(site["name"]),
    441                     "Expected {}, actual {}".format(expected_shape, actual_shape),
    442                     "Try one of the following fixes:",
    443                     "- enclose the batched tensor in a with pyro.plate(...): context",
    444                     "- .to_event(...) the distribution being sampled",
    445                     "- .permute() data dimensions",
    446                 ]
    447             )
    448         )
    450 # Check parallel dimensions on the left of max_plate_nesting.
    451 enum_dim = site["infer"].get("_enumerate_dim")

ValueError: at site "data_target", invalid log_prob shape
  Expected [128, -1], actual [128, 128, 15125]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

i'm sure i used the count matrix, and I have set ref_adata.X.astype("int32") or ref_adata.X.astype("int") ,the problem is still exist. moreover , change accelerator="cpu", the problem is still exist.

when run the lymph node tutorial ,the error is still:

adata_ref = sc.read(
    f'./data/sc.h5ad',
    backup_url='https://cell2location.cog.sanger.ac.uk/paper/integrated_lymphoid_organ_scrna/RegressionNBV4Torch_57covariates_73260cells_10237genes/sc.h5ad'
)
adata_ref.var['SYMBOL'] = adata_ref.var.index
adata_ref.var.set_index('GeneID-2', drop=True, inplace=True)
del adata_ref.raw
from cell2location.utils.filtering import filter_genes
selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
adata_ref = adata_ref[:, selected].copy()
cell2location.models.RegressionModel.setup_anndata(adata=adata_ref,
                        # 10X reaction / sample / batch
                        batch_key='Sample',
                        # cell type, covariate used for constructing signatures
                        labels_key='Subset',
                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
                        categorical_covariate_keys=['Method']
                       )
from cell2location.models import RegressionModel
mod = RegressionModel(adata_ref)
mod.view_anndata_setup()
mod.train(max_epochs=250, batch_size=32,accelerator='gpu')

......
File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/trace_elbo.py:57](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/st_AGA3/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/trace_elbo.py#line=56), in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py:80](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/st_AGA3/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/infer/enum.py#line=79), in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     78 for site in model_trace.nodes.values():
     79     if site["type"] == "sample":
---> 80         check_site_shape(site, max_plate_nesting)
     81 for site in guide_trace.nodes.values():
     82     if site["type"] == "sample":

File [~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py:437](http://localhost:8880/lab/tree/work/Stereo_seq_AGA/st_AGA3/~/.conda/envs/cell2loc_env/lib/python3.10/site-packages/pyro/util.py#line=436), in check_site_shape(site, max_plate_nesting)
    433 for actual_size, expected_size in zip_longest(
    434     reversed(actual_shape), reversed(expected_shape), fillvalue=1
    435 ):
    436     if expected_size != -1 and expected_size != actual_size:
--> 437         raise ValueError(
    438             "\n  ".join(
    439                 [
    440                     'at site "{}", invalid log_prob shape'.format(site["name"]),
    441                     "Expected {}, actual {}".format(expected_shape, actual_shape),
    442                     "Try one of the following fixes:",
    443                     "- enclose the batched tensor in a with pyro.plate(...): context",
    444                     "- .to_event(...) the distribution being sampled",
    445                     "- .permute() data dimensions",
    446                 ]
    447             )
    448         )
    450 # Check parallel dimensions on the left of max_plate_nesting.
    451 enum_dim = site["infer"].get("_enumerate_dim")

ValueError: at site "data_target", invalid log_prob shape
  Expected [32, -1], actual [32, 32, 10237]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions