基于Python實(shí)現(xiàn)的簡單數(shù)字識別程序
這里我們使用全連接神經(jīng)網(wǎng)絡(luò)(MLP) 實(shí)現(xiàn)的 MNIST 數(shù)字識別代碼,結(jié)構(gòu)更簡單,僅包含幾個(gè)線性層和激活函數(shù)。
簡易代碼
模型定義代碼,model.py
import torch.nn as nn
# 定義一個(gè)簡單的 CNN 模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.2)
def forward(self, x):
x = self.flatten(x) # [B, 1, 28, 28] -> [B, 784]
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x) # 輸出層不加激活(CrossEntropyLoss 內(nèi)部含 softmax)
return x然后訓(xùn)練代碼,train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from model import SimpleModel # ?? 從 model.py 導(dǎo)入
# 配置
batch_size = 64
learning_rate = 0.001
num_epochs = 10
model_save_path = 'mnist_mlp.pth'
# 數(shù)據(jù)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 模型、損失、優(yōu)化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 訓(xùn)練
print(f"Training on {device}...")
model.train()
for epoch in range(num_epochs):
total_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')
# 保存
torch.save(model.state_dict(), model_save_path)
print(f"? Model saved to {model_save_path}")訓(xùn)練
在訓(xùn)練之前我們需要安裝下python依賴
pip install torch torchvision
然后我們就可以開始訓(xùn)練模型啦!執(zhí)行命令python ./train.py,你會看到類似輸出
Training on cpu... Epoch [1/10], Loss: 0.3501 Epoch [2/10], Loss: 0.1702 Epoch [3/10], Loss: 0.1335 Epoch [4/10], Loss: 0.1141 Epoch [5/10], Loss: 0.1027 Epoch [6/10], Loss: 0.0915 Epoch [7/10], Loss: 0.0884 Epoch [8/10], Loss: 0.0801 Epoch [9/10], Loss: 0.0769 Epoch [10/10], Loss: 0.0715 ? Model saved to mnist_mlp.pth
目錄下會生成一個(gè)mnist_mlp.pth,mnist_mlp.pth 是一個(gè) PyTorch 模型權(quán)重保存文件,本質(zhì)上是一個(gè) 序列化后的字典(state_dict),存儲了神經(jīng)網(wǎng)絡(luò)中所有可學(xué)習(xí)參數(shù)(如權(quán)重和偏置)的數(shù)值。
測試模型
現(xiàn)在我們拿我們的模型去試試我們的數(shù)字圖片了~
predict.py
# predict.py
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import SimpleModel
import argparse
import os
def predict_image(image_path, model_path='mnist_mlp.pth', device='cpu'):
# 1. 加載模型
model = SimpleModel()
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval() # 推理模式
# 2. 圖像預(yù)處理(必須和訓(xùn)練時(shí)一致!)
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # 轉(zhuǎn)灰度
transforms.Resize((28, 28)), # 調(diào)整為 28x28
transforms.ToTensor(), # 轉(zhuǎn)為 Tensor [0,1]
transforms.Normalize((0.1307,), (0.3081,)) # 用 MNIST 的均值/標(biāo)準(zhǔn)差
])
# 3. 加載并預(yù)處理圖像
image = Image.open(image_path).convert('L') # 強(qiáng)制灰度(兼容 RGB 輸入)
input_tensor = transform(image) # shape: [1, 28, 28]
input_batch = input_tensor.unsqueeze(0) # 增加 batch 維度 → [1, 1, 28, 28]
# 4. 推理
with torch.no_grad():
output = model(input_batch)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
return predicted_class, confidence
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict digit in an image using trained MLP')
parser.add_argument('image_path', type=str, help='Path to the input image (e.g., digit.png)')
args = parser.parse_args()
if not os.path.exists(args.image_path):
print(f"? Error: Image file '{args.image_path}' not found!")
exit(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
digit, conf = predict_image(args.image_path, device=device)
print(f"? Predicted digit: {digit}")
print(f"?? Confidence: {conf:.4f} ({conf*100:.2f}%)")我們可以python .\predict.py .\data\digit.png來看看預(yù)測的結(jié)果如何。
到此這篇關(guān)于基于Python實(shí)現(xiàn)的簡單數(shù)字識別程序的文章就介紹到這了,更多相關(guān)Python數(shù)字識別程序內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python包裝和授權(quán)學(xué)習(xí)教程
包裝是指對一個(gè)已經(jīng)存在的對象進(jìn)行系定義加工,實(shí)現(xiàn)授權(quán)是包裝的一個(gè)特性,下面這篇文章主要給大家介紹了關(guān)于python包裝和授權(quán)的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2023-06-06
tensorflow之獲取tensor的shape作為max_pool的ksize實(shí)例
今天小編就為大家分享一篇tensorflow之獲取tensor的shape作為max_pool的ksize實(shí)例,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01
Python實(shí)現(xiàn)根據(jù)Excel生成Model和數(shù)據(jù)導(dǎo)入腳本
最近遇到一個(gè)需求,有幾十個(gè)Excel,每個(gè)的字段都不一樣,然后都差不多是第一行是表頭,后面幾千上萬的數(shù)據(jù),需要把這些Excel中的數(shù)據(jù)全都加入某個(gè)已經(jīng)上線的Django項(xiàng)目。所以我造了個(gè)自動生成?Model和導(dǎo)入腳本的輪子,希望對大家有所幫助2022-11-11
使用Python和OpenCV實(shí)現(xiàn)實(shí)時(shí)文檔掃描與矯正系統(tǒng)
在日常工作和學(xué)習(xí)中,我們經(jīng)常需要將紙質(zhì)文檔數(shù)字化,手動拍攝文檔照片常常會出現(xiàn)角度傾斜、透?視變形等問題,影響后續(xù)使用,本文將介紹如何使用Python和OpenCV構(gòu)建一個(gè)實(shí)時(shí)文檔掃描與矯正系統(tǒng),能夠通過攝像頭自動檢測文檔邊緣并進(jìn)行透?視變換矯正,需要的朋友可以參考下2025-05-05
解決phantomjs截圖失敗,phantom.exit位置的問題
今天小編就為大家分享一篇解決phantomjs截圖失敗,phantom.exit位置的問題,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-05-05
自定義Django_rest_framework_jwt登陸錯誤返回的解決
這篇文章主要介紹了自定義Django_rest_framework_jwt登陸錯誤返回的解決,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-10-10
pandas dataframe rolling移動計(jì)算方式
在Pandas中,rolling()方法用于執(zhí)行移動窗口計(jì)算,常用于時(shí)間序列數(shù)據(jù)分析,例如,計(jì)算某商品的7天或1個(gè)月銷售總量,可以通過rolling()輕松實(shí)現(xiàn),該方法的關(guān)鍵參數(shù)包括window(窗口大?。?min_periods(最小計(jì)算周期)2024-09-09
使用Python實(shí)現(xiàn)一鍵往Word文檔的表格中填寫數(shù)據(jù)
在工作中,我們經(jīng)常遇到將Excel表中的部分信息填寫到Word文檔的對應(yīng)表格中,以生成報(bào)告,方便打印,所以本文小編就給大家介紹了如何使用Python實(shí)現(xiàn)一鍵往Word文檔的表格中填寫數(shù)據(jù),文中有詳細(xì)的代碼示例供大家參考,需要的朋友可以參考下2023-12-12

