Pytorch中的modle.train,model.eval,with torch.no_grad解讀
modle.train,model.eval,with torch.no_grad解讀
1. 最近在學(xué)習(xí)pytorch過程中遇到了幾個問題
不理解為什么在訓(xùn)練和測試函數(shù)中model.eval(),和model.train()的區(qū)別,經(jīng)查閱后做如下整理
一般情況下,我們訓(xùn)練過程如下:
拿到數(shù)據(jù)后進(jìn)行訓(xùn)練,在訓(xùn)練過程中,使用
model.train():告訴我們的網(wǎng)絡(luò),這個階段是用來訓(xùn)練的,可以更新參數(shù)。
訓(xùn)練完成后進(jìn)行預(yù)測,在預(yù)測過程中,使用
model.eval(): 告訴我們的網(wǎng)絡(luò),這個階段是用來測試的,于是模型的參數(shù)在該階段不進(jìn)行更新。
2. 但是為什么在eval()階段會使用with torch.no_grad()?
查閱相關(guān)資料:傳送門
with torch.no_grad - disables tracking of gradients in autograd.
model.eval() changes the forward() behaviour of the module it is called upon
eg, it disables dropout and has batch norm use the entire population statistics
總結(jié)一下就是說,在eval階段了,即使不更新,但是在模型中所使用的dropout或者batch norm也就失效了,直接都會進(jìn)行預(yù)測,而使用no_grad則設(shè)置讓梯度Autograd設(shè)置為False(因?yàn)樵谟?xùn)練中我們默認(rèn)是True),這樣保證了反向過程為純粹的測試,而不變參數(shù)。
另外,參考文檔說這樣避免每一個參數(shù)都要設(shè)置,解放了GPU底層的時間開銷,在測試階段統(tǒng)一梯度設(shè)置為False
model.eval()與torch.no_grad()的作用
model.eval()
經(jīng)常在模型推理代碼的前面, 都會添加model.eval(), 主要有3個作用:
- 1.不進(jìn)行dropout
- 2.不更新batchnorm的mean 和var 參數(shù)
- 3.不進(jìn)行梯度反向傳播, 但梯度仍然會計(jì)算
torch.no_grad()
torch.no_grad的一般使用方法是, 在代碼塊外面用with torch.no_grad()給包起來。 如下面這樣:
with torch.no_grad(): ?? ?# your code?
它的主要作用有2個:
- 1.不進(jìn)行梯度的計(jì)算(當(dāng)然也就沒辦法反向傳播了), 節(jié)約顯存和算力
- 2.dropout和batchnorn還是會正常更新
異同
從上面的介紹中可以非常明確的看出,它們的相同點(diǎn)是一般都用在推理階段, 但它們的作用是完全不同的, 也沒有重疊。 可以一起使用。
總結(jié)
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
matplotlib實(shí)現(xiàn)區(qū)域顏色填充
這篇文章主要為大家詳細(xì)介紹了matplotlib實(shí)現(xiàn)區(qū)域顏色填充,具有一定的參考價值,感興趣的小伙伴們可以參考一下2019-03-03
Python ''takes exactly 1 argument (2 given)'' Python error
這篇文章主要介紹了Python 'takes exactly 1 argument (2 given)' Python error的相關(guān)資料,需要的朋友可以參考下2016-12-12
深入了解Python中字符串格式化工具f-strings的使用
從Python?3.6版本開始,引入了一種新的字符串格式化機(jī)制,即f-strings,它強(qiáng)大且易于使用的字符串格式化方式,本文就來聊聊他的具體使用,希望對大家有所幫助2023-05-05
Pyecharts 中Geo函數(shù)常用參數(shù)的用法說明
這篇文章主要介紹了Pyecharts 中Geo函數(shù)常用參數(shù)的用法說明,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-02-02

