diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 52c76dc8b2..8b9ec529c0 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -47,6 +47,7 @@ "./sub-packages/bionemo-geneformer/src", "./sub-packages/bionemo-llm/src", "./sub-packages/bionemo-testing/src", + "./sub-packages/bionemo-amplify/src", "./sub-packages/bionemo-example_model/src", "./3rdparty/NeMo", "./3rdparty/Megatron-LM" diff --git a/Dockerfile b/Dockerfile index 6642f3c5e7..dcae4b6a54 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,7 +38,7 @@ EOF # Reinstall TE to avoid debugpy bug in vscode: https://nvbugspro.nvidia.com/bug/5078830 # Pull the latest TE version from https://github.com/NVIDIA/TransformerEngine/releases # Use the version that matches the pytorch base container. -ARG TE_TAG=v1.13 +ARG TE_TAG=2215fa5c7557b66034068816020f9f611019e457 RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \ pip --disable-pip-version-check --no-cache-dir install \ git+https://github.com/NVIDIA/TransformerEngine.git@${TE_TAG} @@ -48,10 +48,13 @@ RUN NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi \ RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip --disable-pip-version-check --no-cache-dir install \ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.2.2.post1 -# Mamba dependancy installation +# Mamba dependency installation RUN pip --disable-pip-version-check --no-cache-dir install \ git+https://github.com/state-spaces/mamba.git@v2.2.2 +ARG XFORMER_ENGINE_TAG=v0.0.29.post1 +RUN pip install -v -U git+https://github.com/facebookresearch/xformers.git@${XFORMER_ENGINE_TAG}#egg=xformers + RUN pip install hatchling # needed to install nemo-run ARG NEMU_RUN_TAG=34259bd3e752fef94045a9a019e4aaf62bd11ce2 RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMU_RUN_TAG} @@ -100,7 +103,7 @@ COPY ./sub-packages /workspace/bionemo2/sub-packages RUN --mount=type=bind,source=./.git,target=./.git \ --mount=type=bind,source=./requirements-test.txt,target=/requirements-test.txt \ --mount=type=bind,source=./requirements-cve.txt,target=/requirements-cve.txt \ - --mount=type=cache,target=/root/.cache <=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "bionemo-amplify" +readme = "README.md" +description = "" +authors = [{ name = "BioNeMo Team", email = "bionemofeedback@nvidia.com" }] +requires-python = ">=3.10" +license = { file = "LICENSE" } +dynamic = ["version"] +dependencies = [ + # internal + 'bionemo-core', + 'bionemo-llm', + 'bionemo-esm2', + # external + # 'xformers' +] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["bionemo.*"] +namespaces = true +exclude = ["test*."] + +[tool.uv] +cache-keys = [{ git = true }] + +[tool.setuptools.dynamic] +version = { file = "VERSION" } diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/__init__.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/convert.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/convert.py new file mode 100644 index 0000000000..d4ac281ab1 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/convert.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import torch +from nemo.lightning import io, teardown +from nemo.lightning.pytorch.utils import dtype_from_hf +from transformers import AutoConfig as HFAutoConfig +from transformers import AutoModel + +from bionemo.amplify.model import AMPLIFYConfig +from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer +from bionemo.llm.lightning import BionemoLightningModule +from bionemo.llm.model.biobert.lightning import biobert_lightning_module + + +@io.model_importer(BionemoLightningModule, "hf") +class HFAMPLIFYImporter(io.ModelConnector[AutoModel, BionemoLightningModule]): + """Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model.""" + + def init(self) -> BionemoLightningModule: + """Initialize the converted model.""" + return biobert_lightning_module(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + """Applies the transformation.""" + source = AutoModel.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto") + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + teardown(trainer, target) + return output_path + + def convert_state(self, source, target): + """Converting HF state dict to NeMo state dict.""" + mapping = { + "encoder.weight": "embedding.word_embeddings.weight", + "transformer_encoder.*.wo.weight": "encoder.layers.*.self_attention.linear_proj.weight", + "transformer_encoder.*.ffn.w12.weight": "encoder.layers.*.mlp.linear_fc1.weight", + "transformer_encoder.*.ffn.w3.weight": "encoder.layers.*.mlp.linear_fc2.weight", + "transformer_encoder.*.attention_norm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "transformer_encoder.*.ffn_norm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "layer_norm_2.weight": "encoder.final_layernorm.weight", + "decoder.weight": "output_layer.weight", + "decoder.bias": "output_layer.bias", + } + + # lm_head.bias + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=[_import_qkv_weight], + # transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight], + ) + + @property + def tokenizer(self) -> BioNeMoAMPLIFYTokenizer: + """We just have the one tokenizer for ESM-2.""" + return BioNeMoAMPLIFYTokenizer() + + @property + def config(self) -> AMPLIFYConfig: + """Returns the transformed ESM-2 config given the model tag.""" + source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True) + output = AMPLIFYConfig( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + position_embedding_type="rope", + num_attention_heads=source.num_attention_heads, + seq_length=source.max_length, + fp16=(dtype_from_hf(source) == torch.float16), + bf16=(dtype_from_hf(source) == torch.bfloat16), + params_dtype=dtype_from_hf(source), + ) + + return output + + +@io.state_transform( + source_key="esm.embeddings.word_embeddings.weight", + target_key="embedding.word_embeddings.weight", +) +def _pad_embeddings(ctx: io.TransformCTX, source_embed): + """Pad the embedding layer to the new input dimension.""" + nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by + hf_embedding_dimension = source_embed.size(0) + num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension + padding_rows = torch.zeros(num_padding_rows, source_embed.size(1)) + return torch.cat((source_embed, padding_rows), dim=0) + + +@io.state_transform( + source_key="lm_head.bias", + target_key="output_layer.bias", +) +def _pad_bias(ctx: io.TransformCTX, source_bias): + """Pad the embedding layer to the new input dimension.""" + nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by + hf_embedding_dimension = source_bias.size(0) + output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device) + output_bias[:hf_embedding_dimension] = source_bias + return output_bias + + +@io.state_transform( + source_key=( + "transformer_encoder.*.q.weight", + "transformer_encoder.*.k.weight", + "transformer_encoder.*.v.weight", + ), + target_key="encoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv_weight(ctx: io.TransformCTX, query, key, value): + """Pad the embedding layer to the new input dimension.""" + concat_weights = torch.cat((query, key, value), dim=0) + input_shape = concat_weights.size() + np = ctx.target.config.num_attention_heads + # transpose weights + # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads] + # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads] + concat_weights = concat_weights.view(3, np, -1, query.size()[-1]) + concat_weights = concat_weights.transpose(0, 1).contiguous() + concat_weights = concat_weights.view(*input_shape) + return concat_weights + + +@io.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.bias", + "esm.encoder.layer.*.attention.self.key.bias", + "esm.encoder.layer.*.attention.self.value.bias", + ), + target_key="encoder.layers.*.self_attention.linear_qkv.bias", +) +def _import_qkv_bias(ctx: io.TransformCTX, query, key, value): + """Pad the embedding layer to the new input dimension.""" + concat_biases = torch.cat((query, key, value), dim=0) + input_shape = concat_biases.size() + np = ctx.target.config.num_attention_heads + # transpose biases + # [num_splits_model_parallel * attention head size * #attention heads] + # --> [attention head size * num_splits_model_parallel * #attention heads] + concat_biases = concat_biases.view(3, np, -1) + concat_biases = concat_biases.transpose(0, 1).contiguous() + concat_biases = concat_biases.view(*input_shape) + return concat_biases diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/hf_rotary.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/hf_rotary.py new file mode 100644 index 0000000000..18c55f9bdc --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/hf_rotary.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Code copied from https://huggingface.co/chandar-lab/AMPLIFY_350M/blob/main/rotary.py + + +from typing import Tuple + +import torch + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + """Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + return torch.polar(torch.ones_like(freqs), freqs) # complex64 + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/model.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/model.py new file mode 100644 index 0000000000..3b76881c73 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/model.py @@ -0,0 +1,329 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Any, Callable, Literal, Optional, Sequence, Type, TypeVar + +import torch +from megatron.core import tensor_parallel +from megatron.core.models.bert.pooler import Pooler +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.transformer import spec_utils +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import get_linear_layer +from torch import Tensor +from torch.nn.functional import silu +from torch.optim import Optimizer + +# from bionemo.amplify.data.tokenizer import BioNeMoAMPLIFYTokenizer +from bionemo.llm.api import MegatronLossType +from bionemo.llm.model.biobert.model import BioBertConfig, MegatronBioBertModel, PositionEmbeddingKinds +from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption +from bionemo.llm.utils import iomixin_utils as iom + + +__all__: Sequence[str] = ( + "AMPLIFYConfig", + "AMPLIFYModel", +) + + +class AMPLIFYLMHead(MegatronModule): + """LM head for AMPLIFY. + + Args: + hidden_size: hidden size + config (TransformerConfig): TransformerConfig object + """ + + def __init__(self, config: TransformerConfig): + super().__init__(config=config) + self.head = IdentityOp() + + def forward(self, hidden_states: Tensor) -> Tensor: + return self.head(hidden_states) + + +class AMPLIFYModel(MegatronBioBertModel): + """AMPLIFY protein language model.""" + + def __init__( + self, + config: TransformerConfig, + num_tokentypes: int, + transformer_layer_spec: spec_utils.ModuleSpec, + vocab_size: int, + max_sequence_length: int, + tokenizer: Optional[Any] = None, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal["learned_absolute", "rope"] = "rope", + rotary_percent: float = 1.0, + seq_len_interpolation_factor: Optional[float] = None, + add_binary_head: bool = True, + return_embeddings: bool = False, + include_embeddings: bool = False, + include_input_ids: bool = False, + use_full_attention_mask: bool = False, + include_hiddens: bool = False, + skip_logits: bool = False, + ) -> None: + """Initialize the AMPLIFY model. + + Args: + config (TransformerConfig): transformer config + num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers + vocab_size (int): vocabulary size + max_sequence_length (int): maximum size of sequence. This is used for positional embedding + tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model) + pre_process (bool): Include embedding layer (used with pipeline parallelism) + post_process (bool): Include an output layer (used with pipeline parallelism) + fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks + share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False. + position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope']. + Defaults is 'learned_absolute'. + rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. + Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None. + add_binary_head (bool): Whether to add a binary head. Defaults to True. + return_embeddings (bool): Whether to return embeddings. Defaults to False. + include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False. + include_input_ids (bool): Whether to include input_ids in the output dictionary. Defaults to False. + use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False. + include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False. + skip_logits (bool): Skip writing the token logits in output dict + """ + super(MegatronBioBertModel, self).__init__(config=config) + self.post_process = post_process + self.add_binary_head = add_binary_head + if return_embeddings: + assert self.post_process, "only return embeddings on the last pipeline stage" + # `b` = batch, `s` = sequence. + # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while + # the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two. + self.use_full_attention_mask = use_full_attention_mask + self.config: TransformerConfig = config + self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + self.add_binary_head = add_binary_head + self.return_embeddings = return_embeddings + self.include_embeddings = include_embeddings + self.include_hiddens = include_hiddens + self.include_input_ids = include_input_ids + self.skip_logits = skip_logits + + if config.activation_func is silu: + multiple_of = 8 + intermediate_size = int(2 * config.ffn_hidden_size / 3) + config.ffn_hidden_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of) + + # megatron core pipelining currently depends on model type + self.model_type = ModelType.encoder_or_decoder + + # Embeddings. + if self.pre_process: + self.register_buffer( + "bert_position_id_tensor", + torch.arange(max_sequence_length, dtype=torch.long, requires_grad=False).unsqueeze(0), + persistent=False, + ) + # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding + # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor. + # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True)) + self.embedding = LanguageModelEmbedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=position_embedding_type, + num_tokentypes=num_tokentypes, + ) + + if self.position_embedding_type == "rope": + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + ) + + # Transformer. + self.encoder = TransformerBlock( + config=self.config, + spec=self.transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + # Output + if post_process: + # TODO: Make sure you are passing in the mpu_vocab_size properly + self.lm_head = AMPLIFYLMHead(config) + + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=True, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, + ) + + self.binary_head = None + if self.add_binary_head: + # TODO: Shoudl switch this to TE ? + self.binary_head = get_linear_layer( + config.hidden_size, 2, config.init_method, config.perform_initialization + ) + + self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel) + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + def embedding_forward( + self, + input_ids: Tensor, + position_ids: Tensor, + tokentype_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + ) -> Tensor: + """Produce embeddings.""" + return self.embedding(input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids) + + +AMPLIFYModelT = TypeVar("AMPLIFYModelT", bound=AMPLIFYModel) + + +@dataclass +class AMPLIFYConfig(BioBertConfig[AMPLIFYModelT, MegatronLossType], iom.IOMixinWithGettersSetters): + """Configuration class for AMPLIFY model. + + Attributes: + num_layers: Number of layers in the model. + hidden_size: Hidden size of the model. + num_attention_heads: Number of attention heads in the model. + ffn_hidden_size: Hidden size of the feed-forward network. + hidden_dropout: Dropout rate for hidden layers. + attention_dropout: Dropout rate for attention layers. + apply_residual_connection_post_layernorm: Whether to apply residual connection after layer normalization. + layernorm_epsilon: Epsilon value for layer normalization. + layernorm_zero_centered_gamma: Whether to zero-center the gamma parameter in layer normalization. + activation_func: Activation function used in the model. + init_method_std: Standard deviation for weight initialization. + apply_query_key_layer_scaling: Whether to apply scaling to query and key layers. + masked_softmax_fusion: Whether to use a kernel that fuses attention softmax with its mask. + fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + share_embeddings_and_output_weights: Whether to share embeddings and output weights. + enable_autocast: Whether to enable autocast for mixed precision. + biobert_spec_option: BiobertSpecOption for the model. + position_embedding_type: Type of position embedding used in the model. + seq_length: Length of the input sequence. + make_vocab_size_divisible_by: Make the vocabulary size divisible by this value. + token_dropout: Whether to apply token dropout. + use_attention_mask: Whether to use attention mask. + use_esm_attention: Whether to use ESM attention. + attention_softmax_in_fp32: Whether to use fp32 for attention softmax. + optimizer_fn: Optional optimizer function for the model. + parallel_output: Whether to use parallel output. + rotary_base: Base value for rotary positional encoding. + rotary_percent: Percentage of rotary positional encoding. + seq_len_interpolation_factor: Interpolation factor for sequence length. + get_attention_mask_from_fusion: Whether to get attention mask from fusion. + nemo1_ckpt_path: Path to NEMO1 checkpoint. + return_only_hidden_states: Whether to return only hidden states. + loss_reduction_class: Loss reduction class for the model. Default to BERTMLMLossWithReduction. + """ + + # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269 + model_cls: Type[AMPLIFYModelT] = AMPLIFYModel + seq_length: int = 512 + num_layers: int = 24 # 32 for 350M, 24 for 120M + hidden_size: int = 640 # 960 for 350M, 640 for 120M + num_attention_heads: int = 10 # 15 for 350M, 10 for 120M + ffn_hidden_size: int = 2560 # Transformer FFN hidden size. Usually 4 * hidden_size. + hidden_dropout: float = 0 # AMPLIFY removes dropout from hidden layers and attention + attention_dropout: float = 0.0 # AMPLIFY does not use attention dropout + apply_residual_connection_post_layernorm: bool = False # TODO: farhadr False is new default, True was BERT pub. + layernorm_epsilon: float = 1.0e-5 + init_method_std: float = 0.02 + + # embedding + token_dropout: bool = False + use_attention_mask: bool = True + + # core attention + use_esm_attention: bool = False # Skip ESM2 custom attention for TE acceleration. Still passes golden value test. + attention_softmax_in_fp32: bool = False + normalize_attention_scores: bool = False + + # From megatron.core.models.gpt.bert_model.GPTModel + fp16_lm_cross_entropy: bool = False # Move the cross entropy unreduced loss calculation for lm head to fp16 + parallel_output: bool = True + share_embeddings_and_output_weights: bool = False + make_vocab_size_divisible_by: int = 1 + position_embedding_type: PositionEmbeddingKinds = "rope" + rotary_interleaved: bool = True + rotary_base: int = 10_000 + rotary_percent: float = 1.0 + + # AMPLIFY specific configuration + add_bias_linear: bool = False # AMPLIFY does not use bias in linear layers + bias_swiglu_fusion: bool = False + bias_activation_fusion: bool = False + bias_dropout_fusion: bool = False + apply_rope_fusion: bool = False + gated_linear_unit: bool = True + masked_softmax_fusion: bool = False + activation_func: str = silu + normalization: str = "RMSNorm" # AMPLIFY uses RMSNorm instead of LayerNorm + layernorm_zero_centered_gamma: bool = False # Zero centered gamma not supported for RMSNorm + biobert_spec_option: BiobertSpecOption = BiobertSpecOption.amplify_with_transformer_engine_spec + apply_query_key_layer_scaling = False + + # TODO: Move this to better places? + get_attention_mask_from_fusion: bool = False + + optimizer_fn: Optional[Callable[[MegatronBioBertModel], Optimizer]] = None + # TODO (@skothenhill,@georgea) update to use the nemo2 checkpoint mixins + # support HF (requires weight interleaving on qkv layer) and nemo1 checkpoints ideally. + nemo1_ckpt_path: str | None = None + # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in + # self.override_parent_fields will be loaded from the checkpoint and override those values here. + initial_ckpt_path: str | None = None + # TODO (@jstjohn) come up with a cleaner way in the biobert module to return user requested + # things as part of the workflow for inference and fine-tuning. + return_embeddings: bool = False + include_embeddings: bool = False + skip_logits: bool = False + return_only_hidden_states: bool = False # return logits diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py new file mode 100644 index 0000000000..95cb691c51 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import transformers +from nemo.lightning.io import IOMixin + + +class BioNeMoAMPLIFYTokenizer(transformers.PreTrainedTokenizerFast, IOMixin): # noqa D101 + def __init__(self): + """A wrapper to make AutoTokenizer serializable for the ESM2 tokenizer.""" + other = transformers.AutoTokenizer.from_pretrained("chandar-lab/AMPLIFY_350M", use_fast=True) + self.__dict__.update(dict(other.__dict__)) diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/__init__.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_convert.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_convert.py new file mode 100644 index 0000000000..2cdb10a654 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_convert.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import torch +from megatron.core.transformer.module import Float16Module +from nemo.lightning import io +from transformers import AutoModel + +from bionemo.amplify.convert import HFAMPLIFYImporter # noqa: F401 +from bionemo.amplify.hf_rotary import apply_rotary_emb +from bionemo.amplify.model import AMPLIFYConfig +from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer +from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype +from bionemo.esm2.testing.compare import ForwardHook, assert_cosine_similarity, get_input_tensors +from bionemo.llm.model.biobert.lightning import biobert_lightning_module +from bionemo.testing import megatron_parallel_state_utils + + +def assert_amplify_equivalence( + ckpt_path: str, + model_tag: str, + precision: PrecisionTypes = "fp32", + rtol: float | None = None, + atol: float | None = None, +) -> None: + tokenizer = BioNeMoAMPLIFYTokenizer() + + input_ids, attention_mask = get_input_tensors(tokenizer) + hf_results = load_and_evaluate_hf_amplify(model_tag, precision, input_ids, attention_mask) + # gc.collect() + # torch.cuda.empty_cache() + nemo_results = load_and_evaluate_nemo_amplify( + tokenizer, + ckpt_path, + precision, + input_ids, + attention_mask, + ) + + torch.testing.assert_close(hf_results["embeddings"], nemo_results["embeddings"], rtol=rtol, atol=atol) + torch.testing.assert_close(hf_results["query_post_rot"], nemo_results["query_post_rot"], rtol=rtol, atol=atol) + torch.testing.assert_close(hf_results["key_post_rot"], nemo_results["key_post_rot"], rtol=rtol, atol=atol) + torch.testing.assert_close(hf_results["value"], nemo_results["value"], rtol=rtol, atol=atol) + + # torch.testing.assert_close(hf_results["attn_output"], nemo_results["attn_output"], rtol=rtol, atol=atol) + assert_cosine_similarity( + hf_results["attn_output"], + nemo_results["attn_output"], + attention_mask.cpu(), + rtol=1e-4, + atol=1e-4, + ) + + assert_cosine_similarity( + hf_results["attn_linear_output"], + nemo_results["attn_linear_output"], + attention_mask.cpu(), + rtol=1e-4, + atol=1e-4, + ) + + for i, (hf_block_output, nemo_block_output) in enumerate( + zip(hf_results["encoder_block_outputs"], nemo_results["encoder_block_outputs"], strict=True) + ): + assert_cosine_similarity( + hf_block_output, + nemo_block_output, + attention_mask.cpu(), + rtol=1e-4, + atol=1e-4, + msg=f"Encoder block output {i}", + ) + + # assert_cosine_similarity(nemo_attn_inputs[0].transpose(0, 1), hf_attn_inputs[0], attention_mask, msg="Attn inputs") + # assert_cosine_similarity( + # nemo_attn_outputs[0].transpose(0, 1), hf_attn_outputs[0], attention_mask, msg="Attn outputs" + # ) + + # assert_cosine_similarity(nemo_hidden_state, hf_hidden_state, attention_mask, rtol, atol) + # assert_cosine_similarity(nemo_logits, hf_logits, attention_mask, rtol, atol) + + +def load_and_evaluate_hf_amplify( + model_tag: str, precision: PrecisionTypes, input_ids: torch.Tensor, attention_mask: torch.Tensor +) -> dict[str, torch.Tensor]: + """Load a HuggingFace model and evaluate it on the given inputs. + + Args: + model_tag: The HuggingFace model tag for the model to compare against. + precision: The precision type to use for the comparison. + input_ids: The input IDs tensor to evaluate. + attention_mask: The attention mask tensor to evaluate. + """ + hf_model = AutoModel.from_pretrained( + model_tag, + torch_dtype=get_autocast_dtype(precision), + trust_remote_code=True, + ) + + embedding_hook = ForwardHook(lambda inputs, outputs: outputs[0]) + hf_model.encoder.register_forward_hook(embedding_hook) + + query_pre_rot_hook = ForwardHook(lambda inputs, outputs: outputs[0]) + hf_model.transformer_encoder[0].q.register_forward_hook(query_pre_rot_hook) + + key_pre_rot_hook = ForwardHook(lambda inputs, outputs: outputs[0]) + hf_model.transformer_encoder[0].k.register_forward_hook(key_pre_rot_hook) + + value_hook = ForwardHook(lambda inputs, outputs: outputs[0]) + hf_model.transformer_encoder[0].v.register_forward_hook(value_hook) + + # The output of the attention layer is the same as the output of the linear layer, but the actual attention function + # isn't wrapped in a nn.Module. + attn_output_hook = ForwardHook(lambda inputs, outputs: inputs[0]) + hf_model.transformer_encoder[0].wo.register_forward_hook(attn_output_hook) + + attn_linear_output_hook = ForwardHook(lambda inputs, outputs: outputs[0]) + hf_model.transformer_encoder[0].wo.register_forward_hook(attn_linear_output_hook) + + encoder_block_hooks = [ + ForwardHook(lambda inputs, outputs: outputs[0]) for _ in range(len(hf_model.transformer_encoder)) + ] + for i, hook in enumerate(encoder_block_hooks): + hf_model.transformer_encoder[i].register_forward_hook(hook) + + hf_model = hf_model.to("cuda").eval() + _ = hf_model(input_ids, attention_mask.float(), output_hidden_states=True) + + # These post-rotary embeddings are applied in the forward pass of the model, so we need to apply them here. + xq = query_pre_rot_hook.data.view( + input_ids.shape[0], + input_ids.shape[1], + hf_model.config.num_attention_heads, + hf_model.transformer_encoder[0].d_head, + ) + xk = key_pre_rot_hook.data.view( + input_ids.shape[0], + input_ids.shape[1], + hf_model.config.num_attention_heads, + hf_model.transformer_encoder[0].d_head, + ) + xq, xk = apply_rotary_emb(xq, xk, hf_model.freqs_cis[: input_ids.shape[1]].cpu()) + + # hf_hidden_state = hf_output_all.hidden_states[-1] + + return { + "embeddings": embedding_hook.data, + "query_post_rot": xq.flatten(-2, -1), + "key_post_rot": xk.flatten(-2, -1), + "value": value_hook.data, + "attn_output": attn_output_hook.data, + "attn_linear_output": attn_linear_output_hook.data, + "encoder_block_outputs": [hook.data for hook in encoder_block_hooks], + } + + +def load_and_evaluate_nemo_amplify( + tokenizer: BioNeMoAMPLIFYTokenizer, + ckpt_path: Path | str, + precision: PrecisionTypes, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, +) -> dict[str, torch.Tensor]: + """Load a AMPLIFY NeMo2 model checkpoint and evaluate it on the input tensors. + + It would be great to make this more ergonomic, i.e., how to create a model from a checkpoint and evaluate it. + + Args: + tokenizer: Not sure why we need to pass a tokenizer to `configure_model`. + ckpt_path: Path to the newly created NeMo2 converted checkpoint. + precision: Precision type to use for the model. + input_ids: Input tokens + attention_mask: Input attention mask + """ + + dtype = get_autocast_dtype(precision) + nemo_config = AMPLIFYConfig( + initial_ckpt_path=str(ckpt_path), + include_embeddings=True, + include_hiddens=True, + params_dtype=dtype, + pipeline_dtype=dtype, + autocast_dtype=dtype, + bf16=dtype is torch.bfloat16, + fp16=dtype is torch.float16, + ) + + nemo_model = nemo_config.configure_model(tokenizer).to("cuda").eval() + + if dtype is torch.float16 or dtype is torch.bfloat16: + nemo_model = Float16Module(nemo_config, nemo_model) + + embedding_hook = ForwardHook(lambda inputs, outputs: outputs[0].transpose(0, 1)) + nemo_model.embedding.register_forward_hook(embedding_hook) + + query_post_rot_hook = ForwardHook(lambda inputs, outputs: inputs[0].transpose(0, 1).flatten(-2, -1)) + nemo_model.encoder.layers[0].self_attention.core_attention.register_forward_hook(query_post_rot_hook) + + key_post_rot_hook = ForwardHook(lambda inputs, outputs: inputs[1].transpose(0, 1).flatten(-2, -1)) + nemo_model.encoder.layers[0].self_attention.core_attention.register_forward_hook(key_post_rot_hook) + + value_post_rot_hook = ForwardHook(lambda inputs, outputs: inputs[2].transpose(0, 1).flatten(-2, -1)) + nemo_model.encoder.layers[0].self_attention.core_attention.register_forward_hook(value_post_rot_hook) + + attn_output_hook = ForwardHook(lambda inputs, outputs: outputs[0].transpose(0, 1)) + nemo_model.encoder.layers[0].self_attention.core_attention.register_forward_hook(attn_output_hook) + + attn_linear_output_hook = ForwardHook(lambda inputs, outputs: outputs[0].transpose(0, 1)) + nemo_model.encoder.layers[0].self_attention.linear_proj.register_forward_hook(attn_linear_output_hook) + + encoder_block_hooks = [ + ForwardHook(lambda inputs, outputs: outputs[0].transpose(0, 1)) for _ in range(len(nemo_model.encoder.layers)) + ] + for i, hook in enumerate(encoder_block_hooks): + nemo_model.encoder.layers[i].register_forward_hook(hook) + + # attn_hook = TestHook() + # nemo_model.encoder.layers[0].register_forward_hook(attn_hook) + + nemo_output = nemo_model(input_ids, attention_mask) + + _ = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] + # nemo_hidden_state = nemo_output["hidden_states"] + + return { + "embeddings": embedding_hook.data, + "query_post_rot": query_post_rot_hook.data, + "key_post_rot": key_post_rot_hook.data, + "value": value_post_rot_hook.data, + "attn_output": attn_output_hook.data, + "attn_linear_output": attn_linear_output_hook.data, + "encoder_block_outputs": [hook.data for hook in encoder_block_hooks], + } + + +def test_convert_amplify_120M_smoke(tmp_path): + model_tag = "chandar-lab/AMPLIFY_120M" + module = biobert_lightning_module(config=AMPLIFYConfig()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + + +def test_convert_amplify_120M(tmp_path): + model_tag = "chandar-lab/AMPLIFY_120M" + module = biobert_lightning_module(config=AMPLIFYConfig()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + with megatron_parallel_state_utils.distributed_model_parallel_state(): + assert_amplify_equivalence(tmp_path / "nemo_checkpoint", model_tag) + + +def test_convert_amplify_350M(tmp_path): + model_tag = "chandar-lab/AMPLIFY_350M" + module = biobert_lightning_module(config=AMPLIFYConfig()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + with megatron_parallel_state_utils.distributed_model_parallel_state(): + assert_amplify_equivalence(tmp_path / "nemo_checkpoint", model_tag) diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_hf_rotary.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_hf_rotary.py new file mode 100644 index 0000000000..174b36d610 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_hf_rotary.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from transformers import AutoConfig + +from bionemo.amplify.hf_rotary import apply_rotary_emb, precompute_freqs_cis +from bionemo.amplify.model import AMPLIFYConfig + + +def test_rope_embeddings(): + rng = torch.Generator().manual_seed(42) + q = torch.randn([2, 72, 10, 64], dtype=torch.float32, generator=rng) + k = torch.randn([2, 72, 10, 64], dtype=torch.float32, generator=rng) + + # AMPLIFY HF Rope + hf_config = AutoConfig.from_pretrained("chandar-lab/AMPLIFY_120M", trust_remote_code=True) + freqs_cis = precompute_freqs_cis(hf_config.hidden_size // hf_config.num_attention_heads, hf_config.max_length) + freqs_cis = freqs_cis[: q.shape[1]] + q_post, k_post = apply_rotary_emb(q, k, freqs_cis) + + # NeMo Rope + nemo_config = AMPLIFYConfig(apply_rope_fusion=False, rotary_interleaved=True) + rotary_pos_layer = RotaryEmbedding( + kv_channels=nemo_config.kv_channels, + rotary_percent=nemo_config.rotary_percent, + rotary_interleaved=nemo_config.rotary_interleaved, + seq_len_interpolation_factor=nemo_config.seq_len_interpolation_factor, + ) + rotary_pos_emb = rotary_pos_layer(q.shape[1]) + q_post_nemo = apply_rotary_pos_emb(q.transpose(0, 1).cuda(), rotary_pos_emb.cuda(), config=nemo_config).cpu() + k_post_nemo = apply_rotary_pos_emb(k.transpose(0, 1).cuda(), rotary_pos_emb.cuda(), config=nemo_config).cpu() + + torch.testing.assert_close(q_post, q_post_nemo.transpose(0, 1)) + torch.testing.assert_close(k_post, k_post_nemo.transpose(0, 1)) + + +# TODO: extend this test to try the DotProductAttention and TEDotProductAttention layers and compare how close the +# outputs are; that seems to be where the outputs between the HF and NeMo implementations are diverging. + +# def test_multi_head_attention(): +# rng = torch.Generator().manual_seed(42) +# q = torch.randn([2, 72, 10, 64], dtype=torch.float32, generator=rng) +# k = torch.randn([2, 72, 10, 64], dtype=torch.float32, generator=rng) +# v = torch.randn([2, 72, 10, 64], dtype=torch.float32, generator=rng) + +# attention_mask = torch.ones([2, 72], dtype=torch.float32).bool() +# attention_mask[0, -7:] = False +# attention_mask[1, -5:] = False + +# q_new = torch.randn([2, 72, 10, 64], dtype=torch.float32, generator=rng) +# k_new = torch.randn([2, 72, 10, 64], dtype=torch.float32, generator=rng) +# v_new = torch.randn([2, 72, 10, 64], dtype=torch.float32, generator=rng) + +# q_new[attention_mask] = q[attention_mask] +# k_new[attention_mask] = k[attention_mask] +# v_new[attention_mask] = v[attention_mask] + +# attention_mask_rep = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, 10, attention_mask.size(-1), 1) + +# attn_output = torch.nn.functional.scaled_dot_product_attention( +# query=q.transpose(1, 2), +# key=k.transpose(1, 2), +# value=v.transpose(1, 2), +# attn_mask=attention_mask_rep, +# dropout_p=0, +# ).transpose(1, 2) + +# attn_output_new = torch.nn.functional.scaled_dot_product_attention( +# query=q_new.transpose(1, 2), +# key=k_new.transpose(1, 2), +# value=v_new.transpose(1, 2), +# attn_mask=attention_mask_rep, +# dropout_p=0, +# ).transpose(1, 2) + +# torch.testing.assert_close(attn_output[attention_mask], attn_output_new[attention_mask]) diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_model.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_model.py new file mode 100644 index 0000000000..18ee0a798f --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_model.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from bionemo.amplify.model import AMPLIFYConfig, AMPLIFYModel +from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer +from bionemo.esm2.model.embedding import ESM2Embedding +from bionemo.llm.model.biobert.model import MegatronBioBertModel +from bionemo.testing import megatron_parallel_state_utils + + +def test_amplify_model_initialized(): + with megatron_parallel_state_utils.distributed_model_parallel_state(): + tokenizer = BioNeMoAMPLIFYTokenizer() + config = AMPLIFYConfig() + model = config.configure_model(tokenizer) + + assert isinstance(model, MegatronBioBertModel) + assert isinstance(model, AMPLIFYModel) + assert isinstance(model.embedding, ESM2Embedding) + + +def test_amplify_model_forward_pass(): + tokenizer = BioNeMoAMPLIFYTokenizer() + + test_proteins = [ + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", + "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", + ] + tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda") + input_ids = tokens["input_ids"] + attention_mask = tokens["attention_mask"] + + nemo_config = AMPLIFYConfig( + num_layers=2, + num_attention_heads=2, + hidden_size=4, + ffn_hidden_size=4 * 4, + ) + + with megatron_parallel_state_utils.distributed_model_parallel_state(): + nemo_model = nemo_config.configure_model(tokenizer).to("cuda").eval() + nemo_output = nemo_model(input_ids, attention_mask) + assert isinstance(nemo_output["token_logits"], torch.Tensor) diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py new file mode 100644 index 0000000000..7722d3c988 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch +from nemo.lightning import io + +from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer + + +@pytest.fixture +def tokenizer(): + return BioNeMoAMPLIFYTokenizer() + + +def test_tokenizer_serialization(tokenizer, tmp_path): + tokenizer.io_dump(tmp_path / "tokenizer", yaml_attrs=[]) # BioNeMoESMTokenizer takes no __init__ arguments + deserialized_tokenizer = io.load(tmp_path / "tokenizer", tokenizer.__class__) + + our_tokens = deserialized_tokenizer.encode("KAISQ", add_special_tokens=False) + amplify_tokens = torch.tensor([17, 7, 2, 14, 10, 18]) + torch.testing.assert_close(torch.tensor(our_tokens), amplify_tokens) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py index 40783c7b8d..1c8323efa8 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py @@ -44,7 +44,7 @@ def __init__( # ESM2 NEW ARGS token_dropout: bool = True, use_attention_mask: bool = True, - mask_token_id: Optional[int] = torch.nan, + mask_token_id: Optional[int] = None, ) -> None: """Initialize the ESM2 Embedding module.""" super().__init__( @@ -65,7 +65,7 @@ def dtype(self) -> torch.dtype: def _apply_esm2_customization( self, word_embeddings: Tensor, input_ids: Tensor, attention_mask: Tensor - ) -> Tuple[Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor | None]: """ESM2 customization for attention masking and token dropout. Args: @@ -95,7 +95,7 @@ def forward( self, input_ids: Tensor, position_ids: Tensor, - tokentype_ids: Optional[int] = None, + tokentype_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass of the embedding module. @@ -103,7 +103,7 @@ def forward( Args: input_ids (Tensor): The input tokens. Shape: [b, s] position_ids (Tensor): The position id's used to calculate position embeddings. Shape: [b, s] - tokentype_ids (int, optional): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None + tokentype_ids (Tensor, optional): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None attention_mask (Tensor): attention mask. Shape: [b, s] Returns: diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py index e8690c1d04..6db46a5732 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py @@ -15,7 +15,9 @@ import gc +import math from pathlib import Path +from typing import Callable import torch from megatron.core.transformer.module import Float16Module @@ -26,7 +28,7 @@ from bionemo.esm2.model.model import ESM2Config -def assert_model_equivalence( +def assert_esm2_equivalence( ckpt_path: Path | str, model_tag: str, precision: PrecisionTypes = "fp32", @@ -49,13 +51,66 @@ def assert_model_equivalence( """ tokenizer = get_tokenizer() + input_ids, attention_mask = get_input_tensors(tokenizer) + + nemo_logits, nemo_hidden_state = load_and_evaluate_nemo_esm2(ckpt_path, precision, input_ids, attention_mask) + gc.collect() + torch.cuda.empty_cache() + hf_logits, hf_hidden_state = load_and_evaluate_hf_model(model_tag, precision, input_ids, attention_mask) + + # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These + # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. + # We don't care about the padding tokens, so we only compare the non-padding tokens. + assert_cosine_similarity(nemo_logits, hf_logits, attention_mask, rtol, atol) + assert_cosine_similarity(nemo_hidden_state, hf_hidden_state, attention_mask, rtol, atol) + + +def get_input_tensors(tokenizer) -> tuple[torch.Tensor, torch.Tensor]: + """Get input tensors for testing. + + Args: + tokenizer: A huggingface-like tokenizer object. + + Returns: + A tuple of the input IDs and attention mask tensors. + """ test_proteins = [ "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", ] - tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda") - input_ids = tokens["input_ids"] - attention_mask = tokens["attention_mask"] + tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True) + input_ids: torch.Tensor = tokens["input_ids"] # type: ignore + attention_mask: torch.Tensor = tokens["attention_mask"] # type: ignore + + # Pad the input IDs and attention mask to be divisible by 8 so xformers doesn't fail. + padded_shape = math.ceil(attention_mask.size(1) / 8) + padded_input_ids = torch.full((input_ids.size(0), padded_shape * 8), tokenizer.pad_token_id, dtype=torch.long) + padded_input_ids[: input_ids.size(0), : input_ids.size(1)] = input_ids + + padded_attention_mask = torch.zeros((attention_mask.size(0), padded_shape * 8), dtype=torch.bool) + padded_attention_mask[: attention_mask.size(0), : attention_mask.size(1)] = attention_mask + + return padded_input_ids.to("cuda"), padded_attention_mask.to("cuda") + + +def load_and_evaluate_nemo_esm2( + ckpt_path: Path | str, + precision: PrecisionTypes, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Load a NeMo2 ESM-2 model and evaluate it on the given inputs. + + Args: + ckpt_path: A path to a NeMo2 checkpoint for an ESM-2 model. + precision: The precision type to use for the comparison. + input_ids: The input IDs tensor to evaluate. + attention_mask: The attention mask tensor to evaluate. + + Returns: + A tuple of the logits and hidden states tensors calculated by the NeMo2 model, respectively. + """ + tokenizer = get_tokenizer() dtype = get_autocast_dtype(precision) nemo_config = ESM2Config( @@ -77,23 +132,122 @@ def assert_model_equivalence( nemo_output = nemo_model(input_ids, attention_mask) nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] nemo_hidden_state = nemo_output["hidden_states"] + return nemo_logits, nemo_hidden_state - del nemo_model - gc.collect() - torch.cuda.empty_cache() - hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(precision)).cuda().eval() +def load_and_evaluate_hf_model( + model_tag: str, precision: PrecisionTypes, input_ids: torch.Tensor, attention_mask: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Load a HuggingFace model and evaluate it on the given inputs. + + Args: + model_tag: The HuggingFace model tag for the model to compare against. + precision: The precision type to use for the comparison. + input_ids: The input IDs tensor to evaluate. + attention_mask: The attention mask tensor to evaluate. + + Returns: + A tuple of the logits and hidden states tensors calculated by the HuggingFace model, respectively. + """ + hf_model = AutoModelForMaskedLM.from_pretrained( + model_tag, + torch_dtype=get_autocast_dtype(precision), + trust_remote_code=True, + ) + hf_model = hf_model.to("cuda").eval() hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) hf_hidden_state = hf_output_all.hidden_states[-1] + return hf_output_all.logits, hf_hidden_state - # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These - # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. - # We don't care about the padding tokens, so we only compare the non-padding tokens. - logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2) - logit_similarity = logit_similarity[attention_mask == 1] - hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2) - hidden_state_similarity = hidden_state_similarity[attention_mask == 1] +def assert_cosine_similarity( + tensor1: torch.Tensor, + tensor2: torch.Tensor, + mask: torch.Tensor, + rtol: float | None = None, + atol: float | None = None, + msg: str | None = None, +) -> None: + """Assert that both the cosine similarity between two tensors is close to 1, and the ratio of their magnitudes is 1. + + Args: + tensor1: The first tensor to compare. + tensor2: The second tensor to compare. + mask: A mask tensor to apply to the comparison. + rtol: The relative tolerance to use for the comparison. Defaults to 1e-4. + atol: The absolute tolerance to use for the comparison. Defaults to 1e-4. + msg: An optional message to include in the assertion error. + """ + assert tensor1.size() == tensor2.size() + + similarity = torch.nn.functional.cosine_similarity(tensor1, tensor2, dim=2) + similarity = similarity[mask == 1] + + torch.testing.assert_close( + similarity, + torch.ones_like(similarity), + rtol=rtol, + atol=atol, + msg=lambda x: f"{msg} (angle): {x}", + ) + + magnitude_similarity = torch.norm(tensor1, dim=2) / torch.norm(tensor2, dim=2) + magnitude_similarity = magnitude_similarity[mask == 1] + torch.testing.assert_close( + magnitude_similarity, + torch.ones_like(magnitude_similarity), + rtol=1e-2, + atol=1e-2, + msg=lambda x: f"{msg} (magnitude): {x}", + ) + + +TransformFn = Callable[ + [tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]], + torch.Tensor, +] + + +class ForwardHook: + """A forward hook to extract a desired intermediate tensor for later comparison.""" + + def __init__(self, transform_fn: TransformFn) -> None: + """A forward hook to extract a desired intermediate tensor for later comparison. + + The resulting tensor is saved in the `data` attribute of the hook. + + Args: + transform_fn: A function that maps the input and output tensors of the module to the desired tensor. + """ + self._transform_fn = transform_fn + self._data: torch.Tensor | None = None + + def __call__(self, module, module_in, module_out): + """The forward hook function.""" + if not isinstance(module_out, tuple): + module_out = (module_out,) + if not isinstance(module_in, tuple): + module_in = (module_in,) + + self._data = self._transform_fn(module_in, module_out).detach().cpu() + + @property + def data(self) -> torch.Tensor: + """The extracted tensor from the forward hook.""" + if self._data is None: + raise ValueError("No data has been saved in this hook.") + return self._data + + +class TestHook: + """A test hook that just captures the raw inputs and outputs.""" + + def __init__(self) -> None: + """A test hook that just captures the raw inputs and outputs.""" + self.inputs: tuple[torch.Tensor, ...] | None = None + self.outputs: tuple[torch.Tensor, ...] | None = None - torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity), rtol=rtol, atol=atol) - torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity), rtol=rtol, atol=atol) + def __call__(self, module, inputs, outputs): + """The forward hook function.""" + self.inputs = inputs + self.outputs = outputs diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/__init__.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/__init__.py new file mode 100644 index 0000000000..25e6abfbc5 --- /dev/null +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py index de8a23a107..51f6fb7124 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py @@ -19,7 +19,7 @@ from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401 from bionemo.esm2.model.model import ESM2Config -from bionemo.esm2.testing.compare import assert_model_equivalence +from bionemo.esm2.testing.compare import assert_esm2_equivalence from bionemo.llm.model.biobert.lightning import biobert_lightning_module from bionemo.testing import megatron_parallel_state_utils @@ -35,7 +35,7 @@ def test_nemo2_conversion_equivalent_8m(tmp_path): module = biobert_lightning_module(config=ESM2Config()) io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") with megatron_parallel_state_utils.distributed_model_parallel_state(): - assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag) + assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag) def test_nemo2_conversion_equivalent_8m_bf16(tmp_path): @@ -43,7 +43,7 @@ def test_nemo2_conversion_equivalent_8m_bf16(tmp_path): module = biobert_lightning_module(config=ESM2Config()) io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"): - assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag, precision="bf16") + assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag, precision="bf16") @pytest.mark.slow @@ -52,4 +52,4 @@ def test_nemo2_conversion_equivalent_650m(tmp_path): module = biobert_lightning_module(config=ESM2Config()) io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") with megatron_parallel_state_utils.distributed_model_parallel_state(): - assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag, atol=1e-4, rtol=1e-4) + assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag, atol=1e-4, rtol=1e-4) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index 8895b3719a..12cf4800ea 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -29,7 +29,7 @@ from bionemo.esm2.data.datamodule import ESMDataModule from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.embedding import ESM2Embedding -from bionemo.esm2.testing.compare import assert_model_equivalence +from bionemo.esm2.testing.compare import assert_esm2_equivalence from bionemo.llm.model.biobert.model import MegatronBioBertModel from bionemo.llm.utils.weight_utils import nemo1_to_nemo2_biobert_key_mapping from bionemo.testing import megatron_parallel_state_utils @@ -180,7 +180,7 @@ def test_model_equivalence_with_huggingface_8m(precision): model_tag = "facebook/esm2_t6_8M_UR50D" ckpt_path = load("esm2/8m:2.0") with megatron_parallel_state_utils.distributed_model_parallel_state(precision=precision): - assert_model_equivalence(ckpt_path, model_tag, precision=precision) + assert_esm2_equivalence(ckpt_path, model_tag, precision=precision) @pytest.mark.slow @@ -188,7 +188,7 @@ def test_model_equivalence_with_huggingface_650m(): model_tag = "facebook/esm2_t33_650M_UR50D" ckpt_path = load("esm2/650m:2.0") with megatron_parallel_state_utils.distributed_model_parallel_state(): - assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) + assert_esm2_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) @pytest.mark.slow @@ -196,7 +196,7 @@ def test_model_equivalence_with_huggingface_650m_bf16(): model_tag = "facebook/esm2_t33_650M_UR50D" ckpt_path = load("esm2/650m:2.0") with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"): - assert_model_equivalence(ckpt_path, model_tag, precision="bf16") + assert_esm2_equivalence(ckpt_path, model_tag, precision="bf16") @pytest.mark.slow @@ -205,4 +205,4 @@ def test_model_equivalence_with_huggingface_3b(): model_tag = "facebook/esm2_t36_3B_UR50D" ckpt_path = load("esm2/3b:2.0") with megatron_parallel_state_utils.distributed_model_parallel_state(): - assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) + assert_esm2_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) diff --git a/sub-packages/bionemo-fw/pyproject.toml b/sub-packages/bionemo-fw/pyproject.toml index e19d4c2d8b..bc0622716f 100644 --- a/sub-packages/bionemo-fw/pyproject.toml +++ b/sub-packages/bionemo-fw/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ 'bionemo-scdl', 'bionemo-size-aware-batching', 'bionemo-webdatamodule', + 'bionemo-amplify', # # NOTE: DO **NOT** INCLUDE: # bionemo-testing (test-time only dependency) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py index d6baae9044..6de8d98207 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py @@ -17,17 +17,17 @@ from enum import Enum from typing import Optional, Sequence, Type +from megatron.core.extensions.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.models.bert import bert_layer_specs from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer import spec_utils from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.custom_layers.transformer_engine import ( - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TERowParallelLinear, -) from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp @@ -56,6 +56,7 @@ class BiobertSpecOption(str, Enum): # ESM2 spec esm2_bert_layer_local_spec = "esm2_bert_layer_local_spec" esm2_bert_layer_with_transformer_engine_spec = "esm2_bert_layer_with_transformer_engine_spec" + amplify_with_transformer_engine_spec = "amplify_with_transformer_engine_spec" def get_biobert_spec( # noqa: D417 @@ -221,5 +222,36 @@ def get_biobert_spec( # noqa: D417 ) return esm2_bert_layer_local_spec + case BiobertSpecOption.amplify_with_transformer_engine_spec: + if core_attention is None: + core_attention = DotProductAttention + + esm2_bert_layer_local_spec = spec_utils.ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=spec_utils.ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.padding}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=core_attention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + mlp=spec_utils.ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ) + return esm2_bert_layer_local_spec + case _: raise NotImplementedError(f"Spec option {biobert_spec_option} not implemented")