-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_reslpr.py
70 lines (55 loc) · 2.57 KB
/
train_reslpr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from utils.dataset_utils import ResLPRTrainDataset
from net.model import ResLPR
from utils.schedulers import LinearWarmupCosineAnnealingLR
import numpy as np
import wandb
from options import options as opt
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import torch.multiprocessing as mp
import os
class ResLPRModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = ResLPR(decoder=True)
self.loss_fn = nn.L1Loss()
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
([clean_name, de_id], degrad_patch, clean_patch) = batch
restored = self.net(degrad_patch)
loss = self.loss_fn(restored, clean_patch)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def lr_scheduler_step(self, scheduler, metric):
scheduler.step(self.current_epoch)
lr = scheduler.get_lr()
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=2e-4)
scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer, warmup_epochs=15, max_epochs=150)
return [optimizer], [scheduler]
if __name__ == '__main__':
print("Options")
print(opt)
if opt.wblogger is not None:
logger = WandbLogger(project=opt.wblogger, name="ResLPR")
else:
logger = TensorBoardLogger(save_dir="logs/")
trainset = ResLPRTrainDataset(opt)
# All model weights will be saved until the end of the training.
checkpoint_callback = ModelCheckpoint(dirpath=opt.ckpt_dir, every_n_epochs=1, save_top_k=-1)
trainloader = DataLoader(trainset, batch_size=opt.batch_size, pin_memory=True, shuffle=True,
drop_last=True, num_workers=opt.num_workers)
model = ResLPRModel()
trainer = pl.Trainer(max_epochs=opt.epochs, accelerator="gpu", devices=opt.num_gpus,
strategy="ddp_find_unused_parameters_true", logger=logger, callbacks=[checkpoint_callback])
# If opt.resume is True and the path of resume_ckpt is valid, then resume from that checkpoint.
ckpt_path = opt.resume_ckpt if opt.resume and opt.resume_ckpt and os.path.isfile(opt.resume_ckpt) else None
trainer.fit(model=model, train_dataloaders=trainloader, ckpt_path=ckpt_path)