From 1940ee05ed5681202b0ccad1203ea4a8c51db75a Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Fri, 31 Jan 2025 07:29:46 -0800 Subject: [PATCH] Add model and tokenizer smoke tests Signed-off-by: Peter St. John --- .../src/bionemo/amplify/model.py | 354 ++++++++++++++++++ .../src/bionemo/amplify/tokenizer.py | 25 ++ .../tests/bionemo/amplify/test_model.py | 58 +++ .../tests/bionemo/amplify/test_tokenizer.py | 35 ++ .../src/bionemo/esm2/model/embedding.py | 8 +- 5 files changed, 476 insertions(+), 4 deletions(-) create mode 100644 sub-packages/bionemo-amplify/src/bionemo/amplify/model.py create mode 100644 sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py create mode 100644 sub-packages/bionemo-amplify/tests/bionemo/amplify/test_model.py create mode 100644 sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/model.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/model.py new file mode 100644 index 0000000000..37d0bcb50f --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/model.py @@ -0,0 +1,354 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Any, Callable, Literal, Optional, Sequence, Type, TypeVar + +import torch +from megatron.core import tensor_parallel +from megatron.core.models.bert.pooler import Pooler +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.transformer import spec_utils +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import get_linear_layer +from torch import Tensor +from torch.nn.functional import silu +from torch.optim import Optimizer + +# from bionemo.amplify.data.tokenizer import BioNeMoAMPLIFYTokenizer +from bionemo.esm2.model.embedding import ESM2Embedding +from bionemo.llm.api import MegatronLossType +from bionemo.llm.model.biobert.model import BioBertConfig, MegatronBioBertModel, PositionEmbeddingKinds +from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption +from bionemo.llm.utils import iomixin_utils as iom + + +__all__: Sequence[str] = ( + "AMPLIFYConfig", + "AMPLIFYModel", +) + + +class AMPLIFYLMHead(MegatronModule): + """LM head for AMPLIFY. + + Args: + hidden_size: hidden size + config (TransformerConfig): TransformerConfig object + """ + + def __init__(self, config: TransformerConfig): + super().__init__(config=config) + self.head = IdentityOp() + + def forward(self, hidden_states: Tensor) -> Tensor: + return self.head(hidden_states) + + +class AMPLIFYModel(MegatronBioBertModel): + """AMPLIFY protein language model.""" + + def __init__( + self, + config: TransformerConfig, + num_tokentypes: int, + transformer_layer_spec: spec_utils.ModuleSpec, + vocab_size: int, + max_sequence_length: int, + tokenizer: Optional[Any] = None, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute", + rotary_percent: float = 1.0, + seq_len_interpolation_factor: Optional[float] = None, + add_binary_head: bool = True, + return_embeddings: bool = False, + include_embeddings: bool = False, + include_input_ids: bool = False, + use_full_attention_mask: bool = False, + include_hiddens: bool = False, + skip_logits: bool = False, + ) -> None: + """Initialize the AMPLIFY model. + + Args: + config (TransformerConfig): transformer config + num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers + vocab_size (int): vocabulary size + max_sequence_length (int): maximum size of sequence. This is used for positional embedding + tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model) + pre_process (bool): Include embedding layer (used with pipeline parallelism) + post_process (bool): Include an output layer (used with pipeline parallelism) + fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks + share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False. + position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope']. + Defaults is 'learned_absolute'. + rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. + Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None. + add_binary_head (bool): Whether to add a binary head. Defaults to True. + return_embeddings (bool): Whether to return embeddings. Defaults to False. + include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False. + include_input_ids (bool): Whether to include input_ids in the output dictionary. Defaults to False. + use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False. + include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False. + skip_logits (bool): Skip writing the token logits in output dict + """ + super(MegatronBioBertModel, self).__init__(config=config) + self.post_process = post_process + self.add_binary_head = add_binary_head + if return_embeddings: + assert self.post_process, "only return embeddings on the last pipeline stage" + # `b` = batch, `s` = sequence. + # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while + # the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two. + self.use_full_attention_mask = use_full_attention_mask + self.config: TransformerConfig = config + self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + self.add_binary_head = add_binary_head + self.return_embeddings = return_embeddings + self.include_embeddings = include_embeddings + self.include_hiddens = include_hiddens + self.include_input_ids = include_input_ids + self.skip_logits = skip_logits + + # megatron core pipelining currently depends on model type + self.model_type = ModelType.encoder_or_decoder + + # Embeddings. + if self.pre_process: + self.register_buffer( + "bert_position_id_tensor", + torch.arange(max_sequence_length, dtype=torch.long, requires_grad=False).unsqueeze(0), + persistent=False, + ) + # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding + # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor. + # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True)) + self.embedding = ESM2Embedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=position_embedding_type, + num_tokentypes=num_tokentypes, + # ESM2 NEW ARGS + token_dropout=self.config.token_dropout, + use_attention_mask=self.config.use_attention_mask, + mask_token_id=tokenizer.mask_token_id, + ) + + if self.position_embedding_type == "rope": + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + ) + + # Transformer. + self.encoder = TransformerBlock( + config=self.config, + spec=self.transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + # Output + if post_process: + # TODO: Make sure you are passing in the mpu_vocab_size properly + self.lm_head = AMPLIFYLMHead(config) + + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=True, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, + ) + + self.binary_head = None + if self.add_binary_head: + # TODO: Shoudl switch this to TE ? + self.binary_head = get_linear_layer( + config.hidden_size, 2, config.init_method, config.perform_initialization + ) + + self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel) + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + def embedding_forward( + self, + input_ids: Tensor, + position_ids: Tensor, + tokentype_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + ): + """Forward pass of the embedding layer. + + Args: + input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs. + position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs. + tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None. + attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None. + + Returns: + Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations. + """ + # ESM2 Customization: ESM2Embedding forward takes attention_mask + # in addition to the args required by LanguageModelEmbedding + return self.embedding( + input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask + ) + + +AMPLIFYModelT = TypeVar("AMPLIFYModelT", bound=AMPLIFYModel) + + +@dataclass +class AMPLIFYConfig(BioBertConfig[AMPLIFYModelT, MegatronLossType], iom.IOMixinWithGettersSetters): + """Configuration class for AMPLIFY model. + + Attributes: + num_layers: Number of layers in the model. + hidden_size: Hidden size of the model. + num_attention_heads: Number of attention heads in the model. + ffn_hidden_size: Hidden size of the feed-forward network. + hidden_dropout: Dropout rate for hidden layers. + attention_dropout: Dropout rate for attention layers. + apply_residual_connection_post_layernorm: Whether to apply residual connection after layer normalization. + layernorm_epsilon: Epsilon value for layer normalization. + layernorm_zero_centered_gamma: Whether to zero-center the gamma parameter in layer normalization. + activation_func: Activation function used in the model. + init_method_std: Standard deviation for weight initialization. + apply_query_key_layer_scaling: Whether to apply scaling to query and key layers. + masked_softmax_fusion: Whether to use a kernel that fuses attention softmax with its mask. + fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + share_embeddings_and_output_weights: Whether to share embeddings and output weights. + enable_autocast: Whether to enable autocast for mixed precision. + biobert_spec_option: BiobertSpecOption for the model. + position_embedding_type: Type of position embedding used in the model. + seq_length: Length of the input sequence. + make_vocab_size_divisible_by: Make the vocabulary size divisible by this value. + token_dropout: Whether to apply token dropout. + use_attention_mask: Whether to use attention mask. + use_esm_attention: Whether to use ESM attention. + attention_softmax_in_fp32: Whether to use fp32 for attention softmax. + optimizer_fn: Optional optimizer function for the model. + parallel_output: Whether to use parallel output. + rotary_base: Base value for rotary positional encoding. + rotary_percent: Percentage of rotary positional encoding. + seq_len_interpolation_factor: Interpolation factor for sequence length. + get_attention_mask_from_fusion: Whether to get attention mask from fusion. + nemo1_ckpt_path: Path to NEMO1 checkpoint. + return_only_hidden_states: Whether to return only hidden states. + loss_reduction_class: Loss reduction class for the model. Default to BERTMLMLossWithReduction. + """ + + # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269 + model_cls: Type[AMPLIFYModelT] = AMPLIFYModel + seq_length: int = 512 + num_layers: int = 24 # 32 for 350M, 24 for 120M + hidden_size: int = 640 # 960 for 350M, 640 for 120M + num_attention_heads: int = 10 # 15 for 350M, 10 for 120M + ffn_hidden_size: int = 2560 # Transformer FFN hidden size. Usually 4 * hidden_size. + hidden_dropout: float = 0 # AMPLIFY removes dropout from hidden layers and attention + attention_dropout: float = 0.0 # AMPLIFY does not use attention dropout + apply_residual_connection_post_layernorm: bool = False # TODO: farhadr False is new default, True was BERT pub. + layernorm_epsilon: float = 1.0e-5 + init_method_std: float = 0.02 + + # embedding + token_dropout: bool = True + use_attention_mask: bool = True + + # core attention + use_esm_attention: bool = False # Skip ESM2 custom attention for TE acceleration. Still passes golden value test. + attention_softmax_in_fp32: bool = False + normalize_attention_scores: bool = False + + # From megatron.core.models.gpt.bert_model.GPTModel + fp16_lm_cross_entropy: bool = False # Move the cross entropy unreduced loss calculation for lm head to fp16 + parallel_output: bool = True + share_embeddings_and_output_weights: bool = True + make_vocab_size_divisible_by: int = 1 + position_embedding_type: PositionEmbeddingKinds = "rope" + rotary_base: int = 10000 + rotary_percent: float = 1.0 + + # AMPLIFY specific configuration + add_bias_linear: bool = False # AMPLIFY does not use bias in linear layers + bias_swiglu_fusion: bool = True + bias_activation_fusion: bool = False + bias_dropout_fusion: bool = False + apply_rope_fusion: bool = True + gated_linear_unit: bool = True + activation_func: str = silu + normalization: str = "RMSNorm" # AMPLIFY uses RMSNorm instead of LayerNorm + layernorm_zero_centered_gamma: bool = False # Zero centered gamma not supported for RMSNorm + biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec + + # TODO: Move this to better places? + get_attention_mask_from_fusion: bool = False + + optimizer_fn: Optional[Callable[[MegatronBioBertModel], Optimizer]] = None + # TODO (@skothenhill,@georgea) update to use the nemo2 checkpoint mixins + # support HF (requires weight interleaving on qkv layer) and nemo1 checkpoints ideally. + nemo1_ckpt_path: str | None = None + # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in + # self.override_parent_fields will be loaded from the checkpoint and override those values here. + initial_ckpt_path: str | None = None + # TODO (@jstjohn) come up with a cleaner way in the biobert module to return user requested + # things as part of the workflow for inference and fine-tuning. + return_embeddings: bool = False + include_embeddings: bool = False + skip_logits: bool = False + return_only_hidden_states: bool = False # return logits + + def __post_init__(self): + """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization.""" + super().__post_init__() + if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec: + self.apply_query_key_layer_scaling = False + if self.gated_linear_unit: + # To keep the number of parameters and the amount of computation constant, we reduce the number of + # hidden units by a factor of 2/3 (https://arxiv.org/pdf/2002.05202.pdf) and make it a multiple of 8 to + # avoid RuntimeError due to misaligned operand + multiple_of = 8 + self.ffn_hidden_size = int(2 * self.ffn_hidden_size / 3) + self.ffn_hidden_size = multiple_of * ((self.ffn_hidden_size + multiple_of - 1) // multiple_of) + else: + raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}") diff --git a/sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py b/sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py new file mode 100644 index 0000000000..95cb691c51 --- /dev/null +++ b/sub-packages/bionemo-amplify/src/bionemo/amplify/tokenizer.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import transformers +from nemo.lightning.io import IOMixin + + +class BioNeMoAMPLIFYTokenizer(transformers.PreTrainedTokenizerFast, IOMixin): # noqa D101 + def __init__(self): + """A wrapper to make AutoTokenizer serializable for the ESM2 tokenizer.""" + other = transformers.AutoTokenizer.from_pretrained("chandar-lab/AMPLIFY_350M", use_fast=True) + self.__dict__.update(dict(other.__dict__)) diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_model.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_model.py new file mode 100644 index 0000000000..18ee0a798f --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_model.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from bionemo.amplify.model import AMPLIFYConfig, AMPLIFYModel +from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer +from bionemo.esm2.model.embedding import ESM2Embedding +from bionemo.llm.model.biobert.model import MegatronBioBertModel +from bionemo.testing import megatron_parallel_state_utils + + +def test_amplify_model_initialized(): + with megatron_parallel_state_utils.distributed_model_parallel_state(): + tokenizer = BioNeMoAMPLIFYTokenizer() + config = AMPLIFYConfig() + model = config.configure_model(tokenizer) + + assert isinstance(model, MegatronBioBertModel) + assert isinstance(model, AMPLIFYModel) + assert isinstance(model.embedding, ESM2Embedding) + + +def test_amplify_model_forward_pass(): + tokenizer = BioNeMoAMPLIFYTokenizer() + + test_proteins = [ + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", + "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", + ] + tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda") + input_ids = tokens["input_ids"] + attention_mask = tokens["attention_mask"] + + nemo_config = AMPLIFYConfig( + num_layers=2, + num_attention_heads=2, + hidden_size=4, + ffn_hidden_size=4 * 4, + ) + + with megatron_parallel_state_utils.distributed_model_parallel_state(): + nemo_model = nemo_config.configure_model(tokenizer).to("cuda").eval() + nemo_output = nemo_model(input_ids, attention_mask) + assert isinstance(nemo_output["token_logits"], torch.Tensor) diff --git a/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py new file mode 100644 index 0000000000..7722d3c988 --- /dev/null +++ b/sub-packages/bionemo-amplify/tests/bionemo/amplify/test_tokenizer.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch +from nemo.lightning import io + +from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer + + +@pytest.fixture +def tokenizer(): + return BioNeMoAMPLIFYTokenizer() + + +def test_tokenizer_serialization(tokenizer, tmp_path): + tokenizer.io_dump(tmp_path / "tokenizer", yaml_attrs=[]) # BioNeMoESMTokenizer takes no __init__ arguments + deserialized_tokenizer = io.load(tmp_path / "tokenizer", tokenizer.__class__) + + our_tokens = deserialized_tokenizer.encode("KAISQ", add_special_tokens=False) + amplify_tokens = torch.tensor([17, 7, 2, 14, 10, 18]) + torch.testing.assert_close(torch.tensor(our_tokens), amplify_tokens) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py index 40783c7b8d..1c8323efa8 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py @@ -44,7 +44,7 @@ def __init__( # ESM2 NEW ARGS token_dropout: bool = True, use_attention_mask: bool = True, - mask_token_id: Optional[int] = torch.nan, + mask_token_id: Optional[int] = None, ) -> None: """Initialize the ESM2 Embedding module.""" super().__init__( @@ -65,7 +65,7 @@ def dtype(self) -> torch.dtype: def _apply_esm2_customization( self, word_embeddings: Tensor, input_ids: Tensor, attention_mask: Tensor - ) -> Tuple[Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor | None]: """ESM2 customization for attention masking and token dropout. Args: @@ -95,7 +95,7 @@ def forward( self, input_ids: Tensor, position_ids: Tensor, - tokentype_ids: Optional[int] = None, + tokentype_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass of the embedding module. @@ -103,7 +103,7 @@ def forward( Args: input_ids (Tensor): The input tokens. Shape: [b, s] position_ids (Tensor): The position id's used to calculate position embeddings. Shape: [b, s] - tokentype_ids (int, optional): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None + tokentype_ids (Tensor, optional): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None attention_mask (Tensor): attention mask. Shape: [b, s] Returns: