diff --git a/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb b/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb
deleted file mode 100644
index ef1088b9f1..0000000000
--- a/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb
+++ /dev/null
@@ -1,898 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "[![ Click here to deploy.](https://uohmivykqgnnbiouffke.supabase.co/storage/v1/object/public/landingpage/brevdeploynavy.svg)](https://console.brev.dev/launchable/deploy?launchableID=env-2rPWpPzzJIxq7SMRJIQehCxBymV)\n",
- "\n",
- "
NOTE It takes about 10 minutes to deploy this notebook as a Launchable. As of this writing, we are working on a free tier so a credit card may be required. You can reach out to your NVIDIA rep for credits.
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# ESM-2 Fine-tuning"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "vscode": {
- "languageId": "plaintext"
- }
- },
- "source": [
- "The [ESM-2](https://www.science.org/doi/abs/10.1126/science.ade2574) model is a transformer-based protein language model that has achieved state-of-the-art results in various protein-related tasks. When fine-tuning ESM2, the task-head plays a crucial role. A task head refers to the additional layer or set of layers added on top of a pre-trained model, like the ESM-2 transformer-based protein language model, to adapt it for a specific downstream task. As a part of transfer learning, a pre-trained model is often utilized to learn generic features from a large-scale dataset. However, these features might not be directly applicable to the specific task at hand. By incorporating a task head, which consists of learnable parameters, the model can adapt and specialize to the target task. The task head serves as a flexible and adaptable component that learns task-specific representations by leveraging the pre-trained features as a foundation. Through fine-tuning, the task head enables the model to learn and extract task-specific patterns, improving performance and addressing the nuances of the downstream task. It acts as a critical bridge between the pre-trained model and the specific task, enabling efficient and effective transfer of knowledge."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- " NOTE This tutorial will guide you through the steps for creating a basic regression fine-tuning task for simplicity. The utilities described in this tutorial are available in:\n",
- "\n",
- "
bionemo.esm2.model.finetune.finetune_regressor
\n",
- "\n",
- "The techniques demonstrated here can be adapted for classification and per-token classification tasks. Utilities needed for secondary structure prediction (token-level classification) are available in \n",
- "\n",
- "
bionemo.esm2.model.finetune.finetune_token_classifier
\n",
- "\n",
- "In the second part of the tutorial, we will cover loading a pre-trained model, fine-tuning it for both regression and per-token classification tasks, and using the fine-tuned models for inference. For instructions on pre-training the ESM-2 model, please refer to the [ESM-2 Pretraining](./pretrain.md) tutorial.
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Building a Regression Fine-tune Module"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Wwe need to define some key classes to successfully build a fine-tuning module in BioNeMo framework: \n",
- "\n",
- "1. **Loss Reduction Class** - To compute the supervised fine-tuning loss.\n",
- "2. **Fine-Tuned Model Head** - Downstream task head model.\n",
- "3. **Fine-Tuned Model** - Model that combines ESM-2 with the task head model.\n",
- "4. **Fine-Tuning Config** - Configures the fine-tuning model and loss to use in the training and inference framework.\n",
- "5. **Dataset** - Training and inference datasets for ESM2 fine-tuning."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1 - Loss Reduction Class"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "A class for calculating the supervised loss of the fine-tune model from targets. We inherit from Megatron Bert Masked Language Model Loss (`BERTMLMLossWithReduction`) and override the `forward()` pass to compute MSE loss of the regression head within a micro-batch. The `reduce()` method is used for computing the average over the micro-batches and is only used for logging.\n",
- "\n",
- "```python\n",
- "class RegressorLossReduction(BERTMLMLossWithReduction):\n",
- " def forward(\n",
- " self, batch: Dict[str, torch.Tensor], forward_out: Dict[str, torch.Tensor]\n",
- " ) -> Tuple[torch.Tensor, Union[PerTokenLossDict, SameSizeLossDict]]:\n",
- "\n",
- " regression_output = forward_out[\"regression_output\"]\n",
- " targets = batch[\"labels\"].to(dtype=regression_output.dtype) # [b, 1]\n",
- "\n",
- " loss = torch.nn.functional.mse_loss(regression_output, targets)\n",
- " return loss, {\"avg\": loss}\n",
- "\n",
- " def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> torch.Tensor:\n",
- " losses = torch.stack([loss[\"avg\"] for loss in losses_reduced_per_micro_batch])\n",
- " return losses.mean()\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 2 - Fine-Tuned Model Head"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "An MLP class for sequence-level regression. This class inherits `MegatronModule` and uses the fine-tune config (`TransformerConfig`) to configure the regression head for the fine-tuned ESM-2 model.\n",
- "\n",
- "```python\n",
- "class MegatronMLPHead(MegatronModule):\n",
- " def __init__(self, config: TransformerConfig):\n",
- " super().__init__(config)\n",
- " layer_sizes = [config.hidden_size, 256, 1]\n",
- " self.linear_layers = torch.nn.ModuleList(\n",
- " [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]\n",
- " )\n",
- " self.act = torch.nn.ReLU()\n",
- " self.dropout = torch.nn.Dropout(p=config.ft_dropout)\n",
- "\n",
- " def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:\n",
- " ...\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 3 - Fine-Tuned Model"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "A fine-tuned ESM-2 model class for token classification tasks. This class inherits from the `ESM2Model` class and adds the custom regression head `MegatronMLPHead` the we created in the previous step. Optionally one can freeze all or parts of the encoder by parsing through the model parameters in the model constructor.\n",
- "\n",
- "```python\n",
- "class ESM2FineTuneSeqModel(ESM2Model):\n",
- " def __init__(self, config, *args, post_process: bool = True, include_embeddings: bool = False, **kwargs):\n",
- " super().__init__(config, *args, post_process=post_process, include_embeddings=True, **kwargs)\n",
- "\n",
- " # freeze encoder parameters\n",
- " if config.encoder_frozen:\n",
- " for _, param in self.named_parameters():\n",
- " param.requires_grad = False\n",
- "\n",
- " if post_process:\n",
- " self.regression_head = MegatronMLPHead(config)\n",
- "\n",
- " def forward(self, *args, **kwargs,):\n",
- " output = super().forward(*args, **kwargs)\n",
- " ...\n",
- " output[\"regression_output\"] = self.regression_head(embeddings)\n",
- " return output\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4 - Fine-Tuning Config"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "A `dataclass` that configures the fine-tuned ESM-2 model. In this example `ESM2FineTuneSeqConfig` inherits from `ESM2GenericConfig` and adds custom arguments to setup the fine-tuned model. The `configure_model()` method of this `dataclass` is called within the `Lightning` module to call the model constructor with the `dataclass` arguments.\n",
- "\n",
- "The common arguments among different fine-tuning tasks are\n",
- "\n",
- "- `model_cls`: The fine-tune model class defined in previous step (`ESM2FineTuneSeqModel`)\n",
- "- `initial_ckpt_path`: BioNeMo 2.0 ESM-2 pre-trained checkpoint\n",
- "- `initial_ckpt_skip_keys_with_these_prefixes`: skips keys when loading parameters from a checkpoint. For example, we should not look for `regression_head` in the pre-trained checkpoint.\n",
- "- `get_loss_reduction_class()`: Implements selection of the appropriate `MegatronLossReduction` class that we defined in the first step of this tutorial.\n",
- "\n",
- "```python\n",
- "\n",
- "@dataclass\n",
- "class ESM2FineTuneSeqConfig(\n",
- " ESM2GenericConfig[ESM2FineTuneSeqModel, RegressorLossReduction], iom.IOMixinWithGettersSetters\n",
- "):\n",
- " model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel\n",
- " # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in\n",
- " # self.override_parent_fields will be loaded from the checkpoint and override those values here.\n",
- " initial_ckpt_path: str | None = None\n",
- " # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint\n",
- " # that has this new head and want to keep using these weights, please drop this next line or set to []\n",
- " initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: [\"regression_head\"])\n",
- "\n",
- " encoder_frozen: bool = True # freeze encoder parameters\n",
- " ft_dropout: float = 0.25 # MLP layer dropout\n",
- "\n",
- " def get_loss_reduction_class(self) -> Type[MegatronLossReduction]:\n",
- " return RegressorLossReduction\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 5 - Dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "We will use a sample dataset for demonstration purposes. Create a dataset class by extending ```bionemo.esm2.model.finetune.dataset.InMemoryProteinDataset```. The `InMemoryProteinDataset` has a `classmethod` (`from_csv`) that reads data from a CSV file that has `sequences` and optionally `labels` columns. It is important to override the `transform_label()` method that returns a `torch.Tensor` containing the label in correct format. As an example we can use this method to add custom tokenization if `label` is a string."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "The custom dataset class will be appropriate (found in ```bionemo.esm2.model.finetune.dataset.InMemorySingleValueDataset```) as it facilitates predicting on a single value. An excerpt from the class is shown below. This example dataset has a class method `from_csv()` that expects a `data_path` to a CSV file that has `sequences`, and `labels` columns.\n",
- "\n",
- "```python\n",
- "class InMemorySingleValueDataset(InMemoryProteinDataset):\n",
- " def __init__(\n",
- " self,\n",
- " sequences: pd.Series,\n",
- " labels: pd.Series | None = None,\n",
- " tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),\n",
- " seed: int = np.random.SeedSequence().entropy,\n",
- " ):\n",
- " super().__init__(sequences, labels, tokenizer, seed)\n",
- "\n",
- " def transform_label(self, label: float) -> Tensor:\n",
- " return torch.tensor([label], dtype=torch.float)\n",
- "```\n",
- "\n",
- "The `transform_label` method allows for custom transformation of raw labels by casting or tokenization and need to be adjusted based on the data. Here we use this method to create a `float` tensor of the regression value."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### DataModule"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "To coordinate the creation of training, validation and testing datasets from your data, we need to use a `datamodule` class. To do this we can directly use or extend the ```ESM2FineTuneDataModule``` class (located at ```bionemo.esm2.model.finetune.datamodule.ESM2FineTuneDataModule```) which defines helpful abstract methods that use your dataset class.\n",
- "\n",
- "```python\n",
- "dataset = InMemorySingleValueDataset.from_csv(data_path)\n",
- "data_module = ESM2FineTuneDataModule(\n",
- " train_dataset=dataset,\n",
- " valid_dataset=dataset\n",
- " micro_batch_size=4, # size of a batch to be processed in a device\n",
- " global_batch_size=8, # size of batch across all devices. Should be multiple of micro_batch_size\n",
- ")\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In the next part of this tutorial we will prepare the input needed to run regression and per-token-classification fine-tuning examples."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Setup and Assumptions"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "All commands should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. For more information on how to build or pull the BioNeMo2 container, refer to the [Initialization Guide](../../getting-started/initialization-guide.md)."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- " NOTE Some of the cells below generate long text output. We're using
%%capture --no-display --no-stderr cell_output
to suppress this output. Comment or delete this line in the cells below to restore full output.
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Import Required Libraries"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "import os\n",
- "import shutil\n",
- "import pandas as pd\n",
- "\n",
- "import warnings\n",
- "warnings.filterwarnings('ignore')\n",
- "warnings.simplefilter('ignore')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Work Directory"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Set the work directory to store data and results:\n",
- "\n",
- " NOTE We set the following to clean up the work directory created by this notebook
cleanup : bool = True
"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "cleanup : bool = True"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Directory '/workspace/bionemo2/esm2_finetune_tutorial' created.\n"
- ]
- }
- ],
- "source": [
- "work_dir=\"/workspace/bionemo2/esm2_finetune_tutorial\"\n",
- "\n",
- "if cleanup and os.path.exists(work_dir):\n",
- " shutil.rmtree(work_dir)\n",
- "\n",
- "if not os.path.exists(work_dir):\n",
- " os.makedirs(work_dir)\n",
- " print(f\"Directory '{work_dir}' created.\")\n",
- "else:\n",
- " print(f\"Directory '{work_dir}' already exists.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Download Pre-trained Model Checkpoints"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The following code will download the internally pre-trained model, `esm2/8m:2.0`, from the NGC registry. Please refer to [ESM-2 Model Overview](../../../models/ESM-2/index.md) for a list of available checkpoints."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "/home/ubuntu/.cache/bionemo/b4ea4d52eea8a25d2c2838617ff678f0da22d384cee195b0c192686816078dcd-esm2_8m_checkpoint.tar.gz.untar\n"
- ]
- }
- ],
- "source": [
- "from bionemo.core.data.load import load\n",
- "\n",
- "checkpoint_path = load(\"esm2/8m:2.0\")\n",
- "print(checkpoint_path)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The above example is downloading an internally trained 8M ESM-2 model. The pre-trained checkpoints can be downloaded from NGC resources using either the following bash command or the `load` function in `bionemo.core.data.load` as shown above.\n",
- "\n",
- "```bash\n",
- "download_bionemo_data esm2/650m:2.0\n",
- "```\n",
- "\n",
- "which returns the checkpoint path (e.g. `.../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar`)\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Fine-tuning"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can take advantage of the ESM2 fine-tuning script in ```bionemo.esm2.scripts.finetune_esm2``` or use the ```finetune_esm2``` executable the fine-tuning process given:\n",
- "\n",
- "- Pre-trained checkpoint of ESM2\n",
- "- Finetune config class name that configures the finetune model and loss reduction\n",
- "- Path to train and validation CSV data files\n",
- "- Dataset class name\n",
- "\n",
- "To get the full list of arguments to tune a finetuning run use:\n",
- "```bash\n",
- "finetune_esm2 --help \n",
- "```\n",
- "For a detailed description of training loop and the arguments please refer to the [ESM-2 Pretraining](./pretrain.md) tutorial."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- " NOTE\n",
- "\n",
- "Due to Megatron limitations, the log produced by the training run iterates on steps/iterations and not epochs. Therefore, `Training epoch` counter stays at value zero while `iteration` and `global_step` increase during the course of training (example in the following).\n",
- "\n",
- "
\n",
- "Training epoch 0, iteration | ... | global_step: | reduced_train_loss: ... | val_loss: ...\n",
- "
\n",
- "\n",
- "to achieve the same epoch-based effect while training, please choose the number of training steps (`num_steps`) so that:\n",
- "\n",
- "
\n",
- "num_steps * global_batch_size = len(dataset) * desired_num_epochs\n",
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Regression"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For the purposes of this demo, we'll assume dataset consists of small set of protein sequences with a target value of `len(sequence) / 100.0` as their labels:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "\n",
- "artificial_sequence_data = [\n",
- " \"TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI\",\n",
- " \"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
- " \"GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN\",\n",
- " \"DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
- " \"KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI\",\n",
- " \"LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP\",\n",
- " \"LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
- " \"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF\",\n",
- " \"ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP\",\n",
- " \"SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT\",\n",
- "]\n",
- "\n",
- "data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data]\n",
- "\n",
- "# Create a DataFrame\n",
- "df = pd.DataFrame(data, columns=[\"sequences\", \"labels\"])\n",
- "\n",
- "# Save the DataFrame to a CSV file\n",
- "data_path = os.path.join(work_dir, \"regression_data.csv\")\n",
- "df.to_csv(data_path, index=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "! finetune_esm2 \\\n",
- " --restore-from-checkpoint-path {checkpoint_path} \\\n",
- " --train-data-path {data_path} \\\n",
- " --valid-data-path {data_path} \\\n",
- " --config-class ESM2FineTuneSeqConfig \\\n",
- " --dataset-class InMemorySingleValueDataset \\\n",
- " --experiment-name \"regression\" \\\n",
- " --num-steps 50 \\\n",
- " --num-gpus 1 \\\n",
- " --val-check-interval 10 \\\n",
- " --log-every-n-steps 10 \\\n",
- " --lr 5e-3 \\\n",
- " --result-dir {work_dir} \\\n",
- " --micro-batch-size 2 \\\n",
- " --num-gpus 1 \\\n",
- " --precision \"bf16-mixed\"\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The previous cell executes the finetuning and saves the checkpoints at the end of the run. The checkpoint path is logged at the end of the finetuning log file: \n",
- "\n",
- "```\n",
- "[NeMo I $$$$-$$-$$ 22:04:28 nemo_logging:393] Async checkpoint save for step 50 (/workspace/bionemo2/esm2_finetune_tutorial/regression/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last-v1.ckpt) finalized successfully.\n",
- "```\n",
- "\n",
- "To avoid long text output from the previous cell, the log is captured and stored into the `cell_output` variable. To visualize the log file uncomment and execute the next cell:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "# print(cell_output.stdout)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can now use the checkpoint stored in the previous step to run inference. We will drop the `.ckpt` from the checkpoint path and provide that to the `--checkpoint-path` argument of `infer_esm2` executable.\n",
- "\n",
- "The input `--data-path` for inference is a CSV file with `sequences` column. It is also required to provide the appropriate `--config-class` name to load the model from the checkpoint. For a detailed description of inference arguments please refer to the [ESM-2 Inference](./inference.ipynb) tutorial."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Create a DataFrame\n",
- "df = pd.DataFrame(artificial_sequence_data, columns=[\"sequences\"])\n",
- "\n",
- "# Save the DataFrame to a CSV file\n",
- "data_path = os.path.join(work_dir, \"sequences.csv\")\n",
- "df.to_csv(data_path, index=False)\n",
- "\n",
- "checkpoint_path = f\"{work_dir}/regression/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last\"\n",
- "results_path = f\"{work_dir}/regression/infer/\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
- " --config-class ESM2FineTuneSeqConfig \\\n",
- " --data-path {data_path} \\\n",
- " --results-path {results_path} \\\n",
- " --micro-batch-size 3 \\\n",
- " --num-gpus 1 \\\n",
- " --precision \"bf16-mixed\" \\\n",
- " --include-embeddings \\\n",
- " --include-input-ids"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The inference results are written into a `.pt` file which can be loaded using PyTorch library:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input_ids\ttorch.Size([10, 1024])\n",
- "embeddings\ttorch.Size([10, 320])\n",
- "regression_output\ttorch.Size([10, 1])\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "results = torch.load(f\"{results_path}/predictions__rank_0.pt\")\n",
- "\n",
- "for key, val in results.items():\n",
- " if val is not None:\n",
- " print(f'{key}\\t{val.shape}')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Toke-level Classification data"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For this task we assign secondary structure label to each token in the sequence:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "secondary_structure_labels = [\n",
- " \"EEEECCCCCHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE\",\n",
- " \"CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
- " \"HHHHHCCCCCHHHHHHHHHHHHHHCCCHHHHHHHHHH\",\n",
- " \"HHHHHHHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
- " \"CHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE\",\n",
- " \"HHHHHHHHHHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC\",\n",
- " \"HHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
- " \"CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\",\n",
- " \"HHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC\",\n",
- " \"CCCCCCCCCCCCCCCCCCCCCCCCCCEEECCCCEEECHHHHHHHHHCCCCCCCCEEECCCCCC\",\n",
- "]\n",
- "\n",
- "data = [(seq, label) for (seq, label) in zip(artificial_sequence_data, secondary_structure_labels)]\n",
- "\n",
- "# Create a DataFrame\n",
- "df = pd.DataFrame(data, columns=[\"sequences\", \"labels\"])\n",
- "\n",
- "# Save the DataFrame to a CSV file\n",
- "data_path = os.path.join(work_dir, \"token_classification_data.csv\")\n",
- "df.to_csv(data_path, index=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "! finetune_esm2 \\\n",
- " --restore-from-checkpoint-path {checkpoint_path} \\\n",
- " --train-data-path {data_path} \\\n",
- " --valid-data-path {data_path} \\\n",
- " --config-class ESM2FineTuneTokenConfig \\\n",
- " --dataset-class InMemoryPerTokenValueDataset \\\n",
- " --experiment-name \"token_level_classification\" \\\n",
- " --num-steps 50 \\\n",
- " --num-gpus 1 \\\n",
- " --val-check-interval 10 \\\n",
- " --log-every-n-steps 10 \\\n",
- " --lr 5e-3 \\\n",
- " --result-dir {work_dir} \\\n",
- " --micro-batch-size 2 \\\n",
- " --num-gpus 1 \\\n",
- " --precision \"bf16-mixed\"\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The previous cell executes the finetuning and saves the checkpoints at the end of the run. The checkpoint path is logged at the end of the finetuning log file: \n",
- "\n",
- "```\n",
- "[NeMo I $$$$-$$-$$ 22:16:46 nemo_logging:393] Async checkpoint save for step 50 (/workspace/bionemo2/esm2_finetune_tutorial/token_level_classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last.ckpt) finalized successfully.\n",
- "```\n",
- "\n",
- "To avoid long text output from the previous cell, the log is captured and stored into the `cell_output` variable. To visualize the log file uncomment and execute the next cell:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "# print(cell_output.stdout)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can now use the checkpoint stored in the previous step to run inference. We will drop the `.ckpt` from the checkpoint path and provide that to the `--checkpoint-path` argument of `infer_esm2` executable.\n",
- "\n",
- "The input `--data-path` for inference is a CSV file with `sequences` column. It is also required to provide the appropriate `--config-class` name to load the model from the checkpoint. For a detailed description of inference arguments please refer to the [ESM-2 Inference](./inference.ipynb) tutorial."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Create a DataFrame\n",
- "df = pd.DataFrame(artificial_sequence_data, columns=[\"sequences\"])\n",
- "\n",
- "# Save the DataFrame to a CSV file\n",
- "data_path = os.path.join(work_dir, \"sequences.csv\")\n",
- "df.to_csv(data_path, index=False)\n",
- "\n",
- "checkpoint_path = f\"{work_dir}/token_level_classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last\"\n",
- "results_path = f\"{work_dir}/token_level_classification/infer/\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture --no-display --no-stderr cell_output\n",
- "\n",
- "! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
- " --config-class ESM2FineTuneTokenConfig \\\n",
- " --data-path {data_path} \\\n",
- " --results-path {results_path} \\\n",
- " --micro-batch-size 3 \\\n",
- " --num-gpus 1 \\\n",
- " --precision \"bf16-mixed\" \\\n",
- " --include-embeddings \\\n",
- " --include-hiddens \\\n",
- " --include-input-ids"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The inference results are written into a `.pt` file which can be loaded using PyTorch library:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "hidden_states\ttorch.Size([10, 1024, 320])\n",
- "input_ids\ttorch.Size([10, 1024])\n",
- "embeddings\ttorch.Size([10, 320])\n",
- "classification_output\ttorch.Size([10, 1024, 3])\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "results = torch.load(f\"{results_path}/predictions__rank_0.pt\")\n",
- "\n",
- "for key, val in results.items():\n",
- " if val is not None:\n",
- " print(f'{key}\\t{val.shape}')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can use the label tokenizer to convert the classification output to class names. Note that for demonstration purposes we are using a small dataset of artificial sequences in this example. You may experience over-fitting and observe no change in the validation metrics. This amount of data and the short training run does not result in accurate predictions."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [],
- "source": [
- "from bionemo.esm2.data.tokenizer import get_tokenizer\n",
- "\n",
- "\n",
- "tokenizer = get_tokenizer()\n",
- "tokens = tokenizer.all_tokens\n",
- "aa_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']\n",
- "aa_indices = [i for i, token in enumerate(tokens) if token in aa_tokens]\n",
- "extra_indices = [i for i, token in enumerate(tokens) if token not in aa_tokens]\n",
- "\n",
- "input_ids = results['input_ids'] # b, s\n",
- "# mask where non-amino acid tokens are True\n",
- "mask = ~torch.isin(input_ids, torch.tensor(extra_indices))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [],
- "source": [
- "from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer\n",
- "\n",
- "label_tokenizer = Label2IDTokenizer()\n",
- "label_tokenizer = label_tokenizer.build_vocab(secondary_structure_labels)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Predicted Secondary Structures:\n",
- "HHHHEEEEECCCCCCCCCCCCCCCEEEHHHHHHEEECCCCCCCCCEEEEEEEEEHHH\n",
- "EEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
- "CCCCCEEEEECCCCCCCCCCCCCEEEECCCCCCCCCC\n",
- "CCCCCCCCCCEEECCCCCEEEEEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
- "ECCCCCCCCCCCCCCCEEEHHHHHHEEECCCCCCCCCEEEEEEEEEHHH\n",
- "CCCCCCCCCCCCCECCCCCCCCCCCCEEEHHEEEHHHHEEHHHHHEE\n",
- "CCCCCEEECCCCCEEEEEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
- "EEEEECCCCCCCCCCCCCCEEEEECCCCCCEE\n",
- "CCCCCECCCCCCCCCCCCEEEHHEEEHHHHEEHHHHHEE\n",
- "EEEEEEEEEEEEEEEEEEEEEEEEEEHHHEEEEHHHECCCCCCCCCEEEEEEEEHHHEEEEEE\n"
- ]
- }
- ],
- "source": [
- "output_ids = torch.argmax(results[\"classification_output\"], dim=-1)\n",
- "\n",
- "print(\"Predicted Secondary Structures:\")\n",
- "for i in range(output_ids.shape[0]):\n",
- " ss_ids = output_ids[i][mask[i]]\n",
- " print(label_tokenizer.ids_to_text(ss_ids.tolist()))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}