keras使用Sequence類(lèi)調(diào)用大規(guī)模數(shù)據(jù)集進(jìn)行訓(xùn)練的實(shí)現(xiàn)
使用Keras如果要使用大規(guī)模數(shù)據(jù)集對(duì)網(wǎng)絡(luò)進(jìn)行訓(xùn)練,就沒(méi)辦法先加載進(jìn)內(nèi)存再?gòu)膬?nèi)存直接傳到顯存了,除了使用Sequence類(lèi)以外,還可以使用迭代器去生成數(shù)據(jù),但迭代器無(wú)法在fit_generation里開(kāi)啟多進(jìn)程,會(huì)影響數(shù)據(jù)的讀取和預(yù)處理效率,在本文中就不在敘述了,有需要的可以另外去百度。
下面是我所使用的代碼
class SequenceData(Sequence):
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
#返回長(zhǎng)度,通過(guò)len(<你的實(shí)例>)調(diào)用
def __len__(self):
return self.L - self.batch_size
#即通過(guò)索引獲取a[0],a[1]這種
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})
def data_generation(self, batch_datas):
#預(yù)處理操作
return img1s,img2s,audios,labels
然后在代碼里通過(guò)fit_generation函數(shù)調(diào)用并訓(xùn)練
這里要注意,use_multiprocessing參數(shù)是是否開(kāi)啟多進(jìn)程,由于python的多線程不是真的多線程,所以多進(jìn)程還是會(huì)獲得比較客觀的加速,但不支持windows,windows下python無(wú)法使用多進(jìn)程。
D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)),
epochs=2, workers=20, #callbacks=[checkpoint],
use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))
同樣的,也可以在測(cè)試的時(shí)候使用
model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)
補(bǔ)充知識(shí):keras數(shù)據(jù)自動(dòng)生成器,繼承keras.utils.Sequence,結(jié)合fit_generator實(shí)現(xiàn)節(jié)約內(nèi)存訓(xùn)練
我就廢話不多說(shuō)了,大家還是直接看代碼吧~
#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
class DataGenerator(keras.utils.Sequence):
def __init__(self, datas, batch_size=1, shuffle=True):
self.batch_size = batch_size
self.datas = datas
self.indexes = np.arange(len(self.datas))
self.shuffle = shuffle
def __len__(self):
#計(jì)算每一個(gè)epoch的迭代次數(shù)
return math.ceil(len(self.datas) / float(self.batch_size))
def __getitem__(self, index):
#生成每個(gè)batch數(shù)據(jù),這里就根據(jù)自己對(duì)數(shù)據(jù)的讀取方式進(jìn)行發(fā)揮了
# 生成batch_size個(gè)索引
batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# 根據(jù)索引獲取datas集合中的數(shù)據(jù)
batch_datas = [self.datas[k] for k in batch_indexs]
# 生成數(shù)據(jù)
X, y = self.data_generation(batch_datas)
return X, y
def on_epoch_end(self):
#在每一次epoch結(jié)束是否需要進(jìn)行一次隨機(jī),重新隨機(jī)一下index
if self.shuffle == True:
np.random.shuffle(self.indexes)
def data_generation(self, batch_datas):
images = []
labels = []
# 生成數(shù)據(jù)
for i, data in enumerate(batch_datas):
#x_train數(shù)據(jù)
image = cv2.imread(data)
image = list(image)
images.append(image)
#y_train數(shù)據(jù)
right = data.rfind("\\",0)
left = data.rfind("\\",0,right)+1
class_name = data[left:right]
if class_name=="dog":
labels.append([0,1])
else:
labels.append([1,0])
#如果為多輸出模型,Y的格式要變一下,外層list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
return np.array(images), np.array(labels)
# 讀取樣本名稱(chēng),然后根據(jù)樣本名稱(chēng)去讀取數(shù)據(jù)
class_num = 0
train_datas = []
for file in os.listdir("D:/xxx"):
file_path = os.path.join("D:/xxx", file)
if os.path.isdir(file_path):
class_num = class_num + 1
for sub_file in os.listdir(file_path):
train_datas.append(os.path.join(file_path, sub_file))
# 數(shù)據(jù)生成器
training_generator = DataGenerator(train_datas)
#構(gòu)建網(wǎng)絡(luò)
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)
以上這篇keras使用Sequence類(lèi)調(diào)用大規(guī)模數(shù)據(jù)集進(jìn)行訓(xùn)練的實(shí)現(xiàn)就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
kafka-python批量發(fā)送數(shù)據(jù)的實(shí)例
今天小編就為大家分享一篇kafka-python批量發(fā)送數(shù)據(jù)的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-12-12
一篇文章帶你了解python標(biāo)準(zhǔn)庫(kù)--sys模塊
這篇文章主要介紹了Python標(biāo)準(zhǔn)庫(kù)之Sys模塊使用詳解,本文講解了使用sys模塊獲得腳本的參數(shù)、處理模塊、使用sys模塊操作模塊搜索路徑、使用sys模塊查找內(nèi)建模塊、使用sys模塊查找已導(dǎo)入的模塊等使用案例,需要的朋友可以參考下2021-08-08
python常用數(shù)據(jù)結(jié)構(gòu)元組詳解
這篇文章主要介紹了python常用數(shù)據(jù)結(jié)構(gòu)元組詳解,文章圍繞主題展開(kāi)詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-08-08
python使用wmi模塊獲取windows下硬盤(pán)信息的方法
這篇文章主要介紹了python使用wmi模塊獲取windows下硬盤(pán)信息的方法,涉及Python獲取系統(tǒng)硬件信息的相關(guān)技巧,需要的朋友可以參考下2015-05-05
瀏覽器常用基本操作之python3+selenium4自動(dòng)化測(cè)試(基礎(chǔ)篇3)
瀏覽器常用基本操作有很多種,今天給大家介紹python3+selenium4自動(dòng)化測(cè)試的操作方法,是最最基礎(chǔ)的一篇,對(duì)python3 selenium4自動(dòng)化測(cè)試相關(guān)知識(shí)感興趣的朋友一起看看吧2021-05-05
python為tornado添加recaptcha驗(yàn)證碼功能
tornado作為微框架,并沒(méi)有自帶驗(yàn)證碼組件,recaptcha是著名的驗(yàn)證碼解決方案,簡(jiǎn)單易用,被很多公司運(yùn)用來(lái)防止惡意注冊(cè)和評(píng)論。tornado添加recaptchaHA非常容易2014-02-02
python實(shí)現(xiàn)指定字符串補(bǔ)全空格的方法
這篇文章主要介紹了python實(shí)現(xiàn)指定字符串補(bǔ)全空格的方法,涉及Python中rjust,ljust和center方法的使用技巧,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2015-04-04
Python中實(shí)現(xiàn)單例模式的n種方式和原理
這篇文章主要介紹了Python中實(shí)現(xiàn)單例模式的n種方式和原理,需要的朋友可以參考下2018-11-11
pandas DataFrame 交集并集補(bǔ)集的實(shí)現(xiàn)
這篇文章主要介紹了pandas DataFrame 交集并集補(bǔ)集的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-06-06

