Skip to content

Commit

Permalink
add shuffling time capability
Browse files Browse the repository at this point in the history
  • Loading branch information
mvinyard committed May 9, 2023
1 parent 0ca693b commit 8277c63
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Create [`PyTorch Datasets`](https://pytorch.org/tutorials/beginner/basics/data_t

## Installation

Install from PYPI (current version: **[`0.0.22`](https://pypi.org/project/torch-adata/)**):
Install from PYPI (current version: **[`0.0.23`](https://pypi.org/project/torch-adata/)**):
```BASH
pip install torch-adata
```
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# -- run setup: ----------------------------------------------------------------
setuptools.setup(
name="torch-adata",
version="0.0.22",
version="0.0.23rc0",
python_requires=">3.9.0",
author="Michael E. Vinyard",
author_email="[email protected]",
Expand Down
2 changes: 1 addition & 1 deletion torch_adata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# -- specify package version: --------------------------------------------------
__version__ = "0.0.22"
__version__ = "0.0.23rc0"


# -- import modules: -----------------------------------------------------------
Expand Down
27 changes: 22 additions & 5 deletions torch_adata/_core/_lightning/_lightning_anndata_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lightning import LightningDataModule
from licorice_font import font_format
import pandas as pd
import numpy as np
import anndata
import torch

Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
test_key="test",
predict_key="predict",
shuffle=True,
shuffle_time_labels = False,
silent=True,
**kwargs,
):
Expand Down Expand Up @@ -78,8 +80,7 @@ def configure_train_val_split(self):
)
self._data_keys = train_val_split.configure_validation(self.init_train_adata, force_reallocate=True)

def subset_adata(self, key: str):

def subset_adata(self, key: str):
self.df = self.adata.obs.copy()
access_key = self.data_keys[key]
if not hasattr(self.df, access_key):
Expand All @@ -96,17 +97,33 @@ def to_dataset(self, key: str) -> torch.utils.data.Dataset:
adata = getattr(self, "{}_adata".format(key))
return AnnDataset(adata=adata, **self.AnnDatasetKWARGS)

def shuffle_time_labels(self):

df = self._adata.obs.copy()
non_t0 = df.loc[df['t'] != 0]['t']

shuffled_t = np.zeros(len(df))
shuffled_t[non_t0.index.astype(int)] = np.random.choice(non_t0.values, len(non_t0))
self._adata.obs["t"] = shuffled_t

def _return_loader(self, dataset_key):

if dataset_key in ["train", "val"]:
# could probably add a flag to make this optional if desired
# these happen every time the loader is called, which is useful when
# you want to shuffle the organization of labels.
# maybe a better way to do it....
self.configure_train_val_split()


if self.hparams["shuffle_time_labels"]:
self.shuffle_time_labels()

if dataset_key == "train":
shuffle=self.hparams["shuffle"]
else:
shuffle = False

shuffle = shuffle_labels = False


if dataset_key == "test":
if not hasattr(self, "n_test_cells"):
self.setup(stage="test")
Expand Down

0 comments on commit 8277c63

Please sign in to comment.