pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例
更新時間:2020年01月10日 10:43:23 作者:xckkcxxck
今天小編就為大家分享一篇pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
代碼如下,U我認為對于新手來說最重要的是學會rnn讀取數(shù)據(jù)的格式。
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
import sys
sys.path.append('..')
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
#定義數(shù)據(jù)
data_tf = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
#定義模型
class rnn_classify(nn.Module):
def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
super(rnn_classify, self).__init__()
self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用兩層lstm
self.classifier = nn.Linear(hidden_feature, num_class)#將最后一個的rnn使用全連接的到最后的輸出結果
def forward(self, x):
#x的大小為(batch,1,28,28),所以我們需要將其轉化為rnn的輸入格式(28,batch,28)
x = x.squeeze() #去掉(batch,1,28,28)中的1,變成(batch, 28,28)
x = x.permute(2, 0, 1)#將最后一維放到第一維,變成(batch,28,28)
out, _ = self.rnn(x) #使用默認的隱藏狀態(tài),得到的out是(28, batch, hidden_feature)
out = out[-1,:,:]#取序列中的最后一個,大小是(batch, hidden_feature)
out = self.classifier(out) #得到分類結果
return out
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
#定義訓練過程
def get_acc(output, label):
total = output.shape[0]
_, pred_label = output.max(1)
num_correct = (pred_label == label).sum().item()
return num_correct / total
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
if torch.cuda.is_available():
net = net.cuda()
prev_time = datetime.datetime.now()
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
net = net.train()
for im, label in train_data:
if torch.cuda.is_available():
im = Variable(im.cuda()) # (bs, 3, h, w)
label = Variable(label.cuda()) # (bs, h, w)
else:
im = Variable(im)
label = Variable(label)
# forward
output = net(im)
loss = criterion(output, label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += get_acc(output, label)
cur_time = datetime.datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
if valid_data is not None:
valid_loss = 0
valid_acc = 0
net = net.eval()
for im, label in valid_data:
if torch.cuda.is_available():
im = Variable(im.cuda())
label = Variable(label.cuda())
else:
im = Variable(im)
label = Variable(label)
output = net(im)
loss = criterion(output, label)
valid_loss += loss.item()
valid_acc += get_acc(output, label)
epoch_str = (
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data), valid_loss / len(valid_data),
valid_acc / len(valid_data)))
else:
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
(epoch, train_loss / len(train_data),
train_acc / len(train_data)))
prev_time = cur_time
print(epoch_str + time_str)
train(net, train_data, test_data, 10, optimizer, criterion)
以上這篇pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
使用Python在Excel工作表中創(chuàng)建圖表的實現(xiàn)步驟
在現(xiàn)代企業(yè)中,數(shù)據(jù)驅動的決策變得越來越重要,Excel作為企業(yè)中最常用的數(shù)據(jù)分析工具,其強大的表格和圖表功能在日常工作中不可或缺,然而,當面對成百上千條數(shù)據(jù)或需要生成定期報告時,手動制作圖表不僅耗時,還容易出錯,所以本文介紹了如何實現(xiàn)Excel圖表自動化生成2025-12-12
使用Python 統(tǒng)計文件夾內所有pdf頁數(shù)的小工具
這篇文章主要介紹了Python 統(tǒng)計文件夾內所有pdf頁數(shù)的小工具,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-03-03
使用Python的Twisted框架編寫非阻塞程序的代碼示例
Twisted是基于異步模式的開發(fā)框架,因而利用Twisted進行非阻塞編程自然也是必會的用法,下面我們就來一起看一下使用Python的Twisted框架編寫非阻塞程序的代碼示例:2016-05-05
torch.optim優(yōu)化算法理解之optim.Adam()解讀
這篇文章主要介紹了torch.optim優(yōu)化算法理解之optim.Adam()解讀,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-11-11

