超詳細(xì)PyTorch實(shí)現(xiàn)手寫數(shù)字識(shí)別器的示例代碼
前言
深度學(xué)習(xí)中有很多玩具數(shù)據(jù),mnist就是其中一個(gè),一個(gè)人能否入門深度學(xué)習(xí)往往就是以能否玩轉(zhuǎn)mnist數(shù)據(jù)來(lái)判斷的,在前面很多基礎(chǔ)介紹后我們就可以來(lái)實(shí)現(xiàn)一個(gè)簡(jiǎn)單的手寫數(shù)字識(shí)別的網(wǎng)絡(luò)了
數(shù)據(jù)的處理
我們使用pytorch自帶的包進(jìn)行數(shù)據(jù)的預(yù)處理
import torch import torchvision import torchvision.transforms as transforms import numpy as np import matplotlib.pyplot as plt transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5), (0.5)) ]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True,num_workers=2)
注釋:transforms.Normalize用于數(shù)據(jù)的標(biāo)準(zhǔn)化,具體實(shí)現(xiàn)
mean:均值 總和后除個(gè)數(shù)
std:方差 每個(gè)元素減去均值再平方再除個(gè)數(shù)
norm_data = (tensor - mean) / std
這里就直接將圖片標(biāo)準(zhǔn)化到了-1到1的范圍,標(biāo)準(zhǔn)化的原因就是因?yàn)槿绻硞€(gè)數(shù)在數(shù)據(jù)中很大很大,就導(dǎo)致其權(quán)重較大,從而影響到其他數(shù)據(jù),而本身我們的數(shù)據(jù)都是平等的,所以標(biāo)準(zhǔn)化后將數(shù)據(jù)分布到-1到1的范圍,使得所有數(shù)據(jù)都不會(huì)有太大的權(quán)重導(dǎo)致網(wǎng)絡(luò)出現(xiàn)巨大的波動(dòng)
trainloader現(xiàn)在是一個(gè)可迭代的對(duì)象,那么我們可以使用for循環(huán)進(jìn)行遍歷了,由于是使用yield返回的數(shù)據(jù),為了節(jié)約內(nèi)存
觀察一下數(shù)據(jù)
def imshow(img): img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # torchvision.utils.make_grid 將圖片進(jìn)行拼接 imshow(torchvision.utils.make_grid(iter(trainloader).next()[0]))

構(gòu)建網(wǎng)絡(luò)
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=28, kernel_size=5) # 14
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 無(wú)參數(shù)學(xué)習(xí)因此無(wú)需設(shè)置兩個(gè)
self.conv2 = nn.Conv2d(in_channels=28, out_channels=28*2, kernel_size=5) # 7
self.fc1 = nn.Linear(in_features=28*2*4*4, out_features=1024)
self.fc2 = nn.Linear(in_features=1024, out_features=10)
def forward(self, inputs):
x = self.pool(F.relu(self.conv1(inputs)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(inputs.size()[0],-1)
x = F.relu(self.fc1(x))
return self.fc2(x)
下面是卷積的動(dòng)態(tài)演示

in_channels:為輸入通道數(shù) 彩色圖片有3個(gè)通道 黑白有1個(gè)通道
out_channels:輸出通道數(shù)
kernel_size:卷積核的大小
stride:卷積的步長(zhǎng)
padding:外邊距大小
輸出的size計(jì)算公式
- h = (h - kernel_size + 2*padding)/stride + 1
- w = (w - kernel_size + 2*padding)/stride + 1
MaxPool2d:是沒有參數(shù)進(jìn)行運(yùn)算的

實(shí)例化網(wǎng)絡(luò)優(yōu)化器,并且使用GPU進(jìn)行訓(xùn)練
net = Net()
opt = torch.optim.Adam(params=net.parameters(), lr=0.001)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
Net( (conv1): Conv2d(1, 28, kernel_size=(5, 5), stride=(1, 1)) (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(28, 56, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=896, out_features=1024, bias=True) (fc2): Linear(in_features=1024, out_features=10, bias=True) )
訓(xùn)練主要代碼
for epoch in range(50):
for images, labels in trainloader:
images = images.to(device)
labels = labels.to(device)
pre_label = net(images)
loss = F.cross_entropy(input=pre_label, target=labels).mean()
pre_label = torch.argmax(pre_label, dim=1)
acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
net.zero_grad()
loss.backward()
opt.step()
print(acc.detach().cpu().numpy(), loss.detach().cpu().numpy())
F.cross_entropy交叉熵函數(shù)

源碼中已經(jīng)幫助我們實(shí)現(xiàn)了softmax因此不需要自己進(jìn)行softmax操作了
torch.argmax計(jì)算最大數(shù)所在索引值
acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32) # pre_label==labels 相同維度進(jìn)行比較相同返回True不同的返回False,True為1 False為0, 即可獲取到相等的個(gè)數(shù),再除總個(gè)數(shù),就得到了Accuracy準(zhǔn)確度了
預(yù)測(cè)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True,num_workers=2) images, labels = iter(testloader).next() images = images.to(device) labels = labels.to(device) with torch.no_grad(): pre_label = net(images) pre_label = torch.argmax(pre_label, dim=1) acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32) print(acc)
總結(jié)
本節(jié)我們了解了標(biāo)準(zhǔn)化數(shù)據(jù)·、卷積的原理、簡(jiǎn)答的構(gòu)建了一個(gè)網(wǎng)絡(luò),并讓它去識(shí)別手寫體,也是對(duì)前面章節(jié)的總匯了
到此這篇關(guān)于超詳細(xì)PyTorch實(shí)現(xiàn)手寫數(shù)字識(shí)別器的示例代碼的文章就介紹到這了,更多相關(guān)PyTorch 手寫數(shù)字識(shí)別器內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python連接MySQL數(shù)據(jù)庫(kù)的四種方法
用?Python?連接到?MySQL?數(shù)據(jù)庫(kù)的方法不是很系統(tǒng),實(shí)際中有幾種不同的連接方法,而且不是所有的方法都能與不同的操作系統(tǒng)很好地配合,本文涵蓋了四種方法,你可以用它們來(lái)連接你的Python應(yīng)用程序和MySQL,需要的朋友可以參考下2024-08-08
Python使用asyncio標(biāo)準(zhǔn)庫(kù)對(duì)異步IO的支持
Python中,所有程序的執(zhí)行都是單線程的,但可同時(shí)執(zhí)行多個(gè)任務(wù),不同的任務(wù)被時(shí)間循環(huán)(Event Loop)控制及調(diào)度,Asyncio是Python并發(fā)編程的一種實(shí)現(xiàn)方式;是Python 3.4版本引入的標(biāo)準(zhǔn)庫(kù),直接內(nèi)置了對(duì)異步IO的支持2023-11-11
對(duì)numpy中的transpose和swapaxes函數(shù)詳解
今天小編就為大家分享一篇對(duì)numpy中的transpose和swapaxes函數(shù)詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2018-08-08
Python機(jī)器學(xué)習(xí)應(yīng)用之基于LightGBM的分類預(yù)測(cè)篇解讀
這篇文章我們繼續(xù)學(xué)習(xí)一下GBDT模型的另一個(gè)進(jìn)化版本:LightGBM,LigthGBM是boosting集合模型中的新進(jìn)成員,由微軟提供,它和XGBoost一樣是對(duì)GBDT的高效實(shí)現(xiàn),原理上它和GBDT及XGBoost類似,都采用損失函數(shù)的負(fù)梯度作為當(dāng)前決策樹的殘差近似值,去擬合新的決策樹2022-01-01
Python面向?qū)ο蟪绦蛟O(shè)計(jì)構(gòu)造函數(shù)和析構(gòu)函數(shù)用法分析
這篇文章主要介紹了Python面向?qū)ο蟪绦蛟O(shè)計(jì)構(gòu)造函數(shù)和析構(gòu)函數(shù)用法,結(jié)合具體實(shí)例形式分析了Python面向?qū)ο蟪绦蛟O(shè)計(jì)中構(gòu)造函數(shù)與析構(gòu)函數(shù)的概念、原理、功能及相關(guān)使用技巧,需要的朋友可以參考下2019-04-04
解決python刪除文件的權(quán)限錯(cuò)誤問題
下面小編就為大家分享一篇解決python刪除文件的權(quán)限錯(cuò)誤問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2018-04-04
wxPython:python首選的GUI庫(kù)實(shí)例分享
wxPython是Python語(yǔ)言的一套優(yōu)秀的GUI圖形庫(kù)。允許Python程序員很方便的創(chuàng)建完整的、功能鍵全的GUI用戶界面。 wxPython是作為優(yōu)秀的跨平臺(tái)GUI庫(kù)wxWidgets的Python封裝和Python模塊的方式提供給用戶的2019-10-10

