diff --git a/docs/docs/assets/images/geneformer/F1-score-models.png b/docs/docs/assets/images/geneformer/F1-score-models.png new file mode 100644 index 0000000000..431f1fcfb4 Binary files /dev/null and b/docs/docs/assets/images/geneformer/F1-score-models.png differ diff --git a/docs/docs/assets/images/geneformer/average-accuracy-models.png b/docs/docs/assets/images/geneformer/average-accuracy-models.png new file mode 100644 index 0000000000..3abd706602 Binary files /dev/null and b/docs/docs/assets/images/geneformer/average-accuracy-models.png differ diff --git a/docs/docs/assets/images/geneformer/loss_curve_new_v_old_geneformer_64_node_10M.png b/docs/docs/assets/images/geneformer/loss_curve_new_v_old_geneformer_64_node_10M.png new file mode 100644 index 0000000000..7dbb980086 Binary files /dev/null and b/docs/docs/assets/images/geneformer/loss_curve_new_v_old_geneformer_64_node_10M.png differ diff --git a/docs/docs/assets/images/geneformer/tflops_bionemo1_vs_bionemo2.png b/docs/docs/assets/images/geneformer/tflops_bionemo1_vs_bionemo2.png new file mode 100644 index 0000000000..834c77c7d7 Binary files /dev/null and b/docs/docs/assets/images/geneformer/tflops_bionemo1_vs_bionemo2.png differ diff --git a/docs/docs/models/geneformer.md b/docs/docs/models/geneformer.md index dd8a979483..45d8a4c955 100644 --- a/docs/docs/models/geneformer.md +++ b/docs/docs/models/geneformer.md @@ -1,9 +1,11 @@ # Geneformer -NOTE: this document references performance numbers and runtime engines that are from the bionemo v1 variant of the model. -These numbers will be updated in a coming release to reflect the new bionemo v2 codebase. The model architecture and -training information will be the same, as checkpoints are converted from bionemo v1 format to v2 format, however -performance benchmarks need to be updated to reflect the latest code. Accuracy should be the same within small epsilon -since we have tests in place showing model equivalency between the two versions. +!!! note "Current checkpoints trained in BioNeMo1" + + This document references performance numbers and runtime engines that are from the bionemo v1 variant of the model. + These numbers will be updated in a coming release to reflect the new bionemo v2 codebase. The model architecture and + training information will be the same, as checkpoints are converted from bionemo v1 format to v2 format. Benchmarks below + are annotated with which version of bionemo generated them. Accuracy should be the same within a small epsilon + since we have tests in place showing model equivalency between the two versions. ## Model Overview @@ -158,6 +160,15 @@ NVIDIA believes Trustworthy AI is a shared responsibility and we have establishe This checkpoint was trained for approximately 11 epochs through the CELLxGENE split. Training was performed on 8 servers with 8 A100 GPUs each for a total of 115430 steps of per-gpu micro batch size 32 and global batch size of 2048. Training took a total of 1 day, 20 hours and 19 minutes of wallclock time. As can be seen in the following image, training and validation curves both decreased fairly smoothly throughout the course of training. In fact validation (blue) and training (orange) loss were both still decreasing at the end of 11 epochs through the dataset. The model could likely be trained for more epochs without overfitting. ![Validation and training losses both decreased smoothly through training](../assets/old_images/sc_fm/geneformer-10m-240530-val-train-loss.png) +!!! note "Training curves from BioNeMo1" + + Note that these curves were generated on BioNeMo1. We see the same general training curves in our initial testing of + BioNeMo2, however. In the following figure the blue line is the previous training run of the 10M model and the + red curve is an equivalent training run on BioNeMo2. As we release new checkpoints they will be trained on BioNeMo2. + + ![Training curve equivalence](../assets/images/geneformer/loss_curve_new_v_old_geneformer_64_node_10M.png) + + ### geneformer-106M-240530 This checkpoint was trained for approximately 11 epochs through the CELLxGENE split. Training was performed on 16 servers with 8 A100 GPUs each for a total of 115430 steps of per-gpu micro batch size 16 and global batch size of 2048. Training took a total of 3 days, 18 hours and 55 minutes of wallclock time. As can be seen in the following image, training and validation curves both decreased fairly smoothly throughout the course of training. In fact validation (blue) and training (orange) loss were both still decreasing at the end of 11 epochs through the dataset. The model could likely be trained for more epochs without overfitting. @@ -166,19 +177,39 @@ This checkpoint was trained for approximately 11 epochs through the CELLxGENE sp Additionally, validation loss decreased both faster and continued to decrease at the same improved rate throughout training in the 106M parameter model (red) as compared to the 10M parameter model (blue). It would be interesting to test even larger models to see if we continue to observe improved performance in larger models. ![106M parameter model outperformed 10M parameter model](../assets/old_images/sc_fm/geneformer-240530-val-comparison.png) +!! note "Training curves from BioNeMo1" + + As stated in the previous section, the figures are from our BioNeMo1 code base where these checkpoints were originally + trained. As we release new checkpoints they will be trained on BioNeMo2. + ## Benchmarking ### Accuracy Benchmarks #### Masked language model (MLM) loss -The following describes the bert MLM token loss. Like in the original BERT paper, and the geneformer paper, 15% of all tokens are included in the loss. Of the included tokens, 80% are `"[MASK]"` token, 2% are a random gene token, and 18% are the correct output token. Note that this was an unintentional deviation from the original publication, but so far it seems to be working well. In the future we will test the intended 80%/10%/10% mixture proposed in the paper. The token loss in the following table is the mean cross entropy loss of the 15% of tokens included in the loss mask averaged across cells. As a baseline geneformer was downloaded from [the ctheodoris/Geneformer page on hugging face on 2024/05/13](https://huggingface.co/ctheodoris/Geneformer) and applied to the same masking/unmasking problem on this dataset. The held-out `test` datset from our training splits described previously was used, and it should be noted that some of these cells may have been involved in training the baseline geneformer. Since the baseline performed slightly worse than our new checkpoints, and our goal was an equivalent or better model checkpoint, this possibility was not explored further. +The following describes the bert MLM token loss. Like in the original BERT paper, and the geneformer paper, 15% of all tokens are included in the loss. Of the included tokens, 80% are `"[MASK]"` token, 2% are a random gene token, and 18% are the correct output token. Note that this was an unintentional deviation from the original publication, but so far it seems to be working well. In the future we will test the intended 80%/10%/10% mixture proposed in the paper. The token loss in the following table is the mean cross entropy loss of the 15% of tokens included in the loss mask averaged across cells. As a baseline geneformer was downloaded from [the ctheodoris/Geneformer page on hugging face on 2024/11/04](https://huggingface.co/ctheodoris/Geneformer) and applied to the same masking/unmasking problem on this dataset, but with model-specific cell representations due to the updated tokenizer and medians dictionary used to train, and the update from training with 2048 tokens to 4096 tokens per cell. The held-out `test` dataset from our training splits described previously was used, and it should be noted that some of these cells may have been involved in training the baseline geneformer. | Model Description | Token Loss (lower is better) | | ---------------------- | ---------------------------- | -| Baseline geneformer | 3.35 | -| geneformer-10M-240530 | 2.79 | -| geneformer-106M-240530 | 2.50 | +| Baseline geneformer | 2.26* | +| geneformer-10M-240530 | 2.64 | +| geneformer-106M-240530 | 2.34 | + +!!! bug "Baseline Geneformer was recently updated on huggingface making loss comparisons challenging." + + [Geneformer](https://huggingface.co/ctheodoris/Geneformer) was recently updated on hugging face to a new version. + In a future release we will make checkpoint conversion scripts available so that the public model can be ran + directly. Some key differences follow: + + * Trained on a much larger 95M cell dataset. Our current checkpoints were trained with 23M cells. + * The new 12 layer baseline geneformer variant sits between our 10M and 106M parameter models in parameter count with + approximately 38M parameters. + * The model is trained with a 4096 context rather than a 2048 context. When forcing the model to make predictions + with a 2048 context, the MLM loss drops to *2.76*, which is probably unfair because this may be "out of domain" for + training. It is really hard to compare these loss numbers directly is the only take-home here. + * The model was trained on a set of 20,275 genes, rather than the older set of 25,426 genes. This would also be + expected to give a boost in loss since there are fewer tokens to choose from. #### Downstream task accuracy @@ -191,11 +222,19 @@ Elmentaite et al. (2020), Developmental Cell. This dataset contains approximatel For more details see the example notebook titled Geneformer-celltype-classification-example.ipynb -![F1-score for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/old_images/sc_fm/F1-score-models.png) -![Average accuracy across cell types for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/old_images/sc_fm/average-accuracy-models.png) +![F1-score for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/images/geneformer/F1-score-models.png) +![Average accuracy across cell types for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/images/geneformer/average-accuracy-models.png) ### Performance Benchmarks The 106M parameter variant of Geneformer achieves over 50 TFLOPS per GPU during training. This is consistent whether trained with 1 or 8 A100s. ![TFLOPs per GPU (A100) shows improved utilization by 106M variant](../assets/old_images/sc_fm/model_tflops_per_gpu_chart_tight_layout.png) + +!!! bug "TFLOPS from BioNeMo1" + + We have observed an approximately 10% degradation in training performance comparing the 10M geneformer model on + the new BioNeMo v2 repository vs the old BioNeMo v1 codebase. We are working to address this change and make them + comparable or better in terms of cluster performance. The numbers above are from the original BioNeMo1 model card. + + ![64 GPU training time 10% slower training time in BioNeMo2 vs BioNeMo1](../assets/images/geneformer/tflops_bionemo1_vs_bionemo2.png) diff --git a/sub-packages/bionemo-geneformer/pyproject.toml b/sub-packages/bionemo-geneformer/pyproject.toml index 5bde43b92b..fb77cb7994 100644 --- a/sub-packages/bionemo-geneformer/pyproject.toml +++ b/sub-packages/bionemo-geneformer/pyproject.toml @@ -25,6 +25,7 @@ bionemo-geneformer-recipe= "bionemo.geneformer.run.recipes:main" sc_memmap = "bionemo.geneformer.scripts.sc_memmap:main_cli" infer_geneformer = "bionemo.geneformer.scripts.infer_geneformer:geneformer_infer_entrypoint" train_geneformer = "bionemo.geneformer.scripts.train_geneformer:entrypoint" +geneformer_mlm_loss_eval = "bionemo.geneformer.scripts.geneformer_mlm_loss_eval:entrypoint" [tool.setuptools.packages.find] where = ["src"] diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py index 2a0cff74c6..50d0315971 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py @@ -87,6 +87,7 @@ def __init__( # noqa: D107 mask_token_prob: float = 0.8, random_token_prob: float = 0.1, prepend_cls_token: bool = True, + eos_token: int | None = None, assert_increasing_columns: bool = True, seed: int = np.random.SeedSequence().entropy, # type: ignore ): @@ -98,6 +99,7 @@ def __init__( # noqa: D107 self.mask_prob = mask_prob self.prepend_cls_token = prepend_cls_token self._seed = seed + self.eos_token = eos_token # check if column indices are increasing for looking up genes. This is a way of spotting if the sc_memmap.py # script produced properly strctured sparse files. self.assert_increasing_columns = assert_increasing_columns @@ -210,6 +212,7 @@ def __getitem__(self, index: EpochIndex) -> types.BertSample: mask_prob=self.mask_prob, random_token_prob=self.random_token_prob, prepend_cls_token=self.prepend_cls_token, + eos_token=self.eos_token, ) @@ -227,6 +230,7 @@ def process_item( # noqa: D417 target_sum: int = 10000, normalize: bool = True, prepend_cls_token: bool = True, + eos_token: None | int = None, ) -> types.BertSample: """Process a single item in the dataset. @@ -262,7 +266,10 @@ def process_item( # noqa: D417 if gene_median is None: raise ValueError("gene_median must be provided for this tokenizer") - max_len = max_len - 1 # - minus 1 for [CLS] token + if prepend_cls_token: + max_len = max_len - 1 # - minus 1 for [CLS] token + if eos_token is not None: + max_len = max_len - 1 # - minus 1 for [EOS] token gene_names = [feature_ids[idx] for idx in gene_idxs] genes, tokens, medians = [], [], [] @@ -295,20 +302,20 @@ def process_item( # noqa: D417 random_seed=int(random_utils.get_seed_from_rng(rng)), mask_config=masking.BertMaskConfig( tokenizer=tokenizer, - random_tokens=range(5, len(tokenizer.vocab)), + random_tokens=range(len(tokenizer.special_tokens), len(tokenizer.vocab)), mask_prob=mask_prob, mask_token_prob=mask_token_prob, random_token_prob=random_token_prob, ), ) - - if prepend_cls_token: + cls_token = tokenizer.token_to_id(tokenizer.cls_token) if prepend_cls_token else None + if cls_token is not None or eos_token is not None: masked_tokens, labels, loss_mask = masking.add_cls_and_eos_tokens( sequence=masked_tokens, labels=labels, loss_mask=loss_mask, - cls_token=tokenizer.token_to_id(tokenizer.cls_token), - eos_token=None, + cls_token=cls_token, + eos_token=eos_token, ) # NeMo megatron assumes this return structure. diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py new file mode 100644 index 0000000000..a79ae52269 --- /dev/null +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import argparse +import functools +import pickle +import sys +from copy import deepcopy +from pathlib import Path +from typing import Dict, Type + +import torch +import torch.distributed +import torch.utils +import torch.utils.data +from megatron.core.transformer.module import Float16Module +from nemo.utils import logging +from torch.utils.data import DataLoader +from tqdm import trange +from transformers import AutoModelForMaskedLM + +from bionemo.core.data.multi_epoch_dataset import EpochIndex +from bionemo.core.utils.dtypes import get_autocast_dtype +from bionemo.geneformer.api import GeneformerConfig +from bionemo.geneformer.data.singlecell.dataset import SingleCellDataset +from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess +from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer +from bionemo.llm.data import collate +from bionemo.llm.model.biobert.model import BioBertConfig +from bionemo.testing import megatron_parallel_state_utils +from bionemo.testing.data.load import load + + +class GeneformerHFAdapter(torch.nn.Module): + """An adapter class for running the HF model against our subset of tokens.""" + + def __init__(self, hf_path: str, my_token_dict: Dict[str, int], nv_tokenizer: GeneTokenizer): + """An adapter that filters and re-orders tokens to match our tokenizer but with the original indices.""" + super().__init__() + self.model = AutoModelForMaskedLM.from_pretrained(hf_path) + self.my_token_dict = deepcopy(my_token_dict) + self.nv_tokenizer = deepcopy(nv_tokenizer) + self.n_tokens_nv = len(self.nv_tokenizer.vocab) + self.n_tokens_hf = len(my_token_dict) + + # nvidia tokenizer has [cls] and [pad] first along with some others that do not overlap. This mapper + hf_ordered_nv_tokenizer = { + self.nv_tokenizer.pad_token: my_token_dict[""], + self.nv_tokenizer.mask_token: my_token_dict[""], + self.nv_tokenizer.cls_token: my_token_dict[""], + self.nv_tokenizer.sep_token: my_token_dict[""], # name doesn't really matter here + } + tokens = list(my_token_dict.items()) + for k, t in tokens[:4]: + assert k.startswith("<") + + missing_nv_tokens = [] + extra_tokens_not_covered = [] + for ens, idx in list(my_token_dict.items())[4:]: + assert ens.startswith("ENSG") + if ens in nv_tokenizer.vocab.keys(): + hf_ordered_nv_tokenizer[ens] = idx + else: + if idx < self.n_tokens_hf: + missing_nv_tokens.append(idx) + else: + extra_tokens_not_covered.append(idx) + self.hf_ordered_nv_tokenizer = hf_ordered_nv_tokenizer + self.extra_tokens_not_covered = extra_tokens_not_covered + self.register_buffer("missing_nv_tokens", torch.tensor(missing_nv_tokens, dtype=int)) + + @property + def device(self) -> torch.device: + """Return the device of this model.""" + # This is populated through the self.register_buffer call in init. + return self.missing_nv_tokens.device + + def get_tokenizer(self) -> GeneTokenizer: + """Return the filtered tokenizer with keys that match the order of the nv model.""" + nv_tok = deepcopy(self.nv_tokenizer) + # HF tokenizer only has pad and mask, no other special tokens. + nv_tok.special_tokens = (nv_tok.mask_token, nv_tok.pad_token) # type: ignore + nv_tok.vocab = self.hf_ordered_nv_tokenizer + nv_tok.decode_vocab = {v: k for k, v in nv_tok.vocab.items()} + return nv_tok + + def forward(self, *args, **kwargs): + """Run forward and return the logits.""" + logits = self.model(*args, **kwargs).logits + # logits[:, :, self.missing_nv_tokens] = -torch.inf + # breakpoint() + return logits + + +def main( + model_path: Path | None, + hf_model_path: str, + dataset_path: Path, + hf_token_dictionary_path: Path, + hf_medians_dictionary_path: Path, + mask_prob: float = 0.15, + batch_size: int = 16, + precision: str = "bf16-mixed", + config_class: Type[BioBertConfig] = GeneformerConfig, + seq_len_nv: int = 2048, + seq_len_hf: int = 2048, + seed: int = 513, +): + """Inference function (requires DDP and only training data that fits in memory).""" + # This is just used to get the tokenizer :( + train_data_path: Path = ( + load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" / "train" + ) + n_devices: int = torch.cuda.device_count() + assert n_devices > 0 + preprocessor = GeneformerPreprocess( + download_directory=train_data_path, + medians_file_path=train_data_path / "medians.json", + tokenizer_vocab_path=train_data_path / "geneformer.vocab", + ) + match preprocessor.preprocess(): + case {"tokenizer": tokenizer, "median_dict": median_dict}: + logging.info("*************** Preprocessing Finished ************") + case _: + logging.error("Failed to download the tokenizer for the NV geneformer model.") + assert False + with open(hf_token_dictionary_path, "rb") as geneformer_hf_token_file: + geneformer_hf_token_dict = pickle.load(geneformer_hf_token_file) + with open(hf_medians_dictionary_path, "rb") as geneformer_hf_median_file: + geneformer_hf_medians_dict = pickle.load(geneformer_hf_median_file) + with megatron_parallel_state_utils.distributed_model_parallel_state(): + geneformer_nv_inferer_cfg = config_class( + seq_length=seq_len_nv, + params_dtype=get_autocast_dtype(precision), + pipeline_dtype=get_autocast_dtype(precision), + autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot + # handle checkpoint resumption here rather than auto-resume so this supports fine-tuning capabilities + initial_ckpt_path=str(model_path) if model_path is not None else None, + initial_ckpt_skip_keys_with_these_prefixes=[], # load everything from the checkpoint. + ) + geneformer_nv_inferer = Float16Module( + geneformer_nv_inferer_cfg, geneformer_nv_inferer_cfg.configure_model(tokenizer).cuda(0 % n_devices) + ).eval() + + # TODO only predict with tokens that exist in both models. + + hf_model = GeneformerHFAdapter(hf_model_path, geneformer_hf_token_dict, tokenizer).eval().cuda(1 % n_devices) + hf_total_params = sum(p.numel() for p in hf_model.parameters() if p.requires_grad) + nv_total_params = sum(p.numel() for p in geneformer_nv_inferer.parameters() if p.requires_grad) + print(f"HF Model Params: {hf_total_params}, NV Model Params: {nv_total_params}", file=sys.stdout) + tokenizer_filt = deepcopy(tokenizer) + ori_nv_vocab_size: int = len(tokenizer.vocab) + hf_tokenizer = hf_model.get_tokenizer() + tokenizer_filt.vocab = { + k: v for k, v in tokenizer.vocab.items() if k in hf_tokenizer.vocab or k in tokenizer.special_tokens + } + + ds_nv = SingleCellDataset( + dataset_path, + tokenizer=tokenizer_filt, # TODO replace with the filtered one. + median_dict=median_dict, + max_len=seq_len_nv, + mask_prob=mask_prob, + seed=seed, + ) + ds_hf_nvfilt = SingleCellDataset( + dataset_path, + hf_tokenizer, + geneformer_hf_medians_dict, + max_len=seq_len_hf, + mask_prob=mask_prob, + eos_token=hf_tokenizer.token_to_id(hf_tokenizer.sep_token), # Stored in the special token + seed=seed, + ) + print(f"Loaded dataset of length (NV): {len(ds_nv)}, (HF): {len(ds_hf_nvfilt)}") + + dl_hf = DataLoader( + ds_hf_nvfilt, + batch_size=batch_size, + sampler=[EpochIndex(epoch=0, idx=i) for i in range(len(ds_hf_nvfilt))], + shuffle=False, + num_workers=0, + drop_last=False, + collate_fn=functools.partial( + collate.bert_padding_collate_fn, + padding_value=ds_hf_nvfilt.tokenizer.pad_id, + min_length=seq_len_hf, + max_length=seq_len_hf, + ), + ) + dl_nv = DataLoader( + ds_nv, + batch_size=batch_size, + sampler=[EpochIndex(epoch=0, idx=i) for i in range(len(ds_nv))], + shuffle=False, + num_workers=0, + drop_last=False, + collate_fn=functools.partial( + collate.bert_padding_collate_fn, + padding_value=ds_nv.tokenizer.pad_id, + min_length=seq_len_nv, + max_length=seq_len_nv, + ), + ) + + with torch.no_grad(): + dl_hf_iter = iter(dl_hf) + dl_nv_iter = iter(dl_nv) + loss_hf = 0.0 + n_hf = 0 + loss_nv = 0.0 + n_nv = 0 + nv_device = geneformer_nv_inferer.module.embedding.position_embeddings.weight.device + hf_device = hf_model.device + for _ in trange(len(dl_hf)): + batch_hf = {k: v.to(hf_device) for k, v in next(dl_hf_iter).items()} + batch_nv = {k: v.to(nv_device) for k, v in next(dl_nv_iter).items()} + logits_hf = hf_model(batch_hf["text"].long(), batch_hf["attention_mask"]) + loss_hf += ( + torch.nn.functional.cross_entropy( + logits_hf[batch_hf["loss_mask"]], + batch_hf["labels"][batch_hf["loss_mask"]], + reduction="sum", + ) + .cpu() + .sum() + .item() + ) + n_hf += batch_hf["loss_mask"].sum().cpu().item() + + logits_nv = ( + geneformer_nv_inferer(batch_nv["text"], batch_nv["attention_mask"])["token_logits"] + .transpose(0, 1) + .contiguous() + ) + loss_nv += ( + torch.nn.functional.cross_entropy( + logits_nv[batch_nv["loss_mask"]][..., :ori_nv_vocab_size], + batch_nv["labels"][batch_nv["loss_mask"]], + reduction="sum", + ) + .cpu() + .sum() + .item() + ) + n_nv += batch_nv["loss_mask"].sum().cpu().item() + print(f"NV mean loss: {loss_nv / n_nv}") + print(f"HF mean loss: {loss_hf / n_hf}") + + +def entrypoint(): + """Main entry point for running the evaluation.""" + parser = argparse.ArgumentParser(description="MLM Performance vs HF Script") + parser.add_argument( + "--model-path", + type=Path, + help="Path to nvidia geneformer model checkpoint (unless you want random weights)", + required=False, + default=None, + ) + parser.add_argument( + "--hf-token-dictionary-path", + type=Path, + help="Path to token dictionary file. " + "Eg `wget https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/token_dictionary_gc95M.pkl`" + "then provide the path to the downloaded file.", + required=True, + ) + parser.add_argument( + "--hf-medians-dictionary-path", + type=Path, + help="Path to token dictionary file. " + "Eg `wget https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_median_dictionary_gc95M.pkl` " + "then provide the path to the downloaded file.", + required=True, + ) + parser.add_argument("--hf-model-path", type=str, default="ctheodoris/Geneformer", help="HF model path") + parser.add_argument("--dataset-path", type=Path, help="Path to dataset directory", required=True) + + args = parser.parse_args() + main( + args.model_path, + args.hf_model_path, + args.dataset_path, + args.hf_token_dictionary_path, + args.hf_medians_dictionary_path, + ) + + +if __name__ == "__main__": + entrypoint()