PyTorch 如何檢查模型梯度是否可導(dǎo)
一、PyTorch 檢查模型梯度是否可導(dǎo)
當(dāng)我們構(gòu)建復(fù)雜網(wǎng)絡(luò)模型或在模型中加入復(fù)雜操作時,可能會需要驗證該模型或操作是否可導(dǎo),即模型是否能夠優(yōu)化,在PyTorch框架下,我們可以使用torch.autograd.gradcheck函數(shù)來實現(xiàn)這一功能。
首先看一下官方文檔中關(guān)于該函數(shù)的介紹:


可以看到官方文檔中介紹了該函數(shù)基于何種方法,以及其參數(shù)列表,下面給出幾個例子介紹其使用方法,注意:
Tensor需要是雙精度浮點型且設(shè)置requires_grad = True
第一個例子:檢查某一操作是否可導(dǎo)
from torch.autograd import gradcheck
import torch
import torch.nn as nn
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)
輸出為:
Are the gradients correct: True
第二個例子:檢查某一網(wǎng)絡(luò)模型是否可導(dǎo)
from torch.autograd import gradcheck
import torch
import torch.nn as nn
# 定義神經(jīng)網(wǎng)絡(luò)模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(
nn.Linear(15, 30),
nn.ReLU(),
nn.Linear(30, 15),
nn.ReLU(),
nn.Linear(15, 1),
nn.Sigmoid()
)
def forward(self, x):
y = self.net(x)
return y
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)
輸出為:
Are the gradients correct: True
二、Pytorch求導(dǎo)
1.標(biāo)量對矩陣求導(dǎo)

驗證:
>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]]) # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True) #4*3矩陣,注意,值必須要是float類型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b) # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad #df/dX = a.dot(b^T)
tensor([[ 2., 3., 4.],
[ 4., 6., 8.],
[ 6., 9., 12.],
[ 8., 12., 16.]])
>>>a.grad b.grad # a和b的requires_grad都為默認(默認為False),所以求導(dǎo)時,沒有梯度
(None, None)
>>>a.mm(b.view(1,-1)) # a.dot(b^T)
tensor([[ 2., 3., 4.],
[ 4., 6., 8.],
[ 6., 9., 12.],
[ 8., 12., 16.]])
2.矩陣對矩陣求導(dǎo)

驗證:
>>>A = torch.tensor([[1,2],[3,4.]]) #2*2矩陣
>>>X = torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True) # 2*3矩陣
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
[19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括號里要加上這句
>>>X.grad
tensor([[4., 4., 4.],
[6., 6., 6.]])
注意:
requires_grad為True的數(shù)組必須是float類型
進行backgrad的必須是標(biāo)量,如果是向量,必須在后面括號里加上torch.ones_like(X)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python繪圖系統(tǒng)之散點圖和條形圖的實現(xiàn)代碼
這篇文章主要為大家詳細介紹了如何使用Python繪制散點圖和條形圖,文中的示例代碼講解詳細,對我們的學(xué)習(xí)或工作有一定的幫助,感興趣的可以了解一下2023-08-08
用sqlalchemy構(gòu)建Django連接池的實例
今天小編就為大家分享一篇用sqlalchemy構(gòu)建Django連接池的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-08-08
NoSql數(shù)據(jù)庫介紹及使用Python連接MongoDB
MongoDB是一個非常流行的NoSQL數(shù)據(jù)庫,常用于大規(guī)模數(shù)據(jù)存儲應(yīng)用,下面這篇文章主要給大家介紹了關(guān)于NoSql數(shù)據(jù)庫及使用Python連接MongoDB的相關(guān)資料,需要的朋友可以參考下2023-06-06
Python從ZabbixAPI獲取信息及實現(xiàn)Zabbix-API 監(jiān)控的方法
這篇文章主要介紹了Python從ZabbixAPI獲取信息及實現(xiàn)Zabbix-API 監(jiān)控的方法,需要的朋友可以參考下2018-09-09
Python基于stuck實現(xiàn)scoket文件傳輸
這篇文章主要介紹了Python基于stuck實現(xiàn)scoket文件傳輸,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-04-04

