Skip to content

Commit

Permalink
switch to update in validation epoch logging
Browse files Browse the repository at this point in the history
  • Loading branch information
sichu2023 committed Dec 28, 2024
1 parent fbc1b59 commit c70af90
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions sub-packages/bionemo-llm/src/bionemo/llm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]

if self.log_train_ppl and parallel_state.is_pipeline_last_stage():
self.train_ppl(logits, batch["labels"])
self.log("train_ppl", self.train_ppl, on_step=True, on_epoch=False)
train_metric_value = self.train_ppl(logits, batch["labels"])
self.log("train_ppl", train_metric_value, on_step=True, on_epoch=False)

return outputs

Expand All @@ -359,8 +359,7 @@ def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]

if self.log_val_ppl and parallel_state.is_pipeline_last_stage():
self.valid_ppl(logits, batch["labels"])
self.log("valid_ppl", self.valid_ppl, on_step=False, on_epoch=True)
self.valid_ppl.update(logits, batch["labels"])

return outputs

Expand All @@ -380,6 +379,11 @@ def validation_loss_reduction(self) -> MegatronLossType: # noqa: D102
def test_loss_reduction(self) -> MegatronLossType: # noqa: D102
return self.loss_reduction_class(validation_step=True)

def on_validation_end(self): # noqa: D102
valid_metric_value = self.valid_ppl.compute()
self.log("valid_ppl", valid_metric_value, on_step=False, on_epoch=True)
self.valid_ppl.reset()


def default_megatron_optimizer() -> MegatronOptimizerModule:
"""Default distributed optimizer uses Adam with a 1e-4 learning rate."""
Expand Down

0 comments on commit c70af90

Please sign in to comment.