Pytorch DataLoader 變長數(shù)據(jù)處理方式
關(guān)于Pytorch中怎么自定義Dataset數(shù)據(jù)集類、怎樣使用DataLoader迭代加載數(shù)據(jù),這篇官方文檔已經(jīng)說得很清楚了,這里就不在贅述。
現(xiàn)在的問題:有的時(shí)候,特別對(duì)于NLP任務(wù)來說,輸入的數(shù)據(jù)可能不是定長的,比如多個(gè)句子的長度一般不會(huì)一致,這時(shí)候使用DataLoader加載數(shù)據(jù)時(shí),不定長的句子會(huì)被胡亂切分,這肯定是不行的。
解決方法是重寫DataLoader的collate_fn,具體方法如下:
# 假如每一個(gè)樣本為:
sample = {
# 一個(gè)句子中各個(gè)詞的id
'token_list' : [5, 2, 4, 1, 9, 8],
# 結(jié)果y
'label' : 5,
}
# 重寫collate_fn函數(shù),其輸入為一個(gè)batch的sample數(shù)據(jù)
def collate_fn(batch):
# 因?yàn)閠oken_list是一個(gè)變長的數(shù)據(jù),所以需要用一個(gè)list來裝這個(gè)batch的token_list
token_lists = [item['token_list'] for item in batch]
# 每個(gè)label是一個(gè)int,我們把這個(gè)batch中的label也全取出來,重新組裝
labels = [item['label'] for item in batch]
# 把labels轉(zhuǎn)換成Tensor
labels = torch.Tensor(labels)
return {
'token_list': token_lists,
'label': labels,
}
# 在使用DataLoader加載數(shù)據(jù)時(shí),注意collate_fn參數(shù)傳入的是重寫的函數(shù)
DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
使用以上方法,可以保證DataLoader能Load出一個(gè)batch的數(shù)據(jù),load出來的東西就是重寫的collate_fn函數(shù)最后return出來的字典。
以上這篇Pytorch DataLoader 變長數(shù)據(jù)處理方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
YOLOv5改進(jìn)系列之增加小目標(biāo)檢測層
yolov5出來已經(jīng)很長時(shí)間了,所以有關(guān)yolov5的一些詳細(xì)介紹在這里就不一一介紹了,下面這篇文章主要給大家介紹了關(guān)于YOLOv5改進(jìn)系列之增加小目標(biāo)檢測層的相關(guān)資料,需要的朋友可以參考下2022-09-09
python 根據(jù)時(shí)間來生成唯一的字符串方法
今天小編就為大家分享一篇python 根據(jù)時(shí)間來生成唯一的字符串方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-01-01
Python寫入MySQL數(shù)據(jù)庫的三種方式詳解
Python 讀取數(shù)據(jù)自動(dòng)寫入 MySQL 數(shù)據(jù)庫,這個(gè)需求在工作中是非常普遍的,主要涉及到 python 操作數(shù)據(jù)庫,讀寫更新等。本文總結(jié)了Python寫入MySQL數(shù)據(jù)庫的三種方式,需要的可以參考一下2022-06-06
django和flask哪個(gè)值得研究學(xué)習(xí)
在本篇文章里小編給大家整理的是一篇關(guān)于django和flask哪個(gè)值得研究學(xué)習(xí)內(nèi)容,需要的朋友們可以參考下。2020-07-07
教你用python提取txt文件中的特定信息并寫入Excel
這篇文章主要給大家介紹了如何利用python提取txt文件中的特定信息并寫入Excel的相關(guān)資料,Python是一個(gè)強(qiáng)大的語言,解決這點(diǎn)問題非常簡單,文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考下2021-11-11

