From 76da4d7b3e2937fecda44afb60e37cdb339ac784 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Sun, 23 Feb 2025 10:17:46 -0500 Subject: [PATCH 1/7] Add gitignore to avoid committing pycaches --- .gitignore | 174 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1800114 --- /dev/null +++ b/.gitignore @@ -0,0 +1,174 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc \ No newline at end of file From e79785e02a484e5f9ab379846b78b4b5f939bc37 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Sun, 23 Feb 2025 10:18:17 -0500 Subject: [PATCH 2/7] Remove circular dependency on model_prototype_contrastive.py --- sccello/src/utils/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sccello/src/utils/helpers.py b/sccello/src/utils/helpers.py index 539f784..b7b5c13 100644 --- a/sccello/src/utils/helpers.py +++ b/sccello/src/utils/helpers.py @@ -11,7 +11,6 @@ EXC_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) from sccello.src.utils import logging_util -from sccello.src.model_prototype_contrastive import PrototypeContrastiveForMaskedLM def set_seed(seed): From 0c6e6fe1ec9caaf7f238ef2061b99c85a5eb341c Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Sun, 23 Feb 2025 10:20:47 -0500 Subject: [PATCH 3/7] Add quickstart loading model script --- README.md | 5 +++++ sccello/script/run_load_model.py | 8 ++++++++ 2 files changed, 13 insertions(+) create mode 100644 sccello/script/run_load_model.py diff --git a/README.md b/README.md index 4fcb44e..ef86686 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,11 @@ from sccello.src.model_prototype_contrastive import PrototypeContrastiveForMaske model = PrototypeContrastiveForMaskedLM.from_pretrained("katarinayuan/scCello-zeroshot", output_hidden_states=True) ``` +or run to ensure the model can load properly +``` +python ./sccello/script/run_load_model.py +``` + * for linear probing tasks (see details in sccello/script/run_cell_type_classification.py) ``` from sccello.src.model_prototype_contrastive import PrototypeContrastiveForSequenceClassification diff --git a/sccello/script/run_load_model.py b/sccello/script/run_load_model.py new file mode 100644 index 0000000..c6a2bce --- /dev/null +++ b/sccello/script/run_load_model.py @@ -0,0 +1,8 @@ +import os +import sys +EXC_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +sys.path.append(EXC_DIR) + +from sccello.src.model_prototype_contrastive import PrototypeContrastiveForMaskedLM + +model = PrototypeContrastiveForMaskedLM.from_pretrained("katarinayuan/scCello-zeroshot", output_hidden_states=True) From 00ef32b75bb1272fabace2122cfa68c62d25a7fe Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 24 Feb 2025 22:50:27 -0500 Subject: [PATCH 4/7] Attempt to build a subset of the data --- sccello/script/run_cell_type_clustering.py | 14 +++++++++++--- sccello/src/data/dataset.py | 4 ++-- sccello/src/utils/data_loading.py | 11 +++++++++++ sccello/src/utils/evaluating.py | 2 +- 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/sccello/script/run_cell_type_clustering.py b/sccello/script/run_cell_type_clustering.py index 8eba8ea..2d1191c 100644 --- a/sccello/script/run_cell_type_clustering.py +++ b/sccello/script/run_cell_type_clustering.py @@ -36,6 +36,7 @@ def parse_args(): parser.add_argument("--model_source", type=str, default="model_prototype_contrastive") parser.add_argument("--indist", type=int, default=0) + parser.add_argument('--sample_outdist', action='store_true', help='Samples the out of distribution') parser.add_argument("--normalize", type=int, default=0) parser.add_argument("--pass_cell_cls", type=int, default=0) @@ -59,7 +60,7 @@ def parse_args(): return args def solve_clustering(args, all_datasets): - trainset, test_data1, test_data2, label_dict = all_datasets + _, test_data1, test_data2, _ = all_datasets args.output_dir = helpers.create_downstream_output_dir(args) @@ -119,10 +120,17 @@ def solve_clustering(args, all_datasets): names = CellTypeClassificationDataset.subsets["frac"] if args.indist: names = [names[0]] + + if args.sample_outdist: + names = [names[1]] for name in names: # every data is tested under the same seeded setting helpers.set_seed(args.seed) args.data_source = f"frac_{name}" - all_datasets = data_loading.get_fracdata(name, args.data_branch, args.indist, False) - solve_clustering(args, all_datasets) \ No newline at end of file + all_datasets = ( + data_loading.get_fracdata_sample(name) + if args.sample_outdist else + data_loading.get_fracdata(name, args.data_branch, args.indist, False) + ) + solve_clustering(args, all_datasets) diff --git a/sccello/src/data/dataset.py b/sccello/src/data/dataset.py index 4089dee..32b91d0 100644 --- a/sccello/src/data/dataset.py +++ b/sccello/src/data/dataset.py @@ -32,8 +32,8 @@ class CellTypeClassificationDataset(): @classmethod def create_dataset(cls, subset_name="celltype"): assert subset_name in cls.subsets["frac"] - valid_data = load_dataset(f"katarinayuan/scCello_ood_{subset_name}_data1")["train"] - test_data = load_dataset(f"katarinayuan/scCello_ood_{subset_name}_data2")["train"] + valid_data = load_dataset(f"katarinayuan/scCello_ood_{subset_name}_data1", split="train") + test_data = load_dataset(f"katarinayuan/scCello_ood_{subset_name}_data2", split="train") valid_data = valid_data.rename_column("cell_type", "label") test_data = test_data.rename_column("cell_type", "label") diff --git a/sccello/src/utils/data_loading.py b/sccello/src/utils/data_loading.py index a9a9709..229eebc 100644 --- a/sccello/src/utils/data_loading.py +++ b/sccello/src/utils/data_loading.py @@ -112,6 +112,17 @@ def get_prestored_data(data_file_name): else: raise NotImplementedError +def get_fracdata_sample(name, num_proc=12): + + from sccello.src.data.dataset import CellTypeClassificationDataset + data1, data2 = CellTypeClassificationDataset.create_dataset(name) + data1 = data1.rename_column("gene_token_ids", "input_ids") + data2 = data2.rename_column("gene_token_ids", "input_ids") + + data1, eval_label_type_idmap = helpers.process_label_type(data1, num_proc, "label") + data2, test_label_type_idmap = helpers.process_label_type(data2, num_proc, "label") + return None, data1, data2, None + def get_fracdata(name, data_branch, indist, batch_effect, num_proc=12): from sccello.src.data.dataset import CellTypeClassificationDataset diff --git a/sccello/src/utils/evaluating.py b/sccello/src/utils/evaluating.py index 533b864..69b34c2 100644 --- a/sccello/src/utils/evaluating.py +++ b/sccello/src/utils/evaluating.py @@ -8,7 +8,7 @@ import ipdb import torch -assert torch.cuda.is_available() +# assert torch.cuda.is_available() import cupy as cp from cuml.metrics.cluster import silhouette_score as cu_silhouette_score From ffed57356d493c143f5816ce419b00d928feeb90 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Mon, 24 Feb 2025 23:17:54 -0500 Subject: [PATCH 5/7] Add back in circular dependency which is needed --- sccello/src/utils/helpers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sccello/src/utils/helpers.py b/sccello/src/utils/helpers.py index b7b5c13..54aaa68 100644 --- a/sccello/src/utils/helpers.py +++ b/sccello/src/utils/helpers.py @@ -76,6 +76,8 @@ def create_downstream_output_dir(args): def load_model_inference(args): + from sccello.src.model_prototype_contrastive import PrototypeContrastiveForMaskedLM + model = eval(args.model_class).from_pretrained(args.pretrained_ckpt, output_hidden_states=True).to("cuda") for param in model.bert.parameters(): param.requires_grad = False From b844bd6fedaa7e73cc75c34602e4ffaaf1447745 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Tue, 25 Feb 2025 00:08:05 -0500 Subject: [PATCH 6/7] Provide a subset of the sample --- sccello/script/run_cell_type_clustering.py | 3 ++- sccello/src/utils/data_loading.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sccello/script/run_cell_type_clustering.py b/sccello/script/run_cell_type_clustering.py index 2d1191c..94c8847 100644 --- a/sccello/script/run_cell_type_clustering.py +++ b/sccello/script/run_cell_type_clustering.py @@ -37,6 +37,7 @@ def parse_args(): parser.add_argument("--indist", type=int, default=0) parser.add_argument('--sample_outdist', action='store_true', help='Samples the out of distribution') + parser.add_argument('--num_samples', type=int, default=10, help='If sample_outdist is set, this sets the number of samples') parser.add_argument("--normalize", type=int, default=0) parser.add_argument("--pass_cell_cls", type=int, default=0) @@ -129,7 +130,7 @@ def solve_clustering(args, all_datasets): helpers.set_seed(args.seed) args.data_source = f"frac_{name}" all_datasets = ( - data_loading.get_fracdata_sample(name) + data_loading.get_fracdata_sample(name, num_samples=args.num_samples) if args.sample_outdist else data_loading.get_fracdata(name, args.data_branch, args.indist, False) ) diff --git a/sccello/src/utils/data_loading.py b/sccello/src/utils/data_loading.py index 229eebc..e0913df 100644 --- a/sccello/src/utils/data_loading.py +++ b/sccello/src/utils/data_loading.py @@ -112,15 +112,18 @@ def get_prestored_data(data_file_name): else: raise NotImplementedError -def get_fracdata_sample(name, num_proc=12): +def get_fracdata_sample(name, num_proc=12, num_samples=10): from sccello.src.data.dataset import CellTypeClassificationDataset data1, data2 = CellTypeClassificationDataset.create_dataset(name) data1 = data1.rename_column("gene_token_ids", "input_ids") data2 = data2.rename_column("gene_token_ids", "input_ids") + data1, data2 = data1.select(range(num_samples)), data2.select(range(num_samples)) + data1, eval_label_type_idmap = helpers.process_label_type(data1, num_proc, "label") data2, test_label_type_idmap = helpers.process_label_type(data2, num_proc, "label") + return None, data1, data2, None def get_fracdata(name, data_branch, indist, batch_effect, num_proc=12): From 67778cbbc48d9ac437037441d2796590a8a56c75 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Sat, 8 Mar 2025 15:14:05 -0500 Subject: [PATCH 7/7] Undo the torch commenting out of assertion --- sccello/script/run_novel_cell_type_classification.py | 8 +++++++- sccello/src/utils/evaluating.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sccello/script/run_novel_cell_type_classification.py b/sccello/script/run_novel_cell_type_classification.py index c360ed5..87ea84d 100644 --- a/sccello/script/run_novel_cell_type_classification.py +++ b/sccello/script/run_novel_cell_type_classification.py @@ -65,7 +65,13 @@ def get_cell_type_labelid2nodeid(cell_type_idmap, clid2nodeid): # e.g., {'placental pericyte': CL:2000078, ...} name2clid = {v.lower(): k for k, v in clid2name.items()} - cell_type2nodeid = dict([(k, clid2nodeid[name2clid[cell_type2name[k]]]) if cell_type2name[k] in name2clid else (k, -1) for k in cell_type2name]) + cell_type2nodeid = dict( + [ + (k, clid2nodeid[name2clid[cell_type2name[k]]]) + if cell_type2name[k] in name2clid else (k, -1) + for k in cell_type2name + ] + ) return cell_type2nodeid def load_cell_type_representation(args, model): diff --git a/sccello/src/utils/evaluating.py b/sccello/src/utils/evaluating.py index 69b34c2..533b864 100644 --- a/sccello/src/utils/evaluating.py +++ b/sccello/src/utils/evaluating.py @@ -8,7 +8,7 @@ import ipdb import torch -# assert torch.cuda.is_available() +assert torch.cuda.is_available() import cupy as cp from cuml.metrics.cluster import silhouette_score as cu_silhouette_score