淺談keras通過model.fit_generator訓練模型(節(jié)省內存)
前言
前段時間在訓練模型的時候,發(fā)現(xiàn)當訓練集的數(shù)量過大,并且輸入的圖片維度過大時,很容易就超內存了,舉個簡單例子,如果我們有20000個樣本,輸入圖片的維度是224x224x3,用float32存儲,那么如果我們一次性將全部數(shù)據(jù)載入內存的話,總共就需要20000x224x224x3x32bit/8=11.2GB 這么大的內存,所以如果一次性要加載全部數(shù)據(jù)集的話是需要很大內存的。
如果我們直接用keras的fit函數(shù)來訓練模型的話,是需要傳入全部訓練數(shù)據(jù),但是好在提供了fit_generator,可以分批次的讀取數(shù)據(jù),節(jié)省了我們的內存,我們唯一要做的就是實現(xiàn)一個生成器(generator)。
1.fit_generator函數(shù)簡介
fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
參數(shù):
generator:一個生成器,或者一個 Sequence (keras.utils.Sequence) 對象的實例。這是我們實現(xiàn)的重點,后面會著介紹生成器和sequence的兩種實現(xiàn)方式。
steps_per_epoch:這個是我們在每個epoch中需要執(zhí)行多少次生成器來生產(chǎn)數(shù)據(jù),fit_generator函數(shù)沒有batch_size這個參數(shù),是通過steps_per_epoch來實現(xiàn)的,每次生產(chǎn)的數(shù)據(jù)就是一個batch,因此steps_per_epoch的值我們通過會設為(樣本數(shù)/batch_size)。如果我們的generator是sequence類型,那么這個參數(shù)是可選的,默認使用len(generator) 。
epochs:即我們訓練的迭代次數(shù)。
verbose:0, 1 或 2。日志顯示模式。 0 = 安靜模式, 1 = 進度條, 2 = 每輪一行
callbacks:在訓練時調用的一系列回調函數(shù)。
validation_data:和我們的generator類似,只是這個使用于驗證的,不參與訓練。
validation_steps:和前面的steps_per_epoch類似。
class_weight:可選的將類索引(整數(shù))映射到權重(浮點)值的字典,用于加權損失函數(shù)(僅在訓練期間)。 這可以用來告訴模型「更多地關注」來自代表性不足的類的樣本。(感覺這個參數(shù)用的比較少)
max_queue_size:整數(shù)。生成器隊列的最大尺寸。默認為10.
workers:整數(shù)。使用的最大進程數(shù)量,如果使用基于進程的多線程。 如未指定,workers 將默認為 1。如果為 0,將在主線程上執(zhí)行生成器。
use_multiprocessing:布爾值。如果 True,則使用基于進程的多線程。默認為False。
shuffle:是否在每輪迭代之前打亂 batch 的順序。 只能與Sequence(keras.utils.Sequence) 實例同用。
initial_epoch: 開始訓練的輪次(有助于恢復之前的訓練)
2.generator實現(xiàn)
2.1生成器的實現(xiàn)方式
樣例代碼:
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
def process_x(path):
img = Image.open(path)
img = img.resize((96,96))
img = img.convert('RGB')
img = np.array(img)
img = np.asarray(img, np.float32) / 255.0
#也可以進行進行一些數(shù)據(jù)數(shù)據(jù)增強的處理
return img
count =1
def generate_arrays_from_file(x_y):
#x_y 是我們的訓練集包括標簽,每一行的第一個是我們的圖片路徑,后面的是我們的獨熱化后的標簽
global count
batch_size = 8
while 1:
batch_x = x_y[(count - 1) * batch_size:count * batch_size, 0]
batch_y = x_y[(count - 1) * batch_size:count * batch_size, 1:]
batch_x = np.array([process_x(img_path) for img_path in batch_x])
batch_y = np.array(batch_y).astype(np.float32)
print("count:"+str(count))
count = count+1
yield (batch_x, batch_y)
model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
x_y = []
model.fit_generator(generate_arrays_from_file(x_y),steps_per_epoch=10, epochs=2,max_queue_size=1,workers=1)
在理解上面代碼之前我們需要首先了解yield的用法。
yield關鍵字:
我們先通過一個例子看一下yield的用法:
def foo():
print("starting...")
while True:
res = yield 4
print("res:",res)
g = foo()
print(next(g))
print("----------")
print(next(g))
運行結果:
starting... 4 ---------- res: None 4
帶yield的函數(shù)是一個生成器,而不是一個函數(shù)。因為foo函數(shù)中有yield關鍵字,所以foo函數(shù)并不會真的執(zhí)行,而是先得到一個生成器的實例,當我們第一次調用next函數(shù)的時候,foo函數(shù)才開始行,首先先執(zhí)行foo函數(shù)中的print方法,然后進入while循環(huán),循環(huán)執(zhí)行到y(tǒng)ield時,yield其實相當于return,函數(shù)返回4,程序停止。所以我們第一次調用next(g)的輸出結果是前面兩行。
然后當我們再次調用next(g)時,這個時候是從上一次停止的地方繼續(xù)執(zhí)行,也就是要執(zhí)行res的賦值操作,因為4已經(jīng)在上一次執(zhí)行被return了,隨意賦值res為None,然后執(zhí)行print(“res:”,res)打印res: None,再次循環(huán)到y(tǒng)ield返回4,程序停止。
所以yield關鍵字的作用就是我們能夠從上一次程序停止的地方繼續(xù)執(zhí)行,這樣我們用作生成器的時候,就避免一次性讀入數(shù)據(jù)造成內存不足的情況。
現(xiàn)在看到上面的示例代碼:
generate_arrays_from_file函數(shù)就是我們的生成器,每次循環(huán)讀取一個batch大小的數(shù)據(jù),然后處理數(shù)據(jù),并返回。x_y是我們的把路徑和標簽合并后的訓練集,類似于如下形式:
['data/img\\fimg_4092.jpg' '0' '1' '0' '0' '0' ]
至于格式不一定要這樣,可以是自己的格式,至于怎么處理,根于自己的格式,在process_x進行處理,這里因為是存放的圖片路徑,所以在process_x函數(shù)的主要作用就是讀取圖片并進行歸一化等操作,也可以在這里定義自己需要進行的操作,例如對圖像進行實時數(shù)據(jù)增強。
2.2使用Sequence實現(xiàn)generator
示例代碼:
class BaseSequence(Sequence):
"""
基礎的數(shù)據(jù)流生成器,每次迭代返回一個batch
BaseSequence可直接用于fit_generator的generator參數(shù)
fit_generator會將BaseSequence再次封裝為一個多進程的數(shù)據(jù)流生成器
而且能保證在多進程下的一個epoch中不會重復取相同的樣本
"""
def __init__(self, img_paths, labels, batch_size, img_size):
#np.hstack在水平方向上平鋪
self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))
self.batch_size = batch_size
self.img_size = img_size
def __len__(self):
#math.ceil表示向上取整
#調用len(BaseSequence)時返回,返回的是每個epoch我們需要讀取數(shù)據(jù)的次數(shù)
return math.ceil(len(self.x_y) / self.batch_size)
def preprocess_img(self, img_path):
img = Image.open(img_path)
resize_scale = self.img_size[0] / max(img.size[:2])
img = img.resize((self.img_size[0], self.img_size[0]))
img = img.convert('RGB')
img = np.array(img)
# 數(shù)據(jù)歸一化
img = np.asarray(img, np.float32) / 255.0
return img
def __getitem__(self, idx):
batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]
batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
batch_y = np.array(batch_y).astype(np.float32)
print(batch_x.shape)
return batch_x, batch_y
#重寫的父類Sequence中的on_epoch_end方法,在每次迭代完后調用。
def on_epoch_end(self):
#每次迭代后重新打亂訓練集數(shù)據(jù)
np.random.shuffle(self.x_y)
在上面代碼中,__len __和__getitem __,是我們重寫的魔法方法,__len __是當我們調用len(BaseSequence)函數(shù)時調用,這里我們返回(樣本總量/batch_size),供我們傳入fit_generator中的steps_per_epoch參數(shù);__getitem __可以讓對象實現(xiàn)迭代功能,這樣在將BaseSequence的對象傳入fit_generator中后,不斷執(zhí)行generator就可循環(huán)的讀取數(shù)據(jù)了。
舉個例子說明一下getitem的作用:
class Animal: def __init__(self, animal_list): self.animals_name = animal_list def __getitem__(self, index): return self.animals_name[index] animals = Animal(["dog","cat","fish"]) for animal in animals: print(animal)
輸出結果:
dog cat fish
并且使用Sequence類可以保證在多進程的情況下,每個epoch中的樣本只會被訓練一次。
以上這篇淺談keras通過model.fit_generator訓練模型(節(jié)省內存)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Windows 下python3.8環(huán)境安裝教程圖文詳解
這篇文章主要介紹了Windows 下python3.8環(huán)境安裝教程圖文詳解,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-03-03
利用Celery實現(xiàn)Django博客PV統(tǒng)計功能詳解
給網(wǎng)站增加pv、uv統(tǒng)計,可以是件很簡單的事,也可以是件很復雜的事。下面這篇文章主要給大家介紹了利用Celery實現(xiàn)Django博客PV統(tǒng)計功能的相關資料,文中介紹的非常詳細,需要的朋友可以參考借鑒,下面來一起看看吧。2017-05-05
Python 序列化和反序列化庫 MarshMallow 的用法實例代碼
marshmallow(Object serialization and deserialization, lightweight and fluffy.)用于對對象進行序列化和反序列化,并同步進行數(shù)據(jù)驗證。這篇文章主要介紹了Python 序列化和反序列化庫 MarshMallow 的用法實例代碼,需要的朋友可以參考下2020-02-02
python使用response.read()接收json數(shù)據(jù)的實例
今天小編就為大家分享一篇python使用response.read()接收json數(shù)據(jù)的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12
tensorflow2.0實現(xiàn)復雜神經(jīng)網(wǎng)絡(多輸入多輸出nn,Resnet)
這篇文章主要介紹了tensorflow2.0實現(xiàn)復雜神經(jīng)網(wǎng)絡(多輸入多輸出nn,Resnet),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2021-03-03

