pytorch加載語(yǔ)音類自定義數(shù)據(jù)集的方法教程
前言
pytorch對(duì)一下常用的公開(kāi)數(shù)據(jù)集有很方便的API接口,但是當(dāng)我們需要使用自己的數(shù)據(jù)集訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),就需要自定義數(shù)據(jù)集,在pytorch中,提供了一些類,方便我們定義自己的數(shù)據(jù)集合
- torch.utils.data.Dataset:所有繼承他的子類都應(yīng)該重寫(xiě) __len()__ , __getitem()__ 這兩個(gè)方法
- __len()__ :返回?cái)?shù)據(jù)集中數(shù)據(jù)的數(shù)量
- __getitem()__ :返回支持下標(biāo)索引方式獲取的一個(gè)數(shù)據(jù)
- torch.utils.data.DataLoader:對(duì)數(shù)據(jù)集進(jìn)行包裝,可以設(shè)置batch_size、是否shuffle....
第一步
自定義的 Dataset 都需要繼承 torch.utils.data.Dataset 類,并且重寫(xiě)它的兩個(gè)成員方法:
- __len()__:讀取數(shù)據(jù),返回?cái)?shù)據(jù)和標(biāo)簽
- __getitem()__:返回?cái)?shù)據(jù)集的長(zhǎng)度
from torch.utils.data import Dataset class AudioDataset(Dataset): def __init__(self, ...): """類的初始化""" pass def __getitem__(self, item): """每次怎么讀數(shù)據(jù),返回?cái)?shù)據(jù)和標(biāo)簽""" return data, label def __len__(self): """返回整個(gè)數(shù)據(jù)集的長(zhǎng)度""" return total
注意事項(xiàng):Dataset只負(fù)責(zé)數(shù)據(jù)的抽象,一次調(diào)用getiitem只返回一個(gè)樣本
案例:
文件目錄結(jié)構(gòu)
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:讀取p225文件夾中的音頻數(shù)據(jù)
class AudioDataset(Dataset): def __init__(self, data_folder, sr=16000, dimension=8192): self.data_folder = data_folder self.sr = sr self.dim = dimension # 獲取音頻名列表 self.wav_list = [] for root, dirnames, filenames in os.walk(data_folder): for filename in fnmatch.filter(filenames, "*.wav"): # 實(shí)現(xiàn)列表特殊字符的過(guò)濾或篩選,返回符合匹配“.wav”字符列表 self.wav_list.append(os.path.join(root, filename)) def __getitem__(self, item): # 讀取一個(gè)音頻文件,返回每個(gè)音頻數(shù)據(jù) filename = self.wav_list[item] wb_wav, _ = librosa.load(filename, sr=self.sr) # 取 幀 if len(wb_wav) >= self.dim: max_audio_start = len(wb_wav) - self.dim audio_start = np.random.randint(0, max_audio_start) wb_wav = wb_wav[audio_start: audio_start + self.dim] else: wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant") return wb_wav, filename def __len__(self): # 音頻文件的總數(shù) return len(self.wav_list)
注意事項(xiàng):19-24行:每個(gè)音頻的長(zhǎng)度不一樣,如果直接讀取數(shù)據(jù)返回出來(lái)的話,會(huì)造成維度不匹配而報(bào)錯(cuò),因此只能每次取一個(gè)音頻文件讀取一幀,這樣顯然并沒(méi)有用到所有的語(yǔ)音數(shù)據(jù),
第二步
實(shí)例化 Dataset 對(duì)象
Dataset= AudioDataset("./p225", sr=16000)
如果要通過(guò)batch讀取數(shù)據(jù)的可直接跳到第三步,如果你想一個(gè)一個(gè)讀取數(shù)據(jù)的可以看我接下來(lái)的操作
# 實(shí)例化AudioDataset對(duì)象
train_set = AudioDataset("./p225", sr=16000)
for i, data in enumerate(train_set):
wb_wav, filname = data
print(i, wb_wav.shape, filname)
if i == 3:
break
# 0 (8192,) ./p225\p225_001.wav
# 1 (8192,) ./p225\p225_002.wav
# 2 (8192,) ./p225\p225_003.wav
# 3 (8192,) ./p225\p225_004.wav
第三步
如果想要通過(guò)batch讀取數(shù)據(jù),需要使用DataLoader進(jìn)行包裝
為何要使用DataLoader?
- 深度學(xué)習(xí)的輸入是mini_batch形式
- 樣本加載時(shí)候可能需要隨機(jī)打亂順序,shuffle操作
- 樣本加載需要采用多線程
pytorch提供的 DataLoader 封裝了上述的功能,這樣使用起來(lái)更方便。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
參數(shù):
- dataset:加載的數(shù)據(jù)集(Dataset對(duì)象)
- batch_size:每個(gè)批次要加載多少個(gè)樣本(默認(rèn)值:1)
- shuffle:每個(gè)epoch是否將數(shù)據(jù)打亂
- sampler:定義從數(shù)據(jù)集中抽取樣本的策略。如果指定,則不能指定洗牌。
- batch_sampler:類似于sampler,但每次返回一批索引。與batch_size、shuffle、sampler和drop_last相互排斥。
- num_workers:使用多進(jìn)程加載的進(jìn)程數(shù),0代表不使用多線程
- collate_fn:如何將多個(gè)樣本數(shù)據(jù)拼接成一個(gè)batch,一般使用默認(rèn)拼接方式
- pin_memory:是否將數(shù)據(jù)保存在pin memory區(qū),pin memory中的數(shù)據(jù)轉(zhuǎn)到GPU會(huì)快一些
- drop_last:dataset中的數(shù)據(jù)個(gè)數(shù)可能不是batch_size的整數(shù)倍,drop_last為T(mén)rue會(huì)將多出來(lái)不足一個(gè)batch的數(shù)據(jù)丟棄
返回:數(shù)據(jù)加載器
案例:
# 實(shí)例化AudioDataset對(duì)象
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
for (i, data) in enumerate(train_loader):
wav_data, wav_name = data
print(wav_data.shape) # torch.Size([8, 8192])
print(i, wav_name)
# ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
# './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
我們來(lái)吃幾個(gè)栗子消化一下:
栗子1
這個(gè)例子就是本文一直舉例的,栗子1只是合并了一下而已
文件目錄結(jié)構(gòu)
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:讀取p225文件夾中的音頻數(shù)據(jù)
import fnmatch
import os
import librosa
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class Aduio_DataLoader(Dataset):
def __init__(self, data_folder, sr=16000, dimension=8192):
self.data_folder = data_folder
self.sr = sr
self.dim = dimension
# 獲取音頻名列表
self.wav_list = []
for root, dirnames, filenames in os.walk(data_folder):
for filename in fnmatch.filter(filenames, "*.wav"): # 實(shí)現(xiàn)列表特殊字符的過(guò)濾或篩選,返回符合匹配“.wav”字符列表
self.wav_list.append(os.path.join(root, filename))
def __getitem__(self, item):
# 讀取一個(gè)音頻文件,返回每個(gè)音頻數(shù)據(jù)
filename = self.wav_list[item]
print(filename)
wb_wav, _ = librosa.load(filename, sr=self.sr)
# 取 幀
if len(wb_wav) >= self.dim:
max_audio_start = len(wb_wav) - self.dim
audio_start = np.random.randint(0, max_audio_start)
wb_wav = wb_wav[audio_start: audio_start + self.dim]
else:
wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
return wb_wav, filename
def __len__(self):
# 音頻文件的總數(shù)
return len(self.wav_list)
train_set = Aduio_DataLoader("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
for (i, data) in enumerate(train_loader):
wav_data, wav_name = data
print(wav_data.shape) # torch.Size([8, 8192])
print(i, wav_name)
# ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
# './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
注意事項(xiàng):
- 27-33行:每個(gè)音頻的長(zhǎng)度不一樣,如果直接讀取數(shù)據(jù)返回出來(lái)的話,會(huì)造成維度不匹配而報(bào)錯(cuò),因此只能每次取一個(gè)音頻文件讀取一幀,這樣顯然并沒(méi)有用到所有的語(yǔ)音數(shù)據(jù),
- 48行:我們?cè)赺_getitem__中并沒(méi)有將numpy數(shù)組轉(zhuǎn)換為tensor格式,可是第48行顯示數(shù)據(jù)是tensor格式的。這里需要引起注意
栗子2
相比于案例1,案例二才是重點(diǎn),因?yàn)槲覀儾豢赡苊看沃粡囊灰纛l文件中讀取一幀,然后讀取另一個(gè)音頻文件,通常情況下,一段音頻有很多幀,我們需要的是按順序的讀取一個(gè)batch_size的音頻幀,先讀取第一個(gè)音頻文件,如果滿足一個(gè)batch,則不用讀取第二個(gè)batch,如果不足一個(gè)batch則讀取第二個(gè)音頻文件,來(lái)補(bǔ)充。
我給出一個(gè)建議,先按順序讀取每個(gè)音頻文件,以窗長(zhǎng)8192、幀移4096對(duì)語(yǔ)音進(jìn)行分幀,然后拼接。得到(幀數(shù),幀長(zhǎng),1)(frame_num, frame_len, 1)的數(shù)組保存到h5中。然后用上面講到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 讀取數(shù)據(jù)。
具體實(shí)現(xiàn)代碼:
第一步:創(chuàng)建一個(gè)H5_generation腳本用來(lái)將數(shù)據(jù)轉(zhuǎn)換為h5格式文件:
第二步:通過(guò)Dataset從h5格式文件中讀取數(shù)據(jù)
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py
def load_h5(h5_path):
# load training data
with h5py.File(h5_path, 'r') as hf:
print('List of arrays in input file:', hf.keys())
X = np.array(hf.get('data'), dtype=np.float32)
Y = np.array(hf.get('label'), dtype=np.float32)
return X, Y
class AudioDataset(Dataset):
"""數(shù)據(jù)加載器"""
def __init__(self, data_folder):
self.data_folder = data_folder
self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1)
def __getitem__(self, item):
# 返回一個(gè)音頻數(shù)據(jù)
X = self.X[item]
Y = self.Y[item]
return X, Y
def __len__(self):
return len(self.X)
train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)
for (i, wav_data) in enumerate(train_loader):
X, Y = wav_data
print(i, X.shape)
# 0 torch.Size([64, 8192, 1])
# 1 torch.Size([64, 8192, 1])
# ...
我嘗試在__init__中生成h5文件,但是會(huì)導(dǎo)致內(nèi)存爆炸,就很奇怪,因此我只好分開(kāi)了,
參考
pytorch學(xué)習(xí)(四)—自定義數(shù)據(jù)集(講的比較詳細(xì))
總結(jié)
到此這篇關(guān)于pytorch加載語(yǔ)音類自定義數(shù)據(jù)集的文章就介紹到這了,更多相關(guān)pytorch加載語(yǔ)音類自定義數(shù)據(jù)集內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Windows環(huán)境下python環(huán)境安裝使用圖文教程
這篇文章主要為大家詳細(xì)介紹了Windows環(huán)境下python安裝使用圖文教程,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-03-03
使用國(guó)內(nèi)鏡像源優(yōu)化pip install下載的方法步驟
在Python開(kāi)發(fā)中,pip 是一個(gè)不可或缺的工具,用于安裝和管理Python包,然而,由于默認(rèn)的PyPI服務(wù)器位于國(guó)外,國(guó)內(nèi)用戶在安裝依賴時(shí)可能會(huì)遇到下載速度慢、連接不穩(wěn)定等問(wèn)題,所以本文將詳細(xì)介紹如何使用國(guó)內(nèi)鏡像源來(lái)加速pip install -r requirements.txt的過(guò)程2025-03-03
Python畫(huà)圖學(xué)習(xí)入門(mén)教程
這篇文章主要介紹了Python畫(huà)圖的方法,結(jié)合實(shí)例形式分析了Python基本的線性圖、餅狀圖等繪制技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2016-07-07
Pytorch平均池化nn.AvgPool2d()使用方法實(shí)例
平均池化層,又叫平均匯聚層,下面這篇文章主要給大家介紹了關(guān)于Pytorch平均池化nn.AvgPool2d()使用方法的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2023-02-02
PyTorch中常見(jiàn)損失函數(shù)的使用詳解
損失函數(shù),又叫目標(biāo)函數(shù),是指計(jì)算機(jī)標(biāo)簽值和預(yù)測(cè)值直接差異的函數(shù),本文為大家整理了PyTorch中常見(jiàn)損失函數(shù)的簡(jiǎn)單解釋和使用,希望對(duì)大家有所幫助2023-06-06
python-redis-lock實(shí)現(xiàn)鎖自動(dòng)續(xù)期的源碼邏輯
這篇文章主要介紹了python-redis-lock實(shí)現(xiàn)鎖自動(dòng)續(xù)期的源碼邏輯,其中用到了多線程threading、弱引用weakref和Lua腳本等相關(guān)知識(shí),需要的朋友可以參考下2024-07-07
python 邊緣擴(kuò)充方式的實(shí)現(xiàn)示例
本文主要介紹了python 邊緣擴(kuò)充方式的實(shí)現(xiàn)示例,文中通過(guò)示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-03-03

