Skip to content

Commit

Permalink
replace tensor_parallel.ColumnParallelLinear with torch.nn.Linear to …
Browse files Browse the repository at this point in the history
…debug
  • Loading branch information
sichu2023 committed Jan 21, 2025
1 parent 7b69b9e commit 3616736
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
26 changes: 15 additions & 11 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import torch
import torch.distributed
from megatron.core import tensor_parallel
from megatron.core.models.bert.bert_lm_head import BertLMHead
from megatron.core.models.bert.pooler import Pooler
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
Expand Down Expand Up @@ -178,16 +177,21 @@ def __init__(
config,
)

self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=True,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
)
# TODO replace tensor_parallel.ColumnParallelLinear with torch.nn.Linear to debug; remove once complete
# self.output_layer = tensor_parallel.ColumnParallelLinear(
# config.hidden_size,
# self.vocab_size,
# config=config,
# init_method=config.init_method,
# is_expert=False,
# bias=True,
# skip_bias_add=False,
# gather_output=not self.parallel_output,
# skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
# embedding_activation_buffer=self.embedding_activation_buffer,
# grad_output_buffer=self.grad_output_buffer,
# )
self.output_layer = torch.nn.Linear(config.hidden_size, self.vocab_size, bias=True)

self.binary_head = None
if self.add_binary_head:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,14 @@ def forward(
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
output_weight = self.shared_embedding_or_output_weight() # noqa: F841

hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states)
if not self.skip_logits:
# TODO add , runtime_gather_output=runtime_gather_output once supported in ColumnParallelLinear
logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight)
# TODO replace tensor_parallel.ColumnParallelLinear with torch.nn.Linear to debug; remove once complete
# logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight)
logits = self.output_layer(hidden_states_after_lm_head)
else:
logits = None

Expand Down

0 comments on commit 3616736

Please sign in to comment.