詳解PyTorch手寫數(shù)字識(shí)別(MNIST數(shù)據(jù)集)
MNIST 手寫數(shù)字識(shí)別是一個(gè)比較簡單的入門項(xiàng)目,相當(dāng)于深度學(xué)習(xí)中的 Hello World,可以讓我們快速了解構(gòu)建神經(jīng)網(wǎng)絡(luò)的大致過程。雖然網(wǎng)上的案例比較多,但還是要自己實(shí)現(xiàn)一遍。代碼采用 PyTorch 1.0 編寫并運(yùn)行。
導(dǎo)入相關(guān)庫
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import torchvision from torch.autograd import Variable from torch.utils.data import DataLoader import cv2
torchvision 用于下載并導(dǎo)入數(shù)據(jù)集
cv2 用于展示數(shù)據(jù)的圖像
獲取訓(xùn)練集和測試集
# 下載訓(xùn)練集
train_dataset = datasets.MNIST(root='./num/',
train=True,
transform=transforms.ToTensor(),
download=True)
# 下載測試集
test_dataset = datasets.MNIST(root='./num/',
train=False,
transform=transforms.ToTensor(),
download=True)
root 用于指定數(shù)據(jù)集在下載之后的存放路徑
transform 用于指定導(dǎo)入數(shù)據(jù)集需要對(duì)數(shù)據(jù)進(jìn)行那種變化操作
train是指定在數(shù)據(jù)集下載完成后需要載入的那部分?jǐn)?shù)據(jù),設(shè)置為 True 則說明載入的是該數(shù)據(jù)集的訓(xùn)練集部分,設(shè)置為 False 則說明載入的是該數(shù)據(jù)集的測試集部分
download 為 True 表示數(shù)據(jù)集需要程序自動(dòng)幫你下載
這樣設(shè)置并運(yùn)行后,就會(huì)在指定路徑中下載 MNIST 數(shù)據(jù)集,之后就可以使用了。
數(shù)據(jù)裝載和預(yù)覽
# dataset 參數(shù)用于指定我們載入的數(shù)據(jù)集名稱
# batch_size參數(shù)設(shè)置了每個(gè)包中的圖片數(shù)據(jù)個(gè)數(shù)
# 在裝載的過程會(huì)將數(shù)據(jù)隨機(jī)打亂順序并進(jìn)打包
# 裝載訓(xùn)練集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
# 裝載測試集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True)
在裝載完成后,可以選取其中一個(gè)批次的數(shù)據(jù)進(jìn)行預(yù)覽:
images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)
在以上代碼中使用了 iter 和 next 來獲取取一個(gè)批次的圖片數(shù)據(jù)和其對(duì)應(yīng)的圖片標(biāo)簽,然后使用 torchvision.utils 中的 make_grid 類方法將一個(gè)批次的圖片構(gòu)造成網(wǎng)格模式。
預(yù)覽圖片如下:

并且打印出了圖片相對(duì)應(yīng)的數(shù)字:

搭建神經(jīng)網(wǎng)絡(luò)
# 卷積層使用 torch.nn.Conv2d
# 激活層使用 torch.nn.ReLU
# 池化層使用 torch.nn.MaxPool2d
# 全連接層使用 torch.nn.Linear
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.BatchNorm1d(120), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.ReLU(),
nn.Linear(84, 10))
# 最后的結(jié)果一定要變?yōu)?10,因?yàn)閿?shù)字的選項(xiàng)是 0 ~ 9
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
前向傳播內(nèi)容:
首先經(jīng)過 self.conv1() 和 self.conv1() 進(jìn)行卷積處理
然后進(jìn)行 x = x.view(x.size()[0], -1),對(duì)參數(shù)實(shí)現(xiàn)扁平化(便于后面全連接層輸入)
最后通過 self.fc1() 和 self.fc2() 定義的全連接層進(jìn)行最后的分類
訓(xùn)練模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001
net = LeNet().to(device)
# 損失函數(shù)使用交叉熵
criterion = nn.CrossEntropyLoss()
# 優(yōu)化函數(shù)使用 Adam 自適應(yīng)優(yōu)化算法
optimizer = optim.Adam(
net.parameters(),
lr=LR,
)
epoch = 1
if __name__ == '__main__':
for epoch in range(epoch):
sum_loss = 0.0
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
optimizer.zero_grad() #將梯度歸零
outputs = net(inputs) #將數(shù)據(jù)傳入網(wǎng)絡(luò)進(jìn)行前向運(yùn)算
loss = criterion(outputs, labels) #得到損失函數(shù)
loss.backward() #反向傳播
optimizer.step() #通過梯度做一步參數(shù)更新
# print(loss)
sum_loss += loss.item()
if i % 100 == 99:
print('[%d,%d] loss:%.03f' %
(epoch + 1, i + 1, sum_loss / 100))
sum_loss = 0.0
測試模型
net.eval() #將模型變換為測試模式
correct = 0
total = 0
for data_test in test_loader:
images, labels = data_test
images, labels = Variable(images).cuda(), Variable(labels).cuda()
output_test = net(images)
_, predicted = torch.max(output_test, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print("correct1: ", correct)
print("Test acc: {0}".format(correct.item() /
len(test_dataset)))
訓(xùn)練及測試的情況:

98% 以上的成功率,效果還不錯(cuò)。
以上就是本文的全部內(nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
Python內(nèi)置random模塊生成隨機(jī)數(shù)的方法
這篇文章主要介紹了Python內(nèi)置random模塊生成隨機(jī)數(shù)的方法,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-05-05
使用Python操作Elasticsearch數(shù)據(jù)索引的教程
這篇文章主要介紹了使用Python操作Elasticsearch數(shù)據(jù)索引的教程,Elasticsearch處理數(shù)據(jù)索引非常高效,要的朋友可以參考下2015-04-04
使用Jupyter notebooks上傳文件夾或大量數(shù)據(jù)到服務(wù)器
這篇文章主要介紹了使用Jupyter notebooks上傳文件夾或大量數(shù)據(jù)到服務(wù)器,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-04-04
在 Jupyter 中重新導(dǎo)入特定的 Python 文件(場景分析)
Jupyter 是數(shù)據(jù)分析領(lǐng)域非常有名的開發(fā)環(huán)境,使用 Jupyter 寫數(shù)據(jù)分析相關(guān)的代碼會(huì)大大節(jié)約開發(fā)時(shí)間。這篇文章主要介紹了在 Jupyter 中如何重新導(dǎo)入特定的 Python 文件,需要的朋友可以參考下2019-10-10
python判斷all函數(shù)輸出結(jié)果是否為true的方法
在本篇內(nèi)容里小編給各位整理的是一篇關(guān)于python判斷all函數(shù)輸出結(jié)果是否為true的方法,有需要的朋友們可以學(xué)習(xí)下。2020-12-12
pyspark創(chuàng)建DataFrame的幾種方法
為了便于操作,使用pyspark時(shí)我們通常將數(shù)據(jù)轉(zhuǎn)為DataFrame的形式來完成清洗和分析動(dòng)作。那么你知道pyspark創(chuàng)建DataFrame有幾種方法嗎,下面就一起來了解一下2021-05-05
Python制作數(shù)據(jù)導(dǎo)入導(dǎo)出工具
正好最近在學(xué)習(xí)python,于是打算用python實(shí)現(xiàn)了數(shù)據(jù)導(dǎo)入導(dǎo)出工具,由于是新手,所以寫的有些不完善的地方還請見諒2015-07-07

