Pytorch反向求導(dǎo)更新網(wǎng)絡(luò)參數(shù)的方法
方法一:手動計算變量的梯度,然后更新梯度
import torch from torch.autograd import Variable # 定義參數(shù) w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True) # 定義輸出 d = torch.mean(w1) # 反向求導(dǎo) d.backward() # 定義學(xué)習(xí)率等參數(shù) lr = 0.001 # 手動更新參數(shù) w1.data.zero_() # BP求導(dǎo)更新參數(shù)之前,需先對導(dǎo)數(shù)置0 w1.data.sub_(lr*w1.grad.data)
一個網(wǎng)絡(luò)中通常有很多變量,如果按照上述的方法手動求導(dǎo),然后更新參數(shù),是很麻煩的,這個時候可以調(diào)用torch.optim
方法二:使用torch.optim
import torch from torch.autograd import Variable import torch.nn as nn import torch.optim as optim # 這里假設(shè)我們定義了一個網(wǎng)絡(luò),為net steps = 10000 # 定義一個optim對象 optimizer = optim.SGD(net.parameters(), lr = 0.01) # 在for循環(huán)中更新參數(shù) for i in range(steps): optimizer.zero_grad() # 對網(wǎng)絡(luò)中參數(shù)當(dāng)前的導(dǎo)數(shù)置0 output = net(input) # 網(wǎng)絡(luò)前向計算 loss = criterion(output, target) # 計算損失 loss.backward() # 得到模型中參數(shù)對當(dāng)前輸入的梯度 optimizer.step() # 更新參數(shù)
注意:torch.optim只用于參數(shù)更新和對參數(shù)的梯度置0,不能計算參數(shù)的梯度,在使用torch.optim進(jìn)行參數(shù)更新之前,需要寫前向與反向傳播求導(dǎo)的代碼
以上這篇Pytorch反向求導(dǎo)更新網(wǎng)絡(luò)參數(shù)的方法就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
- 使用pytorch進(jìn)行張量計算、自動求導(dǎo)和神經(jīng)網(wǎng)絡(luò)構(gòu)建功能
- pytorch如何定義新的自動求導(dǎo)函數(shù)
- 在?pytorch?中實(shí)現(xiàn)計算圖和自動求導(dǎo)
- Pytorch自動求導(dǎo)函數(shù)詳解流程以及與TensorFlow搭建網(wǎng)絡(luò)的對比
- 淺談Pytorch中的自動求導(dǎo)函數(shù)backward()所需參數(shù)的含義
- pytorch中的自定義反向傳播,求導(dǎo)實(shí)例
- 關(guān)于PyTorch 自動求導(dǎo)機(jī)制詳解
- 關(guān)于pytorch求導(dǎo)總結(jié)(torch.autograd)
相關(guān)文章
Python實(shí)現(xiàn)比較兩個列表(list)范圍
這篇文章主要介紹了Python實(shí)現(xiàn)比較兩個列表(list)范圍,本文根據(jù)一道題目實(shí)現(xiàn)解決代碼,本文分別給出題目和解答源碼,需要的朋友可以參考下2015-06-06
Python實(shí)現(xiàn)常見網(wǎng)絡(luò)通信的示例詳解
這篇文章主要為大家詳細(xì)介紹了Python實(shí)現(xiàn)常見網(wǎng)絡(luò)通信的相關(guān)方法,文中的示例代碼講解詳細(xì),感興趣的小伙伴就跟隨小編一起學(xué)習(xí)一下吧2025-04-04

