Skip to content

Commit da97c51

Browse files
Add lr and grad norm logging
1 parent 01fd8ee commit da97c51

File tree

6 files changed

+79
-0
lines changed

6 files changed

+79
-0
lines changed

configs/callbacks/default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
defaults:
22
- model_summary
33
- rich_progress_bar
4+
- lr_monitor
45
- _self_
56

67
model_checkpoint:

configs/callbacks/lr_monitor.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
lr_monitor:
2+
_target_: lightning.pytorch.callbacks.LearningRateMonitor
3+
logging_interval: step

configs/experiment/logging.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# @package _global_
2+
3+
logger:
4+
wandb:
5+
name: debug-lr-gradnorm-logging
6+
7+
trainer:
8+
max_steps: 100
9+
log_every_n_steps: 5
10+
val_check_interval: 5
11+
limit_val_batches: 2
12+
check_val_every_n_epoch: null
13+
14+
model:
15+
net:
16+
embedder:
17+
embedding_dim: 32
18+
encoder:
19+
n_layers: 2
20+
scheduler:
21+
_target_: transformers.get_cosine_schedule_with_warmup
22+
_partial_: true
23+
num_warmup_steps: 10
24+
num_training_steps: ${trainer.max_steps}
25+
26+
data:
27+
batch_size: 8
28+
per_device_batch_size: 8
29+
30+
compile: false

glm_experiments/models/bert_lit_module.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn as nn
77
from biofoundation.model.scoring import compute_llr_mlm
88
from lightning import LightningModule
9+
from lightning.pytorch.utilities import grad_norm
910
from sklearn.metrics import average_precision_score
1011
from torchmetrics.aggregation import CatMetric
1112

@@ -134,3 +135,8 @@ def configure_optimizers(self) -> dict[str, Any]:
134135
"interval": "step",
135136
},
136137
}
138+
139+
def on_before_optimizer_step(self, optimizer: torch.optim.Optimizer) -> None:
140+
"""Log gradient norm before optimizer step."""
141+
norms = grad_norm(self, norm_type=2)
142+
self.log("train/grad_norm", norms["grad_2.0_norm_total"])

tests/test_bert_lit_module.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,29 @@ def test_validation_step_mlm_still_works(bert_lit_module):
209209
# Should not raise
210210
result = bert_lit_module.validation_step(batch, batch_idx=0, dataloader_idx=0)
211211
assert result is None
212+
213+
214+
def test_on_before_optimizer_step_logs_grad_norm(bert_lit_module):
215+
"""Test that on_before_optimizer_step computes and logs gradient norm."""
216+
from lightning.pytorch.utilities import grad_norm
217+
218+
batch_size = 2
219+
seq_len = 100
220+
221+
batch = {
222+
"input_ids": torch.randint(0, 6, (batch_size, seq_len)),
223+
"labels": torch.randint(0, 6, (batch_size, seq_len)),
224+
"loss_weight": torch.ones(batch_size, seq_len),
225+
}
226+
227+
# Forward and backward to populate gradients
228+
loss = bert_lit_module.model_step(batch)
229+
loss.backward()
230+
231+
# Compute expected grad norm
232+
norms = grad_norm(bert_lit_module, norm_type=2)
233+
expected_norm = norms["grad_2.0_norm_total"]
234+
235+
# Verify gradients exist and norm is reasonable
236+
assert expected_norm > 0.0
237+
assert torch.isfinite(torch.tensor(expected_norm))

tests/test_configs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import hydra
2+
from hydra import compose, initialize
23
from hydra.core.hydra_config import HydraConfig
4+
from lightning.pytorch.callbacks import LearningRateMonitor
35
from omegaconf import DictConfig
46

57

@@ -35,3 +37,14 @@ def test_eval_config(cfg_eval: DictConfig) -> None:
3537
hydra.utils.instantiate(cfg_eval.data)
3638
hydra.utils.instantiate(cfg_eval.model)
3739
hydra.utils.instantiate(cfg_eval.trainer)
40+
41+
42+
def test_lr_monitor_callback_config() -> None:
43+
"""Test that LearningRateMonitor callback config instantiates correctly."""
44+
with initialize(version_base="1.3", config_path="../configs/callbacks"):
45+
cfg = compose(config_name="lr_monitor")
46+
47+
callback = hydra.utils.instantiate(cfg)
48+
49+
assert isinstance(callback, LearningRateMonitor)
50+
assert callback.logging_interval == "step"

0 commit comments

Comments
 (0)