在pytorch中對非葉節(jié)點的變量計算梯度實例
在pytorch中一般只對葉節(jié)點進行梯度計算,也就是下圖中的d,e節(jié)點,而對非葉節(jié)點,也即是c,b節(jié)點則沒有顯式地去保留其中間計算過程中的梯度(因為一般來說只有葉節(jié)點才需要去更新),這樣可以節(jié)省很大部分的顯存,但是在調(diào)試過程中,有時候我們需要對中間變量梯度進行監(jiān)控,以確保網(wǎng)絡的有效性,這個時候我們需要打印出非葉節(jié)點的梯度,為了實現(xiàn)這個目的,我們可以通過兩種手段進行。

注冊hook函數(shù)
Tensor.register_hook[2] 可以注冊一個反向梯度傳導時的hook函數(shù),這個hook函數(shù)將會在每次計算 關于該張量
的時候 被調(diào)用,經(jīng)常用于調(diào)試的時候打印出非葉節(jié)點梯度。當然,通過這個手段,你也可以自定義某一層的梯度更新方法。[3] 具體到這里的打印非葉節(jié)點的梯度,代碼如:
def hook_y(grad): print(grad) x = Variable(torch.ones(2, 2), requires_grad=True) y = x + 2 z = y * y * 3 y.register_hook(hook_y) out = z.mean() out.backward()
輸出如:
tensor([[4.5000, 4.5000], [4.5000, 4.5000]])
retain_grad()
Tensor.retain_grad()顯式地保存非葉節(jié)點的梯度,當然代價就是會增加顯存的消耗,而用hook函數(shù)的方法則是在反向計算時直接打印,因此不會增加顯存消耗,但是使用起來retain_grad()要比hook函數(shù)方便一些。代碼如:
x = Variable(torch.ones(2, 2), requires_grad=True) y = x + 2 y.retain_grad() z = y * y * 3 out = z.mean() out.backward() print(y.grad)
輸出如:
tensor([[4.5000, 4.5000], [4.5000, 4.5000]])
以上這篇在pytorch中對非葉節(jié)點的變量計算梯度實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
python3 將階乘改成函數(shù)形式進行調(diào)用的操作
這篇文章主要介紹了python3 將階乘改成函數(shù)形式進行調(diào)用的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-03-03
詳解Python操作RabbitMQ服務器消息隊列的遠程結果返回
RabbitMQ是一款基于MQ的服務器,Python可以通過Pika庫來進行程序操控,這里我們將來詳解Python操作RabbitMQ服務器消息隊列的遠程結果返回:2016-06-06

