-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
71 lines (57 loc) · 2.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# --- 100 characters ------------------------------------------------------------------------------
# Created by: Shaun Spinelli 2019/10/12
import logging as lg
from tqdm import tqdm
import torch
# from . import metrics
_logger = lg.getLogger("train")
class Training:
def __init__(self, metrics, loss, optim, data, epochs, model, save_dir):
"""Training runner
Args:
metrics (MetricManager):
loss (torch.nn.modules.loss):
optim (torch.optim):
data (DataLoader):
epochs (int):
model ():
save_dir (str): directory to save model
"""
self.metrics = metrics
self.loss = loss
self.optim = optim
self.data = data
self.model = model
self.epochs = epochs
self.save_dir = save_dir
self.step = 0
def train_step(self, batch):
data, labels = batch
# data.cuda(), labels.cuda()
preds = self.model(data.cuda())
loss = self.loss(preds, labels.cuda())
self.metrics.update(preds, labels, self.step)
if self.metrics.writer:
self.metrics.writer.add_scalar("loss", loss.item(), self.step)
# _logger.debug(f'Loss: {loss.item()}')
self.optim.zero_grad() # zero gradients
loss.backward() # calculate gradients
self.optim.step() # updated weights
def save_checkpoint(self):
"""Save checkpoint with current step number"""
torch.save(self.model.state_dict(), f'{self.save_dir}/model-{self.step}.pth')
def train_loop(self):
for i in range(self.epochs):
# _logger.info(f'Epoch {i}/{self.epochs}')
print(f'Epoch {i}/{self.epochs}')
for batch in tqdm(self.data):
self.train_step(batch)
self.step += 1
self.metrics.reset()
self.save_checkpoint()
def run(self):
try:
self.train_loop()
except KeyboardInterrupt:
_logger.debug("Quitting due to user cancel")
self.save_checkpoint()