Skip to content

Commit

Permalink
add comment on unreduced_token_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
sichu2023 committed Jan 22, 2025
1 parent 3616736 commit 1ca668e
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def forward(
forward_out_report = {}

# NOTE: token_logits is [sequence, batch] but labels and other fiels, including the loss are [batch, sequence]
# TODO: logits always match on tp=1 and tp=2, and among tp ranks
# however, unreduced_token_loss does not match between tp=1 and tp=2. only match among tp ranks at tp=2
unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"]) # [b s]

# TODO(@jstjohn) also handle different output keys, like the sequence loss.
Expand Down

0 comments on commit 1ca668e

Please sign in to comment.