pytorch cnn 識別手寫的字實現(xiàn)自建圖片數(shù)據(jù)
更新時間:2018年05月20日 17:03:26 作者:瓦力冫
這篇文章主要介紹了pytorch cnn 識別手寫的字實現(xiàn)自建圖片數(shù)據(jù),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
本文主要介紹了pytorch cnn 識別手寫的字實現(xiàn)自建圖片數(shù)據(jù),分享給大家,具體如下:
# library
# standard library
import os
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1) # reproducible
# Hyper Parameters
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001 # learning rate
root = "./mnist/raw/"
def default_loader(path):
# return Image.open(path).convert('RGB')
return Image.open(path)
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
fh.close()
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
img = Image.fromarray(np.array(img), mode='L')
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn.Conv2d(
in_channels=1, # input height
out_channels=16, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
)
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
nn.ReLU(), # activation
nn.MaxPool2d(2), # output shape (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
output = self.out(x)
return output, x # return x for visualization
cnn = CNN()
print(cnn) # net architecture
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
# training and testing
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
b_x = Variable(x) # batch x
b_y = Variable(y) # batch y
output = cnn(b_x)[0] # cnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if step % 50 == 0:
cnn.eval()
eval_loss = 0.
eval_acc = 0.
for i, (tx, ty) in enumerate(test_loader):
t_x = Variable(tx)
t_y = Variable(ty)
output = cnn(t_x)[0]
loss = loss_func(output, t_y)
eval_loss += loss.data[0]
pred = torch.max(output, 1)[1]
num_correct = (pred == t_y).sum()
eval_acc += float(num_correct.data[0])
acc_rate = eval_acc / float(len(test_data))
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))
圖片和label 見上一篇文章《pytorch 把MNIST數(shù)據(jù)集轉(zhuǎn)換成圖片和txt》
結(jié)果如下:

以上就是本文的全部內(nèi)容,希望對大家的學習有所幫助,也希望大家多多支持腳本之家。
您可能感興趣的文章:
- Pytorch 使用CNN圖像分類的實現(xiàn)
- pytorch實現(xiàn)textCNN的具體操作
- Pytorch mask-rcnn 實現(xiàn)細節(jié)分享
- 在Pytorch中使用Mask R-CNN進行實例分割操作
- pytorch實現(xiàn)CNN卷積神經(jīng)網(wǎng)絡
- pytorch實現(xiàn)用CNN和LSTM對文本進行分類方式
- 用Pytorch訓練CNN(數(shù)據(jù)集MNIST,使用GPU的方法)
- pytorch + visdom CNN處理自建圖片數(shù)據(jù)集的方法
- PyTorch CNN實戰(zhàn)之MNIST手寫數(shù)字識別示例
- PyTorch上實現(xiàn)卷積神經(jīng)網(wǎng)絡CNN的方法
- 基于PyTorch實現(xiàn)一個簡單的CNN圖像分類器
相關文章
python3+PyQt5重新實現(xiàn)QT事件處理程序
這篇文章主要為大家詳細介紹了python3+PyQt5重新實現(xiàn)QT事件處理程序,具有一定的參考價值,感興趣的小伙伴們可以參考一下2018-04-04
Python調(diào)用DeepSeek?API實現(xiàn)對本地數(shù)據(jù)庫的AI管理
這篇文章主要為大家詳細介紹了Python如何基于DeepSeek模型實現(xiàn)對本地數(shù)據(jù)庫的AI管理,文中的示例代碼簡潔易懂,有需要的小伙伴可以跟隨小編一起學習一下2025-02-02
Python?selenium?find_element()示例詳解
selenium定位元素的函數(shù)/方法可以分為兩類:find_element及find_elements,下面這篇文章主要給大家介紹了關于Python?selenium?find_element()的相關資料,需要的朋友可以參考下2022-07-07

