From 3616736c95d19466039c0c892e0278df238aebe6 Mon Sep 17 00:00:00 2001 From: sichu Date: Tue, 21 Jan 2025 23:20:55 +0000 Subject: [PATCH] replace tensor_parallel.ColumnParallelLinear with torch.nn.Linear to debug --- .../src/bionemo/esm2/model/model.py | 26 +++++++++++-------- .../src/bionemo/llm/model/biobert/model.py | 6 +++-- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py index 7f89afbb78..48ed843382 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py @@ -21,7 +21,6 @@ import torch import torch.distributed -from megatron.core import tensor_parallel from megatron.core.models.bert.bert_lm_head import BertLMHead from megatron.core.models.bert.pooler import Pooler from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding @@ -178,16 +177,21 @@ def __init__( config, ) - self.output_layer = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - self.vocab_size, - config=config, - init_method=config.init_method, - bias=True, - skip_bias_add=False, - gather_output=not self.parallel_output, - skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, - ) + # TODO replace tensor_parallel.ColumnParallelLinear with torch.nn.Linear to debug; remove once complete + # self.output_layer = tensor_parallel.ColumnParallelLinear( + # config.hidden_size, + # self.vocab_size, + # config=config, + # init_method=config.init_method, + # is_expert=False, + # bias=True, + # skip_bias_add=False, + # gather_output=not self.parallel_output, + # skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, + # embedding_activation_buffer=self.embedding_activation_buffer, + # grad_output_buffer=self.grad_output_buffer, + # ) + self.output_layer = torch.nn.Linear(config.hidden_size, self.vocab_size, bias=True) self.binary_head = None if self.add_binary_head: diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py index cc14b6f7d9..a328312e8c 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py @@ -444,12 +444,14 @@ def forward( # logits and loss output_weight = None if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() + output_weight = self.shared_embedding_or_output_weight() # noqa: F841 hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states) if not self.skip_logits: # TODO add , runtime_gather_output=runtime_gather_output once supported in ColumnParallelLinear - logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight) + # TODO replace tensor_parallel.ColumnParallelLinear with torch.nn.Linear to debug; remove once complete + # logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight) + logits = self.output_layer(hidden_states_after_lm_head) else: logits = None