diff --git a/.gitignore b/.gitignore index 5fe5a0ab48..0f191692fd 100644 --- a/.gitignore +++ b/.gitignore @@ -194,3 +194,4 @@ coverage.xml Thumbs.db .python_history +test_experiment/* diff --git a/sub-packages/bionemo-amplify/LICENSE b/sub-packages/bionemo-amplify/LICENSE new file mode 120000 index 0000000000..61bc2cda7e --- /dev/null +++ b/sub-packages/bionemo-amplify/LICENSE @@ -0,0 +1 @@ +../../LICENSE/license.txt \ No newline at end of file diff --git a/sub-packages/bionemo-amplify/README.md b/sub-packages/bionemo-amplify/README.md new file mode 100644 index 0000000000..debd96bb63 --- /dev/null +++ b/sub-packages/bionemo-amplify/README.md @@ -0,0 +1,13 @@ +# bionemo-amplify + + +### Setup +To install, execute the following: +```bash +pip install -e . +``` + +To run unit tests, execute: +```bash +pytest -v . +``` diff --git a/sub-packages/bionemo-amplify/VERSION b/sub-packages/bionemo-amplify/VERSION new file mode 120000 index 0000000000..558194c5a5 --- /dev/null +++ b/sub-packages/bionemo-amplify/VERSION @@ -0,0 +1 @@ +../../VERSION \ No newline at end of file diff --git a/sub-packages/bionemo-amplify/pyproject.toml b/sub-packages/bionemo-amplify/pyproject.toml new file mode 100644 index 0000000000..37998a798d --- /dev/null +++ b/sub-packages/bionemo-amplify/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "bionemo-amplify" +readme = "README.md" +description = "BioNeMo AMPLIFY" +authors = [{ name = "BioNeMo Team", email = "bionemofeedback@nvidia.com" }] +requires-python = ">=3.10" +license = { file = "LICENSE" } +dynamic = ["version"] +dependencies = [ + # bionemo sub-packages + 'bionemo-core', + 'bionemo-esm2', + 'bionemo-llm', + # external +] + +[project.scripts] +train_amplify = "bionemo.amplify.scripts.train_amplify:train_amplify_entrypoint" + +# Make sure that the tokenizer files are included along with the python files during installation. +[tool.setuptools.package-data] +"bionemo.amplify" = ["data/tokenizer/*.json", "data/tokenizer/*.txt"] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["bionemo.*"] +namespaces = true +exclude = ["test*."] + +[tool.setuptools.dynamic] +version = { file = "VERSION" } + +[tool.uv] +cache-keys = [{ git = true }] diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/__init__.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/api.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/api.py new file mode 100644 index 0000000000..1a481b94e5 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/api.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Sequence + +from bionemo.amplify.model.model import AMPLIFYConfig, AMPLIFYModel + + +__all__: Sequence[str] = ( + "AMPLIFYConfig", + "AMPLIFYModel", +) diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/data/__init__.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/data/datamodule.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/datamodule.py new file mode 100644 index 0000000000..71ba0a99a0 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/datamodule.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +from typing import Literal + +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from nemo.lightning.data import WrappedDataLoader +from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging + +from bionemo.core.data.multi_epoch_dataset import MultiEpochDatasetResampler +from bionemo.amplify.data import dataset, tokenizer +from bionemo.llm.data import collate +from bionemo.llm.data.datamodule import MegatronDataModule +from bionemo.llm.utils.datamodule_utils import infer_num_samples + + +Mode = Literal["train", "validation", "test"] + + +class AMPLIFYDataModule(MegatronDataModule): + """LightningDataModule wrapper of `AMPLIFYDataset`.""" + def __init__( + self, + hf_dataset_name: str = "chandar-lab/UR100P", + seed: int | None = 42, + min_seq_length: int | None = None, + max_seq_length: int = 512, + micro_batch_size: int = 512, + global_batch_size: int = 4096, + num_workers: int = 10, # TODO(@jomitchell) can this be automatically set? + persistent_workers: bool = True, + pin_memory: bool = True, + rampup_batch_size: list[int] | None = None, + mask_prob: float = 0.15, + mask_token_prob: float = 0.8, + mask_random_prob: float = 0.1, + random_mask_strategy: dataset.RandomMaskStrategy = dataset.RandomMaskStrategy.AMINO_ACIDS_ONLY, + tokenizer: tokenizer.BioNeMoAMPLIFYTokenizer = tokenizer.get_tokenizer(), + dataloader_type: Literal["single", "cyclic"] = "single", + ) -> None: + """Initialize the AMPLIFYDataModule. + + Args: + hf_dataset_name: The name of the HuggingFace dataset. Defaults to "chandar-lab/UR100P". + seed: Input random seed. If None, initializes randomly. Defaults to 42. + min_seq_length: Whether to pad sequences to a minimum length. If None, no extra padding is added. Defaults + to None. + max_seq_length: The maximum context length for the AMPLIFY transformer. Defaults to 512. + micro_batch_size: Passed to MegatronDataSampler. Defaults to 512. + global_batch_size: Passed to MegatronDataSampler. Defaults to 4096. + num_workers: The number of workers for the pytorch Dataloaders. Defaults to 10. + persistent_workers: Whether to keep the workers alive between epochs. Defaults to True. + pin_memory: Whether to pin GPU memory in the pytorch Dataloaders. Defaults to True. + rampup_batch_size: Passed to MegatronDataSampler. Defaults to None. + mask_prob: The overall chance of masking a token and having it appear in the loss fn. Defaults to 0.15. + mask_token_prob: Percentage of masked tokens that get assigned the id. Defaults to 0.8. + mask_random_prob: Percentage of masked tokens assigned to a random amino acid. Defaults to 0.1. + random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.AMINO_ACIDS_ONLY. + tokenizer: The AMPLIFY tokenizer. Defaults to the one returned by `tokenizer.get_tokenizer()`. + dataloader_type: The type of dataloader to use. Defaults to "single". + """ + super().__init__() + self._hf_dataset_name = hf_dataset_name + self._seed = seed + self._min_seq_length = min_seq_length + self._max_seq_length = max_seq_length + self._mask_prob = mask_prob + self._mask_token_prob = mask_token_prob + self._mask_random_prob = mask_random_prob + self._random_mask_strategy = random_mask_strategy + self._tokenizer = tokenizer + + self._micro_batch_size = micro_batch_size + self._num_workers = num_workers + self._persistent_workers = persistent_workers + self._pin_memory = pin_memory + + self.data_sampler = MegatronDataSampler( + seq_len=max_seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + dataloader_type=dataloader_type, # `MegatronPretrainingRandomSampler` from "cyclic" is failing. + rampup_batch_size=rampup_batch_size, + ) + + @property + def tokenizer(self) -> tokenizer.BioNeMoAMPLIFYTokenizer: + """Returns the tokenizer.""" + return self._tokenizer + + def setup(self, stage: str = "") -> None: + """Setup the AMPLIFYDataModule. + + Args: + stage: Unused. + + Raises: + RuntimeError: If the trainer is not attached, or if the trainer's max_steps is not set. + """ + del stage # Unused. + + if not hasattr(self, "trainer") or self.trainer is None: + raise RuntimeError("Setup should be completed when trainer and config are attached.") + + if self.trainer.max_epochs is not None and self.trainer.max_epochs > 1: + logging.warning( + "Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used " + "in each. Instead set max_epochs to 1 and increase the number of max_steps." + ) + + max_train_steps = self.trainer.max_steps + if max_train_steps <= 0: + raise RuntimeError("Please specify trainer.max_steps") + + # Create training dataset + num_train_samples = int( + max_train_steps * self.data_sampler.global_batch_size + ) # training data requires upsampling (multiply by max_train_steps) on single MegatronPretrainingRandomSampler + _train_ds = dataset.AMPLIFYMaskedResidueDataset(hf_dataset_name=self._hf_dataset_name, + dataset_subset=None, + split="train", + seed=self._seed, + max_seq_length=self._max_seq_length, + mask_prob=self._mask_prob, + mask_token_prob=self._mask_token_prob, + mask_random_prob=self._mask_random_prob, + random_mask_strategy=self._random_mask_strategy, + tokenizer=self._tokenizer) + self._train_ds = MultiEpochDatasetResampler(_train_ds, num_samples=num_train_samples, shuffle=True, seed=self._seed) + + # Create validation dataset + _valid_ds = dataset.AMPLIFYMaskedResidueDataset(hf_dataset_name=self._hf_dataset_name, + dataset_subset="UniProt", + split="test", + seed=self._seed, + max_seq_length=self._max_seq_length, + mask_prob=self._mask_prob, + mask_token_prob=self._mask_token_prob, + mask_random_prob=self._mask_random_prob, + random_mask_strategy=self._random_mask_strategy, + tokenizer=self._tokenizer) + num_val_samples = infer_num_samples(limit_batches=self.trainer.limit_val_batches, + num_samples_in_dataset=len(_valid_ds), + global_batch_size=self.data_sampler.global_batch_size, + stage="val") + self._valid_ds = MultiEpochDatasetResampler(_valid_ds, num_samples=num_val_samples, shuffle=False, seed=self._seed) + + assert ( + hasattr(self, "trainer") and self.trainer is not None + ), "Setup should be completed when trainer and config are attached." + + def _create_dataloader(self, dataset, mode: Mode, **kwargs) -> WrappedDataLoader: + """Create dataloader for train, validation, and test stages. + + Args: + dataset: The dataset to create the dataloader for. + mode: Stage of training, which is used to determined if consumed_samples in MegatronPretrainingSampler should be initialized to 0 (validation/test), or be set to the previous value from state_dict in case of checkpoint resumption (train). + **kwargs: Additional arguments to pass to the dataloader. + """ + self.update_init_global_step() + assert self._tokenizer.pad_token_id is not None, "Tokenizer must have a pad token id." + + return WrappedDataLoader( + mode=mode, + dataset=dataset, + num_workers=self._num_workers, + pin_memory=self._pin_memory, + persistent_workers=self._persistent_workers, + collate_fn=functools.partial( + collate.bert_padding_collate_fn, + padding_value=self._tokenizer.pad_token_id, + min_length=self._min_seq_length, + max_length=self._max_seq_length, + ), + **kwargs, + ) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + """Returns the dataloader for training data.""" + return self._create_dataloader(self._train_ds, mode="train") + + def val_dataloader(self) -> EVAL_DATALOADERS: + """Returns the dataloader for validation data.""" + return self._create_dataloader(self._valid_ds, mode="validation") + + def test_dataloader(self) -> EVAL_DATALOADERS: + """Raises a not implemented error.""" + raise NotImplementedError("No test dataset provided for AMPLIFY") diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/data/dataset.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/dataset.py new file mode 100644 index 0000000000..0f7da3bdee --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/dataset.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Literal + +import numpy as np +import torch +from torch.utils.data import Dataset + +from datasets import load_dataset as hf_load_dataset + +from bionemo.core.data.multi_epoch_dataset import EpochIndex +from bionemo.core.utils import random_utils +from bionemo.amplify.data import tokenizer +from bionemo.llm.data import masking +from bionemo.llm.data.types import BertSample +from bionemo.esm2.data.dataset import RandomMaskStrategy, _random_crop + +class AMPLIFYMaskedResidueDataset(Dataset): + """Dataset class for AMPLIFY pretraining that implements sampling of UR100P sequences. + """ + + def __init__( + self, + hf_dataset_name: str, + dataset_subset: str = None, + split: Literal["train", "test"] = "train", + seed: int = np.random.SeedSequence().entropy, # type: ignore + max_seq_length: int = 512, + mask_prob: float = 0.15, + mask_token_prob: float = 0.8, + mask_random_prob: float = 0.1, + random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.AMINO_ACIDS_ONLY, + tokenizer: tokenizer.BioNeMoAMPLIFYTokenizer = tokenizer.get_tokenizer(), + ) -> None: + """Initializes the dataset. + + Args: + hf_dataset_name: Name of the HuggingFace dataset containing UR100P protein sequences. + dataset_subset: Name of the UR100P HuggingFace dataset subset (or relative data_dir). + split: The split of the dataset to use ["train", "test"]. Defaults to "train". + total_samples: Total number of samples to draw from the dataset. + seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure + that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is + generated. + max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens. + mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15. + mask_token_prob: Proportion of masked tokens that get assigned the id. Defaults to 0.8. + mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1. + random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.AMINO_ACIDS_ONLY. + tokenizer: The input AMPLIFY tokenizer. Defaults to the standard AMPLIFY tokenizer. + """ + self.protein_dataset = hf_load_dataset(hf_dataset_name, data_dir=dataset_subset, split=split) + self.total_samples = len(self.protein_dataset) + self.seed = seed + self.max_seq_length = max_seq_length + self.random_mask_strategy = random_mask_strategy + + if tokenizer.mask_token_id is None: + raise ValueError("Tokenizer does not have a mask token.") + + self.mask_config = masking.BertMaskConfig( + tokenizer=tokenizer, + random_tokens=range(tokenizer.vocab_size) + if self.random_mask_strategy == RandomMaskStrategy.ALL_TOKENS + else range(6, tokenizer.vocab_size), + mask_prob=mask_prob, + mask_token_prob=mask_token_prob, + random_token_prob=mask_random_prob, + ) + + self.tokenizer = tokenizer + + def __len__(self) -> int: + """Returns the total number of sequences in the dataset. + """ + return self.total_samples + + def __getitem__(self, index: EpochIndex) -> BertSample: + """Deterministically masks and returns a protein sequence from the dataset. + Args: + index: The current epoch and the index of the cluster to sample. + + Returns: + A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields. + """ + # Initialize a random number generator with a seed that is a combination of the dataset seed, epoch, and index. + rng = np.random.default_rng([self.seed, index.epoch, index.idx]) + if index.idx >= len(self): + raise IndexError(f"Index {index.idx} out of range [0, {len(self)}).") + + sequence = self.protein_dataset[int(index.idx)]["sequence"] + + # We don't want special tokens before we pass the input to the masking function; we add these in the collate_fn. + tokenized_sequence = self._tokenize(sequence) + cropped_sequence = _random_crop(tokenized_sequence, self.max_seq_length, rng) + + # Get a single integer seed for torch from our rng, since the index tuple is hard to pass directly to torch. + torch_seed = random_utils.get_seed_from_rng(rng) + masked_sequence, labels, loss_mask = masking.apply_bert_pretraining_mask( + tokenized_sequence=cropped_sequence, # type: ignore + random_seed=torch_seed, + mask_config=self.mask_config, + ) + + return { + "text": masked_sequence, + "types": torch.zeros_like(masked_sequence, dtype=torch.int64), + "attention_mask": torch.ones_like(masked_sequence, dtype=torch.int64), + "labels": labels, + "loss_mask": loss_mask, + "is_random": torch.zeros_like(masked_sequence, dtype=torch.int64), + } + + def _tokenize(self, sequence: str) -> torch.Tensor: + """Tokenize a protein sequence. + + Args: + sequence: The protein sequence. + + Returns: + The tokenized sequence. + """ + tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt") + return tensor.flatten() # type: ignore diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/README.md b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/README.md new file mode 100644 index 0000000000..126d80f048 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/README.md @@ -0,0 +1,10 @@ +# Vendored tokenizer config for chandar-lab/AMPLIFY_350M + +This directory contains the output of + +```python +from transformers import AutoTokenizer +AutoTokenizer.from_pretrained("chandar-lab/AMPLIFY_350M").save_pretrained(".") +``` + +for reproducible results and to reduce reliance on external API calls. diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/__init__.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/__init__.py new file mode 100644 index 0000000000..2b00b8bd27 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/__init__.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from importlib.resources import files + +from transformers import AutoTokenizer +from nemo.lightning.io import IOMixin + + +class BioNeMoAMPLIFYTokenizer(AutoTokenizer, IOMixin): # noqa D101 + def __init__(self): + """A wrapper to make AutoTokenizer serializable. + """ + other = AutoTokenizer.from_pretrained(str(files("bionemo.amplify.data.tokenizer")), use_fast=True) + for attr in dir(other): + if not attr.startswith("_"): + setattr(self, attr, getattr(other, attr)) + #In case PreTrainedTokenizer is inherited and special token IDs are not in dir + if hasattr(other, "mask_token_id"): + setattr(self, "mask_token_id", getattr(other, "mask_token_id")) + if hasattr(other, "pad_token_id"): + setattr(self, "pad_token_id", getattr(other, "pad_token_id")) + + +@functools.cache +def get_tokenizer() -> BioNeMoAMPLIFYTokenizer: + """Get the tokenizer for the AMPLIFY model.""" + return BioNeMoAMPLIFYTokenizer() diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/special_tokens_map.json b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/special_tokens_map.json new file mode 100644 index 0000000000..91c92322af --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/special_tokens_map.json @@ -0,0 +1,37 @@ +{ + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "mask_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/tokenizer.json b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/tokenizer.json new file mode 100644 index 0000000000..3d89b94f15 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/tokenizer.json @@ -0,0 +1,154 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 4, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Split", + "pattern": { + "String": "" + }, + "behavior": "Removed", + "invert": false + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "pair": [ + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + } + ], + "special_tokens": { + "": { + "id": "", + "ids": [ + 3 + ], + "tokens": [ + "" + ] + }, + "": { + "id": "", + "ids": [ + 4 + ], + "tokens": [ + "" + ] + } + } + }, + "decoder": null, + "model": { + "type": "WordPiece", + "unk_token": "", + "continuing_subword_prefix": "##", + "max_input_chars_per_word": 100, + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3, + "": 4, + "|": 5, + "L": 6, + "A": 7, + "G": 8, + "V": 9, + "S": 10, + "E": 11, + "R": 12, + "T": 13, + "I": 14, + "D": 15, + "P": 16, + "K": 17, + "Q": 18, + "N": 19, + "F": 20, + "Y": 21, + "M": 22, + "H": 23, + "W": 24, + "C": 25, + "B": 26 + } + } +} \ No newline at end of file diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/tokenizer_config.json b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/tokenizer_config.json new file mode 100644 index 0000000000..d42c71d862 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/tokenizer_config.json @@ -0,0 +1,58 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "", + "clean_up_tokenization_spaces": true, + "eos_token": "", + "mask_token": "", + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 2048, + "pad_token": "", + "padding_side": "right", + "tokenizer_class": "PreTrainedTokenizerFast", + "truncation_side": "right", + "unk_token": "" +} diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/vocab.txt b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/vocab.txt new file mode 100644 index 0000000000..dfd20a325f --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/data/tokenizer/vocab.txt @@ -0,0 +1,27 @@ + + + + + +| +L +A +G +V +S +E +R +T +I +D +P +K +Q +N +F +Y +M +H +W +C +B \ No newline at end of file diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/model/__init__.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/model/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/model/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/model/model.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/model/model.py new file mode 100644 index 0000000000..2bee47e37e --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/model/model.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from dataclasses import dataclass +from typing import Callable, Literal, Optional, Sequence, Type, TypeVar + +import torch +from torch import Tensor +from torch.nn.functional import silu +from torch.optim import Optimizer + +from bionemo.amplify.data.tokenizer import BioNeMoAMPLIFYTokenizer +from bionemo.esm2.model.attention import ESM2TEDotProductAttention +from bionemo.esm2.model.embedding import ESM2Embedding + +from bionemo.llm.model.biobert.model import BioBertConfig, MegatronBioBertModel, PositionEmbeddingKinds +from bionemo.llm.api import MegatronLossType +from bionemo.llm.utils import iomixin_utils as iom +from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption + +from megatron.core import tensor_parallel +from megatron.core.models.bert.pooler import Pooler +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.transformer import spec_utils +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.utils import get_linear_layer + + +__all__: Sequence[str] = ( + "AMPLIFYConfig", + "AMPLIFYModel", +) + + +class AMPLIFYLMHead(MegatronModule): + """LM head for AMPLIFY + + Args: + hidden_size: hidden size + config (TransformerConfig): TransformerConfig object + """ + + def __init__(self, config: TransformerConfig): + super().__init__(config=config) + self.head = IdentityOp() + + def forward(self, hidden_states: Tensor) -> Tensor: + return self.head(hidden_states) + +class AMPLIFYModel(MegatronBioBertModel): + """AMPLIFY protein language model.""" + def __init__( + self, + config: TransformerConfig, + num_tokentypes: int, + transformer_layer_spec: spec_utils.ModuleSpec, + vocab_size: int, + max_sequence_length: int, + tokenizer: Optional[BioNeMoAMPLIFYTokenizer] = None, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal["learned_absolute", "rope"] = "rope", + rotary_percent: float = 1.0, + seq_len_interpolation_factor: Optional[float] = None, + add_binary_head: bool = True, + return_embeddings: bool = False, + include_embeddings: bool = False, + include_input_ids: bool = False, + use_full_attention_mask: bool = False, + include_hiddens: bool = False, + skip_logits: bool = False, + ) -> None: + """Initialize the AMPLIFY model. + + Args: + config (TransformerConfig): transformer config + num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers + vocab_size (int): vocabulary size + max_sequence_length (int): maximum size of sequence. This is used for positional embedding + tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model) + pre_process (bool): Include embedding layer (used with pipeline parallelism) + post_process (bool): Include an output layer (used with pipeline parallelism) + fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks + share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False. + position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope']. + Defaults is 'learned_absolute'. + rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. + Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None. + add_binary_head (bool): Whether to add a binary head. Defaults to True. + return_embeddings (bool): Whether to return embeddings. Defaults to False. + include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False. + include_input_ids (bool): Whether to include input_ids in the output dictionary. Defaults to False. + use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False. + include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False. + skip_logits (bool): Skip writing the token logits in output dict + """ + super(MegatronBioBertModel, self).__init__(config=config) + self.post_process = post_process + self.add_binary_head = add_binary_head + if return_embeddings: + assert self.post_process, "only return embeddings on the last pipeline stage" + # `b` = batch, `s` = sequence. + # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while + # the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two. + self.use_full_attention_mask = use_full_attention_mask + self.config: TransformerConfig = config + self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + self.add_binary_head = add_binary_head + self.return_embeddings = return_embeddings + self.include_embeddings = include_embeddings + self.include_hiddens = include_hiddens + self.include_input_ids = include_input_ids + self.skip_logits = skip_logits + + # megatron core pipelining currently depends on model type + self.model_type = ModelType.encoder_or_decoder + + if config.gated_linear_unit: + # To keep the number of parameters and the amount of computation constant, we reduce the number of + # hidden units by a factor of 2/3 (https://arxiv.org/pdf/2002.05202.pdf) and make it a multiple of 8 to + # avoid RuntimeError due to misaligned operand + multiple_of = 8 + config.ffn_hidden_size = int(2 * config.ffn_hidden_size / 3) + config.ffn_hidden_size = multiple_of * ((config.ffn_hidden_size + multiple_of - 1) // multiple_of) + self.config.ffn_hidden_size = config.ffn_hidden_size + + # Embeddings. + if self.pre_process: + self.register_buffer( + "bert_position_id_tensor", + torch.arange(max_sequence_length, dtype=torch.long, requires_grad=False).unsqueeze(0), + persistent=False, + ) + # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding + # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor. + # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True)) + self.embedding = ESM2Embedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=position_embedding_type, + num_tokentypes=num_tokentypes, + # ESM2 NEW ARGS + token_dropout=self.config.token_dropout, + use_attention_mask=self.config.use_attention_mask, + mask_token_id=tokenizer.mask_token_id, + ) + + if self.position_embedding_type == "rope": + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + ) + + # Transformer. + self.encoder = TransformerBlock( + config=self.config, + spec=self.transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + # Output + if post_process: + # TODO: Make sure you are passing in the mpu_vocab_size properly + self.lm_head = AMPLIFYLMHead(config) + + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=True, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, + ) + + self.binary_head = None + if self.add_binary_head: + # TODO: Shoudl switch this to TE ? + self.binary_head = get_linear_layer( + config.hidden_size, 2, config.init_method, config.perform_initialization + ) + + self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel) + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + def embedding_forward( + self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None + ): + """Forward pass of the embedding layer. + + Args: + input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs. + position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs. + tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None. + attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None. + + Returns: + Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations. + """ + # ESM2 Customization: ESM2Embedding forward takes attention_mask + # in addition to the args required by LanguageModelEmbedding + return self.embedding( + input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask + ) + + +AMPLIFYModelT = TypeVar("AMPLIFYModelT", bound=AMPLIFYModel) + +@dataclass +class AMPLIFYConfig(BioBertConfig[AMPLIFYModelT, MegatronLossType], iom.IOMixinWithGettersSetters): + """Configuration class for AMPLIFY model. + + Attributes: + num_layers: Number of layers in the model. + hidden_size: Hidden size of the model. + num_attention_heads: Number of attention heads in the model. + ffn_hidden_size: Hidden size of the feed-forward network. + hidden_dropout: Dropout rate for hidden layers. + attention_dropout: Dropout rate for attention layers. + apply_residual_connection_post_layernorm: Whether to apply residual connection after layer normalization. + layernorm_epsilon: Epsilon value for layer normalization. + layernorm_zero_centered_gamma: Whether to zero-center the gamma parameter in layer normalization. + activation_func: Activation function used in the model. + init_method_std: Standard deviation for weight initialization. + apply_query_key_layer_scaling: Whether to apply scaling to query and key layers. + masked_softmax_fusion: Whether to use a kernel that fuses attention softmax with its mask. + fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + share_embeddings_and_output_weights: Whether to share embeddings and output weights. + enable_autocast: Whether to enable autocast for mixed precision. + biobert_spec_option: BiobertSpecOption for the model. + position_embedding_type: Type of position embedding used in the model. + seq_length: Length of the input sequence. + make_vocab_size_divisible_by: Make the vocabulary size divisible by this value. + token_dropout: Whether to apply token dropout. + use_attention_mask: Whether to use attention mask. + use_esm_attention: Whether to use ESM attention. + attention_softmax_in_fp32: Whether to use fp32 for attention softmax. + optimizer_fn: Optional optimizer function for the model. + parallel_output: Whether to use parallel output. + rotary_base: Base value for rotary positional encoding. + rotary_percent: Percentage of rotary positional encoding. + seq_len_interpolation_factor: Interpolation factor for sequence length. + get_attention_mask_from_fusion: Whether to get attention mask from fusion. + nemo1_ckpt_path: Path to NEMO1 checkpoint. + return_only_hidden_states: Whether to return only hidden states. + loss_reduction_class: Loss reduction class for the model. Default to BERTMLMLossWithReduction. + """ + + # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269 + model_cls: Type[AMPLIFYModelT] = AMPLIFYModel + seq_length: int = 512 + num_layers: int = 24 # 32 for 350M, 24 for 120M + hidden_size: int = 640 # 960 for 350M, 640 for 120M + num_attention_heads: int = 10 # 15 for 350M, 10 for 120M + ffn_hidden_size: int = 2560 # Transformer FFN hidden size. Usually 4 * hidden_size. + hidden_dropout: float = 0 # AMPLIFY removes dropout from hidden layers and attention + attention_dropout: float = 0.0 # AMPLIFY does not use attention dropout + apply_residual_connection_post_layernorm: bool = False # TODO: farhadr False is new default, True was BERT pub. + layernorm_epsilon: float = 1.0e-5 + init_method_std: float = 0.02 + + # embedding + token_dropout: bool = True + use_attention_mask: bool = True + + # core attention + use_esm_attention: bool = False # Skip ESM2 custom attention for TE acceleration. Still passes golden value test. + attention_softmax_in_fp32: bool = False + normalize_attention_scores: bool = False + + # From megatron.core.models.gpt.bert_model.GPTModel + fp16_lm_cross_entropy: bool = False # Move the cross entropy unreduced loss calculation for lm head to fp16 + parallel_output: bool = True + share_embeddings_and_output_weights: bool = True + make_vocab_size_divisible_by: int = 1 + position_embedding_type: PositionEmbeddingKinds = "rope" + rotary_base: int = 10000 + rotary_percent: float = 1. + + #AMPLIFY specific configuration + add_bias_linear: bool = False # AMPLIFY does not use bias in linear layers + bias_swiglu_fusion: bool = True + bias_activation_fusion: bool = False + bias_dropout_fusion: bool = False + apply_rope_fusion: bool = True + gated_linear_unit: bool = True + activation_func: str = silu + normalization: str = "RMSNorm" # AMPLIFY uses RMSNorm instead of LayerNorm + layernorm_zero_centered_gamma: bool = False # Zero centered gamma not supported for RMSNorm + biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec + + # TODO: Move this to better places? + get_attention_mask_from_fusion: bool = False + + optimizer_fn: Optional[Callable[[MegatronBioBertModel], Optimizer]] = None + # TODO (@skothenhill,@georgea) update to use the nemo2 checkpoint mixins + # support HF (requires weight interleaving on qkv layer) and nemo1 checkpoints ideally. + nemo1_ckpt_path: str | None = None + # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in + # self.override_parent_fields will be loaded from the checkpoint and override those values here. + initial_ckpt_path: str | None = None + # TODO (@jstjohn) come up with a cleaner way in the biobert module to return user requested + # things as part of the workflow for inference and fine-tuning. + return_embeddings: bool = False + include_embeddings: bool = False + include_input_ids: bool = False + skip_logits: bool = False + return_only_hidden_states: bool = False # return logits + + def __post_init__(self): + """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization.""" + super().__post_init__() + + if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec: + self.apply_query_key_layer_scaling = False + elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec: + logging.warning( + "BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead." + ) + self.apply_query_key_layer_scaling = True + else: + raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}") + + self.core_attention_override = ESM2TEDotProductAttention diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/run/__init__.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/run/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/run/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/scripts/README.md b/sub-packages/bionemo-amplify/src/bionemo/amplify/scripts/README.md new file mode 100644 index 0000000000..e7d951c23d --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/scripts/README.md @@ -0,0 +1,3 @@ +## AMPLIFY Scripts Directory +This is a collection for one-off scripts that can be ran through the command line. See the `[project.scripts]` section +of the pyproject.toml file for how these are generated. diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/scripts/__init__.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/scripts/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/scripts/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/scripts/train_amplify.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/scripts/train_amplify.py new file mode 100644 index 0000000000..a3a6daff76 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/scripts/train_amplify.py @@ -0,0 +1,664 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path +from typing import List, Optional, Sequence, get_args + +from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning import resume +from nemo.lightning.pytorch import callbacks as nl_callbacks +from nemo.lightning.pytorch.optim import MegatronOptimizerModule + +from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype +from bionemo.amplify.api import AMPLIFYConfig +from bionemo.amplify.data.datamodule import AMPLIFYDataModule +from bionemo.esm2.data.dataset import RandomMaskStrategy +from bionemo.amplify.data.tokenizer import get_tokenizer +from bionemo.llm.lightning import PerplexityLoggingCallback +from bionemo.llm.model.biobert.lightning import biobert_lightning_module +from bionemo.llm.model.biobert.model import BiobertSpecOption +from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size +from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger + + +__all__: Sequence[str] = ("main", "parser") + + +def main( + hf_dataset_name: str, + num_nodes: int, + devices: int, + min_seq_length: Optional[int], + max_seq_length: int, + result_dir: Path, + num_steps: int, + warmup_steps: int, + decay_steps: int, + limit_val_batches: int, + val_check_interval: int, + log_every_n_steps: Optional[int], + num_dataset_workers: int, + biobert_spec_option: BiobertSpecOption, # TODO(@farhadrgh) clarify how to parse this. + lr: float, + micro_batch_size: int, + accumulate_grad_batches: int, + experiment_name: str, + resume_if_exists: bool, + precision: PrecisionTypes, + wandb_entity: Optional[str] = None, + wandb_project: Optional[str] = None, + wandb_offline: bool = False, + wandb_tags: Optional[List[str]] = None, + wandb_group: Optional[str] = None, + wandb_id: Optional[str] = None, + wandb_anonymous: Optional[bool] = False, + wandb_log_model: bool = True, + pipeline_model_parallel_size: int = 1, + tensor_model_parallel_size: int = 1, + create_tensorboard_logger: bool = False, + nemo1_init_path: Optional[Path] = None, + restore_from_checkpoint_path: Optional[str] = None, + save_best_checkpoint: bool = True, + save_last_checkpoint: bool = True, + metric_to_monitor_for_checkpoints: str = "val_loss", + save_top_k: int = 2, + nsys_profiling: bool = False, + nsys_start_step: int = 0, + nsys_end_step: Optional[int] = None, + nsys_ranks: List[int] = [0], + random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS, + num_layers: int = 32, + hidden_size: int = 960, + num_attention_heads: int = 15, + ffn_hidden_size: int = 960 * 4, + overlap_grad_reduce: bool = True, + overlap_param_gather: bool = False, # TODO waiting for a NeMo fix + average_in_collective: bool = True, + grad_reduce_in_fp32: bool = False, +) -> None: + """Train an AMPLIFY model on UR100P data. + + Args: + hf_dataset_name: Name of the HuggingFace dataset containing UR100P protein sequences. + num_nodes (int): Number of nodes to run on + devices (int): number of devices + seq_length (int): sequence length + result_dir (Path): directory to store results, logs and checkpoints + wandb_entity (str): The team posting this run (default: your username or your default team) + wandb_project (str): The name of the project to which this run will belong. + wandb_tags (List[str]): Tags associated with this run. + wandb_group (str): A unique string shared by all runs in a given group + wandb_offline (bool): Run offline (data can be streamed later to wandb servers). + wandb_id (str): Sets the version, mainly used to resume a previous run. + wandb_anonymous (bool): Enables or explicitly disables anonymous logging. + wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers. + num_steps (int): number of steps to train the model for + limit_val_batches (int): limit the number of validation global batches to this many + val_check_interval (int): number of steps to periodically check the validation loss and save num_dataset_workers ( + int): num dataset workers + biobert_spec_option (BiobertSpecOption): the biobert spec option (architecture) to use for this run + lr (float): learning rate + micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size + experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the + result_dir that stores the logs and checkpoints. + resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet] + create_tensorboard_logger (bool): create the tensorboard logger + restore_from_checkpoint_path (Optional[str]): If set, restores the model from the directory passed in. Expects the + checkpoint to be created by using the ModelCheckpoint class and always_save_context=True. + save_best_checkpoint (bool): whether to save the best checkpoint + save_last_checkpoint (bool): whether to save the last checkpoint + metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints + save_top_k (int): number of top checkpoints to save + nsys_profiling (bool): whether to enable nsys profiling + nsys_start_step (int): start step for nsys profiling + nsys_end_step (Optional[int]): end step for nsys profiling + nsys_ranks (List[int]): ranks for nsys profiling + random_mask_strategy (RandomMaskStrategy): random mask strategy + num_layers (int): number of layers + hidden_size (int): hidden size + num_attention_heads (int): number of attention heads + ffn_hidden_size (int): feed forward hidden size + overlap_grad_reduce (bool): overlap gradient reduction + overlap_param_gather (bool): overlap parameter gather + average_in_collective (bool): average in collective + grad_reduce_in_fp32 (bool): gradient reduction in fp32 + """ + # Create the result directory if it does not exist. + result_dir.mkdir(parents=True, exist_ok=True) + + # Setup the strategy and trainer + global_batch_size = infer_global_batch_size( + micro_batch_size=micro_batch_size, + num_nodes=num_nodes, + devices=devices, + accumulate_grad_batches=accumulate_grad_batches, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + ) + + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + pipeline_dtype=get_autocast_dtype(precision), + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + overlap_grad_reduce=overlap_grad_reduce, + overlap_param_gather=overlap_param_gather, + average_in_collective=average_in_collective, + grad_reduce_in_fp32=grad_reduce_in_fp32, + use_distributed_optimizer=False, + ), + find_unused_parameters=True, + gradient_as_bucket_view=True, + ckpt_include_optimizer=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ) + + # for wandb integration + # Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/pytorch_lightning.loggers.html" + wandb_config: Optional[WandbConfig] = ( + None + if wandb_project is None + else WandbConfig( + offline=wandb_offline, + project=wandb_project, + entity=wandb_entity, + tags=wandb_tags, + group=wandb_group, + id=wandb_id, + anonymous=wandb_anonymous, + log_model=wandb_log_model, + ) + ) + + callbacks = [ + PerplexityLoggingCallback(log_train=False, log_val=True), + RichModelSummary(max_depth=4), + LearningRateMonitor(), + nl_callbacks.PreemptionCallback(), + ] + if nsys_profiling: + if nsys_end_step is None: + nsys_end_step = num_steps + callbacks.append( + nl_callbacks.NsysCallback( + start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True + ) + ) + + trainer = nl.Trainer( + devices=devices, + max_steps=num_steps, + accelerator="gpu", + strategy=strategy, + limit_val_batches=limit_val_batches, # This controls upsampling and downsampling + val_check_interval=val_check_interval, + log_every_n_steps=log_every_n_steps, + num_nodes=num_nodes, + callbacks=callbacks, + plugins=nl.MegatronMixedPrecision( + precision=precision, + params_dtype=get_autocast_dtype(precision), + pipeline_dtype=get_autocast_dtype(precision), + grad_reduce_in_fp32=grad_reduce_in_fp32, + autocast_enabled=False, + ), + ) + + tokenizer = get_tokenizer() + + # Initialize the data module. + data = AMPLIFYDataModule( + hf_dataset_name=hf_dataset_name, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + min_seq_length=min_seq_length, + max_seq_length=max_seq_length, + num_workers=num_dataset_workers, + random_mask_strategy=random_mask_strategy, + tokenizer=tokenizer, + ) + # Configure the model + amplify_config = AMPLIFYConfig( + seq_length=max_seq_length, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=ffn_hidden_size, + params_dtype=get_autocast_dtype(precision), + pipeline_dtype=get_autocast_dtype(precision), + autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot + biobert_spec_option=biobert_spec_option, + nemo1_ckpt_path=str(nemo1_init_path) if nemo1_init_path is not None else None, + # handle checkpoint resumption here rather than auto-resume so this supports fine-tuning capabilities + initial_ckpt_path=str(restore_from_checkpoint_path) if restore_from_checkpoint_path is not None else None, + variable_seq_lengths=min_seq_length != max_seq_length, + ) + + model = biobert_lightning_module( + amplify_config, + tokenizer=tokenizer, + optimizer=MegatronOptimizerModule( + config=OptimizerConfig( + lr=lr, + optimizer="adam", # fused_adam not supported + use_distributed_optimizer=False, + weight_decay=0.01, + adam_beta1=0.9, + adam_beta2=0.95, + clip_grad=1.0, + ), + lr_scheduler=nl.lr_scheduler.CosineAnnealingScheduler( + min_lr=0.1*lr, + max_steps=decay_steps, + warmup_steps=warmup_steps, + constant_steps=0, + ), + ), + ) + + # Configure our custom Checkpointer + checkpoint_callback = nl_callbacks.ModelCheckpoint( + save_last=save_last_checkpoint, + monitor=metric_to_monitor_for_checkpoints, # "val_loss", + save_top_k=save_top_k, + every_n_train_steps=val_check_interval, + always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe + filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this. + ) + + # Setup the logger and train the model + nemo_logger = setup_nemo_lightning_logger( + root_dir=result_dir, + name=experiment_name, + initialize_tensorboard_logger=create_tensorboard_logger, + wandb_config=wandb_config, + ckpt_callback=checkpoint_callback, + ) + + llm.train( + model=model, + data=data, + trainer=trainer, + log=nemo_logger, + resume=resume.AutoResume( + resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training. + resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. + ), + ) + + +def train_amplify_entrypoint(): + """Entrypoint for running pretraining for amplify.""" + # 1. get arguments + parser = get_parser() + args = parser.parse_args() + # 2. Call pretrain with args + main( + hf_dataset_name=args.hf_dataset_name, + num_nodes=args.num_nodes, + devices=args.num_gpus, + min_seq_length=args.min_seq_length, + max_seq_length=args.max_seq_length, + result_dir=args.result_dir, + wandb_entity=args.wandb_entity, + wandb_project=args.wandb_project, + wandb_tags=args.wandb_tags, + wandb_group=args.wandb_group, + wandb_id=args.wandb_id, + wandb_anonymous=args.wandb_anonymous, + wandb_log_model=args.wandb_log_model, + wandb_offline=args.wandb_offline, + num_steps=args.num_steps, + warmup_steps=args.warmup_steps, + decay_steps=args.decay_steps, + limit_val_batches=args.limit_val_batches, + val_check_interval=args.val_check_interval, + log_every_n_steps=args.log_every_n_steps, + num_dataset_workers=args.num_dataset_workers, + biobert_spec_option=args.biobert_spec_option, + lr=args.lr, + micro_batch_size=args.micro_batch_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + tensor_model_parallel_size=args.tensor_model_parallel_size, + accumulate_grad_batches=args.accumulate_grad_batches, + precision=args.precision, + experiment_name=args.experiment_name, + resume_if_exists=args.resume_if_exists, + nemo1_init_path=args.nemo1_init_path, + restore_from_checkpoint_path=args.restore_from_checkpoint_path, + save_best_checkpoint=args.save_best_checkpoint, + save_last_checkpoint=args.save_last_checkpoint, + metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints, + save_top_k=args.save_top_k, + nsys_profiling=args.nsys_profiling, + nsys_start_step=args.nsys_start_step, + nsys_end_step=args.nsys_end_step, + nsys_ranks=args.nsys_ranks, + random_mask_strategy=args.random_mask_strategy, + num_layers=args.num_layers, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + ffn_hidden_size=args.ffn_hidden_size, + overlap_grad_reduce=not args.no_overlap_grad_reduce, + overlap_param_gather=args.overlap_param_gather, + average_in_collective=not args.no_average_in_collective, + grad_reduce_in_fp32=args.grad_reduce_in_fp32, + ) + +def get_parser(): + """Return the cli parser for this tool.""" + # TODO migrate to hydra config + # Parse the arguments and pull them out into local variables for ease of future refactor to a + # config management system. + parser = argparse.ArgumentParser(description="Pretrain AMPLIFY with UR100P data.") + parser.add_argument( + "--hf-dataset-name", + type=str, + required=True, + help="Name of the HuggingFace dataset containing UR100P protein sequences", + ) + parser.add_argument( + "--precision", + type=str, + choices=get_args(PrecisionTypes), + required=False, + default="bf16-mixed", + help="Precision type to use for training.", + ) + parser.add_argument( + "--lr", + type=float, + required=False, + default=1e-3, + help="Learning rate for training. Default is 1e-3", + ) + parser.add_argument( + "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger." + ) + # FIXME (@skothenhill) figure out how checkpointing and resumption should work with the new nemo trainer + parser.add_argument( + "--resume-if-exists", action="store_true", default=False, help="Resume training if a checkpoint exists." + ) + parser.add_argument( + "--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory." + ) + parser.add_argument("--experiment-name", type=str, required=False, default="amplify", help="Name of the experiment.") + + parser.add_argument("--wandb-entity", type=str, default=None, help="The team posting this run") + parser.add_argument("--wandb-project", type=str, default=None, help="Wandb project name ") + parser.add_argument("--wandb-tags", nargs="+", type=str, default=None, help="Tags associated with this run") + parser.add_argument( + "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group" + ) + parser.add_argument( + "--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run" + ) + parser.add_argument("--wandb-anonymous", action="store_true", help="Enable or explicitly disable anonymous logging") + parser.add_argument( + "--wandb-log-model", action="store_true", help="Save checkpoints in wandb dir to upload on W&B servers" + ) + parser.add_argument("--wandb-offline", action="store_true", help="Use wandb in offline mode") + parser.add_argument( + "--num-gpus", + type=int, + required=False, + default=1, + help="Number of GPUs to use for training. Default is 1.", + ) + parser.add_argument( + "--num-nodes", + type=int, + required=False, + default=1, + help="Number of nodes to use for training. Default is 1.", + ) + parser.add_argument( + "--num-steps", + type=int, + required=False, + default=1_000_000, + help="Number of steps to use for training. Default is 1,000,000.", + ) + parser.add_argument( + "--warmup-steps", + type=int, + required=False, + default=1000, + help="Number of warmup steps for WarmupAnnealDecayHold Scheduler. Default is 1000.", + ) + parser.add_argument( + "--decay-steps", + type=int, + required=False, + default=900_000, + help="Number of decay steps for WarmupAnnealDecayHold Scheduler. Default is 900,000.", + ) + parser.add_argument( + "--num-dataset-workers", + type=int, + required=False, + default=1, + help="Number of workers to use for training. Default is 1.", + ) + parser.add_argument( + "--val-check-interval", + type=int, + required=False, + default=10000, + help="Number of steps between validation. Default is 10000.", + ) + parser.add_argument( + "--log-every-n-steps", + type=int, + required=False, + default=100, + help="Number of steps between logging. Default is 100.", + ) + parser.add_argument( + "--min-seq-length", + type=int, + required=False, + default=512, + help="Minimum sequence length. Sampled will be padded if less than this value.", + ) + parser.add_argument( + "--max-seq-length", + type=int, + required=False, + default=512, + help="Maximum sequence length. Samples will be truncated if exceeds this value.", + ) + parser.add_argument( + "--limit-val-batches", + type=float_or_int_or_none, + required=False, + default=1.0, + help="Number of global batches used for validation if int. Fraction of validation dataset if float. Default is 1.0.", + ) + parser.add_argument( + "--micro-batch-size", + type=int, + required=False, + default=64, + help="Micro-batch size. Global batch size is inferred from this.", + ) + parser.add_argument( + "--pipeline-model-parallel-size", + type=int, + required=False, + default=1, + help="Pipeline model parallel size. Default is 1.", + ) + parser.add_argument( + "--tensor-model-parallel-size", + type=int, + required=False, + default=1, + help="Tensor model parallel size. Default is 1.", + ) + parser.add_argument( + "--accumulate-grad-batches", + type=int, + required=False, + default=1, + help="Gradient accumulation steps. Global batch size is inferred from this.", + ) + parser.add_argument( + "--biobert-spec-option", + type=BiobertSpecOption, + choices=[e.value for e in BiobertSpecOption], + required=False, + default=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec.value, + help="Biobert spec option to use for the model. Default is 'esm2_bert_layer_with_transformer_engine_spec'.", + ) + parser.add_argument( + "--nemo1-init-path", + type=Path, + required=False, + help="Path to nemo1 file, if desired to load at init time.", + ) + parser.add_argument( + "--save-best-checkpoint", + action="store_true", + default=True, + help="Save the best checkpoint based on the metric to monitor.", + ) + parser.add_argument( + "--save-last-checkpoint", + action="store_true", + default=True, + help="Save the last checkpoint.", + ) + parser.add_argument( + "--metric-to-monitor-for-checkpoints", + type=str, + required=False, + default="val_loss", + help="The metric to monitor for checkpointing.", + ) + parser.add_argument( + "--save-top-k", + type=int, + required=False, + default=2, + help="Save the top k checkpoints.", + ) + parser.add_argument( + "--restore-from-checkpoint-path", + type=Path, + required=False, + default=None, + help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.", + ) + parser.add_argument( + "--nsys-profiling", + action="store_true", + default=False, + help="Enable targeted `nsys` profiling on the training loop for a defined step range. To actually get profiling output you must run the whole program with `nsys`. For example: " + " `nsys profile -s none -o output_report_name -t cuda,nvtx --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop [regular python command here]`", + ) + # start, end, rank + parser.add_argument( + "--nsys-start-step", + type=int, + required=False, + default=0, + help="Start nsys profiling after this step.", + ) + parser.add_argument( + "--nsys-end-step", + type=int, + required=False, + help="End nsys profiling after this step.", + ) + # rank as list of integers + parser.add_argument( + "--nsys-ranks", + type=int, + nargs="+", + required=False, + default=[0], + help="Enable nsys profiling for these ranks.", + ) + + # AMPLIFY specific configuration (default: 120M) + parser.add_argument( + "--random-mask-strategy", + type=RandomMaskStrategy, + choices=[e.value for e in RandomMaskStrategy], + default=RandomMaskStrategy.ALL_TOKENS.value, + help=f"""In pretraining, 15%% of all tokens are masked and among which 10%% are replaced with a random token. This class controls the set of random tokens to choose from. Options are: '{"', '".join([e.value for e in RandomMaskStrategy])}'. Note that 'all_token' will introduce non-canonical amino acid tokens as effective mask tokens, and the resultant loss will appear lower than that from 'amino_acids_only'. Note that 'all_token' is the method used in hugging face as well as portions of fairseq.""", + ) + parser.add_argument( + "--num-layers", + type=int, + required=False, + default=24, + help="Number of layers in the model. Default is 24.", + ) + parser.add_argument( + "--hidden-size", + type=int, + required=False, + default=640, + help="Hidden size of the model. Default is 640.", + ) + parser.add_argument( + "--num-attention-heads", + type=int, + required=False, + default=10, + help="Number of attention heads in the model. Default is 10.", + ) + parser.add_argument( + "--ffn-hidden-size", + type=int, + required=False, + default=4 * 640, + help="FFN hidden size of the model. Default is 4 * 640.", + ) + # DDP config + parser.add_argument( + "--no-overlap-grad-reduce", + action="store_true", + default=False, + ) + parser.add_argument( + "--overlap-param-gather", + action="store_true", + default=False, + ) # TODO waiting for a NeMo fix + parser.add_argument( + "--no-average-in-collective", + action="store_true", + default=False, + ) + parser.add_argument( + "--grad-reduce-in-fp32", + action="store_true", + default=False, + ) + return parser + + +if __name__ == "__main__": + train_amplify_entrypoint() diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/__init__.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/conftest.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/conftest.py new file mode 100644 index 0000000000..535127ba6e --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/conftest.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sqlite3 + +import pandas as pd +import pytest + +from bionemo.amplify.data.tokenizer import get_tokenizer + + +@pytest.fixture +def tokenizer(): + """Return the AMPLIFY tokenizer.""" + return get_tokenizer() + + +@pytest.fixture +def dummy_protein_dataset(tmp_path): + """Create a mock protein dataset.""" + db_file = tmp_path / "protein_dataset.db" + conn = sqlite3.connect(str(db_file)) + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE protein ( + id TEXT PRIMARY KEY, + sequence TEXT + ) + """ + ) + + proteins = [ + ("UniRef90_A", "ACDEFGHIKLMNPQRSTVWY"), + ("UniRef90_B", "DEFGHIKLMNPQRSTVWYAC"), + ("UniRef90_C", "MGHIKLMNPQRSTVWYACDE"), + ("UniRef50_A", "MKTVRQERLKSIVRI"), + ("UniRef50_B", "MRILERSKEPVSGAQLA"), + ] + cursor.executemany("INSERT INTO protein VALUES (?, ?)", proteins) + + conn.commit() + conn.close() + + return db_file + + +@pytest.fixture +def dummy_parquet_train_val_inputs(tmp_path): + """Create a mock protein train and val cluster parquet.""" + train_cluster_path = tmp_path / "train_clusters.parquet" + train_clusters = pd.DataFrame( + { + "ur90_id": [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]], + } + ) + train_clusters.to_parquet(train_cluster_path) + + valid_cluster_path = tmp_path / "valid_clusters.parquet" + valid_clusters = pd.DataFrame( + { + "ur50_id": ["UniRef50_A", "UniRef50_B"], + } + ) + valid_clusters.to_parquet(valid_cluster_path) + return train_cluster_path, valid_cluster_path diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/data/test_datamodule.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/data/test_datamodule.py new file mode 100644 index 0000000000..86c8054707 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/data/test_datamodule.py @@ -0,0 +1,377 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest +import torch.utils.data + +from bionemo.amplify.data.datamodule import AMPLIFYDataModule +from bionemo.llm.utils.datamodule_utils import tensor_dict_hash + + +def test_create_amplify_datamodule_raises_without_trainer(dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + ) + assert data_module is not None + + with pytest.raises(RuntimeError, match="Setup should be completed when trainer and config are attached."): + data_module.setup() + + +def test_create_amplify_datamodule_raises_without_trainer_max_steps(dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 0 + + with pytest.raises(RuntimeError, match="Please specify trainer.max_steps"): + data_module.setup() + + +def test_create_amplify_datamodule_creates_valid_dataloaders(dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=8, + micro_batch_size=4, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 1 + + data_module.setup() + + train_dataloader = data_module.train_dataloader() + assert isinstance(train_dataloader, torch.utils.data.DataLoader) + + val_dataloader = data_module.val_dataloader() + assert isinstance(val_dataloader, torch.utils.data.DataLoader) + + assert len(train_dataloader) == 10 * 8 # max steps * global batch size + assert len(val_dataloader) == 8 # global batch size; index reset every val epoch + + for batch in train_dataloader: + assert isinstance(batch, dict) + assert isinstance(batch["text"], torch.Tensor) + assert isinstance(batch["attention_mask"], torch.Tensor) + assert isinstance(batch["labels"], torch.Tensor) + assert isinstance(batch["loss_mask"], torch.Tensor) + assert isinstance(batch["is_random"], torch.Tensor) + + for batch in val_dataloader: + assert isinstance(batch, dict) + assert isinstance(batch["text"], torch.Tensor) + assert isinstance(batch["attention_mask"], torch.Tensor) + assert isinstance(batch["labels"], torch.Tensor) + assert isinstance(batch["loss_mask"], torch.Tensor) + assert isinstance(batch["is_random"], torch.Tensor) + + +def test_create_amplify_datamodule_creates_valid_dataloaders_with_fractional_limit_val_batches( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 0.5 # fractional value + + data_module.setup() + + train_dataloader = data_module.train_dataloader() + assert isinstance(train_dataloader, torch.utils.data.DataLoader) + + val_dataloader = data_module.val_dataloader() + assert isinstance(val_dataloader, torch.utils.data.DataLoader) + + assert len(train_dataloader) == 10 * 1 # max steps * global batch size + assert len(val_dataloader) == int(2 * 0.5) // 1 # number of validation clusters // global batch size + + +def test_create_amplify_datamodule_creates_valid_dataloaders_fractional_limit_val_batches_smaller_than_global_batch_size( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=8, + micro_batch_size=4, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 0.5 # fractional value + + # num_val_cluster * limit_val_batches = 2 * 0.5 = 1 < global_batch_size + with pytest.raises(ValueError, match="The limited number of val samples 1 is less than the global batch size 8"): + data_module.setup() + + +@pytest.mark.parametrize("limit_val_batches", [0, 0.0]) +def test_create_amplify_datamodule_creates_valid_dataloaders_fractional_limit_val_batches_0( + dummy_protein_dataset, dummy_parquet_train_val_inputs, limit_val_batches +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=8, + micro_batch_size=4, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = limit_val_batches + + with pytest.raises(ValueError, match="Invalid choice of limit_val_batches size: %s" % limit_val_batches): + data_module.setup() + + +def test_create_amplify_datamodule_creates_valid_dataloaders_fractional_limit_val_batches_not_multiple_of_global_batch_size( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 0.7 # fractional value + + data_module.setup() + + train_dataloader = data_module.train_dataloader() + assert isinstance(train_dataloader, torch.utils.data.DataLoader) + + val_dataloader = data_module.val_dataloader() + assert isinstance(val_dataloader, torch.utils.data.DataLoader) + + assert len(train_dataloader) == 10 * 1 # max steps * global batch size + assert len(val_dataloader) == int(2 * 0.7) // 1 # number of validation clusters // global batch size + + +def test_create_amplify_datamodule_creates_valid_dataloaders_fractional_limit_val_batches_1p0( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 1.0 # fractional value to use the whole dataset + + data_module.setup() + + train_dataloader = data_module.train_dataloader() + assert isinstance(train_dataloader, torch.utils.data.DataLoader) + + val_dataloader = data_module.val_dataloader() + assert isinstance(val_dataloader, torch.utils.data.DataLoader) + + assert len(train_dataloader) == 10 * 1 # max steps * global batch size + assert len(val_dataloader) == 2 // 1 # number of validation clusters // global batch size + + +def test_create_amplify_datamodule_limit_val_batches_none_equals_limit_val_batches_1p0( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module with limit_val_batches = 1.0 + data_module_one = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module_one is not None + + data_module_one.trainer = mock.Mock() + data_module_one.trainer.max_epochs = 1 + data_module_one.trainer.max_steps = 10 + data_module_one.trainer.val_check_interval = 2 + data_module_one.trainer.limit_val_batches = 1.0 # fractional value to use the whole dataset + + data_module_one.setup() + + # Initialize the data module with limit_val_batches = None + data_module_none = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module_none is not None + + data_module_none.trainer = mock.Mock() + data_module_none.trainer.max_epochs = 1 + data_module_none.trainer.max_steps = 10 + data_module_none.trainer.val_check_interval = 2 + data_module_none.trainer.limit_val_batches = None # None to use the whole dataset + + data_module_none.setup() + + # Check that the two dataloaders have the same number of samples. + assert len(data_module_one.val_dataloader()) == len(data_module_none.val_dataloader()) + + +def test_create_amplify_datamodule_valid_dataloaders_has_consistent_samples_per_epoch( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + """ + Test that the AMPLIFYDataModule dataloaders produce consistent samples per epoch. + + This test ensures that the AMPLIFYDataModule creates dataloaders that produce consistent + samples across epochs, even if the data is reshuffled (controlled by `is_ordered`). + + Parameters: + - dummy_protein_dataset: A dummy protein dataset used for testing. + - dummy_parquet_train_val_inputs: A tuple containing paths to dummy parquet files + for training and validation clusters. + """ + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + micro_batch_size = 2 + is_ordered = False # allow random sampling to be independent between epoches + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=micro_batch_size, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 1 + data_module.trainer.val_check_interval = 1 + data_module.trainer.limit_val_batches = 1.0 # use the whole validation dataset + + data_module.setup() + + # hash values from batches of the first epoch + batch_hashes1 = [tensor_dict_hash(batch) for batch in data_module.val_dataloader()] + + if is_ordered: # second epoch should have exactly the same output including order + for batch in data_module.val_dataloader(): + batch_hash = tensor_dict_hash(batch) + assert batch_hash == batch_hashes1.pop() + else: # second epoch should have the same output but can be reshuffled + batch_hashes1 = set(batch_hashes1) + batch_hashes2 = {tensor_dict_hash(batch) for batch in data_module.val_dataloader()} + assert batch_hashes1 == batch_hashes2 diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/data/test_dataset.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/data/test_dataset.py new file mode 100644 index 0000000000..eb2c665549 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/data/test_dataset.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pandas as pd +import pytest +import torch + +from bionemo.amplify.data.dataset import ( + AMPLIFYMaskedResidueDataset, + create_train_dataset, + create_valid_dataset, +) + + +def test_protein_sqlite_dataset(dummy_protein_dataset): + """Test the ProteinSQLiteDataset class.""" + + dataset = ProteinSQLiteDataset(dummy_protein_dataset) + + assert len(dataset) == 5 + + assert dataset["UniRef90_A"] == "ACDEFGHIKLMNPQRSTVWY" + assert dataset["UniRef90_B"] == "DEFGHIKLMNPQRSTVWYAC" + assert dataset["UniRef90_C"] == "MGHIKLMNPQRSTVWYACDE" + assert dataset["UniRef50_A"] == "MKTVRQERLKSIVRI" + assert dataset["UniRef50_B"] == "MRILERSKEPVSGAQLA" + + +def test_AMPLIFYPreTrainingDataset_getitem_has_expected_structure(dummy_protein_dataset, tokenizer): + """Test that the AMPLIFYPreTrainingDataset's __getitem__ method is deterministic.""" + + protein_dataset = ProteinSQLiteDataset(dummy_protein_dataset) + clusters = [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]] + amplify_dataset = AMPLIFYMaskedResidueDataset( + protein_dataset=protein_dataset, clusters=clusters, total_samples=10, seed=123 + ) + + sample = amplify_dataset[0] + assert len(sample["text"]) == len(protein_dataset["UniRef90_A"]) + 2 + + # Make sure all masked tokens are standard amino acids. + for token in sample["labels"][sample["loss_mask"]].tolist(): + assert token in range(4, 24) + + # Make sure non-masked tokens are -1. + assert torch.all(sample["labels"][~sample["loss_mask"]] == -1) + + assert sample["text"][0] == tokenizer.cls_token_id + assert sample["text"][-1] == tokenizer.eos_token_id + + +def test_AMPLIFYPreTrainingDataset_getitem_match_for_identical_seeds(dummy_protein_dataset): + """Test that the AMPLIFYPreTrainingDataset's __getitem__ method is deterministic.""" + + dataset = ProteinSQLiteDataset(dummy_protein_dataset) + clusters = [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]] + + dataset1 = AMPLIFYMaskedResidueDataset(protein_dataset=dataset, clusters=clusters, total_samples=10, seed=123) + dataset2 = AMPLIFYMaskedResidueDataset(protein_dataset=dataset, clusters=clusters, total_samples=10, seed=123) + + # Check that the datasets are equal. + for i in range(len(dataset1)): + sample1 = dataset1[i] + sample2 = dataset2[i] + + for key in sample1: + torch.testing.assert_close(sample1[key], sample2[key]) + + +def test_AMPLIFYPreTrainingDataset_getitem_is_deterministic(dummy_protein_dataset): + """Test that the AMPLIFYPreTrainingDataset's __getitem__ method is deterministic.""" + + dataset = ProteinSQLiteDataset(dummy_protein_dataset) + clusters = [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]] + + dataset = AMPLIFYMaskedResidueDataset(protein_dataset=dataset, clusters=clusters, total_samples=10, seed=123) + + sample1 = dataset[8] + + for _ in range(10): + sample2 = dataset[8] + for key in sample1: + torch.testing.assert_close(sample1[key], sample2[key]) + + +def test_AMPLIFYPreTrainingDataset_getitem_differs_with_different_seeds(dummy_protein_dataset): + """Test that the AMPLIFYPreTrainingDataset's __getitem__ method is deterministic.""" + + dataset = ProteinSQLiteDataset(dummy_protein_dataset) + clusters = [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]] + + dataset1 = AMPLIFYMaskedResidueDataset(protein_dataset=dataset, clusters=clusters, total_samples=10, seed=123) + dataset2 = AMPLIFYMaskedResidueDataset(protein_dataset=dataset, clusters=clusters, total_samples=10, seed=321) + + for i in range(len(dataset)): + sample1 = dataset1[i] + sample2 = dataset2[i] + assert not torch.equal(sample1["text"], sample2["text"]) + + +def test_AMPLIFYPreTrainingDataset_getitem_changes_each_epoch(dummy_protein_dataset): + """Test that the AMPLIFYPreTrainingDataset's __getitem__ method is deterministic.""" + + dataset = ProteinSQLiteDataset(dummy_protein_dataset) + clusters = [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]] + + dataset = AMPLIFYMaskedResidueDataset(protein_dataset=dataset, clusters=clusters, total_samples=10, seed=123) + + sample1 = dataset[0] + sample2 = dataset[2] + assert len(sample1["text"]) == len(sample2["text"]) # These should both be UniRef90_A + assert not torch.equal(sample1["text"], sample2["text"]) + + sample1 = dataset[0] + sample2 = dataset[4] + assert len(sample1["text"]) == len(sample2["text"]) + assert not torch.equal(sample1["text"], sample2["text"]) + + +def test_AMPLIFYPreTrainingDataset_fails_with_empty_cluster(dummy_protein_dataset): + """Test that the AMPLIFYPreTrainingDataset's __getitem__ method is deterministic.""" + + dataset = ProteinSQLiteDataset(dummy_protein_dataset) + clusters = [["UniRef90_A"], [], ["UniRef90_B", "UniRef90_C"]] + + dataset = AMPLIFYMaskedResidueDataset(protein_dataset=dataset, clusters=clusters, total_samples=10, seed=123) + + with pytest.raises(ValueError, match="Cluster 1 is empty."): + dataset[1] + + +def test_AMPLIFYPreTrainingDataset_crops_out_start_and_end(dummy_protein_dataset, tokenizer): + prot_dataset = ProteinSQLiteDataset(dummy_protein_dataset) + clusters = [["UniRef90_A"]] + + dataset = AMPLIFYMaskedResidueDataset( + protein_dataset=prot_dataset, clusters=clusters, seed=123, total_samples=10, max_seq_length=1024 + ) + + assert len(dataset[0]["text"]) == len(prot_dataset["UniRef90_A"]) + 2 + assert dataset[0]["text"][0] == tokenizer.cls_token_id + assert dataset[0]["text"][-1] == tokenizer.eos_token_id + + dataset = AMPLIFYMaskedResidueDataset( + protein_dataset=prot_dataset, clusters=clusters, seed=123, total_samples=10, max_seq_length=3 + ) + + assert len(dataset[0]["text"]) == 3 + + # With a max length of 3, both the start and end tokens cant be present. + assert not ((dataset[0]["text"][0] == tokenizer.cls_token_id) & (dataset[0]["text"][-1] == tokenizer.eos_token_id)) + + +def test_AMPLIFYPreTrainingDataset_raises_index_error_outside_bounds(dummy_protein_dataset): + """Test that the AMPLIFYPreTrainingDataset's __getitem__ method is deterministic.""" + + dataset = ProteinSQLiteDataset(dummy_protein_dataset) + clusters = [["UniRef90_A"], [], ["UniRef90_B", "UniRef90_C"]] + + dataset = AMPLIFYMaskedResidueDataset(protein_dataset=dataset, clusters=clusters, total_samples=10, seed=123) + + with pytest.raises(IndexError, match="Index 10 out of range \\[0, 10\\)."): + dataset[10] + + with pytest.raises(IndexError, match="Index -1 out of range \\[0, 10\\)."): + dataset[-1] + + +def test_create_train_dataset(dummy_protein_dataset, tmp_path): + cluster_file = pd.DataFrame( + { + "ur90_id": [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]], + } + ) + + cluster_file.to_parquet(tmp_path / "train_clusters.parquet") + + dataset = create_train_dataset( + cluster_file=tmp_path / "train_clusters.parquet", db_path=dummy_protein_dataset, total_samples=10, seed=123 + ) + assert len(dataset) == 10 + dataset[6] # Make sure it doesn't crash. + + +def test_create_valid_dataset(dummy_protein_dataset, tmp_path): + cluster_file = pd.DataFrame( + { + "ur50_id": ["UniRef90_A", "UniRef90_B", "UniRef90_C"], + } + ) + + cluster_file.to_parquet(tmp_path / "valid_clusters.parquet") + + dataset = create_valid_dataset( + clusters=tmp_path / "valid_clusters.parquet", db_path=dummy_protein_dataset, total_samples=10, seed=123 + ) + assert len(dataset) == 10 + dataset[6] # Make sure it doesn't crash. diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/data/test_tokenizer.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/data/test_tokenizer.py new file mode 100644 index 0000000000..163bca5a2d --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/data/test_tokenizer.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch + +from bionemo.amplify.data.tokenizer import get_tokenizer + + +@pytest.fixture +def tokenizer(): + return get_tokenizer() + + +def test_tokenize_protein1(tokenizer): + our_tokens = tokenizer.encode( + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", add_special_tokens=False + ) + + # fmt: off + amplify_tokens = torch.tensor( + [22, 17, 13, 9, 12, 18, 11, 12, 6, 17, 10, 14, 9, 12, 14, 6, 11, 12, + 10, 17, 11, 16, 9, 10, 8, 7, 18, 6, 7, 11, 11, 6, 10, 9, 10, 12, 18, + 9, 14, 9, 18, 15, 14, 7, 21, 6, 12, 10, 6, 8, 21, 19, 14, 9, 7, 13, + 16, 12, 8, 21, 9, 6, 7, 8, 8]) + # fmt: on + + torch.testing.assert_close(torch.tensor(our_tokens), amplify_tokens) + + +def test_tokenize_protein2(tokenizer): + our_tokens = tokenizer.encode( + "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE", add_special_tokens=False + ) + + # fmt: off + amplify_tokens = torch.tensor( + [17, 7, 6, 13, 7, 12, 18, 18, 11, 9, 20, 15, 6, 14, 12, 15, 23, 14, 10, + 18, 13, 8, 22, 16, 16, 13, 12, 7, 11, 14, 7, 18, 12, 6, 8, 20, 12, 10, + 16, 19, 7, 7, 11, 11, 23, 6, 17, 7, 6, 7, 12, 17, 8, 9, 14, 11, 14, 9, + 10, 8, 7, 10, 12, 8, 14, 12, 6, 6, 18, 11, 11]) + # fmt: on + + torch.testing.assert_close(torch.tensor(our_tokens), amplify_tokens) + + +def test_tokenize_protein2_with_mask(tokenizer): + our_tokens = tokenizer.encode( + "KALTARQQEVFDLIRDISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE", add_special_tokens=False + ) + + # fmt: off + amplify_tokens = torch.tensor( + [17, 7, 6, 13, 7, 12, 18, 18, 11, 9, 20, 15, 6, 14, 12, 15, 2, 14, 10, 18, + 13, 8, 22, 16, 16, 13, 12, 7, 11, 14, 7, 18, 12, 6, 8, 20, 12, 10, 16, 19, + 7, 7, 11, 11, 23, 6, 17, 7, 6, 7, 12, 17, 8, 9, 14, 11, 14, 9, 10, 8, 7, + 10, 12, 8, 14, 12, 6, 6, 18, 11, 11]) + # fmt: on + + torch.testing.assert_close(torch.tensor(our_tokens), amplify_tokens) + + +def test_tokenize_protein3(tokenizer): + our_tokens = tokenizer.encode("KAI SQ", add_special_tokens=False) + amplify_tokens = torch.tensor([17, 7, 2, 14, 1, 10, 18]) + torch.testing.assert_close(torch.tensor(our_tokens), amplify_tokens) + + +def test_tokenize_non_standard_tokens(tokenizer): + our_tokens = tokenizer.encode("".join(["", "", "", ""]), add_special_tokens=False) + amplify_tokens = torch.tensor([0, 4, 1, 2]) + torch.testing.assert_close(torch.tensor(our_tokens), amplify_tokens) + + +def test_tokenize_with_invalid_token(tokenizer): + assert tokenizer.encode("", add_special_tokens=False) == [1, 1, 1] + + +def test_tokenize_with_empty_string(tokenizer): + assert tokenizer.encode("", add_special_tokens=True) == [3, 4] diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/__init__.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_attention.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_attention.py new file mode 100644 index 0000000000..e084e6b466 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_attention.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math + +import pytest +import torch +from megatron.core.transformer.enums import AttnMaskType + +from bionemo.amplify.api import AMPLIFYConfig +from bionemo.amplify.model.attention import AMPLIFYDotProductAttention +from bionemo.testing import megatron_parallel_state_utils + + +@pytest.fixture(scope="module") +def config() -> AMPLIFYConfig: + with megatron_parallel_state_utils.distributed_model_parallel_state(): + yield AMPLIFYConfig( + seq_length=20, hidden_size=4, num_attention_heads=4, attention_dropout=0.1, use_amplify_attention=True + ) + + +@pytest.fixture +def attention_layer(config): + return AMPLIFYDotProductAttention( + config=config, + layer_number=0, + attn_mask_type=AttnMaskType.padding, + attention_type="normal", + ).eval() + + +def test_init(attention_layer, config): + assert attention_layer.config.use_amplify_attention + assert attention_layer.config == config + + +def test_forward(attention_layer, config): + batch_size = 2 + sequence_length = config.seq_length + hidden_size = config.hidden_size + device = torch.device("cuda") + + query = torch.randn(sequence_length, batch_size, 1, hidden_size, device=device) + key = torch.randn(sequence_length, batch_size, 1, hidden_size, device=device) + value = torch.randn(sequence_length, batch_size, 1, hidden_size, device=device) + random_ints = torch.randint(0, 2, (batch_size, 1, sequence_length, sequence_length), device=device) + attention_mask = ((random_ints + torch.transpose(random_ints, dim0=2, dim1=3)) / 2).to( + dtype=torch.bool + ) # symmetric mask tensor + + output = attention_layer(query, key, value, attention_mask) + assert output.shape == (sequence_length, batch_size, hidden_size) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.half]) +def test_attention_with_mask(attention_layer, dtype): + sequence_length_val = 3 + sequence_length_query = 1 + batch_size = 2 + emb_dim = 4 + device = torch.device("cuda") + + # query and key such that the dot prod is an all-ones tensor + query = torch.ones(batch_size, sequence_length_query, 1, emb_dim, device=device, dtype=dtype) / math.sqrt(emb_dim) + key = torch.ones(batch_size, sequence_length_val, 1, emb_dim, device=device, dtype=dtype) / math.sqrt(emb_dim) + + query = query.transpose(0, 1) + key = key.transpose(0, 1) + + attention_mask = torch.zeros(batch_size, 1, 1, sequence_length_val, device=device, dtype=dtype) + attention_mask[0, :, :, 2:] = 1 # average first two tensors in val + attention_mask[1, :, :, 1:] = 1 # select first item from val + + values = torch.stack([torch.arange(sequence_length_val)] * batch_size).to(device=device, dtype=dtype) + 1.0 + values = torch.stack([values] * emb_dim, dim=2).unsqueeze(2).transpose(0, 1) + + assert values.shape == (sequence_length_val, batch_size, 1, emb_dim) + + # softmax will make the the avg first 2 tensors in vals (ones + twos)/2 and second row is just ones + output = attention_layer(query, key, values, attention_mask) + expected_output = torch.tensor( + [[[1.5000, 1.5000, 1.5000, 1.5000], [1.0000, 1.0000, 1.0000, 1.0000]]], device=device, dtype=dtype + ) + assert torch.equal(output, expected_output) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.half]) +def test_amplify_scale_mask_softmax(attention_layer, config, dtype): + batch_size = 2 + sequence_length = config.seq_length + num_attention_heads = config.num_attention_heads + + input_tensor = torch.randn(batch_size, num_attention_heads, sequence_length, sequence_length, dtype=dtype) + + output = attention_layer.amplify_scale_mask_softmax(input_tensor) + assert output.shape == input_tensor.shape + assert output.dtype == dtype diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_embedding.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_embedding.py new file mode 100644 index 0000000000..856ad61991 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_embedding.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch + +from bionemo.amplify.api import AMPLIFYConfig +from bionemo.amplify.data.tokenizer import BioNeMoAutoTokenizer, get_tokenizer +from bionemo.amplify.model.embedding import AMPLIFY_MASK_RATIO_TRAIN, AMPLIFYEmbedding +from bionemo.llm.lightning import get_dtype_device +from bionemo.testing import megatron_parallel_state_utils + + +@pytest.fixture(scope="module") +def tokenizer() -> BioNeMoAutoTokenizer: + yield get_tokenizer() + + +@pytest.fixture(scope="module") +def embedding(tokenizer) -> AMPLIFYEmbedding: + with megatron_parallel_state_utils.distributed_model_parallel_state(): + config = AMPLIFYConfig(seq_length=20, hidden_size=128) + model = config.configure_model(tokenizer) + yield model.embedding + + +def test_init(embedding, tokenizer): + assert isinstance(embedding, AMPLIFYEmbedding) + assert embedding.token_dropout is True + assert embedding.use_attention_mask is True + assert embedding.mask_token_id == tokenizer.mask_token_id + + +def test_forward(embedding): + _, device = get_dtype_device(embedding) + vocab_size = embedding.vocab_size + max_sequence_length = embedding.max_sequence_length + + input_ids = torch.randint(0, vocab_size, (2, 10), device=device) # [b, s] + position_ids = torch.randint(0, max_sequence_length, (2, 10), device=device) # [b, s] + attention_mask = torch.randint(0, 2, (2, 10), device=device, dtype=torch.bool) # [b, s, s] + output = embedding(input_ids, position_ids, attention_mask=attention_mask) + assert output.shape == (10, 2, 128) # [s, b, h] + + +def test_apply_amplify_customization(embedding): + # Create mock input tensors + batch_size = 2 + sequence_length = 5 + hidden_size = embedding.config.hidden_size + mask_token_id = embedding.mask_token_id + + input_ids = torch.tensor([[1, 2, 3, mask_token_id, 5], [6, 7, 8, 9, mask_token_id]]) # (b, s) + attention_mask = torch.tensor([[1, 0, 1, 1, 1], [1, 1, 1, 0, 1]], dtype=torch.bool) # (b, s, s) + word_embeddings = torch.randn(batch_size, sequence_length, hidden_size) # (b, s, h) + + # Call the _apply_amplify_customization function + output_embeddings, embeddings_mask = embedding._apply_amplify_customization( + word_embeddings, input_ids, attention_mask + ) + + # Check the output shapes + assert output_embeddings.shape == (batch_size, sequence_length, hidden_size) + assert embeddings_mask.shape == (batch_size, sequence_length) + + # Check the token dropout and attention masking logic + assert torch.allclose(output_embeddings[0, 3, :], torch.zeros_like(output_embeddings[0, 3, :])) + assert torch.allclose(output_embeddings[1, 4, :], torch.zeros_like(output_embeddings[1, 4, :])) + assert torch.allclose(embeddings_mask[0, 1], torch.zeros_like(embeddings_mask[0, 1])) + assert torch.allclose(embeddings_mask[1, 3], torch.zeros_like(embeddings_mask[1, 3])) + + # Check the mask ratio calculation + mask_ratio_observed = (input_ids == mask_token_id).sum(-1).float() / attention_mask.sum(-1) + assert torch.allclose(mask_ratio_observed, torch.tensor([0.25, 0.25])) + + # Check the word embeddings scaling + scale_factor = (1 - AMPLIFY_MASK_RATIO_TRAIN) / (1 - mask_ratio_observed)[:, None, None] + word_embeddings = word_embeddings.masked_fill((input_ids == mask_token_id).unsqueeze(-1), 0.0) + assert torch.allclose(output_embeddings, word_embeddings * scale_factor * embeddings_mask.unsqueeze(-1)) + + # Check the attention masking + assert torch.equal(embeddings_mask, attention_mask) diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_lr_scheduler.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_lr_scheduler.py new file mode 100644 index 0000000000..d7c4ad5070 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_lr_scheduler.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from bionemo.amplify.model.lr_scheduler import WarmupAnnealDecayHold, WarmupAnnealDecayHoldScheduler + + +def test_warmup_anneal_decay_hold_scheduler_exists(): + scheduler = WarmupAnnealDecayHoldScheduler(warmup_steps=2000, min_lr=4e-5, max_steps=500000, max_lr=4e-4) + assert scheduler is not None + assert scheduler.max_steps == 500000 + assert scheduler.warmup_steps == 2000 + assert scheduler.max_lr == 4e-4 + assert scheduler.min_lr == 4e-5 + + +def test_warmup_anneal_decay_hold_works(): + optim = torch.optim.Adam(torch.nn.Linear(10, 1).parameters(), lr=4e-4, weight_decay=0.01, betas=[0.9, 0.98]) + max_lr = 0.1 + min_lr = 0.01 + anneal_percentage = 0.50 + constant_value = anneal_percentage * max_lr + scheduler = WarmupAnnealDecayHold( + optimizer=optim, + warmup_steps=20, + min_lr=min_lr, + max_steps=100, + max_lr=max_lr, + anneal_percentage=anneal_percentage, + ) + + assert scheduler.get_lr()[0] == min_lr + # Check initial LR + for _ in range(20): + scheduler.step() + # Check warmup phase + assert scheduler.get_lr()[0] == max_lr + + # Check decay is lower than max + for _ in range(20): + scheduler.step() + + decay_lr = scheduler.get_lr()[0] + # Check decay is lower than last decay + assert decay_lr < max_lr + + # Keep decay stepping + for _ in range(20): + scheduler.step() + + decay_low = scheduler.get_lr()[0] + assert decay_low < decay_lr + assert decay_low == constant_value + + for _ in range(30): + scheduler.step() + + assert scheduler.get_lr()[0] == constant_value + + # Check hold phase. Run it much longer and confirm + for _ in range(300): + scheduler.step() + + assert scheduler.get_lr()[0] == constant_value diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_model.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_model.py new file mode 100644 index 0000000000..409bcdee4d --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/model/test_model.py @@ -0,0 +1,281 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tarfile +from copy import deepcopy +from pathlib import Path +from typing import List, Tuple +from unittest import mock + +import pytest +import torch +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from torch import Tensor +from transformers import AMPLIFYForMaskedLM + +from bionemo import amplify +from bionemo.core.utils.dtypes import get_autocast_dtype +from bionemo.core.utils.random_utils import random_numpy_context +from bionemo.amplify.api import AMPLIFYConfig, AMPLIFYModel +from bionemo.amplify.data.datamodule import AMPLIFYDataModule +from bionemo.amplify.data.tokenizer import get_tokenizer +from bionemo.amplify.model.embedding import AMPLIFYEmbedding +from bionemo.llm.model.biobert.model import MegatronBioBertModel +from bionemo.llm.utils.weight_utils import nemo1_to_nemo2_biobert_key_mapping +from bionemo.testing import megatron_parallel_state_utils +from bionemo.testing.data.load import load + + +bionemo2_root: Path = ( + # amplify module's path is the most dependable --> don't expect this to change! + Path(amplify.__file__) + # This gets us from 'sub-packages/bionemo-amplify/src/bionemo/amplify/__init__.py' to 'sub-packages/bionemo-amplify' + .parent.parent.parent.parent + # From here, we want to get to the root of the repository: _before_ sub-packages/ + .parent.parent +).absolute() +assert bionemo2_root != Path("/") +nemo1_checkpoint_path: Path = load("amplify/nv_650m:1.0") + + +def reduce_hiddens(hiddens: Tensor, attention_mask: Tensor) -> Tensor: + """reduce last layer's hidden values to embeddings + + Args: + hiddens: [b, s, h] tensor of hidden values + attention_mask: [b, s] attention mask tensor + + Returns: + reduced embedding tensor [b, h] + """ + masks = torch.sum(attention_mask, dim=1) + embeddings = torch.zeros( + size=(hiddens.shape[0], hiddens.shape[2]), + dtype=torch.float32, + device=torch.cuda.current_device(), + ) + for i, (hidden, mask) in enumerate(zip(hiddens, masks)): + embeddings[i, :] = torch.mean(hidden[1 : mask - 1], dim=0) + return embeddings + + +@pytest.fixture(scope="module") +def amplify_config() -> AMPLIFYConfig: + with megatron_parallel_state_utils.distributed_model_parallel_state(): + yield AMPLIFYConfig() + + +@pytest.fixture(scope="module") +def amplify_650M_config_w_ckpt() -> AMPLIFYConfig: + with megatron_parallel_state_utils.distributed_model_parallel_state(): + yield AMPLIFYConfig(nemo1_ckpt_path=nemo1_checkpoint_path) + + +@pytest.fixture(scope="module") +def amplify_model(amplify_config) -> AMPLIFYModel: + with megatron_parallel_state_utils.distributed_model_parallel_state(): + tokenizer = get_tokenizer() + model = amplify_config.configure_model(tokenizer) + yield model + + +@pytest.fixture(scope="module") +def sample_data() -> List[Tuple[str, str]]: + """Generates sample protein sequences for sanity checks, including mask tokens.""" + max_length = 1022 # The maximum length of the protein sequences to be considered. + sample_data = [ + ( + "protein1", + "MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA", + ), + ("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"), + ( + "protein3", + "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", + ), + ( + "protein4", + "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLA", + ), + ] + # add another sample protein that uses the maximum length to test this edge case + sample_data.append(("protein5", (sample_data[0][1] * 3)[:max_length])) + yield sample_data + + +def _compute_loss(model, dataloader, vocab_size=None): + loss = 0 + n = 0 + limit_batches = 10 + for i, batch in enumerate(dataloader): + assert isinstance(batch, dict) + result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda()) + + # bionemo AMPLIFY vocab_size + if vocab_size is not None: + logits = result["token_logits"][..., :vocab_size] + else: + logits = result.logits + + loss_mask = batch["loss_mask"].cuda() + target = batch["labels"].cuda() + + loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum") + n += loss_mask.sum() + + if limit_batches is not None and i + 1 >= limit_batches: + break + mean_loss: Tensor = loss / n + return mean_loss + + +def test_amplify_model_initialized(amplify_model): + assert isinstance(amplify_model, MegatronBioBertModel) + assert isinstance(amplify_model, AMPLIFYModel) + assert isinstance(amplify_model.embedding, AMPLIFYEmbedding) + + +def test_amplify_650m_checkpoint(amplify_model): + with tarfile.open(nemo1_checkpoint_path, "r") as ckpt, torch.no_grad(): + ckpt_file = ckpt.extractfile("./model_weights.ckpt") + + old_state_dict = torch.load(ckpt_file) + # megatron is not registering inv_freq params anymore. + # TODO: update Bionemo checkpoints + old_state_dict.pop("model.language_model.rotary_pos_emb.inv_freq") + + new_state_dict = amplify_model.state_dict_for_save_checkpoint() + + # Set the new_model_prefix to "" since we are looking at the base megatron model and not the lightning module which stores a copy of + # this model into self.module + old_keys = {nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="") for k in old_state_dict} + assert len(old_keys) == len(old_state_dict), "Mapping unexpectedly discarded some keys." + + new_keys = set(new_state_dict) + for k, v in old_state_dict.items(): + # Make sure the shapes of the weights match. + assert new_state_dict[nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="")].shape == v.shape + + extra_keys = new_keys.difference(old_keys) + extra_non_null_keys = {k for k in extra_keys if new_state_dict[k] is not None} + assert not extra_non_null_keys, "There are new keys that have state that is missing from the old checkpoint." + + missing_old_keys = old_keys.difference(new_keys) + assert not missing_old_keys, "There are keys in the old checkpoint that are missing from the new model." + + +def test_amplify_golden_values(amplify_650M_config_w_ckpt, sample_data): + assert amplify_650M_config_w_ckpt.core_attention_override is not None + tokenizer = AutoTokenizer(pretrained_model_name="facebook/amplify_t33_650M_UR50D") + tokens = tokenizer.tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True).to("cuda") + input_ids = tokens["input_ids"] + attention_mask = tokens["attention_mask"] + + # HF 650M model + hf_model = AMPLIFYForMaskedLM.from_pretrained( + "facebook/amplify_t33_650M_UR50D", torch_dtype=get_autocast_dtype(32) + ).cuda() + + with torch.no_grad(): + hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) + hf_logits = hf_output_all.logits * attention_mask.unsqueeze(-1) + hf_embeddings = reduce_hiddens(hf_output_all.hidden_states[-1], attention_mask) + + # free GPU RAM + del hf_model + gc.collect() + torch.cuda.empty_cache() + + # configure the model to return logits + model = amplify_650M_config_w_ckpt.configure_model(get_tokenizer()).cuda() + model.eval() + result = model(input_ids, attention_mask) + logits = result["token_logits"][..., : tokenizer.vocab_size] + logits = logits * attention_mask.unsqueeze(-1) # incorporate masking logic + + # free GPU RAM + del model + gc.collect() + torch.cuda.empty_cache() + + # configure the model to return hiddens + amplify_650M_config_hiddens = deepcopy(amplify_650M_config_w_ckpt) + amplify_650M_config_hiddens.set_hparam("return_only_hidden_states", True) + model = amplify_650M_config_hiddens.configure_model(get_tokenizer()).cuda() + model.eval() + hiddens = model(input_ids, attention_mask) + embeddings = reduce_hiddens(torch.transpose(hiddens, 0, 1).float(), attention_mask) + + torch.testing.assert_close(logits, hf_logits, atol=9e-2, rtol=0.0) + torch.testing.assert_close(embeddings, hf_embeddings, atol=5e-3, rtol=0.0) + + +def test_amplify_loss(amplify_650M_config_w_ckpt, dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + compute_hf_reference: bool = False + seed: int = 42 + + with ( + torch.inference_mode(), + megatron_parallel_state_utils.distributed_model_parallel_state(seed), + random_numpy_context(seed), + ): + tokenizer = get_tokenizer() + + # AMPLIFY model initialized with 650M params + model = amplify_650M_config_w_ckpt.configure_model(tokenizer).cuda() + + # Initialize the data module. + data_module = AMPLIFYDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=8, + micro_batch_size=4, + min_seq_length=None, + max_seq_length=1024, + seed=seed, + ) + assert data_module is not None + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 1 + + data_module.setup() + + train_dataloader = data_module.train_dataloader() + assert isinstance(train_dataloader, torch.utils.data.DataLoader) + + val_dataloader = data_module.val_dataloader() + assert isinstance(val_dataloader, torch.utils.data.DataLoader) + + mean_loss = _compute_loss(model, train_dataloader, vocab_size=tokenizer.vocab_size) + + if compute_hf_reference: + # HF model initialized with 650M params + hf_model = AMPLIFYForMaskedLM.from_pretrained( + "facebook/amplify_t33_650M_UR50D", torch_dtype=get_autocast_dtype(32) + ).cuda() + hf_mean_loss = _compute_loss(hf_model, train_dataloader) + print(f"hf_mean_loss: {hf_mean_loss}") + else: + hf_mean_loss = torch.tensor(3.0298714637756348).cuda() + + torch.testing.assert_close(mean_loss, hf_mean_loss, atol=1e-4, rtol=0.0) diff --git a/tach.toml b/tach.toml index 8bb10ef323..838057b798 100644 --- a/tach.toml +++ b/tach.toml @@ -8,6 +8,7 @@ exclude = [ "build", ] source_roots = [ + 'sub-packages/bionemo-amplify/src', 'sub-packages/bionemo-core/src', 'sub-packages/bionemo-esm2/src', 'sub-packages/bionemo-example_model/src', @@ -21,6 +22,14 @@ source_roots = [ 'sub-packages/bionemo-webdatamodule/src', ] +[[modules]] +path = "bionemo.amplify" +depends_on = [ + { path = "bionemo.core" }, + { path = "bionemo.esm2" }, + { path = "bionemo.llm" }, +] + [[modules]] path = "bionemo.core" depends_on = [] @@ -42,6 +51,7 @@ depends_on = [ [[modules]] path = "bionemo.fw" depends_on = [ + { path = "bionemo.amplify" }, { path = "bionemo.core" }, { path = "bionemo.esm2" }, { path = "bionemo.geneformer" },