PyTorch中model.zero_grad()和optimizer.zero_grad()用法
廢話不多說,直接上代碼吧~
model.zero_grad()
optimizer.zero_grad()
首先,這兩種方式都是把模型中參數(shù)的梯度設為0
當optimizer = optim.Optimizer(net.parameters())時,二者等效,其中Optimizer可以是Adam、SGD等優(yōu)化器
def zero_grad(self): """Sets gradients of all model parameters to zero.""" for p in self.parameters(): if p.grad is not None: p.grad.data.zero_()
補充知識:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解
引言
一般訓練神經網(wǎng)絡,總是逃不開optimizer.zero_grad之后是loss(后面有的時候還會寫forward,看你網(wǎng)絡怎么寫了)之后是是net.backward之后是optimizer.step的這個過程。
real_a, real_b = batch[0].to(device), batch[1].to(device) fake_b = net_g(real_a) optimizer_d.zero_grad() # 判別器對虛假數(shù)據(jù)進行訓練 fake_ab = torch.cat((real_a, fake_b), 1) pred_fake = net_d.forward(fake_ab.detach()) loss_d_fake = criterionGAN(pred_fake, False) # 判別器對真實數(shù)據(jù)進行訓練 real_ab = torch.cat((real_a, real_b), 1) pred_real = net_d.forward(real_ab) loss_d_real = criterionGAN(pred_real, True) # 判別器損失 loss_d = (loss_d_fake + loss_d_real) * 0.5 loss_d.backward() optimizer_d.step()
上面這是一段cGAN的判別器訓練過程。標題中所涉及到的這些方法,其實整個神經網(wǎng)絡的參數(shù)更新過程(特別是反向傳播),具體是怎么操作的,我們一起來探討一下。
參數(shù)更新和反向傳播

上圖為一個簡單的梯度下降示意圖。比如以SGD為例,是算一個batch計算一次梯度,然后進行一次梯度更新。這里梯度值就是對應偏導數(shù)的計算結果。顯然,我們進行下一次batch梯度計算的時候,前一個batch的梯度計算結果,沒有保留的必要了。所以在下一次梯度更新的時候,先使用optimizer.zero_grad把梯度信息設置為0。
我們使用loss來定義損失函數(shù),是要確定優(yōu)化的目標是什么,然后以目標為頭,才可以進行鏈式法則和反向傳播。
調用loss.backward方法時候,Pytorch的autograd就會自動沿著計算圖反向傳播,計算每一個葉子節(jié)點的梯度(如果某一個變量是由用戶創(chuàng)建的,則它為葉子節(jié)點)。使用該方法,可以計算鏈式法則求導之后計算的結果值。
optimizer.step用來更新參數(shù),就是圖片中下半部分的w和b的參數(shù)更新操作。
以上這篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
用Python寫飛機大戰(zhàn)游戲之pygame入門(4):獲取鼠標的位置及運動
這篇文章主要介紹了用Python寫飛機大戰(zhàn)游戲之pygame入門(4):獲取鼠標的位置及運動,需要的朋友可以參考下2015-11-11
ubuntu?20.04系統(tǒng)下如何切換gcc/g++/python的版本
這篇文章主要給大家介紹了關于ubuntu?20.04系統(tǒng)下如何切換gcc/g++/python版本的相關資料,文中通過代碼介紹的非常詳細,對大家學習或者使用ubuntu具有一定的參考借鑒價值,需要的朋友可以參考下2023-12-12
最好的Python DateTime 庫之 Pendulum 長篇解析
datetime 模塊是 Python 中最重要的內置模塊之一,它為實際編程問題提供許多開箱即用的解決方案,非常靈活和強大。例如,timedelta 是我最喜歡的工具之一2021-11-11

