Pytorch 實現(xiàn)數(shù)據(jù)集自定義讀取
更新時間:2020年01月18日 17:20:27 作者:_寒潭雁影
今天小編就為大家分享一篇Pytorch 實現(xiàn)數(shù)據(jù)集自定義讀取,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
以讀取VOC2012語義分割數(shù)據(jù)集為例,具體見代碼注釋:
VocDataset.py
from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time
#VOC數(shù)據(jù)集分類對應(yīng)顏色標簽
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]]
#顏色標簽空間轉(zhuǎn)到序號標簽空間,就他媽這里浪費巨量的時間,這里還他媽的有問題
def voc_label_indices(colormap, colormap2label):
"""Assign label indices for Pascal VOC2012 Dataset."""
idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
#out = np.empty(idx.shape, dtype = np.int64)
out = colormap2label[idx]
out=out.astype(np.int64)#數(shù)據(jù)類型轉(zhuǎn)換
end = time.time()
return out
class MyDataset(data.Dataset):#創(chuàng)建自定義的數(shù)據(jù)讀取類
def __init__(self, root, is_train, crop_size=(320,480)):
self.rgb_mean =(0.485, 0.456, 0.406)
self.rgb_std = (0.229, 0.224, 0.225)
self.root=root
self.crop_size=crop_size
images = []#創(chuàng)建空列表存文件名稱
txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
with open(txt_fname, 'r') as f:
self.images = f.read().split()
#數(shù)據(jù)名稱整理
self.files = []
for name in self.images:
img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
self.files.append({
"img": img_file,
"label": label_file,
"name": name
})
self.colormap2label = np.zeros(256**3)
#整個循環(huán)的意思就是將顏色標簽映射為單通道的數(shù)組索引
for i, cm in enumerate(VOC_COLORMAP):
self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
#按照索引讀取每個元素的具體內(nèi)容
def __getitem__(self, index):
datafiles = self.files[index]
name = datafiles["name"]
image = Image.open(datafiles["img"])
label = Image.open(datafiles["label"]).convert('RGB')#打開的是PNG格式的圖片要轉(zhuǎn)到rgb的格式下,不然結(jié)果會比較要命
#以圖像中心為中心截取固定大小圖像,小于固定大小的圖像則自動填0
imgCenterCrop = transforms.Compose([
transforms.CenterCrop(self.crop_size),
transforms.ToTensor(),
transforms.Normalize(self.rgb_mean, self.rgb_std),#圖像數(shù)據(jù)正則化
])
labelCenterCrop = transforms.CenterCrop(self.crop_size)
cropImage=imgCenterCrop(image)
croplabel=labelCenterCrop(label)
croplabel=torch.from_numpy(np.array(croplabel)).long()#把標簽數(shù)據(jù)類型轉(zhuǎn)為torch
#將顏色標簽圖轉(zhuǎn)為序號標簽圖
mylabel=voc_label_indices(croplabel, self.colormap2label)
return cropImage,mylabel
#返回圖像數(shù)據(jù)長度
def __len__(self):
return len(self.files)
Train.py
import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from VocDataset import MyDataset
#VOC數(shù)據(jù)集分類對應(yīng)顏色標簽
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]]
root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)
#從數(shù)據(jù)集中拿出一個批次的數(shù)據(jù)
for i, data in enumerate(trainloader):
getimgs, labels= data
img = transforms.ToPILImage()(getimgs[0])
labels = labels.numpy()#tensor轉(zhuǎn)numpy
labels=labels[0]#獲得批次標簽集中的一張標簽圖像
labels = labels.transpose((1,0))#數(shù)組維度切換,將第1維換到第0維,第0維換到第1維
##將單通道索引標簽圖片映射回顏色標簽圖片
newIm= Image.new('RGB', (480, 320))#創(chuàng)建一張與標簽大小相同的圖片,用以顯示標簽所對應(yīng)的顏色
for i in range(0, 480):
for j in range(0, 320):
sele=labels[i][j]#取得坐標點對應(yīng)像素的值
newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))
#顯示圖像和標簽
plt.figure("image")
ax1 = plt.subplot(1,2,1)
ax2 = plt.subplot(1,2,2)
plt.sca(ax1)
plt.imshow(img)
plt.sca(ax2)
plt.imshow(newIm)
plt.show()
以上這篇Pytorch 實現(xiàn)數(shù)據(jù)集自定義讀取就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python深度學(xué)習albumentations數(shù)據(jù)增強庫
下面開始albumenations的正式介紹,在這里我強烈建議英語基礎(chǔ)還好的讀者去官方網(wǎng)站跟著教程一步步學(xué)習,而這里的內(nèi)容主要是我自己的一個總結(jié)以及方便英語能力較弱的讀者學(xué)習2021-09-09
如何在Python中將字符串轉(zhuǎn)換為數(shù)組詳解
最近在用Python,做一個小腳本,有個操作就是要把內(nèi)容換成數(shù)組對象再進行相關(guān)操作,下面這篇文章主要給大家介紹了關(guān)于如何在Python中將字符串轉(zhuǎn)換為數(shù)組的相關(guān)資料,需要的朋友可以參考下2022-12-12
使用python將csv數(shù)據(jù)導(dǎo)入mysql數(shù)據(jù)庫
這篇文章主要為大家詳細介紹了如何使用python將csv數(shù)據(jù)導(dǎo)入mysql數(shù)據(jù)庫,文中的示例代碼講解詳細,感興趣的小伙伴可以跟隨小編一起學(xué)習一下2024-05-05
在Django下測試與調(diào)試REST API的方法詳解
今天小編就為大家分享一篇在Django下測試與調(diào)試REST API的方法詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-08-08
Python Pandas學(xué)習之Pandas數(shù)據(jù)結(jié)構(gòu)詳解
Pandas中一共有三種數(shù)據(jù)結(jié)構(gòu),分別為:Series、DataFrame和MultiIndex(老版本中叫Panel )。其中Series是一維數(shù)據(jù)結(jié)構(gòu),DataFrame是二維的表格型數(shù)據(jù)結(jié)構(gòu),MultiIndex是三維的數(shù)據(jù)結(jié)構(gòu)。本文將詳細為大家講解這三個數(shù)據(jù)結(jié)構(gòu),需要的可以參考一下2022-02-02

