Skip to content

model load: Missing key(s) in state_dict #430

@aowenson-imm

Description

@aowenson-imm

RegressionModel.load fails on a save that I have just created.


Minimal code sample (that we can run without your data, using public data)

import scanpy as sc
import cell2location
from cell2location.models import RegressionModel

adata = sc.datasets.visium_sge(sample_id="V1_Human_Lymph_Node")
adata.obs['sample'] = list(adata.uns['spatial'].keys())[0]
adata.var['SYMBOL'] = adata.var_names
adata.var.set_index('gene_ids', drop=True, inplace=True)
# find mitochondria-encoded (MT) genes
adata.var['MT_gene'] = [gene.startswith('MT-') for gene in adata.var['SYMBOL']]
# remove MT genes for spatial mapping (keeping their counts in the object)
adata.obsm['MT'] = adata[:, adata.var['MT_gene'].values].X.toarray()
adata = adata[:, ~adata.var['MT_gene'].values]

adata_ref = sc.read(
    f'./data/sc.h5ad',
    backup_url='https://cell2location.cog.sanger.ac.uk/paper/integrated_lymphoid_organ_scrna/RegressionNBV4Torch_57covariates_73260cells_10237genes/sc.h5ad'
)
adata_ref.var['SYMBOL'] = adata_ref.var.index
# rename 'GeneID-2' as necessary for your data
adata_ref.var.set_index('GeneID-2', drop=True, inplace=True)
# delete unnecessary raw slot (to be removed in a future version of the tutorial)
del adata_ref.raw
# filter the object
selected = cell2location.utils.filtering.filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
adata_ref = adata_ref[:, selected].copy()

cell2location.models.RegressionModel.setup_anndata(
    adata=adata_ref, 
    batch_key='Sample',
    labels_key='Subset',
    categorical_covariate_keys=['Method'])
mod = RegressionModel(adata_ref)

# save
mod.save('mod-save-with-anndata', save_anndata=True, overwrite=True)

# load
mod2 = RegressionModel.load('mod-save-with-anndata')
INFO     File mod-save-with-anndata/model.pt already downloaded                                                    
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /package/python-cbrg/current/3.11.3/lib/python3.11/s ...
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /package/python-cbrg/current/3.11.3/lib/python3.11/s ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /package/python-cbrg/current/3.11.3/lib/python3.11/s ...
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
/package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.
Epoch 1/273:   0%
 1/273 [00:02<09:10,  2.02s/it, v_num=1]
`Trainer.fit` stopped: `max_steps=1` reached.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[19], line 3
      1 # load
----> 2 mod2 = RegressionModel.load('mod-save-with-anndata')

File /package/python-cbrg/current/3.11.14/lib/python3.11/site-packages/scvi/model/base/_base_model.py:873, in BaseModelClass.load(cls, dir_path, adata, accelerator, device, prefix, backup_url, datamodule, allowed_classes_names_list)
    871 else:
    872     model.module.on_load(model, pyro_param_store=pyro_param_store)
--> 873 model.module.load_state_dict(model_state_dict)
    875 model.to_device(device)
    877 model.module.eval()

File /package/python-cbrg/current/3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py:2624, in Module.load_state_dict(self, state_dict, strict, assign)
   2616         error_msgs.insert(
   2617             0,
   2618             "Missing key(s) in state_dict: {}. ".format(
   2619                 ", ".join(f'"{k}"' for k in missing_keys)
   2620             ),
   2621         )
   2623 if len(error_msgs) > 0:
-> 2624     raise RuntimeError(
   2625         "Error(s) in loading state_dict for {}:\n\t{}".format(
   2626             self.__class__.__name__, "\n\t".join(error_msgs)
   2627         )
   2628     )
   2629 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for RegressionBaseModule:
	Missing key(s) in state_dict: "_guide.locs.per_cluster_mu_fg_unconstrained", "_guide.locs.detection_tech_gene_tg_unconstrained", "_guide.locs.detection_mean_y_e_unconstrained", "_guide.locs.s_g_gene_add_alpha_hyp_unconstrained", "_guide.locs.s_g_gene_add_mean_unconstrained", "_guide.locs.s_g_gene_add_alpha_e_inv_unconstrained", "_guide.locs.s_g_gene_add_unconstrained", "_guide.locs.alpha_g_phi_hyp_unconstrained", "_guide.locs.alpha_g_inverse_unconstrained", "_guide.scales.per_cluster_mu_fg_unconstrained", "_guide.scales.detection_tech_gene_tg_unconstrained", "_guide.scales.detection_mean_y_e_unconstrained", "_guide.scales.s_g_gene_add_alpha_hyp_unconstrained", "_guide.scales.s_g_gene_add_mean_unconstrained", "_guide.scales.s_g_gene_add_alpha_e_inv_unconstrained", "_guide.scales.s_g_gene_add_unconstrained", "_guide.scales.alpha_g_phi_hyp_unconstrained", "_guide.scales.alpha_g_inverse_unconstrained".

Versions:

cell2location 0.1.5
lightning 2.5.6
pyro-ppl 1.8.4
scvi-tools 1.4.0
torch 2.8.0+cu129

Related other question

Can save auto-detect whether the model should be saved with save_anndata=True? Naively I would expect this code to work no problem, but load generates a different error:

mod.save('mod-save')
mod.load('mod-save')
...
ValueError: Save path contains no saved anndata and no adata was passed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions