Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Geneformer modelcard and loss eval script #392

Merged
merged 12 commits into from
Nov 6, 2024
Merged
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
polinabinder1 marked this conversation as resolved.
Show resolved Hide resolved
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
61 changes: 50 additions & 11 deletions docs/docs/models/geneformer.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Geneformer
NOTE: this document references performance numbers and runtime engines that are from the bionemo v1 variant of the model.
These numbers will be updated in a coming release to reflect the new bionemo v2 codebase. The model architecture and
training information will be the same, as checkpoints are converted from bionemo v1 format to v2 format, however
performance benchmarks need to be updated to reflect the latest code. Accuracy should be the same within small epsilon
since we have tests in place showing model equivalency between the two versions.
!!! note "Current checkpoints trained in BioNeMo1"

This document references performance numbers and runtime engines that are from the bionemo v1 variant of the model.
jstjohn marked this conversation as resolved.
Show resolved Hide resolved
These numbers will be updated in a coming release to reflect the new bionemo v2 codebase. The model architecture and
training information will be the same, as checkpoints are converted from bionemo v1 format to v2 format. Benchmarks below
are annotated with which version of bionemo generated them. Accuracy should be the same within a small epsilon
since we have tests in place showing model equivalency between the two versions.

## Model Overview

Expand Down Expand Up @@ -158,6 +160,15 @@ NVIDIA believes Trustworthy AI is a shared responsibility and we have establishe
This checkpoint was trained for approximately 11 epochs through the CELLxGENE split. Training was performed on 8 servers with 8 A100 GPUs each for a total of 115430 steps of per-gpu micro batch size 32 and global batch size of 2048. Training took a total of 1 day, 20 hours and 19 minutes of wallclock time. As can be seen in the following image, training and validation curves both decreased fairly smoothly throughout the course of training. In fact validation (blue) and training (orange) loss were both still decreasing at the end of 11 epochs through the dataset. The model could likely be trained for more epochs without overfitting.
![Validation and training losses both decreased smoothly through training](../assets/old_images/sc_fm/geneformer-10m-240530-val-train-loss.png)

!!! note "Training curves from BioNeMo1"

Note that these curves were generated on BioNeMo1. We see the same general training curves in our initial testing of
BioNeMo2, however. In the following figure the blue line is the previous training run of the 10M model and the
red curve is an equivalent training run on BioNeMo2. As we release new checkpoints they will be trained on BioNeMo2.

![Training curve equivalence](../assets/images/geneformer/loss_curve_new_v_old_geneformer_64_node_10M.png)


### geneformer-106M-240530

This checkpoint was trained for approximately 11 epochs through the CELLxGENE split. Training was performed on 16 servers with 8 A100 GPUs each for a total of 115430 steps of per-gpu micro batch size 16 and global batch size of 2048. Training took a total of 3 days, 18 hours and 55 minutes of wallclock time. As can be seen in the following image, training and validation curves both decreased fairly smoothly throughout the course of training. In fact validation (blue) and training (orange) loss were both still decreasing at the end of 11 epochs through the dataset. The model could likely be trained for more epochs without overfitting.
Expand All @@ -166,19 +177,39 @@ This checkpoint was trained for approximately 11 epochs through the CELLxGENE sp
Additionally, validation loss decreased both faster and continued to decrease at the same improved rate throughout training in the 106M parameter model (red) as compared to the 10M parameter model (blue). It would be interesting to test even larger models to see if we continue to observe improved performance in larger models.
![106M parameter model outperformed 10M parameter model](../assets/old_images/sc_fm/geneformer-240530-val-comparison.png)

!! note "Training curves from BioNeMo1"

As stated in the previous section, the figures are from our BioNeMo1 code base where these checkpoints were originally
trained. As we release new checkpoints they will be trained on BioNeMo2.

## Benchmarking

### Accuracy Benchmarks

#### Masked language model (MLM) loss

The following describes the bert MLM token loss. Like in the original BERT paper, and the geneformer paper, 15% of all tokens are included in the loss. Of the included tokens, 80% are `"[MASK]"` token, 2% are a random gene token, and 18% are the correct output token. Note that this was an unintentional deviation from the original publication, but so far it seems to be working well. In the future we will test the intended 80%/10%/10% mixture proposed in the paper. The token loss in the following table is the mean cross entropy loss of the 15% of tokens included in the loss mask averaged across cells. As a baseline geneformer was downloaded from [the ctheodoris/Geneformer page on hugging face on 2024/05/13](https://huggingface.co/ctheodoris/Geneformer) and applied to the same masking/unmasking problem on this dataset. The held-out `test` datset from our training splits described previously was used, and it should be noted that some of these cells may have been involved in training the baseline geneformer. Since the baseline performed slightly worse than our new checkpoints, and our goal was an equivalent or better model checkpoint, this possibility was not explored further.
The following describes the bert MLM token loss. Like in the original BERT paper, and the geneformer paper, 15% of all tokens are included in the loss. Of the included tokens, 80% are `"[MASK]"` token, 2% are a random gene token, and 18% are the correct output token. Note that this was an unintentional deviation from the original publication, but so far it seems to be working well. In the future we will test the intended 80%/10%/10% mixture proposed in the paper. The token loss in the following table is the mean cross entropy loss of the 15% of tokens included in the loss mask averaged across cells. As a baseline geneformer was downloaded from [the ctheodoris/Geneformer page on hugging face on 2024/11/04](https://huggingface.co/ctheodoris/Geneformer) and applied to the same masking/unmasking problem on this dataset, but with model-specific cell representations due to the updated tokenizer and medians dictionary used to train, and the update from training with 2048 tokens to 4096 tokens per cell. The held-out `test` dataset from our training splits described previously was used, and it should be noted that some of these cells may have been involved in training the baseline geneformer.

| Model Description | Token Loss (lower is better) |
| ---------------------- | ---------------------------- |
| Baseline geneformer | 3.35 |
| geneformer-10M-240530 | 2.79 |
| geneformer-106M-240530 | 2.50 |
| Baseline geneformer | 2.26* |
| geneformer-10M-240530 | 2.64 |
| geneformer-106M-240530 | 2.34 |

!!! bug "Baseline Geneformer was recently updated on huggingface making loss comparisons challenging."

[Geneformer](https://huggingface.co/ctheodoris/Geneformer) was recently updated on hugging face to a new version.
In a future release we will make checkpoint conversion scripts available so that the public model can be ran
directly. Some key differences follow:

* Trained on a much larger 95M cell dataset. Our current checkpoints were trained with 23M cells.
* The new 12 layer baseline geneformer variant sits between our 10M and 106M parameter models in parameter count with
approximately 38M parameters.
* The model is trained with a 4096 context rather than a 2048 context. When forcing the model to make predictions
with a 2048 context, the MLM loss drops to *2.76*, which is probably unfair because this may be "out of domain" for
training. It is really hard to compare these loss numbers directly is the only take-home here.
* The model was trained on a set of 20275 genes, rather than the older set of 25426 genes. This would also be
expected to give a boost in loss since there are fewer tokens to chose from.
jstjohn marked this conversation as resolved.
Show resolved Hide resolved

#### Downstream task accuracy

Expand All @@ -191,11 +222,19 @@ Elmentaite et al. (2020), Developmental Cell. This dataset contains approximatel

For more details see the example notebook titled Geneformer-celltype-classification-example.ipynb

![F1-score for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/old_images/sc_fm/F1-score-models.png)
![Average accuracy across cell types for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/old_images/sc_fm/average-accuracy-models.png)
![F1-score for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/images/geneformer/F1-score-models.png)
![Average accuracy across cell types for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/images/geneformer/average-accuracy-models.png)

### Performance Benchmarks

The 106M parameter variant of Geneformer achieves over 50 TFLOPS per GPU during training. This is consistent whether trained with 1 or 8 A100s.

![TFLOPs per GPU (A100) shows improved utilization by 106M variant](../assets/old_images/sc_fm/model_tflops_per_gpu_chart_tight_layout.png)

!!! bug "TFLOPS from BioNeMo1"

We have observed an approximately 10% degradation in training performance comparing the 10M geneformer model on
the new BioNeMo v2 repository vs the old BioNeMo v1 codebase. We are working to address this change and make them
comparable or better in terms of cluster performance. The numbers above are from the original BioNeMo1 model card.

![64 GPU training time 10% slower training time in BioNeMo2 vs BioNeMo1](../assets/images/geneformer/tflops_bionemo1_vs_bionemo2.png)
1 change: 1 addition & 0 deletions sub-packages/bionemo-geneformer/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
sc_memmap = "bionemo.geneformer.scripts.sc_memmap:main_cli"
infer_geneformer = "bionemo.geneformer.scripts.infer_geneformer:geneformer_infer_entrypoint"
train_geneformer = "bionemo.geneformer.scripts.train_geneformer:entrypoint"
geneformer_mlm_loss_eval = "bionemo.geneformer.scripts.geneformer_mlm_loss_eval:entrypoint"

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__( # noqa: D107
mask_token_prob: float = 0.8,
random_token_prob: float = 0.1,
prepend_cls_token: bool = True,
eos_token: int | None = None,
assert_increasing_columns: bool = True,
seed: int = np.random.SeedSequence().entropy, # type: ignore
):
Expand All @@ -98,6 +99,7 @@ def __init__( # noqa: D107
self.mask_prob = mask_prob
self.prepend_cls_token = prepend_cls_token
self._seed = seed
self.eos_token = eos_token
# check if column indices are increasing for looking up genes. This is a way of spotting if the sc_memmap.py
# script produced properly strctured sparse files.
self.assert_increasing_columns = assert_increasing_columns
Expand Down Expand Up @@ -210,6 +212,7 @@ def __getitem__(self, index: EpochIndex) -> types.BertSample:
mask_prob=self.mask_prob,
random_token_prob=self.random_token_prob,
prepend_cls_token=self.prepend_cls_token,
eos_token=self.eos_token,
)


Expand All @@ -227,6 +230,7 @@ def process_item( # noqa: D417
target_sum: int = 10000,
normalize: bool = True,
prepend_cls_token: bool = True,
eos_token: None | int = None,
) -> types.BertSample:
"""Process a single item in the dataset.

Expand Down Expand Up @@ -262,7 +266,10 @@ def process_item( # noqa: D417
if gene_median is None:
raise ValueError("gene_median must be provided for this tokenizer")

max_len = max_len - 1 # - minus 1 for [CLS] token
if prepend_cls_token:
max_len = max_len - 1 # - minus 1 for [CLS] token
if eos_token is not None:
max_len = max_len - 1 # - minus 1 for [EOS] token

gene_names = [feature_ids[idx] for idx in gene_idxs]
genes, tokens, medians = [], [], []
Expand Down Expand Up @@ -295,20 +302,20 @@ def process_item( # noqa: D417
random_seed=int(random_utils.get_seed_from_rng(rng)),
mask_config=masking.BertMaskConfig(
tokenizer=tokenizer,
random_tokens=range(5, len(tokenizer.vocab)),
random_tokens=range(len(tokenizer.special_tokens), len(tokenizer.vocab)),
mask_prob=mask_prob,
mask_token_prob=mask_token_prob,
random_token_prob=random_token_prob,
),
)

if prepend_cls_token:
cls_token = tokenizer.token_to_id(tokenizer.cls_token) if prepend_cls_token else None
if cls_token is not None or eos_token is not None:
masked_tokens, labels, loss_mask = masking.add_cls_and_eos_tokens(
sequence=masked_tokens,
labels=labels,
loss_mask=loss_mask,
cls_token=tokenizer.token_to_id(tokenizer.cls_token),
eos_token=None,
cls_token=cls_token,
eos_token=eos_token,
)

# NeMo megatron assumes this return structure.
Expand Down
Loading
Loading