Skip to content

Batch export for posteriors #2

@sezginerr

Description

@sezginerr

Hello,

I have been trying to export posteriors batch by batch with a defined batch size, rather than exporting them in a single batch. This is achieved by concatenating local variables exported from batches along the cell dimension in the scvi library. I believe the scvi library assumes that the variables defined with 'obs_plate' are the local variables, while the global variables (any other variable) are sampled from the last batch. However, for the cell2fate model, local and global variables seem to the same, as the guide method in the cell2fate model is an instance of 'poutine.messenger.Messenger'. In this case, posterior sampling does not use the 'return_sites' information and seems to return everything. Here is the function that returns one posterior sampling in the scvi library:

def _get_one_posterior_sample(
    self,
    args,
    kwargs,
    return_sites: Optional[list] = None,
    return_observed: bool = False,
  ):
    if isinstance(self.module.guide, poutine.messenger.Messenger):
        # This already includes trace-replay behavior.
        sample = self.module.guide(*args, **kwargs)
    else:
        guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs)
        model_trace = poutine.trace(
            poutine.replay(self.module.model, guide_trace)
        ).get_trace(*args, **kwargs)
sample = {
            name: site["value"]
            for name, site in model_trace.nodes.items()
            if (
                (site["type"] == "sample")  # sample statement
                and (
                    (return_sites is None) or (name in return_sites)
                )  # selected in return_sites list
                and (
                    (
                        (not site.get("is_observed", True)) or return_observed
                    )  # don't save observed unless requested
                    or (site.get("infer", False).get("_deterministic", False))
                )  # unless it is deterministic
                and not isinstance(
                    site.get("fn", None), poutine.subsample_messenger._Subsample
                )  # don't save plates
            )
        } 

To adapt batch sampling in the cell2fate model, I made a few changes to the function. Firstly, I removed the global sampling, which would sample everything in the last batch again. I'm not sure if it's correct to do so, as the RF/GO analysis slightly changes when I use single batch sampling with or without global variable sampling. Secondly, I defined the variables that have cell number dimensions and concatenated them in the cell number dimension after sampling. For the other variables, I simply applied element-wise averaging between batches. Here is the adapted function that I use:

Details
def _posterior_samples_minibatch(
        self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs
    ):
        samples = dict()

        _, device = parse_use_gpu_arg(use_gpu)

        batch_size = batch_size if batch_size is not None else settings.batch_size

        train_dl = AnnDataLoader(
            self.adata_manager, shuffle=False, batch_size=batch_size
        )
        # sample local parameters
        i = 0
        cell_specific=['t_c', 'T_c', 'mu_expression', 'detection_y_c', 'mu', 'data_target']
        for tensor_dict in track(
            train_dl,
            style="tqdm",
            description="Sampling local variables, batch: ",
        ):
            args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
            args = [a.to(device) for a in args]
            kwargs = {k: v.to(device) for k, v in kwargs.items()}
            self.to_device(device)
            if i == 0:
                return_observed = getattr(sample_kwargs, "return_observed", False)
                obs_plate_sites = self._get_obs_plate_sites(
                    args, kwargs, return_observed=return_observed
                )
                if len(obs_plate_sites) == 0:
                    # if no local variables - don't sample
                    break
                obs_plate_dim = list(obs_plate_sites.values())[0]
                sample_kwargs_obs_plate = sample_kwargs.copy()
                sample_kwargs_obs_plate[
                    "return_sites"
                ] = self._get_obs_plate_return_sites(
                    sample_kwargs["return_sites"], list(obs_plate_sites.keys())
                )

                sample_kwargs_obs_plate["show_progress"] = False

                samples = self._get_posterior_samples(
                    args, kwargs, **sample_kwargs_obs_plate
                )
            else:
                samples_ = self._get_posterior_samples(
                    args, kwargs, **sample_kwargs_obs_plate
                )

                num_cells_in_batch = samples_['t_c'].shape[1]
                for k in samples.keys():

                    if samples_[k].ndim >1:
                        if k in cell_specific:
                            samples[k] =  np.concatenate([samples[k], samples_[k]], axis=1)
                    else:
                        ratio_cells = num_cells_in_batch / batch_size
                        samples[k] = (samples[k] * i + samples_[k] * ratio_cells) / (i+ratio_cells)

                i += 1

            i += 1

         self.module.to(device)

        return samples

<\details>

This approach seems to work well, as the results with batch export and without batch export are quite similar. However, there is again a slight difference between RF/GO analyses. I believe that difference arises because, even though the seed is set to 0, a change in the sampling strategy will alter the random function's output. I also tried to sample the variables that are not cell-specific from only the last batch, but it does not seem to work well, even though it should work theoretically, as the variables that are not cell-specific are sampled from already learned distributions.

I wanted to ask if my approach in line with the cell2fate assumption. Thank you very much for your help.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions