-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Hello,
I have been trying to export posteriors batch by batch with a defined batch size, rather than exporting them in a single batch. This is achieved by concatenating local variables exported from batches along the cell dimension in the scvi library. I believe the scvi library assumes that the variables defined with 'obs_plate' are the local variables, while the global variables (any other variable) are sampled from the last batch. However, for the cell2fate model, local and global variables seem to the same, as the guide method in the cell2fate model is an instance of 'poutine.messenger.Messenger'. In this case, posterior sampling does not use the 'return_sites' information and seems to return everything. Here is the function that returns one posterior sampling in the scvi library:
def _get_one_posterior_sample(
self,
args,
kwargs,
return_sites: Optional[list] = None,
return_observed: bool = False,
):
if isinstance(self.module.guide, poutine.messenger.Messenger):
# This already includes trace-replay behavior.
sample = self.module.guide(*args, **kwargs)
else:
guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(
poutine.replay(self.module.model, guide_trace)
).get_trace(*args, **kwargs)
sample = {
name: site["value"]
for name, site in model_trace.nodes.items()
if (
(site["type"] == "sample") # sample statement
and (
(return_sites is None) or (name in return_sites)
) # selected in return_sites list
and (
(
(not site.get("is_observed", True)) or return_observed
) # don't save observed unless requested
or (site.get("infer", False).get("_deterministic", False))
) # unless it is deterministic
and not isinstance(
site.get("fn", None), poutine.subsample_messenger._Subsample
) # don't save plates
)
} To adapt batch sampling in the cell2fate model, I made a few changes to the function. Firstly, I removed the global sampling, which would sample everything in the last batch again. I'm not sure if it's correct to do so, as the RF/GO analysis slightly changes when I use single batch sampling with or without global variable sampling. Secondly, I defined the variables that have cell number dimensions and concatenated them in the cell number dimension after sampling. For the other variables, I simply applied element-wise averaging between batches. Here is the adapted function that I use:
Details
def _posterior_samples_minibatch(
self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs
):
samples = dict()
_, device = parse_use_gpu_arg(use_gpu)
batch_size = batch_size if batch_size is not None else settings.batch_size
train_dl = AnnDataLoader(
self.adata_manager, shuffle=False, batch_size=batch_size
)
# sample local parameters
i = 0
cell_specific=['t_c', 'T_c', 'mu_expression', 'detection_y_c', 'mu', 'data_target']
for tensor_dict in track(
train_dl,
style="tqdm",
description="Sampling local variables, batch: ",
):
args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
args = [a.to(device) for a in args]
kwargs = {k: v.to(device) for k, v in kwargs.items()}
self.to_device(device)
if i == 0:
return_observed = getattr(sample_kwargs, "return_observed", False)
obs_plate_sites = self._get_obs_plate_sites(
args, kwargs, return_observed=return_observed
)
if len(obs_plate_sites) == 0:
# if no local variables - don't sample
break
obs_plate_dim = list(obs_plate_sites.values())[0]
sample_kwargs_obs_plate = sample_kwargs.copy()
sample_kwargs_obs_plate[
"return_sites"
] = self._get_obs_plate_return_sites(
sample_kwargs["return_sites"], list(obs_plate_sites.keys())
)
sample_kwargs_obs_plate["show_progress"] = False
samples = self._get_posterior_samples(
args, kwargs, **sample_kwargs_obs_plate
)
else:
samples_ = self._get_posterior_samples(
args, kwargs, **sample_kwargs_obs_plate
)
num_cells_in_batch = samples_['t_c'].shape[1]
for k in samples.keys():
if samples_[k].ndim >1:
if k in cell_specific:
samples[k] = np.concatenate([samples[k], samples_[k]], axis=1)
else:
ratio_cells = num_cells_in_batch / batch_size
samples[k] = (samples[k] * i + samples_[k] * ratio_cells) / (i+ratio_cells)
i += 1
i += 1
self.module.to(device)
return samples<\details>
This approach seems to work well, as the results with batch export and without batch export are quite similar. However, there is again a slight difference between RF/GO analyses. I believe that difference arises because, even though the seed is set to 0, a change in the sampling strategy will alter the random function's output. I also tried to sample the variables that are not cell-specific from only the last batch, but it does not seem to work well, even though it should work theoretically, as the variables that are not cell-specific are sampled from already learned distributions.
I wanted to ask if my approach in line with the cell2fate assumption. Thank you very much for your help.