Pytorch寫數(shù)字識別LeNet模型
LeNet網(wǎng)絡(luò)

LeNet網(wǎng)絡(luò)過卷積層時候保持分辨率不變,過池化層時候分辨率變小。實現(xiàn)如下
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np
import tqdm as tqdm
class LeNet(nn.Module):
? ? def __init__(self) -> None:
? ? ? ? super().__init__()
? ? ? ? self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2,stride=2),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2,stride=2),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Flatten(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(16*25,120),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(120,84),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(84,10))
? ? ? ??
? ??
? ? def forward(self,x):
? ? ? ? return self.sequential(x)
class MLP(nn.Module):
? ? def __init__(self) -> None:
? ? ? ? super().__init__()
? ? ? ? self.sequential = nn.Sequential(nn.Flatten(),
? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(28*28,120),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(120,84),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(84,10))
? ? ? ??
? ??
? ? def forward(self,x):
? ? ? ? return self.sequential(x)
epochs = 15
batch = 32
lr=0.9
loss = nn.CrossEntropyLoss()
model = LeNet()
optimizer = torch.optim.SGD(model.parameters(),lr)
device = torch.device('cuda')
root = r"./"
trans_compose ?= transforms.Compose([transforms.ToTensor(),
? ? ? ? ? ? ? ? ? ? ])
train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True)
test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True)
train_loader = DataLoader(train_data,batch_size=batch,shuffle=True)
test_loader = DataLoader(test_data,batch_size=batch,shuffle=False)
model.to(device)
loss.to(device)
# model.apply(init_weights)
for epoch in range(epochs):
? train_loss = 0
? test_loss = 0
? correct_train = 0
? correct_test = 0
? for index,(x,y) in enumerate(train_loader):
? ? x = x.to(device)
? ? y = y.to(device)
? ? predict = model(x)
? ? L = loss(predict,y)
? ? optimizer.zero_grad()
? ? L.backward()
? ? optimizer.step()
? ? train_loss = train_loss + L
? ? correct_train += (predict.argmax(dim=1)==y).sum()
? acc_train = correct_train/(batch*len(train_loader))
? with torch.no_grad():
? ? for index,(x,y) in enumerate(test_loader):
? ? ? [x,y] = [x.to(device),y.to(device)]
? ? ? predict = model(x)
? ? ? L1 = loss(predict,y)
? ? ? test_loss = test_loss + L1
? ? ? correct_test += (predict.argmax(dim=1)==y).sum()
? ? acc_test = correct_test/(batch*len(test_loader))
? print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')訓(xùn)練結(jié)果
epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229
泛化能力測試
找了一張圖片,將其分割成只含一個數(shù)字的圖片進(jìn)行測試

images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE)
h,w = images_np.shape
images_np = np.array(255*torch.ones(h,w))-images_np#圖片反色
images = Image.fromarray(images_np)
plt.figure(1)
plt.imshow(images)
test_images = []
for i in range(10):
? for j in range(16):
? ? test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16])
sample = test_images[77]
sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
predict = model(sample_tensor)
output = predict.argmax()
print(output)
plt.figure(2)
plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))
此時預(yù)測結(jié)果為4,預(yù)測正確。從這段代碼中可以看到有一個反色的步驟,若不反色,結(jié)果會受到影響,如下圖所示,預(yù)測為0,錯誤。
模型用于輸入的圖片是單通道的黑白圖片,這里由于可視化出現(xiàn)了黃色,但實際上是黑白色,反色操作說明了數(shù)據(jù)的預(yù)處理十分的重要,很多數(shù)據(jù)如果是不清理過是無法直接用于推理的。

將所有用來泛化性測試的圖片進(jìn)行準(zhǔn)確率測試:
correct = 0
i = 0
cnt = 1
for sample in test_images:
? sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
? sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
? predict = model(sample_tensor)
? output = predict.argmax()
? if(output==i):
? ? correct+=1
? if(cnt%16==0):
? ? i+=1
? cnt+=1
acc_g = correct/len(test_images)
print(f'acc_g:{acc_g}')如果不反色,acc_g=0.15
acc_g:0.50625
到此這篇關(guān)于Pytorch寫數(shù)字識別LeNet模型的文章就介紹到這了,更多相關(guān)Pytorch寫數(shù)字識別LeNet模型內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python 基于Twisted框架的文件夾網(wǎng)絡(luò)傳輸源碼
這篇文章主要介紹了Python 基于Twisted框架的文件夾網(wǎng)絡(luò)傳輸源碼,需要的朋友可以參考下2016-08-08
Python日期時間模塊datetime詳解與Python 日期時間的比較,計算實例代碼
python中的datetime模塊提供了操作日期和時間功能,本文為大家講解了datetime模塊的使用方法及與其相關(guān)的日期比較,計算實例2018-09-09
使用python/pytorch讀取數(shù)據(jù)集的示例代碼
這篇文章主要為大家詳細(xì)介紹了使用python/pytorch讀取數(shù)據(jù)集的示例,文中的示例代碼講解詳細(xì),具有一定參考價值,感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-12-12
python 時間 T 去掉 帶上ms 毫秒 時間格式的操作
這篇文章主要介紹了python 時間 T 去掉 帶上ms 毫秒 時間格式的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-04-04
實踐Python的爬蟲框架Scrapy來抓取豆瓣電影TOP250
這篇文章主要介紹了實踐Python的爬蟲框架Scrapy來抓取豆瓣電影TOP250的過程,文中的環(huán)境基于Windows操作系統(tǒng),需要的朋友可以參考下2016-01-01
python?selenium在打開的瀏覽器中動態(tài)調(diào)整User?Agent
這篇文章主要介紹的是python?selenium在打開的瀏覽器中動態(tài)調(diào)整User?Agent,具體相關(guān)資料請需要的朋友參考下面文章詳細(xì)內(nèi)容,希望對你有所幫助2022-02-02

