在pytorch中實現(xiàn)只讓指定變量向后傳播梯度
pytorch中如何只讓指定變量向后傳播梯度?
(或者說如何讓指定變量不參與后向傳播?)
有以下公式,假如要讓L對xvar求導(dǎo):

(1)中,L對xvar的求導(dǎo)將同時計算out1部分和out2部分;
(2)中,L對xvar的求導(dǎo)只計算out2部分,因為out1的requires_grad=False;
(3)中,L對xvar的求導(dǎo)只計算out1部分,因為out2的requires_grad=False;
驗證如下:
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed May 23 10:02:04 2018
@author: hy
"""
import torch
from torch.autograd import Variable
print("Pytorch version: {}".format(torch.__version__))
x=torch.Tensor([1])
xvar=Variable(x,requires_grad=True)
y1=torch.Tensor([2])
y2=torch.Tensor([7])
y1var=Variable(y1)
y2var=Variable(y2)
#(1)
print("For (1)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(2)
print("For (2)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
out1 = out1.detach()
print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
#(3)
print("For (3)")
print("xvar requres_grad: {}".format(xvar.requires_grad))
print("y1var requres_grad: {}".format(y1var.requires_grad))
print("y2var requres_grad: {}".format(y2var.requires_grad))
out1 = xvar*y1var
print("out1 requres_grad: {}".format(out1.requires_grad))
out2 = xvar*y2var
print("out2 requres_grad: {}".format(out2.requires_grad))
#out1 = out1.detach()
out2 = out2.detach()
print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad))
L=torch.pow(out1-out2,2)
L.backward()
print("xvar.grad: {}".format(xvar.grad))
xvar.grad.data.zero_()
pytorch中,將變量的requires_grad設(shè)為False,即可讓變量不參與梯度的后向傳播;
但是不能直接將out1.requires_grad=False;
其實,Variable類型提供了detach()方法,所返回變量的requires_grad為False。
注意:如果out1和out2的requires_grad都為False的話,那么xvar.grad就出錯了,因為梯度沒有傳到xvar
補充:
volatile=True表示這個變量不計算梯度, 參考:Volatile is recommended for purely inference mode, when you're sure you won't be even calling .backward(). It's more efficient than any other autograd setting - it will use the absolute minimal amount of memory to evaluate the model. volatile also determines that requires_grad is False.
以上這篇在pytorch中實現(xiàn)只讓指定變量向后傳播梯度就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
flask框架使用orm連接數(shù)據(jù)庫的方法示例
這篇文章主要介紹了flask框架使用orm連接數(shù)據(jù)庫的方法,結(jié)合實例形式分析了flask框架使用flask_sqlalchemy包進行mysql數(shù)據(jù)庫連接操作的具體步驟與相關(guān)實現(xiàn)技巧,需要的朋友可以參考下2018-07-07
Python面向?qū)ο笾鄳B(tài)原理與用法案例分析
這篇文章主要介紹了Python面向?qū)ο笾鄳B(tài)原理與用法,結(jié)合具體案例形式分析了Python多態(tài)的具體功能、原理、使用方法與操作注意事項,需要的朋友可以參考下2019-12-12
利用python在excel里面直接使用sql函數(shù)的方法
今天小編就為大家分享一篇利用python在excel里面直接使用sql函數(shù)的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-02-02
基于python內(nèi)置函數(shù)與匿名函數(shù)詳解
下面小編就為大家分享一篇基于python內(nèi)置函數(shù)與匿名函數(shù)詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-01-01
springboot配置文件抽離 git管理統(tǒng) 配置中心詳解
在本篇文章里小編給大家整理的是關(guān)于springboot配置文件抽離 git管理統(tǒng) 配置中心的相關(guān)知識點內(nèi)容,有需要的朋友們可以學(xué)習(xí)下。2019-09-09
Python數(shù)據(jù)分析之?Matplotlib?3D圖詳情
本文主要介紹了Python數(shù)據(jù)分析之Matplotlib 3D圖詳情,Matplotlib提供了mpl_toolkits.mplot3d工具包來進行3D圖表的繪制,下文總結(jié)了更多相關(guān)資料,需要的小伙伴可以參考一下2022-05-05
Python 工具類實現(xiàn)大文件斷點續(xù)傳功能詳解
用python進行大文件下載的時候,一旦出現(xiàn)網(wǎng)絡(luò)波動問題,導(dǎo)致文件下載到一半。如果將下載不完全的文件刪掉,那么又需要從頭開始,如果連續(xù)網(wǎng)絡(luò)波動,是不是要頭禿了。本文提供斷點續(xù)傳下載工具方法,希望可以幫助到你2021-10-10
Django項目uwsgi+Nginx保姆級部署教程實現(xiàn)
這篇文章主要介紹了Django項目uwsgi+Nginx保姆級部署教程實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-04-04

