關(guān)于pytorch中網(wǎng)絡(luò)loss傳播和參數(shù)更新的理解
相比于2018年,在ICLR2019提交論文中,提及不同框架的論文數(shù)量發(fā)生了極大變化,網(wǎng)友發(fā)現(xiàn),提及tensorflow的論文數(shù)量從2018年的228篇略微提升到了266篇,keras從42提升到56,但是pytorch的數(shù)量從87篇提升到了252篇。
TensorFlow: 228--->266
Keras: 42--->56
Pytorch: 87--->252
在使用pytorch中,自己有一些思考,如下:
1. loss計(jì)算和反向傳播
import torch.nn as nn criterion = nn.MSELoss().cuda() output = model(input) loss = criterion(output, target) loss.backward()
通過(guò)定義損失函數(shù):criterion,然后通過(guò)計(jì)算網(wǎng)絡(luò)真實(shí)輸出和真實(shí)標(biāo)簽之間的誤差,得到網(wǎng)絡(luò)的損失值:loss;
最后通過(guò)loss.backward()完成誤差的反向傳播,通過(guò)pytorch的內(nèi)在機(jī)制完成自動(dòng)求導(dǎo)得到每個(gè)參數(shù)的梯度。
需要注意,在機(jī)器學(xué)習(xí)或者深度學(xué)習(xí)中,我們需要通過(guò)修改參數(shù)使得損失函數(shù)最小化或最大化,一般是通過(guò)梯度進(jìn)行網(wǎng)絡(luò)模型的參數(shù)更新,通過(guò)loss的計(jì)算和誤差反向傳播,我們得到網(wǎng)絡(luò)中,每個(gè)參數(shù)的梯度值,后面我們?cè)偻ㄟ^(guò)優(yōu)化算法進(jìn)行網(wǎng)絡(luò)參數(shù)優(yōu)化更新。
2. 網(wǎng)絡(luò)參數(shù)更新
在更新網(wǎng)絡(luò)參數(shù)時(shí),我們需要選擇一種調(diào)整模型參數(shù)更新的策略,即優(yōu)化算法。
優(yōu)化算法中,簡(jiǎn)單的有一階優(yōu)化算法:

其中
就是通常說(shuō)的學(xué)習(xí)率,
是函數(shù)的梯度;
自己的理解是,對(duì)于復(fù)雜的優(yōu)化算法,基本原理也是這樣的,不過(guò)計(jì)算更加復(fù)雜。
在pytorch中,torch.optim是一個(gè)實(shí)現(xiàn)各種優(yōu)化算法的包,可以直接通過(guò)這個(gè)包進(jìn)行調(diào)用。
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
注意:
1)在前面部分1中,已經(jīng)通過(guò)loss的反向傳播得到了每個(gè)參數(shù)的梯度,然后再本部分通過(guò)定義優(yōu)化器(優(yōu)化算法),確定了網(wǎng)絡(luò)更新的方式,在上述代碼中,我們將模型的需要更新的參數(shù)傳入優(yōu)化器。
2)注意優(yōu)化器,即optimizer中,傳入的模型更新的參數(shù),對(duì)于網(wǎng)絡(luò)中有多個(gè)模型的網(wǎng)絡(luò),我們可以選擇需要更新的網(wǎng)絡(luò)參數(shù)進(jìn)行輸入即可,上述代碼,只會(huì)更新model中的模型參數(shù)。對(duì)于需要更新多個(gè)模型的參數(shù)的情況,可以參考以下代碼:
optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在優(yōu)化前需要先將梯度歸零,即optimizer.zeros()。
3. loss計(jì)算和參數(shù)更新
import torch.nn as nn import torch criterion = nn.MSELoss().cuda() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) output = model(input) loss = criterion(output, target) optimizer.zero_grad() # 將所有參數(shù)的梯度都置零 loss.backward() # 誤差反向傳播計(jì)算參數(shù)梯度 optimizer.step() # 通過(guò)梯度做一步參數(shù)更新
以上這篇關(guān)于pytorch中網(wǎng)絡(luò)loss傳播和參數(shù)更新的理解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python的Flask框架及Nginx實(shí)現(xiàn)靜態(tài)文件訪問(wèn)限制功能
這篇文章主要介紹了Python的Flask框架及Nginx實(shí)現(xiàn)靜態(tài)文件訪問(wèn)限制功能,Nginx方面利用到了自帶的XSendfile,需要的朋友可以參考下2016-06-06
python圖形開(kāi)發(fā)GUI庫(kù)pyqt5的基本使用方法詳解
這篇文章主要介紹了python圖形開(kāi)發(fā)GUI庫(kù)pyqt5的基本使用方法詳解,需要的朋友可以參考下2020-02-02
Flask?+?MySQL如何實(shí)現(xiàn)用戶(hù)注冊(cè),登錄和登出的項(xiàng)目實(shí)踐
本文主要介紹了Flask?+?MySQL?如何實(shí)現(xiàn)用戶(hù)注冊(cè),登錄和登出的項(xiàng)目實(shí)踐,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2022-06-06
利用matplotlib實(shí)現(xiàn)兩張子圖分別畫(huà)函數(shù)圖
這篇文章主要介紹了利用matplotlib實(shí)現(xiàn)兩張子圖分別畫(huà)函數(shù)圖問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-08-08
Python之使用adb shell命令啟動(dòng)應(yīng)用的方法詳解
今天小編就為大家分享一篇Python之使用adb shell命令啟動(dòng)應(yīng)用的方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01
Python?3.12安裝庫(kù)報(bào)錯(cuò)解決方案
這篇文章主要介紹了Python?3.12安裝庫(kù)報(bào)錯(cuò)的解決方案,講解了Python?3.12移除pkgutil.ImpImporter支持導(dǎo)致的AttributeError錯(cuò)誤,并提供了兩種解決方案,需要的朋友可以參考下2025-03-03
在vscode中配置python環(huán)境過(guò)程解析
這篇文章主要介紹了在vscode中配置python環(huán)境過(guò)程解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-09-09

