python中的Pytorch建模流程匯總
本節(jié)內(nèi)容學(xué)習(xí)幫助大家梳理神經(jīng)網(wǎng)絡(luò)訓(xùn)練的架構(gòu)。
一般我們訓(xùn)練神經(jīng)網(wǎng)絡(luò)有以下步驟:
- 導(dǎo)入庫(kù)
- 設(shè)置訓(xùn)練參數(shù)的初始值
- 導(dǎo)入數(shù)據(jù)集并制作數(shù)據(jù)集
- 定義神經(jīng)網(wǎng)絡(luò)架構(gòu)
- 定義訓(xùn)練流程
- 訓(xùn)練模型
推薦文章:
分享4款 Python 自動(dòng)數(shù)據(jù)分析神器
以下,我就將上述步驟使用代碼進(jìn)行注釋講解:
1 導(dǎo)入庫(kù)
import torch from torch import nn from torch.nn import functional as F from torch import optim from torch.utils.data import DataLoader, DataLoader import torchvision import torchvision.transforms as transforms
2 設(shè)置初始值
# 學(xué)習(xí)率 lr = 0.15 # 優(yōu)化算法參數(shù) gamma = 0.8 # 每次小批次訓(xùn)練個(gè)數(shù) bs = 128 # 整體數(shù)據(jù)循環(huán)次數(shù) epochs = 10
3 導(dǎo)入并制作數(shù)據(jù)集
本次我們使用FashionMNIST圖像數(shù)據(jù)集,每個(gè)圖像是一個(gè)28*28的像素?cái)?shù)組,共有10個(gè)衣物類別,比如連衣裙、運(yùn)動(dòng)鞋、包等。
注:初次運(yùn)行下載需要等待較長(zhǎng)時(shí)間。
# 導(dǎo)入數(shù)據(jù)集 mnist = torchvision.datasets.FashionMNIST( ? ? root = './Datastes' ? ? , train = True ? ? , download = True ? ? , transform = transforms.ToTensor()) ? ?? # 制作數(shù)據(jù)集 batchdata = DataLoader(mnist ? ? ? ? ? ? ? ? ? ? ? ?, batch_size = bs ? ? ? ? ? ? ? ? ? ? ? ?, shuffle = True ? ? ? ? ? ? ? ? ? ? ? ?, drop_last = False)
我們可以對(duì)數(shù)據(jù)進(jìn)行檢查:
for x, y in batchdata: ? ? print(x.shape) ? ? print(y.shape) ? ? break # torch.Size([128, 1, 28, 28]) # torch.Size([128])
可以看到一個(gè)batch中有128個(gè)樣本,每個(gè)樣本的維度是1*28*28。
之后我們確定模型的輸入維度與輸出維度:
# 輸入的維度 input_ = mnist.data[0].numel() # 784 # 輸出的維度 output_ = len(mnist.targets.unique()) # 10
4 定義神經(jīng)網(wǎng)絡(luò)架構(gòu)
先使用一個(gè)128個(gè)神經(jīng)元的全連接層,然后用relu激活函數(shù),再將其結(jié)果映射到標(biāo)簽的維度,并使用softmax進(jìn)行激活。
# 定義神經(jīng)網(wǎng)絡(luò)架構(gòu) class Model(nn.Module): ? ? def __init__(self, in_features, out_features): ? ? ? ? super().__init__() ? ? ? ? self.linear1 = nn.Linear(in_features, 128, bias = True) ? ? ? ? self.output = nn.Linear(128, out_features, bias = True) ? ?? ? ? def forward(self, x): ? ? ? ? x = x.view(-1, 28*28) ? ? ? ? sigma1 = torch.relu(self.linear1(x)) ? ? ? ? sigma2 = F.log_softmax(self.output(sigma1), dim = -1) ? ? ? ? return sigma2
5 定義訓(xùn)練流程
在實(shí)際應(yīng)用中,我們一般會(huì)將訓(xùn)練模型部分封裝成一個(gè)函數(shù),而這個(gè)函數(shù)可以繼續(xù)細(xì)分為以下幾步:
- 定義損失函數(shù)與優(yōu)化器
- 完成向前傳播
- 計(jì)算損失
- 反向傳播
- 梯度更新
- 梯度清零
在此六步核心操作的基礎(chǔ)上,我們通常還需要對(duì)模型的訓(xùn)練進(jìn)度、損失值與準(zhǔn)確度進(jìn)行監(jiān)視。
注釋代碼如下:
# 封裝訓(xùn)練模型的函數(shù)
def fit(net, batchdata, lr, gamma, epochs):
# 參數(shù):模型架構(gòu)、數(shù)據(jù)、學(xué)習(xí)率、優(yōu)化算法參數(shù)、遍歷數(shù)據(jù)次數(shù)
? ? # 5.1 定義損失函數(shù)
? ? criterion = nn.NLLLoss()
? ? # 5.1 定義優(yōu)化算法
? ? opt = optim.SGD(net.parameters(), lr = lr, momentum = gamma)
? ??
? ? # 監(jiān)視進(jìn)度:循環(huán)之前,一個(gè)樣本都沒(méi)有看過(guò)
? ? samples = 0
? ? # 監(jiān)視準(zhǔn)確度:循環(huán)之前,預(yù)測(cè)正確的個(gè)數(shù)為0
? ? corrects = 0
? ??
? ? # 全數(shù)據(jù)訓(xùn)練幾次
? ? for epoch in range(epochs):
? ? ? ? # 對(duì)每個(gè)batch進(jìn)行訓(xùn)練
? ? ? ? for batch_idx, (x, y) in enumerate(batchdata):
? ? ? ? ? ? # 保險(xiǎn)起見,將標(biāo)簽轉(zhuǎn)為1維,與樣本對(duì)齊
? ? ? ? ? ? y = y.view(x.shape[0])
? ? ? ? ? ??
? ? ? ? ? ? # 5.2 正向傳播
? ? ? ? ? ? sigma = net.forward(x)
? ? ? ? ? ? # 5.3 計(jì)算損失
? ? ? ? ? ? loss = criterion(sigma, y)
? ? ? ? ? ? # 5.4 反向傳播
? ? ? ? ? ? loss.backward()
? ? ? ? ? ? # 5.5 更新梯度
? ? ? ? ? ? opt.step()
? ? ? ? ? ? # 5.6 梯度清零
? ? ? ? ? ? opt.zero_grad()
? ? ? ? ? ??
? ? ? ? ? ? # 監(jiān)視進(jìn)度:每訓(xùn)練一個(gè)batch,模型見過(guò)的數(shù)據(jù)就會(huì)增加x.shape[0]
? ? ? ? ? ? samples += x.shape[0]
? ? ? ? ? ??
? ? ? ? ? ? # 求解準(zhǔn)確度:全部判斷正確的樣本量/已經(jīng)看過(guò)的總樣本量
? ? ? ? ? ? # 得到預(yù)測(cè)標(biāo)簽
? ? ? ? ? ? yhat = torch.max(sigma, -1)[1]
? ? ? ? ? ? # 將正確的加起來(lái)
? ? ? ? ? ? corrects += torch.sum(yhat == y)
? ? ? ? ? ??
? ? ? ? ? ? # 每200個(gè)batch和最后結(jié)束時(shí),打印模型的進(jìn)度
? ? ? ? ? ? if (batch_idx + 1) % 200 == 0 or batch_idx == (len(batchdata) - 1):
? ? ? ? ? ? ? ? # 監(jiān)督模型進(jìn)度
? ? ? ? ? ? ? ? print("Epoch{}:[{}/{} {: .0f}%], Loss:{:.6f}, Accuracy:{:.6f}".format(
? ? ? ? ? ? ? ? ? ? epoch + 1
? ? ? ? ? ? ? ? ? ? , samples
? ? ? ? ? ? ? ? ? ? , epochs*len(batchdata.dataset)
? ? ? ? ? ? ? ? ? ? , 100*samples/(epochs*len(batchdata.dataset))
? ? ? ? ? ? ? ? ? ? , loss.data.item()
? ? ? ? ? ? ? ? ? ? , float(100.0*corrects/samples)))6 訓(xùn)練模型
# 設(shè)置隨機(jī)種子 torch.manual_seed(51) # 實(shí)例化模型 net = Model(input_, output_) # 訓(xùn)練模型 fit(net, batchdata, lr, gamma, epochs) # Epoch1:[25600/600000 ?4%], Loss:0.524430, Accuracy:69.570312 # Epoch1:[51200/600000 ?9%], Loss:0.363422, Accuracy:74.984375 # ...... # Epoch10:[600000/600000 ?100%], Loss:0.284664, Accuracy:85.771835
現(xiàn)在我們已經(jīng)用Pytorch訓(xùn)練了最基礎(chǔ)的神經(jīng)網(wǎng)絡(luò),并且可以查看其訓(xùn)練成果。大家可以將代碼復(fù)制進(jìn)行運(yùn)行!
雖然沒(méi)有用到復(fù)雜的模型,但是我們?cè)诿看谓r(shí)的基本思想都是一致的
到此這篇關(guān)于python中的Pytorch建模流程匯總的文章就介紹到這了,更多相關(guān)Pytorch建模流程內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
詳解python如何在django中為用戶模型添加自定義權(quán)限
這篇文章主要介紹了python如何在django中為用戶模型添加自定義權(quán)限,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-10-10
Python列表刪除元素del、pop()和remove()的區(qū)別小結(jié)
這篇文章主要給大家介紹了關(guān)于Python列表刪除元素del、pop()和remove()的區(qū)別,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家學(xué)習(xí)或者使用Python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-09-09
python中Flask框架簡(jiǎn)單入門實(shí)例
這篇文章主要介紹了python中Flask框架簡(jiǎn)單入門實(shí)例,以一個(gè)hello程序簡(jiǎn)單分析了Flask框架的使用技巧,需要的朋友可以參考下2015-03-03
Python使用gluon/mxnet模塊實(shí)現(xiàn)的mnist手寫數(shù)字識(shí)別功能完整示例
這篇文章主要介紹了Python使用gluon/mxnet模塊實(shí)現(xiàn)的mnist手寫數(shù)字識(shí)別功能,結(jié)合完整實(shí)例形式分析了Python調(diào)用gluon/mxnet模塊識(shí)別手寫字的具體實(shí)現(xiàn)技巧,需要的朋友可以參考下2019-12-12
Django框架實(shí)現(xiàn)的分頁(yè)demo示例
這篇文章主要介紹了Django框架實(shí)現(xiàn)的分頁(yè)demo,結(jié)合實(shí)例形式分析了Django框架分頁(yè)的步驟、原理、相關(guān)操作技巧與注意事項(xiàng),需要的朋友可以參考下2019-05-05
Scrapy爬蟲框架集成selenium及全面詳細(xì)講解
這篇文章主要為大家介紹了Scrapy集成selenium,以及scarpy爬蟲框架全面講解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步早日升職加薪2022-04-04

