Skip to content

Commit 4ef017a

Browse files
committed
lenet实现
0 parents  commit 4ef017a

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

lenet/main.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#-*- coding:utf-8-*-
2+
3+
import torch
4+
from torch.utils.data import Dataset, DataLoader
5+
6+
from moudle import LeNet
7+
from torchvision.datasets.mnist import MNIST
8+
import torchvision.transforms as transforms
9+
import time
10+
11+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12+
13+
batch_size=1
14+
epoch_num=5 #2-0.8323, 5-0.9545
15+
LR = 0.001
16+
17+
data_train = MNIST('../data/mnist',
18+
download=True,
19+
transform=transforms.Compose([
20+
# transforms.Resize((32, 32)),
21+
transforms.ToTensor(),
22+
23+
]))
24+
25+
data_test = MNIST('../data/mnist',
26+
train=False,
27+
download=True,
28+
transform=transforms.Compose([
29+
30+
#transforms.Resize((32, 32)),
31+
transforms.ToTensor()]))
32+
33+
34+
35+
train_loader = DataLoader(
36+
dataset= data_train, #CustomDataset(),
37+
batch_size=batch_size, # 批大小
38+
39+
shuffle=True, # 是否随机打乱顺序
40+
num_workers=8, # 多线程读取数据的线程数
41+
)
42+
43+
test_loader = DataLoader(
44+
dataset= data_test, #CustomDataset(),
45+
batch_size=batch_size, # 批大小
46+
47+
shuffle=True, # 是否随机打乱顺序
48+
num_workers=8, # 多线程读取数据的线程数
49+
)
50+
51+
52+
net = LeNet().to(device)
53+
54+
opt = torch.optim.SGD(net.parameters(), lr=LR)
55+
56+
loss_function = torch.nn.CrossEntropyLoss()
57+
58+
def train():
59+
net.train()
60+
for epoch in range(epoch_num):
61+
total_loss = 0
62+
epoch_step = 0
63+
tic = time.time()
64+
65+
for batch_image, batch_label in train_loader:
66+
batch_image = batch_image.to(device)
67+
batch_label = batch_label.to(device)
68+
69+
opt.zero_grad()
70+
output = net(batch_image)
71+
loss = loss_function(output, batch_label)
72+
73+
total_loss += loss
74+
epoch_step += 1
75+
76+
loss.backward()
77+
opt.step()
78+
79+
toc = time.time()
80+
print("one epoch does take approximately " + str((toc - tic)) + " seconds),average loss: " + str(total_loss/epoch_step))
81+
82+
#torch.save(net.state_dict(), "./moudle/moudle")
83+
84+
def test():
85+
net.eval()
86+
total_correct = 0
87+
for batch_image, batch_label in test_loader:
88+
batch_image = batch_image.to(device)
89+
batch_label = batch_label.to(device)
90+
91+
output = net(batch_image)
92+
93+
pred = output.detach().max(1)[1]
94+
total_correct += pred.eq(batch_label.view_as(pred)).sum()
95+
96+
print("total_correct:", float(total_correct) / len(data_test))
97+
98+
99+
100+
if __name__ == "__main__":
101+
train()
102+
test()
103+

lenet/moudle.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#-*-coding:utf-8-*-
2+
3+
'''
4+
《深度学习-卷积神经网络从入门到精通》中的lenet5实现(p43 - p44),
5+
但是原书公式的效果极差,所以相对原书做了如下修改:
6+
1. 将sigmoid改换成Relu
7+
2. 增加一个84 -> 10的全连接层
8+
3. 将平均池化换成最大池化
9+
'''
10+
11+
import torch
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
from torch.autograd import Variable
15+
16+
class LeNet(nn.Module):
17+
def __init__(self):
18+
super(LeNet, self).__init__()
19+
20+
self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
21+
self.conv2 = nn.Conv2d(6, 16, 5)
22+
self.conv3 = nn.Conv2d(16, 120, 5)
23+
24+
self.fc1 = nn.Linear(120, 84)
25+
self.fc2 = nn.Linear(84, 10)
26+
27+
def forward(self, input):
28+
# 28 * 28 - > 28 * 28 -> 14 * 14
29+
x = F.max_pool2d(F.relu(self.conv1(input)), 2, stride=2)
30+
# 14 * 14 -> 10 * 10 - > 5 * 5
31+
x = F.max_pool2d(F.relu(self.conv2(x)), 2, stride=2)
32+
33+
x = F.relu(self.conv3(x))
34+
35+
x = x.view(x.size(0),-1)
36+
37+
# 84 * 84 -> 10*10
38+
x = self.fc2(F.relu(self.fc1(x)))
39+
40+
x = F.softmax(x,dim=1)
41+
#print(x.size)
42+
return x
43+
44+

0 commit comments

Comments
 (0)