-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
44 lines (30 loc) · 1.36 KB
/
Copy pathtrain.py
File metadata and controls
44 lines (30 loc) · 1.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from models import *
from torch import nn
from torch.optim import SGD
from dataset import *
class GAN():
def __init__(self, args):
self.args = args
self.data = LoadDataset()
self.cnn = CNN().to(self.args.device)
self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = SGD(self.cnn.parameters(), lr=0.001, momentum=0.9)
self.dataLoader = torch.utils.data.DataLoader(self.data, batch_size=args.batch_size, shuffle=True)
print("Training Dataset : {} prepared.".format(len(self.data)))
print("Network prepared.")
def run(self):
for epoch in range(self.args.epochs):
for _iter, data in enumerate(self.dataLoader):
sequence, labels = data
sequence = sequence.reshape(self.args.batch_size,32768,1)
sequence = sequence.to(self.args.device)
labels = labels.to(self.args.device)
self.optimizer.zero_grad()
outputs, _ = self.cnn(sequence)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
print(f"[Epoch] {epoch} - [Loss] {loss.item()}")
torch.save({'CNN': self.cnn.state_dict()}, 'models.pt')
print("Finished Training")