Pytorch中DataLoader的使用方法詳解
在Pytorch中,torch.utils.data中的Dataset與DataLoader是處理數(shù)據(jù)集的兩個函數(shù),用來處理加載數(shù)據(jù)集。通常情況下,使用的關(guān)鍵在于構(gòu)建dataset類。
一:dataset類構(gòu)建。
在構(gòu)建數(shù)據(jù)集類時,除了__init__(self),還要有__len__(self)與__getitem__(self,item)兩個方法,這三個是必不可少的,至于其它用于數(shù)據(jù)處理的函數(shù),可以任意定義。
class dataset:
def __init__(self,...):
...
def __len__(self,...):
return n
def __getitem__(self,item):
return data[item]正常情況下,該數(shù)據(jù)集是要繼承Pytorch中Dataset類的,但實際操作中,即使不繼承,數(shù)據(jù)集類構(gòu)建后仍可以用Dataloader()加載的。
在dataset類中,__len__(self)返回數(shù)據(jù)集中數(shù)據(jù)個數(shù),__getitem__(self,item)表示每次返回第item條數(shù)據(jù)。
二:DataLoader使用
在構(gòu)建dataset類后,即可使用DataLoader加載。DataLoader中常用參數(shù)如下:
1.dataset:需要載入的數(shù)據(jù)集,如前面構(gòu)造的dataset類。
2.batch_size:批大小,在神經(jīng)網(wǎng)絡(luò)訓(xùn)練時我們很少逐條數(shù)據(jù)訓(xùn)練,而是幾條數(shù)據(jù)作為一個batch進行訓(xùn)練。
3.shuffle:是否在打亂數(shù)據(jù)集樣本順序。True為打亂,F(xiàn)alse反之。
4.drop_last:是否舍去最后一個batch的數(shù)據(jù)(很多情況下數(shù)據(jù)總數(shù)N與batch size不整除,導(dǎo)致最后一個batch不為batch size)。True為舍去,F(xiàn)alse反之。
三:舉例
兔兔以指標為1,數(shù)據(jù)個數(shù)為100的數(shù)據(jù)為例。
import torch
from torch.utils.data import DataLoader
class dataset:
def __init__(self):
self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
self.y=(torch.sin(self.x)+1)/2
def __len__(self):
return 100
def __getitem__(self, item):
return self.x[item],self.y[item]
data=DataLoader(dataset(),batch_size=10,shuffle=True)
for batch in data:
print(batch)當然,利用這個數(shù)據(jù)集可以進行簡單的神經(jīng)網(wǎng)絡(luò)訓(xùn)練。
from torch import nn
data=DataLoader(dataset(),batch_size=10,shuffle=True)
bp=nn.Sequential(nn.Linear(1,5),
nn.Sigmoid(),
nn.Linear(5,1),
nn.Sigmoid())
optim=torch.optim.Adam(params=bp.parameters())
Loss=nn.MSELoss()
for epoch in range(10):
print('the {} epoch'.format(epoch))
for batch in data:
yp=bp(batch[0])
loss=Loss(yp,batch[1])
optim.zero_grad()
loss.backward()
optim.step()ps:下面再給大家補充介紹下Pytorch中DataLoader的使用。
前言
最近開始接觸pytorch,從跑別人寫好的代碼開始,今天需要把輸入數(shù)據(jù)根據(jù)每個batch的最長輸入數(shù)據(jù),填充到一樣的長度(之前是將所有的數(shù)據(jù)直接填充到一樣的長度再輸入)。
剛開始是想偷懶,沒有去認真了解輸入的機制,結(jié)果一直報錯…還是要認真學(xué)習(xí)呀!
加載數(shù)據(jù)
pytorch中加載數(shù)據(jù)的順序是:
①創(chuàng)建一個dataset對象
②創(chuàng)建一個dataloader對象
③循環(huán)dataloader對象,將data,label拿到模型中去訓(xùn)練
dataset
你需要自己定義一個class,里面至少包含3個函數(shù):
①__init__:傳入數(shù)據(jù),或者像下面一樣直接在函數(shù)里加載數(shù)據(jù)
②__len__:返回這個數(shù)據(jù)集一共有多少個item
③__getitem__:返回一條訓(xùn)練數(shù)據(jù),并將其轉(zhuǎn)換成tensor
import torch
from torch.utils.data import Dataset
class Mydata(Dataset):
def __init__(self):
a = np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True)
b = np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True)
d = np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True)
c = np.load("D:/Python/nlp/NRE/c.npy")
self.x = list(zip(a,b,d,c))
def __getitem__(self, idx):
assert idx < len(self.x)
return self.x[idx]
def __len__(self):
return len(self.x)
dataloader
參數(shù):
dataset:傳入的數(shù)據(jù)
shuffle = True:是否打亂數(shù)據(jù)
collate_fn:使用這個參數(shù)可以自己操作每個batch的數(shù)據(jù)
dataset = Mydata() dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = mycollate)
下面是將每個batch的數(shù)據(jù)填充到該batch的最大長度
def mycollate(data):
a = []
b = []
c = []
d = []
max_len = len(data[0][0])
for i in data:
if len(i[0])>max_len:
max_len = len(i[0])
if len(i[1])>max_len:
max_len = len(i[1])
if len(i[2])>max_len:
max_len = len(i[2])
print(max_len)
# 填充
for i in data:
if len(i[0])<max_len:
i[0].extend([27] * (max_len-len(i[0])))
if len(i[1])<max_len:
i[1].extend([27] * (max_len-len(i[1])))
if len(i[2])<max_len:
i[2].extend([27] * (max_len-len(i[2])))
a.append(i[0])
b.append(i[1])
d.append(i[2])
c.extend(i[3])
# 這里要自己轉(zhuǎn)成tensor
a = torch.Tensor(a)
b = torch.Tensor(b)
c = torch.Tensor(c)
d = torch.Tensor(d)
data1 = [a,b,d,c]
print("data1",data1)
return data1
結(jié)果:

最后循環(huán)該dataloader ,拿到數(shù)據(jù)放入模型進行訓(xùn)練:
for ii, data in enumerate(test_data_loader):
if opt.use_gpu:
data = list(map(lambda x: torch.LongTensor(x.long()).cuda(), data))
else:
data = list(map(lambda x: torch.LongTensor(x.long()), data))
out = model(data[:-1]) #數(shù)據(jù)data[:-1]
loss = F.cross_entropy(out, data[-1])# 最后一列是標簽
寫在最后:建議像我一樣剛開始不太熟練的小伙伴,在處理數(shù)據(jù)輸入的時候可以打印出來仔細查看。
到此這篇關(guān)于Pytorch中DataLoader的使用方法的文章就介紹到這了,更多相關(guān)Pytorch DataLoader內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
使用Pyhton集合set()實現(xiàn)成果查漏的例子
今天小編就為大家分享一篇使用Pyhton集合set()實現(xiàn)成果查漏的例子,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-11-11
Django中日期處理注意事項與自定義時間格式轉(zhuǎn)換詳解
這篇文章主要給大家介紹了關(guān)于Django中日期處理注意事項與自定義時間格式轉(zhuǎn)換的相關(guān)資料,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2018-08-08
python+splinter實現(xiàn)12306網(wǎng)站刷票并自動購票流程
這篇文章主要為大家詳細介紹了python+splinter實現(xiàn)12306網(wǎng)站刷票并自動購票流程,具有一定的參考價值,感興趣的小伙伴們可以參考一下2018-09-09
Python復(fù)制Word內(nèi)容并使用格式設(shè)字體與大小實例代碼
這篇文章主要介紹了Python復(fù)制Word內(nèi)容并使用格式設(shè)字體與大小實例代碼,小編覺得還是挺不錯的,具有一定借鑒價值,需要的朋友可以參考下2018-01-01

