Pytorch技法之繼承Subset類完成自定義數(shù)據(jù)拆分
我們?cè)?《torch.utils.data.DataLoader與迭代器轉(zhuǎn)換操作》 中介紹了如何使用Pytorch內(nèi)置的數(shù)據(jù)集進(jìn)行論文實(shí)驗(yàn),如 torchvision.datasets 。下面是加載內(nèi)置訓(xùn)練數(shù)據(jù)集的常見操作:
from torchvision.datasets import FashionMNIST from torchvision.transforms import Compose, ToTensor, Normalize RAW_DATA_PATH = './rawdata' transform = Compose( ? ? ? ? [ToTensor(), ? ? ? ? ?Normalize((0.1307,), (0.3081,)) ? ? ? ? ?] ? ? ) train_data = FashionMNIST( ? ? ? ? root=RAW_DATA_PATH, ? ? ? ? download=True, ? ? ? ? train=True, ? ? ? ? transform=transform ? ? )
這里的train_data 做為 dataset 對(duì)象,它擁有許多熟悉,我們可以通過以下方法獲取樣本數(shù)據(jù)的分類類別集合、樣本的特征維度、樣本的標(biāo)簽集合等信息。
classes = train_data.classes num_features = train_data.data[0].shape[0] train_labels = train_data.targets print(classes) print(num_features) print(train_labels)
輸出如下:
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
28
tensor([9, 0, 0, ..., 3, 0, 5])
但是,我們常常會(huì)在訓(xùn)練集的基礎(chǔ)上拆分出驗(yàn)證集(或者只用部分?jǐn)?shù)據(jù)來(lái)進(jìn)行訓(xùn)練)。我們想到的第一個(gè)方法是使用 torch.utils.data.random_split 對(duì) dataset 進(jìn)行劃分,下面我們假設(shè)劃分10000個(gè)樣本做為訓(xùn)練集,其余樣本做為驗(yàn)證集:
from torch.utils.data import random_split k = 10000 train_data, valid_data = random_split(train_data, [k, len(train_data)-k])
注意我們?nèi)绻蛴?train_data 和 valid_data 的類型,可以看到顯示:
<class 'torch.utils.data.dataset.Subset'>
已經(jīng)不再是torchvision.datasets.mnist.FashionMNIST 對(duì)象,而是一個(gè)所謂的 Subset 對(duì)象!此時(shí) Subset 對(duì)象雖然仍然還存有 data 屬性,但是內(nèi)置的 target 和 classes 屬性已經(jīng)不復(fù)存在,
比如如果我們強(qiáng)行訪問 valid_data 的 target 屬性:
valid_target = valid_data.target
就會(huì)報(bào)如下錯(cuò)誤:
'Subset' object has no attribute 'target'
但如果我們?cè)诤罄m(xù)的代碼中常常會(huì)將拆分后的數(shù)據(jù)集也默認(rèn)為 dataset 對(duì)象,那么該如何做到代碼的一致性呢?
這里有一個(gè)trick,那就是以繼承 SubSet 類的方式的方式定義一個(gè)新的 CustomSubSet 類,使新類在保持 SubSet 類的基本屬性的基礎(chǔ)上,擁有和原本數(shù)據(jù)集類相似的屬性,如 targets 和 classes 等:
from torch.utils.data import Subset class CustomSubset(Subset): ? ? '''A custom subset class''' ? ? def __init__(self, dataset, indices): ? ? ? ? super().__init__(dataset, indices) ? ? ? ? self.targets = dataset.targets # 保留targets屬性 ? ? ? ? self.classes = dataset.classes # 保留classes屬性 ? ? def __getitem__(self, idx): #同時(shí)支持索引訪問操作 ? ? ? ? x, y = self.dataset[self.indices[idx]] ? ? ? ? ? ? ? return x, y? ? ? def __len__(self): # 同時(shí)支持取長(zhǎng)度操作 ? ? ? ? return len(self.indices)
然后就引出了第二種劃分方法,即通過初始化 CustomSubset 對(duì)象的方式直接對(duì)數(shù)據(jù)集進(jìn)行劃分(這里為了簡(jiǎn)化省略了shuffle的步驟):
import numpy as np from copy import deepcopy origin_data = deepcopy(train_data) train_data = CustomSubset(origin_data, np.arange(k)) valid_data = CustomSubset(origin_data, np.arange(k, len(origin_data))-k)
注意: CustomSubset 類的初始化方法的第二個(gè)參數(shù) indices 為樣本索引,我們可以通過 np.arange() 的方法來(lái)創(chuàng)建。
然后,我們?cè)僭L問 valid_data 對(duì)應(yīng)的 classes 和 targes 屬性:
print(valid_data.classes) print(valid_data.targets)
此時(shí),我們發(fā)現(xiàn)可以成功訪問這些屬性了:
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] tensor([9, 0, 0, ?..., 3, 0, 5])
當(dāng)然, CustomSubset 的作用并不只是添加數(shù)據(jù)集的屬性,我們還可以自定義一些數(shù)據(jù)預(yù)處理操作。
我們將類的結(jié)構(gòu)修改如下:
class CustomSubset(Subset): ? ? '''A custom subset class with customizable data transformation''' ? ? def __init__(self, dataset, indices, subset_transform=None): ? ? ? ? super().__init__(dataset, indices) ? ? ? ? self.targets = dataset.targets ? ? ? ? self.classes = dataset.classes ? ? ? ? self.subset_transform = subset_transform ? ? def __getitem__(self, idx): ? ? ? ? x, y = self.dataset[self.indices[idx]] ? ? ? ?? ? ? ? ? if self.subset_transform: ? ? ? ? ? ? x = self.subset_transform(x) ? ? ?? ? ? ? ? return x, y ?? ? ?? ? ? def __len__(self):? ? ? ? ? return len(self.indices)
我們可以在使用樣本前設(shè)置好數(shù)據(jù)預(yù)處理算子:
from torchvision import transforms valid_data.subset_transform = transforms.Compose(\ ? ? [transforms.RandomRotation((180,180))])
這樣,我們?cè)傧裣铝羞@樣用索引訪問取出數(shù)據(jù)集樣本時(shí),就會(huì)自動(dòng)調(diào)用算子完成預(yù)處理操作:
print(valid_data[0])
打印結(jié)果縮略如下:
(tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)
到此這篇關(guān)于Pytorch技法之繼承Subset類完成自定義數(shù)據(jù)拆分的文章就介紹到這了,更多相關(guān)繼承Subset類完成自定義數(shù)據(jù)拆分內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python?slack桌面自動(dòng)化開發(fā)工具
這篇文章主要為大家介紹了python?slack桌面自動(dòng)化開發(fā)工具使用示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-08-08
Python利用Diagrams繪制漂亮的系統(tǒng)架構(gòu)圖
Diagrams 是一個(gè)基于Python繪制云系統(tǒng)架構(gòu)的模塊,它能夠通過非常簡(jiǎn)單的描述就能可視化架構(gòu)。本文將利用它繪制漂亮的系統(tǒng)架構(gòu)圖,感興趣的可以了解一下2023-01-01
Python?獲取指定開頭指定結(jié)尾所夾中間內(nèi)容(推薦)
獲取文章中指定開頭、指定結(jié)尾中所夾的內(nèi)容。其中,開頭和結(jié)尾均有多種,但最多也就十幾種,所以代碼還是具有可行性的,今天小編給大家介紹通過Python?獲取指定開頭指定結(jié)尾所夾中間內(nèi)容,感興趣的朋友一起看看吧2023-02-02
python開發(fā)之for循環(huán)操作實(shí)例詳解
這篇文章主要介紹了python開發(fā)之for循環(huán)操作,以實(shí)例形式較為詳細(xì)的分析了Python中for循環(huán)的具體使用技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-11-11
tensorflow模型繼續(xù)訓(xùn)練 fineturn實(shí)例
今天小編就為大家分享一篇tensorflow模型繼續(xù)訓(xùn)練 fineturn實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2020-01-01
Python面向?qū)ο罂偨Y(jié)及類與正則表達(dá)式詳解
Python中的類提供了面向?qū)ο缶幊痰乃谢竟δ埽侯惖睦^承機(jī)制允許多個(gè)基類,派生類可以覆蓋基類中的任何方法,方法中可以調(diào)用基類中的同名方法。這篇文章主要介紹了Python面向?qū)ο罂偨Y(jié)及類與正則表達(dá)式 ,需要的朋友可以參考下2019-04-04

