pytorch實(shí)現(xiàn)圖像識別(實(shí)戰(zhàn))
1. 代碼講解
1.1 導(dǎo)庫
import os.path from os import listdir import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.nn import AdaptiveAvgPool2d from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data import Dataset import torchvision.transforms as transforms from sklearn.model_selection import train_test_split
1.2 標(biāo)準(zhǔn)化、transform、設(shè)置GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
normalize = transforms.Normalize(
? ?mean=[0.485, 0.456, 0.406],
? ?std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose([transforms.ToTensor(), normalize]) ?# 轉(zhuǎn)換1.3 預(yù)處理數(shù)據(jù)
class DogDataset(Dataset):
# 定義變量
? ? def __init__(self, img_paths, img_labels, size_of_images): ?
? ? ? ? self.img_paths = img_paths
? ? ? ? self.img_labels = img_labels
? ? ? ? self.size_of_images = size_of_images
# 多少長圖片
? ? def __len__(self):
? ? ? ? return len(self.img_paths)
# 打開每組圖片并處理每張圖片
? ? def __getitem__(self, index):
? ? ? ? PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images)
? ? ? ? TENSOR_IMAGE = transform(PIL_IMAGE)
? ? ? ? label = self.img_labels[index]
? ? ? ? return TENSOR_IMAGE, label
print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train')))
print(len(pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')))
print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\test')))
train_paths = []
test_paths = []
labels = []
# 訓(xùn)練集圖片路徑
train_paths_lir = r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train'
for path in listdir(train_paths_lir):
? ? train_paths.append(os.path.join(train_paths_lir, path)) ?
# 測試集圖片路徑
labels_data = pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')
labels_data = pd.DataFrame(labels_data) ?
# 把字符標(biāo)簽離散化,因?yàn)閿?shù)據(jù)有120種狗,不離散化后面把數(shù)據(jù)給模型時會報錯:字符標(biāo)簽過多。把字符標(biāo)簽從0-119編號
size_mapping = {}
value = 0
size_mapping = dict(labels_data['breed'].value_counts())
for kay in size_mapping:
? ? size_mapping[kay] = value
? ? value += 1
# print(size_mapping)
labels = labels_data['breed'].map(size_mapping)
labels = list(labels)
# print(labels)
print(len(labels))
# 劃分訓(xùn)練集和測試集
X_train, X_test, y_train, y_test = train_test_split(train_paths, labels, test_size=0.2)
train_set = DogDataset(X_train, y_train, (32, 32))
test_set = DogDataset(X_test, y_test, (32, 32))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)1.4 建立模型
class LeNet(nn.Module): ? ? def __init__(self): ? ? ? ? super(LeNet, self).__init__() ? ? ? ? self.features = nn.Sequential( ? ? ? ? ? ? nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5), ? ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2), ? ? ? ? ? ? nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2) ? ? ? ? ) ? ? ? ? self.classifier = nn.Sequential( ? ? ? ? ? ? nn.Linear(16 * 5 * 5, 120), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(120, 84), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(84, 120) ? ? ? ? ) ? ? def forward(self, x): ? ? ? ? batch_size = x.shape[0] ? ? ? ? x = self.features(x) ? ? ? ? x = x.view(batch_size, -1) ? ? ? ? x = self.classifier(x) ? ? ? ? return x model = LeNet().to(device) criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(model.parameters()) TRAIN_LOSS = [] ?# 損失 TRAIN_ACCURACY = [] ?# 準(zhǔn)確率
1.5 訓(xùn)練模型
def train(epoch):
? ? model.train()
? ? epoch_loss = 0.0 # 損失
? ? correct = 0 ?# 精確率
? ? for batch_index, (Data, Label) in enumerate(train_loader):
? ? # 扔到GPU中
? ? ? ? Data = Data.to(device)
? ? ? ? Label = Label.to(device)
? ? ? ? output_train = model(Data)
? ? # 計算損失
? ? ? ? loss_train = criterion(output_train, Label)
? ? ? ? epoch_loss = epoch_loss + loss_train.item()
? ? # 計算精確率
? ? ? ? pred = torch.max(output_train, 1)[1]
? ? ? ? train_correct = (pred == Label).sum()
? ? ? ? correct = correct + train_correct.item()
? ? # 梯度歸零、反向傳播、更新參數(shù)
? ? ? ? optimizer.zero_grad()
? ? ? ? loss_train.backward()
? ? ? ? optimizer.step()
? ? print('Epoch: ', epoch, 'Train_loss: ', epoch_loss / len(train_set), 'Train correct: ', correct / len(train_set))1.6 測試模型
和訓(xùn)練集差不多。
def test():
? ? model.eval()
? ? correct = 0.0
? ? test_loss = 0.0
? ? with torch.no_grad():
? ? ? ? for Data, Label in test_loader:
? ? ? ? ? ? Data = Data.to(device)
? ? ? ? ? ? Label = Label.to(device)
? ? ? ? ? ? test_output = model(Data)
? ? ? ? ? ? loss = criterion(test_output, Label)
? ? ? ? ? ? pred = torch.max(test_output, 1)[1]
? ? ? ? ? ? test_correct = (pred == Label).sum()
? ? ? ? ? ? correct = correct + test_correct.item()
? ? ? ? ? ? test_loss = test_loss + loss.item()
? ? print('Test_loss: ', test_loss / len(test_set), 'Test correct: ', correct / len(test_set))1.7結(jié)果
epoch = 10 for n_epoch in range(epoch): ? ? train(n_epoch) test()

到此這篇關(guān)于pytorch實(shí)現(xiàn)圖像識別(實(shí)戰(zhàn))的文章就介紹到這了,更多相關(guān)pytorch實(shí)現(xiàn)圖像識別內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python中property屬性的介紹及其應(yīng)用詳解
這篇文章主要介紹了python中property屬性的介紹及其應(yīng)用詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2019-08-08
python機(jī)器學(xué)習(xí)Sklearn實(shí)戰(zhàn)adaboost算法示例詳解
這篇文章主要為大家介紹了python機(jī)器學(xué)習(xí)Sklearn實(shí)戰(zhàn)adaboost算法的示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步早日升職加薪2021-11-11
python?Pandas庫read_excel()參數(shù)實(shí)例詳解
人們經(jīng)常用pandas處理表格型數(shù)據(jù),時常需要讀入excel表格數(shù)據(jù),下面這篇文章主要給大家介紹了關(guān)于python?Pandas庫read_excel()參數(shù)的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-07-07
淺析AST抽象語法樹及Python代碼實(shí)現(xiàn)
Abstract Syntax Tree抽象語法樹簡寫為ATS,是相當(dāng)于用樹結(jié)構(gòu)將代碼程式表現(xiàn)出來的一種數(shù)據(jù)結(jié)構(gòu),這里我們就來淺析AST抽象語法樹及Python代碼實(shí)現(xiàn)2016-06-06
Python實(shí)用秘技之快速優(yōu)化導(dǎo)包順序詳解
這篇文章主要來和大家分享一個Python中的實(shí)用秘技,那就是如何快速優(yōu)化導(dǎo)包順序,文中的示例代碼簡潔易懂,快跟隨小編一起學(xué)習(xí)起來吧2023-06-06
python中concurrent.futures的具體使用
concurrent.futures是Python標(biāo)準(zhǔn)庫的一部分,提供了ThreadPoolExecutor和ProcessPoolExecutor兩種執(zhí)行器,用于管理線程池和進(jìn)程池,通過這些執(zhí)行器,可以簡化多線程和多進(jìn)程任務(wù)的管理,提高程序執(zhí)行效率2024-09-09
Numpy實(shí)現(xiàn)矩陣運(yùn)算及線性代數(shù)應(yīng)用
這篇文章主要介紹了Numpy實(shí)現(xiàn)矩陣運(yùn)算及線性代數(shù)應(yīng)用,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-03-03

