PyTorch中model.eval()使用與作用小結(jié)
一、??model.train()與model.eval()是什么?
使用 PyTorch 進(jìn)行深度學(xué)習(xí)訓(xùn)練時(shí),我們經(jīng)常會(huì)看到如下的代碼片段:
model.train() # 訓(xùn)練階段... model.eval() # 驗(yàn)證或測(cè)試階段...
很多初學(xué)者第一次看到時(shí)都會(huì)問(wèn):
“為什么要在測(cè)試前加一句 model.eval()?
不加行不行?到底起了什么作用?”
eval,英文意即為評(píng)估

在 PyTorch 中,每個(gè)神經(jīng)網(wǎng)絡(luò)模型都是一個(gè) nn.Module 的子類。
而 nn.Module 中有兩個(gè)非常重要的模式:
| 模式 | 含義 | 常用于 |
|---|---|---|
| model.train() | 開(kāi)啟訓(xùn)練模式(默認(rèn)) | 模型訓(xùn)練階段 |
| model.eval() | 開(kāi)啟評(píng)估模式 | 驗(yàn)證、測(cè)試階段 |
?? 它們的區(qū)別不在于是否計(jì)算梯度,
而在于模型內(nèi)部某些層(如 Dropout、BatchNorm)的行為發(fā)生變化。
二、為什么需要model.eval()
神經(jīng)網(wǎng)絡(luò)中有些層在“訓(xùn)練”和“推理”階段需要不同的行為,例如:
Dropout 層
- 在訓(xùn)練時(shí),會(huì)隨機(jī)“丟棄”一部分神經(jīng)元(防止過(guò)擬合);
- 在測(cè)試時(shí),則應(yīng)該關(guān)閉 Dropout,讓所有神經(jīng)元都參與計(jì)算。
如果你不調(diào)用 model.eval(),
那在測(cè)試階段 Dropout 仍然會(huì)隨機(jī)丟棄神經(jīng)元,導(dǎo)致結(jié)果不穩(wěn)定、性能下降。
Batch Normalization 層(BN層)
- 在訓(xùn)練時(shí),BatchNorm 會(huì)根據(jù)當(dāng)前 mini-batch 的均值和方差進(jìn)行標(biāo)準(zhǔn)化;
- 在測(cè)試時(shí),應(yīng)該使用在訓(xùn)練中統(tǒng)計(jì)到的“全局均值和方差”來(lái)規(guī)范化。
如果不切換到 eval 模式,
BN 層會(huì)繼續(xù)更新統(tǒng)計(jì)信息,導(dǎo)致推理結(jié)果偏差甚至錯(cuò)誤。
? 結(jié)論:
model.eval() 的核心作用是讓模型中某些層(Dropout、BatchNorm)進(jìn)入“推理模式”。
三、model.eval()與torch.no_grad()的區(qū)別
這兩個(gè)經(jīng)常一起出現(xiàn),很多人容易混淆:
| 功能 | 是否影響 Dropout/BN | 是否停止計(jì)算梯度 | 使用場(chǎng)景 |
|---|---|---|---|
| model.eval() | ? 是 | ? 否 | 切換模型狀態(tài)(推理模式) |
| torch.no_grad() | ? 否 | ? 是 | 禁止梯度計(jì)算,加快推理速度、節(jié)省顯存 |
因此,推理時(shí)我們通常會(huì)這樣寫??:
model.eval() # 切換為推理模式
with torch.no_grad(): # 不計(jì)算梯度
outputs = model(inputs)
四、完整示例:對(duì)比train()和eval()
讓我們用一個(gè)小例子直觀看看區(qū)別 ??
import torch
import torch.nn as nn
# 一個(gè)簡(jiǎn)單的網(wǎng)絡(luò),包含 Dropout
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(4, 4)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
return self.dropout(self.fc(x))
# 創(chuàng)建模型和輸入
x = torch.ones(4)
model = SimpleNet()
# 訓(xùn)練模式
model.train()
print("Train Mode Output:")
for _ in range(3):
print(model(x))
# 推理模式
model.eval()
print("\nEval Mode Output:")
for _ in range(3):
print(model(x))
輸出對(duì)比
Train Mode Output: tensor([-0.0000, -1.4387, 0.7793, 0.0000], grad_fn=<MulBackward0>) tensor([-0.0000, -1.4387, 0.0000, 0.0000], grad_fn=<MulBackward0>) tensor([-0.0000, -1.4387, 0.7793, 0.0000], grad_fn=<MulBackward0>) Eval Mode Output: tensor([-0.2442, -0.7194, 0.3897, 0.9389], grad_fn=<ViewBackward0>) tensor([-0.2442, -0.7194, 0.3897, 0.9389], grad_fn=<ViewBackward0>) tensor([-0.2442, -0.7194, 0.3897, 0.9389], grad_fn=<ViewBackward0>)
? 說(shuō)明:
- 訓(xùn)練模式下 Dropout 隨機(jī)屏蔽神經(jīng)元,因此每次輸出不同;
- 推理模式下 Dropout 被關(guān)閉,輸出穩(wěn)定。
五、與model.train()的區(qū)別總結(jié)
| 比較項(xiàng) | model.train() | model.eval() |
|---|---|---|
| 模型狀態(tài) | 訓(xùn)練模式 | 推理模式 |
| Dropout | 啟用隨機(jī)丟棄 | 關(guān)閉 |
| BatchNorm | 使用批次統(tǒng)計(jì) | 使用全局統(tǒng)計(jì) |
| 是否影響梯度 | ? 不影響 | ? 不影響 |
| 常用場(chǎng)景 | 模型訓(xùn)練階段 | 驗(yàn)證、推理階段 |
六、完整實(shí)戰(zhàn)代碼(訓(xùn)練 + 驗(yàn)證)
import torch
import torch.nn as nn
import torch.optim as optim
# 定義簡(jiǎn)單模型
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4, 10)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(10, 3)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
return self.fc2(x)
model = Net()
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
for epoch in range(3):
# ===== 訓(xùn)練階段 =====
model.train()
optimizer.zero_grad()
x = torch.randn(5, 4)
y = torch.randint(0, 3, (5,))
out = model(x)
loss = criterion(out, y)
loss.backward()
optimizer.step()
# ===== 驗(yàn)證階段 =====
model.eval()
with torch.no_grad():
val_x = torch.randn(5, 4)
val_out = model(val_x)
val_pred = val_out.argmax(dim=1)
print(f"Epoch {epoch}: loss={loss.item():.4f}, val_pred={val_pred.tolist()}")
輸出如下:
Epoch 0: loss=1.0044, val_pred=[1, 1, 1, 2, 2] Epoch 1: loss=0.9953, val_pred=[2, 1, 2, 2, 2] Epoch 2: loss=1.2143, val_pred=[2, 2, 1, 2, 1]
? 訓(xùn)練時(shí):
- Dropout 啟用;
- BatchNorm 統(tǒng)計(jì)更新。
? 驗(yàn)證時(shí):
- Dropout 關(guān)閉;
- BatchNorm 使用訓(xùn)練統(tǒng)計(jì)參數(shù)。
七、常見(jiàn)錯(cuò)誤與避坑指南
| 錯(cuò)誤用法 | 后果 |
|---|---|
| 在測(cè)試時(shí)忘記 model.eval() | Dropout、BN 層仍隨機(jī),導(dǎo)致結(jié)果波動(dòng)、不穩(wěn)定 |
| 在推理時(shí)忘記 torch.no_grad() | 會(huì)記錄梯度,浪費(fèi)顯存、速度變慢 |
| 在訓(xùn)練時(shí)調(diào)用了 model.eval() | 模型學(xué)不動(dòng),BN 不更新統(tǒng)計(jì)信息 |
| 忘記在訓(xùn)練開(kāi)始前加 model.train() | 模型仍在推理模式,訓(xùn)練效果不佳 |
八、小結(jié)
| 項(xiàng)目 | 說(shuō)明 |
|---|---|
| 函數(shù)名 | model.eval() |
| 所屬模塊 | torch.nn.Module |
| 作用 | 切換模型到評(píng)估(推理)模式 |
| 影響層 | Dropout、BatchNorm |
| 與 no_grad 區(qū)別 | eval() 控制模式,no_grad 控制梯度 |
| 使用場(chǎng)景 | 驗(yàn)證、測(cè)試、推理階段 |
| 常用組合 | model.eval() + with torch.no_grad(): |
到此這篇關(guān)于PyTorch中model.eval()使用與作用小結(jié)的文章就介紹到這了,更多相關(guān)PyTorch model.eval()使用內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python中的copy()函數(shù)詳解(list,array)
這篇文章主要介紹了Python中的copy()函數(shù)詳解(list,array),具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-09-09
Python的numpy庫(kù)中將矩陣轉(zhuǎn)換為列表等函數(shù)的方法
下面小編就為大家分享一篇Python的numpy庫(kù)中將矩陣轉(zhuǎn)換為列表等函數(shù)的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-04-04
使用python實(shí)現(xiàn)群發(fā)微信消息的工具
如果您想批量向微信好友發(fā)送相同的內(nèi)容,手動(dòng)一個(gè)個(gè)操作非常費(fèi)時(shí)費(fèi)力,這時(shí)候可以用Python實(shí)現(xiàn)自動(dòng)化處理,更加高效方便,下面小編就來(lái)和大家講講具體操作吧2025-05-05
Tensorflow中k.gradients()和tf.stop_gradient()用法說(shuō)明
這篇文章主要介紹了Tensorflow中k.gradients()和tf.stop_gradient()用法說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-06-06
python同義詞替換的實(shí)現(xiàn)(jieba分詞)
這篇文章主要介紹了python同義詞替換的實(shí)現(xiàn)(jieba分詞),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-01-01
基于進(jìn)程內(nèi)通訊的python聊天室實(shí)現(xiàn)方法
這篇文章主要介紹了基于進(jìn)程內(nèi)通訊的python聊天室實(shí)現(xiàn)方法,實(shí)例分析了Python聊天室的相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2015-06-06
pytorch 使用單個(gè)GPU與多個(gè)GPU進(jìn)行訓(xùn)練與測(cè)試的方法
今天小編就為大家分享一篇pytorch 使用單個(gè)GPU與多個(gè)GPU進(jìn)行訓(xùn)練與測(cè)試的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-08-08

