詳解model.train()和model.eval()兩種模式的原理與用法
一、兩種模式
pytorch可以給我們提供兩種方式來切換訓(xùn)練和評(píng)估(推斷)的模式,分別是:model.train() 和 model.eval()。
一般用法是:在訓(xùn)練開始之前寫上 model.trian() ,在測試時(shí)寫上 model.eval() 。
二、功能
1. model.train()
在使用 pytorch 構(gòu)建神經(jīng)網(wǎng)絡(luò)的時(shí)候,訓(xùn)練過程中會(huì)在程序上方添加一句model.train(),作用是 啟用 batch normalization 和 dropout 。
如果模型中有BN層(Batch Normalization)和 Dropout ,需要在 訓(xùn)練時(shí) 添加 model.train()。
model.train() 是保證 BN 層能夠用到 每一批數(shù)據(jù) 的均值和方差。對(duì)于 Dropout,model.train() 是 隨機(jī)取一部分 網(wǎng)絡(luò)連接來訓(xùn)練更新參數(shù)。
2. model.eval()
model.eval()的作用是 不啟用 Batch Normalization 和 Dropout。
如果模型中有 BN 層(Batch Normalization)和 Dropout,在 測試時(shí) 添加 model.eval()。
model.eval() 是保證 BN 層能夠用 全部訓(xùn)練數(shù)據(jù) 的均值和方差,即測試過程中要保證 BN 層的均值和方差不變。對(duì)于 Dropout,model.eval() 是利用到了 所有 網(wǎng)絡(luò)連接,即不進(jìn)行隨機(jī)舍棄神經(jīng)元。
為什么測試時(shí)要用 model.eval() ?
訓(xùn)練完 train 樣本后,生成的模型 model 要用來測試樣本了。在 model(test) 之前,需要加上model.eval(),否則的話,有輸入數(shù)據(jù),即使不訓(xùn)練,它也會(huì)改變權(quán)值。這是 model 中含有 BN 層和 Dropout 所帶來的的性質(zhì)。
eval() 時(shí),pytorch 會(huì)自動(dòng)把 BN 和 DropOut 固定住,不會(huì)取平均,而是用訓(xùn)練好的值。
不然的話,一旦 test 的 batch_size 過小,很容易就會(huì)被 BN 層導(dǎo)致生成圖片顏色失真極大。
eval() 在非訓(xùn)練的時(shí)候是需要加的,沒有這句代碼,一些網(wǎng)絡(luò)層的值會(huì)發(fā)生變動(dòng),不會(huì)固定,你神經(jīng)網(wǎng)絡(luò)每一次生成的結(jié)果也是不固定的,生成質(zhì)量可能好也可能不好。
也就是說,測試過程中使用model.eval(),這時(shí)神經(jīng)網(wǎng)絡(luò)會(huì) 沿用 batch normalization 的值,而并 不使用 dropout。
3. 總結(jié)與對(duì)比
如果模型中有 BN 層(Batch Normalization)和 Dropout,需要在訓(xùn)練時(shí)添加 model.train(),在測試時(shí)添加 model.eval()。
其中 model.train() 是保證 BN 層用每一批數(shù)據(jù)的均值和方差,而 model.eval() 是保證 BN 用全部訓(xùn)練數(shù)據(jù)的均值和方差;
而對(duì)于 Dropout,model.train() 是隨機(jī)取一部分網(wǎng)絡(luò)連接來訓(xùn)練更新參數(shù),而 model.eval() 是利用到了所有網(wǎng)絡(luò)連接。
三、Dropout 簡介
dropout 常常用于抑制過擬合。
設(shè)置Dropout時(shí),torch.nn.Dropout(0.5),這里的 0.5 是指該層(layer)的神經(jīng)元在每次迭代訓(xùn)練時(shí)會(huì)隨機(jī)有 50% 的可能性被丟棄(失活),不參與訓(xùn)練。也就是將上一層數(shù)據(jù)減少一半傳播。
到此這篇關(guān)于詳解model.train()和model.eval()兩種模式的原理與用法的文章就介紹到這了,更多相關(guān)model.train()和model.eval()原理用法內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)單項(xiàng)鏈表的最全教程
單向鏈表也叫單鏈表,是鏈表中最簡單的一種形式,它的每個(gè)節(jié)點(diǎn)包含兩個(gè)域,一個(gè)信息域(元素域)和一個(gè)鏈接域,這個(gè)鏈接指向鏈表中的下一個(gè)節(jié)點(diǎn),而最后一個(gè)節(jié)點(diǎn)的鏈接域則指向一個(gè)空值,這篇文章主要介紹了Python實(shí)現(xiàn)單項(xiàng)鏈表,需要的朋友可以參考下2023-01-01
python環(huán)境配置方式(服務(wù)器+本地)
這篇文章詳細(xì)介紹了在服務(wù)器上安裝和配置Anaconda3、TensorFlow、PyTorch等深度學(xué)習(xí)環(huán)境的步驟,包括下載、初始化、創(chuàng)建環(huán)境、驗(yàn)證安裝以及解決一些常見問題2025-01-01
Python/MySQL實(shí)現(xiàn)Excel文件自動(dòng)處理數(shù)據(jù)功能
在沒有服務(wù)器存儲(chǔ)數(shù)據(jù),只有excel文件的情況下,如何利用SQL和python實(shí)現(xiàn)數(shù)據(jù)分析和數(shù)據(jù)自動(dòng)處理的功能?本文就來和大家聊聊解決辦法2023-02-02
解決python -m pip install --upgrade pip 升級(jí)不成功問題
這篇文章主要介紹了python -m pip install --upgrade pip 解決升級(jí)不成功問題,需要的朋友可以參考下2020-03-03
解決使用Spyder IDE時(shí)matplotlib繪圖的顯示問題
這篇文章主要介紹了解決使用Spyder IDE時(shí)matplotlib繪圖的顯示問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2021-04-04
python實(shí)現(xiàn)文本界面網(wǎng)絡(luò)聊天室
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)文本界面網(wǎng)絡(luò)聊天室,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-12-12

