PyTorch基于MNIST的手寫數(shù)字識別
1. 深度學(xué)習(xí)與PyTorch簡介
深度學(xué)習(xí)作為機(jī)器學(xué)習(xí)的重要分支,已在計算機(jī)視覺、自然語言處理等領(lǐng)域取得了顯著成果。PyTorch是由Facebook開源的深度學(xué)習(xí)框架,以其動態(tài)計算圖和直觀的API設(shè)計而廣受歡迎。本文以經(jīng)典的MNIST手寫數(shù)字?jǐn)?shù)據(jù)集為例,展示如何利用PyTorch框架構(gòu)建并訓(xùn)練深度學(xué)習(xí)模型。
2. 環(huán)境配置與數(shù)據(jù)準(zhǔn)備
2.1 環(huán)境檢查
首先檢查PyTorch及相關(guān)庫的版本,確保環(huán)境配置正確:
import torch import torchvision import torchaudio from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor from matplotlib import pyplot as plt print(torch.__version__) print(torchaudio.__version__) print(torchvision.__version__)

2.2 數(shù)據(jù)加載與預(yù)處理
MNIST數(shù)據(jù)集包含60,000個訓(xùn)練樣本和10,000個測試樣本,每個樣本為28×28像素的灰度手寫數(shù)字圖像。
training_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.MNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
參數(shù):
root:數(shù)據(jù)存儲路徑train:是否為訓(xùn)練集download:是否自動下載transform:數(shù)據(jù)預(yù)處理轉(zhuǎn)換,ToTensor()將PIL圖像轉(zhuǎn)換為張量并歸一化到[0,1]
2.3 數(shù)據(jù)可視化
我們可以查看數(shù)據(jù)集的樣本分布:
print(len(training_data))
figure = plt.figure()
for i in range(9):
img, label = training_data[i + 59000]
figure.add_subplot(3, 3, i + 1)
plt.title(label)
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()


2.4 數(shù)據(jù)批量加載
使用DataLoader實(shí)現(xiàn)數(shù)據(jù)的批量加載和隨機(jī)打亂:
# 增加批次大小
train_dataloader = DataLoader(training_data, batch_size=128) # 增大batch size
test_dataloader = DataLoader(test_data, batch_size=128)
for X, y in test_dataloader:
print(f"Shape of X[N,C,H,W]:{X.shape}")
print(f"Shape of y:{y.shape} {y.dtype}")
break

3. 神經(jīng)網(wǎng)絡(luò)模型設(shè)計
3.1 設(shè)備選擇
根據(jù)可用硬件選擇計算設(shè)備:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

3.2 神經(jīng)網(wǎng)絡(luò)架構(gòu)
設(shè)計一個包含多個全連接層的深度神經(jīng)網(wǎng)絡(luò):
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.a = 10
self.flatten = nn.Flatten()
原始架構(gòu)
self.hidden1 = nn.Linear(28 * 28, 128)
self.hidden2 = nn.Linear(128, 256)
self.out = nn.Linear(256, 10)
def forward(self, x):
# 原始前向傳播
x = self.flatten(x)
x = self.hidden1(x)
x = torch.sigmoid(x)
x = self.hidden2(x)
x = torch.sigmoid(x)
return x
3.3 模型實(shí)例化
model = NeuralNetwork().to(device) print(model)

4. 訓(xùn)練與評估流程
4.1 訓(xùn)練函數(shù)
def train(dataloader, model, loss_fn, optimizer):
model.train()
batch_size_num = 1
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model.forward(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_value = loss.item()
if batch_size_num % 100 == 0:
print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
batch_size_num += 1
訓(xùn)練步驟:
model.train():設(shè)置為訓(xùn)練模式(啟用Dropout)- 前向傳播計算預(yù)測值
- 計算損失函數(shù)值
optimizer.zero_grad():清空梯度loss.backward():反向傳播計算梯度optimizer.step():更新模型參數(shù)
4.2 測試函數(shù)
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model.forward(X)
test_loss = loss_fn(pred, y)
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
a = (pred.argmax(1) == y)
b = (pred.argmax(1) == y).type(torch.float)
test_loss /= num_batches
correct /= size
print(f"Test result:\n Accuracy:{(100 * correct):.2f}%, Avg loss: {test_loss}")
測試要點(diǎn):
model.eval():設(shè)置為評估模式(禁用Dropout)torch.no_grad():禁用梯度計算,節(jié)省內(nèi)存pred.argmax(1):獲取預(yù)測類別
5. 損失函數(shù)配置
loss_fn = nn.CrossEntropyLoss()
損失函數(shù)說明:
- 使用
CrossEntropyLoss,適用于多分類問題 - 結(jié)合了LogSoftmax和NLLLoss,直接輸出分類概率
6. 模型訓(xùn)練與評估
6.1 優(yōu)化器配置
# 原始優(yōu)化器 optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
6.2 單次訓(xùn)練與測試
train(train_dataloader, model, loss_fn, optimizer) test(train_dataloader, model, loss_fn)

6.3 多輪訓(xùn)練(可選)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n----------------------")
train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

7. 提高準(zhǔn)確率的優(yōu)化方式
- 層數(shù)增加:從2層隱藏層增加到3層,增強(qiáng)模型表達(dá)能力
- 神經(jīng)元增加:第一層從128個神經(jīng)元增加到512個
- 激活函數(shù):用ReLU替代sigmoid,緩解梯度消失問題
- 正則化:添加Dropout層(0.2丟棄率),防止過擬合
- 改進(jìn)優(yōu)化器:降低學(xué)習(xí)率
# 改進(jìn)架構(gòu)
self.hidden1 = nn.Linear(28 * 28, 512) # 增加神經(jīng)元
self.dropout1 = nn.Dropout(0.2) # 添加Dropout
self.hidden2 = nn.Linear(512, 256)
self.dropout2 = nn.Dropout(0.2) # 添加Dropout
self.hidden3 = nn.Linear(256, 128) # 增加一層
self.out = nn.Linear(128, 10)
# 改進(jìn)的前向傳播
x = self.flatten(x)
x = self.hidden1(x)
x = torch.relu(x) # 使用ReLU替代sigmoid
x = self.dropout1(x) # 訓(xùn)練時隨機(jī)丟棄
x = self.hidden2(x)
x = torch.relu(x) # 使用ReLU替代sigmoid
x = self.dropout2(x) # 訓(xùn)練時隨機(jī)丟棄
x = self.hidden3(x)
x = torch.relu(x)
x = self.out(x)
# 改進(jìn)優(yōu)化器 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 降低學(xué)習(xí)率

到此這篇關(guān)于PyTorch基于MNIST的手寫數(shù)字識別的文章就介紹到這了,更多相關(guān)PyTorch MNIST手寫數(shù)字識別內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python報錯:TypeError:?‘xxx‘?object?is?not?subscriptable解決
這篇文章主要給大家介紹了關(guān)于Python報錯:TypeError:?‘xxx‘?object?is?not?subscriptable的解決辦法,TypeError是Python中的一種錯誤,表示操作或函數(shù)應(yīng)用于不合適類型的對象時發(fā)生,文中將解決辦法介紹的非常詳細(xì),需要的朋友可以參考下2024-08-08
Python實(shí)現(xiàn)OpenCV的安裝與使用示例
這篇文章主要介紹了Python實(shí)現(xiàn)OpenCV的安裝與使用,結(jié)合實(shí)例形式分析了Python中OpenCV的安裝及針對圖片的相關(guān)操作技巧,需要的朋友可以參考下2018-03-03
Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)篩選及提取序列中元素的方法
這篇文章主要介紹了Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)篩選及提取序列中元素的方法,涉及Python列表推導(dǎo)式、生成器表達(dá)式及filter()函數(shù)相關(guān)使用技巧,需要的朋友可以參考下2018-03-03
python中l(wèi)iteral_eval函數(shù)的使用小結(jié)
literal_eval是Python標(biāo)準(zhǔn)庫ast模塊中的一個安全函數(shù),用于將包含 Python字面量表達(dá)式的字符串安全地轉(zhuǎn)換為對應(yīng)的Python對象,下面就來介紹一下literal_eval函數(shù)的使用2025-08-08
Python實(shí)現(xiàn)對二維碼數(shù)據(jù)進(jìn)行壓縮
當(dāng)前二維碼的應(yīng)用越來越廣泛,包括疫情時期的健康碼也是應(yīng)用二維碼的典型案例。本文的目標(biāo)很明確,就是使用python,實(shí)現(xiàn)一張二維碼顯示更多信息,代碼簡單實(shí)用,感興趣的可以了解一下2023-02-02
Keras自動下載的數(shù)據(jù)集/模型存放位置介紹
這篇文章主要介紹了Keras自動下載的數(shù)據(jù)集/模型存放位置介紹,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06
如何用Python檢查SQLite數(shù)據(jù)庫中表是否存在
Python查詢表中數(shù)據(jù)有多種方法,具體取決于你使用的數(shù)據(jù)庫類型和查詢工具,這篇文章主要介紹了如何用Python檢查SQLite數(shù)據(jù)庫中表是否存在的相關(guān)資料,文中通過代碼介紹的非常詳細(xì),需要的朋友可以參考下2025-11-11

