Skip to content

Commit

Permalink
Merge pull request #15 from mvinyard/demo
Browse files Browse the repository at this point in the history
Demo
  • Loading branch information
mvinyard authored Jun 22, 2023
2 parents fd9b3a6 + 85fcb13 commit 506766d
Show file tree
Hide file tree
Showing 12 changed files with 347 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Create [`PyTorch Datasets`](https://pytorch.org/tutorials/beginner/basics/data_t

## Installation

Install from PYPI (current version: **[`0.0.23`](https://pypi.org/project/torch-adata/)**):
Install from PYPI (current version: **[`0.0.24`](https://pypi.org/project/torch-adata/)**):
```BASH
pip install torch-adata
```
Expand Down
15 changes: 10 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__module_name__ = "setup.py"
__doc__ = """PYPI package distribution setup."""
__author__ = ", ".join(["Michael E. Vinyard"])
__email__ = ", ".join(["[email protected]"])
__email__ = ", ".join(["[email protected]"])


# -- import packages: ----------------------------------------------------------
Expand All @@ -15,10 +15,10 @@
# -- run setup: ----------------------------------------------------------------
setuptools.setup(
name="torch-adata",
version="0.0.23",
version="0.0.24",
python_requires=">3.9.0",
author="Michael E. Vinyard",
author_email="mvinyard@g.harvard.edu",
author_email="mvinyard[email protected]",
url=None,
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
Expand All @@ -28,11 +28,16 @@
"anndata>=0.9.1",
"licorice_font>=0.0.3",
"lightning>=2.0.1",
"torch>=2.0",
"torch>=2.0",
# "numpy==1.23",
"scanpy==1.9.3",
"scikit-learn==1.2.2",
"webfiles",
"vinplots",
],
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.9",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Bio-Informatics",
],
Expand Down
4 changes: 2 additions & 2 deletions torch_adata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# -- specify package version: --------------------------------------------------
__version__ = "0.0.23"
__version__ = "0.0.24"


# -- import modules: -----------------------------------------------------------
Expand All @@ -18,4 +18,4 @@
# -- import API-hidden core modules (for dev): ---------------------------------
from ._core import _core_ancilliary as _core
from . import _utils
from . import _demo
from . import _demo
5 changes: 5 additions & 0 deletions torch_adata/_demo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

from ._plot_umap import PlotUMAP
from ._pbmc_3k import PBMC3k
from ._larry import LARRY
from ._demo_data import DemoData
26 changes: 26 additions & 0 deletions torch_adata/_demo/_demo_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@


from .. import _utils as utils
from ._pbmc_3k import PBMC3k
from ._larry import LARRY


class DemoData(utils.ABCParse):
def __init__(self):
""""""

@property
def PBMC3k(self):
if not hasattr(self, "_pbmcs"):
self._pbmcs = PBMC3k()
else:
self._pbmcs.plot()
return self._pbmcs.adata

@property
def LARRY(self):
if not hasattr(self, "_larry"):
self._larry = LARRY()
else:
self._larry.plot()
return self._larry.adata
56 changes: 56 additions & 0 deletions torch_adata/_demo/_larry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

import vinplots
import anndata
import os
import web_files

from .. import _utils as utils
from ._plot_umap import PlotUMAP

class LARRY(utils.ABCParse):
_HTTP_ADDRESS = "https://figshare.com/ndownloader/files/38171943"

def __init__(self, filename="adata.LARRY.h5ad", data_dir="data", silent=False):
self._config(locals())

def _config(self, kwargs):
""""""

self.__parse__(kwargs, public=[None])

self._INFO = utils.InfoMessage()

if not os.path.exists(self._data_dir):
os.mkdir(self._data_dir)

def _download(self):
msg = "Downloading {:.<10}...".format(self._filename)
self._INFO(msg, end=" ")
f = web_files.WebFile(
http_address=self._HTTP_ADDRESS,
local_path=self.local_path,
)
f.download()
print("Done.")

@property
def local_path(self):
return os.path.join(self._data_dir, self._filename)

@property
def adata(self):
if not hasattr(self, "_adata"):
if not os.path.exists(self.local_path):
self._download()
self._adata = anndata.read_h5ad(self.local_path)
self._adata.uns["cmap"] = vinplots.colors.LARRY_in_vitro
if not self._silent:
print(self._adata)
self.plot()
return self._adata

def plot(self):
umap_plot = PlotUMAP(
self._adata, title="LARRY hematopoiesis"
)
umap_plot(groupby="Cell type annotation")
39 changes: 39 additions & 0 deletions torch_adata/_demo/_pbmc_3k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

# -- import packages: --------------------------------------------------------
import scanpy as sc
import pandas as pd
import vinplots


from .. import _utils as utils
from ._plot_umap import PlotUMAP

class PBMC3k(utils.ABCParse):
def __init__(self, silent=False):
self.__parse__(locals(), public=[None])

def _configure(self):
_adata = sc.datasets.pbmc3k_processed()
cell_type_df = pd.DataFrame(_adata.obs["louvain"].unique()).reset_index()
cell_type_df.columns = ["cell_type_idx", "louvain"]
df_obs = _adata.obs.merge(cell_type_df, on="louvain", how="left")
df_obs.index = df_obs.index.astype(str)
_adata.obs = df_obs
_adata.uns["cmap"] = vinplots.colors.pbmc3k
return _adata

@property
def adata(self):
if not hasattr(self, "_adata"):
self._adata = self._configure()
if not self._silent:
print(self._adata)
self.plot()
return self._adata

def plot(self):
umap_plot = PlotUMAP(
self._adata, title="10x PBMCs (~3k cells)"
)
umap_plot(groupby="louvain")

65 changes: 65 additions & 0 deletions torch_adata/_demo/_plot_umap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

from .. import _utils as utils

import vinplots


class PlotUMAP(utils.ABCParse):
def __init__(
self,
adata,
title="10x PBMCs (~3k cells)",
wspace=0.1,
rm_ticks=True,
spines_to_delete=["top", "bottom", "right", "left"],
labels=["UMAP-1", "UMAP-2"],
use_key="X_umap",
cmap_key="cmap",
alpha=0.65,
):
self.__parse__(locals(), public=["adata"])

def __config__(self):
qiuck_plot_params = utils.FunctionInspector(func=vinplots.quick_plot)

self.fig, self.axes = vinplots.quick_plot(
nplots=1, ncols=1, **qiuck_plot_params(self._PARAMS)
)

@property
def CMAP(self):
return self.adata.uns[self._cmap_key]

def _plot_grouped(self):
for en, (group, group_df) in enumerate(self.adata.obs.groupby(self.groupby)):
if group in self.plot_behind:
z = 0
else:
z = en + 1
xu = self.adata[group_df.index].obsm[self._use_key]
self.axes[0].scatter(
xu[:, 0],
xu[:, 1],
label=group,
color=self.CMAP[group],
alpha=self._alpha,
zorder = z,
)

def _format(self):
for ax in self.axes:
ax.set_xlabel(self._labels[0])
ax.set_ylabel(self._labels[1])
self.axes[0].legend(loc=(1, 0.25), edgecolor="w")
self.axes[0].set_title(self._title)

def __call__(
self,
groupby=None,
plot_behind = ["Undifferentiated"],
):
self.__update__(locals())

self.__config__()
self._plot_grouped()
self._format()
4 changes: 4 additions & 0 deletions torch_adata/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

from ._abc_parse import ABCParse
from ._info_message import InfoMessage
from ._function_inspector import FunctionInspector
73 changes: 73 additions & 0 deletions torch_adata/_utils/_abc_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from abc import ABC
from typing import Dict, List, Any
NoneType = type(None)

class ABCParse(ABC):
_BUILT = False

def __init__(self, *args, **kwargs):
"""
we avoid defining things in __init__ because this subsequently
mandates the use of `super().__init__()`
"""
pass

def __build__(self) -> None:

self._PARAMS = {}
self._IGNORE = ["self", "__class__"]
self._stored_private = []
self._stored_public = []

self._BUILT = True

def __set__(
self, key: str, val: Any, public: List = [], private: List = []
) -> None:

self._PARAMS[key] = val
if (key in private) and (not key in public):
self._stored_private.append(key)
key = f"_{key}"
else:
self._stored_public.append(key)
setattr(self, key, val)

def __set_existing__(self, key: str, val: Any) -> None:

self._PARAMS[key] = val

if key in self._stored_private:
key = f"_{key}"
setattr(self, key, val)

@property
def _STORED(self) -> List:
return self._stored_private + self._stored_public

def __parse__(
self, kwargs: Dict, public: List = [], private: List = [], ignore: List = []
):

if not self._BUILT:
self.__build__()

self._IGNORE += ignore

if len(public) > 0:
private = list(kwargs.keys())

for key, val in kwargs.items():
if not key in self._IGNORE:
self.__set__(key, val, public, private)

def __update__(self, kwargs: dict, public: List = [], private: List = []) -> None:

if not self._BUILT:
self.__build__()

for key, val in kwargs.items():
if not isinstance(val, NoneType) and key in self._STORED:
self.__set_existing__(key, val)
elif not isinstance(val, NoneType) and not key in self._IGNORE:
self.__set__(key, val, public, private)
53 changes: 53 additions & 0 deletions torch_adata/_utils/_function_inspector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@


import inspect


class FunctionInspector:
_FUNC_KWARGS = {}

def __init__(self, func, ignore=[]):
"""
Parameters:
-----------
func
type: Any
ignore
type: list
Returns:
--------
None
"""
self._func = func
self._ignore = ignore

def _extract_func_params(self, func):
return list(inspect.signature(func).parameters.keys())

@property
def FUNC_PARAMS(self):
return self._extract_func_params(self._func)

def forward(self, key, val):
if (key in self.FUNC_PARAMS) and (not key in self._ignore):
self._FUNC_KWARGS[key] = val

def __call__(self, kwargs) -> dict:
"""
Parameters:
-----------
kwargs
type: dict
Returns:
--------
func_kwargs
type: dict
"""

for key, val in kwargs.items():
self.forward(key, val)

return self._FUNC_KWARGS
Loading

0 comments on commit 506766d

Please sign in to comment.