From c70af9073ff06673de7797ea0f971b38f1961d5c Mon Sep 17 00:00:00 2001 From: sichu Date: Sat, 28 Dec 2024 20:31:16 +0000 Subject: [PATCH] switch to update in validation epoch logging --- .../bionemo-llm/src/bionemo/llm/lightning.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py b/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py index e37c665fa5..d020d4b107 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py @@ -348,8 +348,8 @@ def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor: logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s] if self.log_train_ppl and parallel_state.is_pipeline_last_stage(): - self.train_ppl(logits, batch["labels"]) - self.log("train_ppl", self.train_ppl, on_step=True, on_epoch=False) + train_metric_value = self.train_ppl(logits, batch["labels"]) + self.log("train_ppl", train_metric_value, on_step=True, on_epoch=False) return outputs @@ -359,8 +359,7 @@ def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor: logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s] if self.log_val_ppl and parallel_state.is_pipeline_last_stage(): - self.valid_ppl(logits, batch["labels"]) - self.log("valid_ppl", self.valid_ppl, on_step=False, on_epoch=True) + self.valid_ppl.update(logits, batch["labels"]) return outputs @@ -380,6 +379,11 @@ def validation_loss_reduction(self) -> MegatronLossType: # noqa: D102 def test_loss_reduction(self) -> MegatronLossType: # noqa: D102 return self.loss_reduction_class(validation_step=True) + def on_validation_end(self): # noqa: D102 + valid_metric_value = self.valid_ppl.compute() + self.log("valid_ppl", valid_metric_value, on_step=False, on_epoch=True) + self.valid_ppl.reset() + def default_megatron_optimizer() -> MegatronOptimizerModule: """Default distributed optimizer uses Adam with a 1e-4 learning rate."""