-
Notifications
You must be signed in to change notification settings - Fork 69
Open
Labels
bugSomething isn't workingSomething isn't working
Description
RegressionModel.load fails on a save that I have just created.
Minimal code sample (that we can run without your data, using public data)
import scanpy as sc
import cell2location
from cell2location.models import RegressionModel
adata = sc.datasets.visium_sge(sample_id="V1_Human_Lymph_Node")
adata.obs['sample'] = list(adata.uns['spatial'].keys())[0]
adata.var['SYMBOL'] = adata.var_names
adata.var.set_index('gene_ids', drop=True, inplace=True)
# find mitochondria-encoded (MT) genes
adata.var['MT_gene'] = [gene.startswith('MT-') for gene in adata.var['SYMBOL']]
# remove MT genes for spatial mapping (keeping their counts in the object)
adata.obsm['MT'] = adata[:, adata.var['MT_gene'].values].X.toarray()
adata = adata[:, ~adata.var['MT_gene'].values]
adata_ref = sc.read(
f'./data/sc.h5ad',
backup_url='https://cell2location.cog.sanger.ac.uk/paper/integrated_lymphoid_organ_scrna/RegressionNBV4Torch_57covariates_73260cells_10237genes/sc.h5ad'
)
adata_ref.var['SYMBOL'] = adata_ref.var.index
# rename 'GeneID-2' as necessary for your data
adata_ref.var.set_index('GeneID-2', drop=True, inplace=True)
# delete unnecessary raw slot (to be removed in a future version of the tutorial)
del adata_ref.raw
# filter the object
selected = cell2location.utils.filtering.filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
adata_ref = adata_ref[:, selected].copy()
cell2location.models.RegressionModel.setup_anndata(
adata=adata_ref,
batch_key='Sample',
labels_key='Subset',
categorical_covariate_keys=['Method'])
mod = RegressionModel(adata_ref)
# save
mod.save('mod-save-with-anndata', save_anndata=True, overwrite=True)
# load
mod2 = RegressionModel.load('mod-save-with-anndata')INFO File mod-save-with-anndata/model.pt already downloaded
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /package/python-cbrg/current/3.11.3/lib/python3.11/s ...
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /package/python-cbrg/current/3.11.3/lib/python3.11/s ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /package/python-cbrg/current/3.11.3/lib/python3.11/s ...
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.
Epoch 1/273: 0%
1/273 [00:02<09:10, 2.02s/it, v_num=1]
`Trainer.fit` stopped: `max_steps=1` reached.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[19], line 3
1 # load
----> 2 mod2 = RegressionModel.load('mod-save-with-anndata')
File /package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/scvi/model/base/_base_model.py:873, in BaseModelClass.load(cls, dir_path, adata, accelerator, device, prefix, backup_url, datamodule, allowed_classes_names_list)
871 else:
872 model.module.on_load(model, pyro_param_store=pyro_param_store)
--> 873 model.module.load_state_dict(model_state_dict)
875 model.to_device(device)
877 model.module.eval()
File /package/python-cbrg/current/3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py:2624, in Module.load_state_dict(self, state_dict, strict, assign)
2616 error_msgs.insert(
2617 0,
2618 "Missing key(s) in state_dict: {}. ".format(
2619 ", ".join(f'"{k}"' for k in missing_keys)
2620 ),
2621 )
2623 if len(error_msgs) > 0:
-> 2624 raise RuntimeError(
2625 "Error(s) in loading state_dict for {}:\n\t{}".format(
2626 self.__class__.__name__, "\n\t".join(error_msgs)
2627 )
2628 )
2629 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for RegressionBaseModule:
Missing key(s) in state_dict: "_guide.locs.per_cluster_mu_fg_unconstrained", "_guide.locs.detection_tech_gene_tg_unconstrained", "_guide.locs.detection_mean_y_e_unconstrained", "_guide.locs.s_g_gene_add_alpha_hyp_unconstrained", "_guide.locs.s_g_gene_add_mean_unconstrained", "_guide.locs.s_g_gene_add_alpha_e_inv_unconstrained", "_guide.locs.s_g_gene_add_unconstrained", "_guide.locs.alpha_g_phi_hyp_unconstrained", "_guide.locs.alpha_g_inverse_unconstrained", "_guide.scales.per_cluster_mu_fg_unconstrained", "_guide.scales.detection_tech_gene_tg_unconstrained", "_guide.scales.detection_mean_y_e_unconstrained", "_guide.scales.s_g_gene_add_alpha_hyp_unconstrained", "_guide.scales.s_g_gene_add_mean_unconstrained", "_guide.scales.s_g_gene_add_alpha_e_inv_unconstrained", "_guide.scales.s_g_gene_add_unconstrained", "_guide.scales.alpha_g_phi_hyp_unconstrained", "_guide.scales.alpha_g_inverse_unconstrained".Versions:
cell2location 0.1.5
lightning 2.5.6
pyro-ppl 1.8.4
scvi-tools 1.4.0
torch 2.8.0+cu129
Related other question
Can save auto-detect whether the model should be saved with save_anndata=True? Naively I would expect this code to work no problem, but load generates a different error:
mod.save('mod-save')
mod.load('mod-save')
...
ValueError: Save path contains no saved anndata and no adata was passed.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working