pytorch 實現(xiàn)在測試的時候啟用dropout
我們知道,dropout一般都在訓(xùn)練的時候使用,那么測試的時候如何也開啟dropout呢?
在pytorch中,網(wǎng)絡(luò)有train和eval兩種模式,在train模式下,dropout和batch normalization會生效,而val模式下,dropout不生效,bn固定參數(shù)。
想要在測試的時候使用dropout,可以把dropout單獨設(shè)為train模式,這里可以使用apply函數(shù):
def apply_dropout(m):
if type(m) == nn.Dropout:
m.train()
下面是完整demo代碼:
# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(8, 8)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.fc(x)
x = self.dropout(x)
return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
if type(m) == nn.Dropout:
m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)
運行結(jié)果:

可以看到,在eval模式下,由于dropout未生效,每次跑的結(jié)果不同,利用apply函數(shù),將Dropout單獨設(shè)為train模式,dropout就生效了。
補充:Pytorch之dropout避免過擬合測試
一.做數(shù)據(jù)


二.搭建神經(jīng)網(wǎng)絡(luò)


三.訓(xùn)練

四.對比測試結(jié)果
注意:測試過程中,一定要注意模式切換


以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
- PyTorch使用Tricks:Dropout,R-Dropout和Multi-Sample?Dropout方式
- Pytorch?nn.Dropout的用法示例詳解
- Python深度學(xué)習(xí)pytorch神經(jīng)網(wǎng)絡(luò)Dropout應(yīng)用詳解解
- Pytorch之如何dropout避免過擬合
- PyTorch dropout設(shè)置訓(xùn)練和測試模式的實現(xiàn)
- pytorch Dropout過擬合的操作
- 淺談pytorch中的dropout的概率p
- PyTorch 實現(xiàn)L2正則化以及Dropout的操作
- pytorch 中nn.Dropout的使用說明
- pytorch中Dropout的具體用法
相關(guān)文章
使用Keras構(gòu)造簡單的CNN網(wǎng)絡(luò)實例
這篇文章主要介紹了使用Keras構(gòu)造簡單的CNN網(wǎng)絡(luò)實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06
Python學(xué)習(xí)之12個常用基礎(chǔ)語法詳解
這篇文章主要為大家介紹了12個Python小案例,包含了日常開發(fā)中非常實用的語法,快來跟隨小編一起學(xué)習(xí)一下,看看自己都會多少個呢2022-02-02
PyTorch中torch.load()的用法和應(yīng)用
torch.load()它用于加載由torch.save()保存的模型或張量,本文主要介紹了PyTorch中torch.load()的用法和應(yīng)用,具有一定的參考價值,感興趣的可以了解一下2024-03-03
談一談數(shù)組拼接tf.concat()和np.concatenate()的區(qū)別
今天小編就為大家分享一篇談?wù)剶?shù)組拼接tf.concat()和np.concatenate()的區(qū)別,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-02-02

