聊聊pytorch測(cè)試的時(shí)候?yàn)楹我由蟤odel.eval()
Do need to use model.eval() when I test?
Sure, Dropout works as a regularization for preventing overfitting during training.
It randomly zeros the elements of inputs in Dropout layer on forward call.
It should be disabled during testing since you may want to use full model (no element is masked)
使用PyTorch進(jìn)行訓(xùn)練和測(cè)試時(shí)一定注意要把實(shí)例化的model指定train/eval,eval()時(shí),框架會(huì)自動(dòng)把BN和DropOut固定住,不會(huì)取平均,而是用訓(xùn)練好的值,不然的話,一旦test的batch_size過(guò)小,很容易就會(huì)被BN層導(dǎo)致生成圖片顏色失真極大?。。。。?!
補(bǔ)充:pytorch中model eval和torch no grad()的區(qū)別
model.eval()和with torch.no_grad()的區(qū)別
在PyTorch中進(jìn)行validation時(shí),會(huì)使用model.eval()切換到測(cè)試模式,在該模式下,
主要用于通知dropout層和batchnorm層在train和val模式間切換
在train模式下,dropout網(wǎng)絡(luò)層會(huì)按照設(shè)定的參數(shù)p設(shè)置保留激活單元的概率(保留概率=p); batchnorm層會(huì)繼續(xù)計(jì)算數(shù)據(jù)的mean和var等參數(shù)并更新。
在val模式下,dropout層會(huì)讓所有的激活單元都通過(guò),而batchnorm層會(huì)停止計(jì)算和更新mean和var,直接使用在訓(xùn)練階段已經(jīng)學(xué)出的mean和var值。
該模式不會(huì)影響各層的gradient計(jì)算行為,即gradient計(jì)算和存儲(chǔ)與training模式一樣,只是不進(jìn)行反傳(backprobagation)
而with torch.no_grad()則主要是用于停止autograd模塊的工作,以起到加速和節(jié)省顯存的作用,具體行為就是停止gradient計(jì)算,從而節(jié)省了GPU算力和顯存,但是并不會(huì)影響dropout和batchnorm層的行為。
使用場(chǎng)景
如果不在意顯存大小和計(jì)算時(shí)間的話,僅僅使用model.eval()已足夠得到正確的validation的結(jié)果;而with torch.zero_grad()則是更進(jìn)一步加速和節(jié)省gpu空間(因?yàn)椴挥糜?jì)算和存儲(chǔ)gradient),從而可以更快計(jì)算,也可以跑更大的batch來(lái)測(cè)試。
補(bǔ)充:Pytorch的modle.train,model.eval,with torch.no_grad的個(gè)人理解
1. 最近在學(xué)習(xí)pytorch過(guò)程中遇到了幾個(gè)問(wèn)題
不理解為什么在訓(xùn)練和測(cè)試函數(shù)中model.eval(),和model.train()的區(qū)別,經(jīng)查閱后做如下整理
一般情況下,我們訓(xùn)練過(guò)程如下:
1、拿到數(shù)據(jù)后進(jìn)行訓(xùn)練,在訓(xùn)練過(guò)程中,使用
model.train():告訴我們的網(wǎng)絡(luò),這個(gè)階段是用來(lái)訓(xùn)練的,可以更新參數(shù)。
2、訓(xùn)練完成后進(jìn)行預(yù)測(cè),在預(yù)測(cè)過(guò)程中,使用
model.eval() : 告訴我們的網(wǎng)絡(luò),這個(gè)階段是用來(lái)測(cè)試的,于是模型的參數(shù)在該階段不進(jìn)行更新。
2. 但是為什么在eval()階段會(huì)使用with torch.no_grad()?
查閱相關(guān)資料:傳送門
with torch.no_grad - disables tracking of gradients in autograd.
model.eval() changes the forward() behaviour of the module it is called upon
eg, it disables dropout and has batch norm use the entire population statistics
總結(jié)一下就是說(shuō),在eval階段了,即使不更新,但是在模型中所使用的dropout或者batch norm也就失效了,直接都會(huì)進(jìn)行預(yù)測(cè),而使用no_grad則設(shè)置讓梯度Autograd設(shè)置為False(因?yàn)樵谟?xùn)練中我們默認(rèn)是True),這樣保證了反向過(guò)程為純粹的測(cè)試,而不變參數(shù)。
另外,參考文檔說(shuō)這樣避免每一個(gè)參數(shù)都要設(shè)置,解放了GPU底層的時(shí)間開(kāi)銷,在測(cè)試階段統(tǒng)一梯度設(shè)置為False
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python數(shù)據(jù)結(jié)構(gòu)集合的相關(guān)詳解
集合是Python中一種無(wú)序且元素唯一的數(shù)據(jù)結(jié)構(gòu),主要用于存儲(chǔ)不重復(fù)的元素,Python提供set類型表示集合,可通過(guò){}或set()創(chuàng)建,集合元素不可重復(fù)且無(wú)序,不支持索引訪問(wèn),但可迭代,集合可變,支持添加、刪除元素,集合操作包括并集、交集、差集等,可通過(guò)運(yùn)算符或方法執(zhí)行2024-09-09
基于python實(shí)現(xiàn)一個(gè)簡(jiǎn)單的瀏覽器引擎
瀏覽器引擎是用來(lái)處理、渲染和顯示網(wǎng)頁(yè)內(nèi)容的核心組件,其主要任務(wù)是將用戶輸入的URL所代表的網(wǎng)頁(yè)資源加載并呈現(xiàn)出來(lái),通常包括HTML、CSS、JavaScript以及各種多媒體內(nèi)容,本文給大家介紹了如何基于python實(shí)現(xiàn)一個(gè)簡(jiǎn)單的瀏覽器引擎,需要的朋友可以參考下2024-10-10
Python3基礎(chǔ)之基本數(shù)據(jù)類型概述
這篇文章主要介紹了Python3的基本數(shù)據(jù)類型,需要的朋友可以參考下2014-08-08
詳解如何使用pandas進(jìn)行時(shí)間序列數(shù)據(jù)的周期轉(zhuǎn)換
時(shí)間序列數(shù)據(jù)是數(shù)據(jù)分析中經(jīng)常遇到的類型,為了更多的挖掘出數(shù)據(jù)內(nèi)部的信息,我們常常依據(jù)原始數(shù)據(jù)中的時(shí)間周期,將其轉(zhuǎn)換成不同跨度的周期,下面以模擬的K線數(shù)據(jù)為例,演示如何使用pandas來(lái)進(jìn)行周期轉(zhuǎn)換,感興趣的朋友可以參考下2024-05-05

