在keras下實(shí)現(xiàn)多個(gè)模型的融合方式
在網(wǎng)上搜過發(fā)現(xiàn)關(guān)于keras下的模型融合框架其實(shí)很簡單,奈何網(wǎng)上說了一大堆,這個(gè)東西官方文檔上就有,自己寫了個(gè)demo:
# Function:基于keras框架下實(shí)現(xiàn),多個(gè)獨(dú)立任務(wù)分類
# Writer: PQF
# Time: 2019/9/29
import numpy as np
from keras.layers import Input, Dense
from keras.models import Model
import tensorflow as tf
# 生成訓(xùn)練集
dataset_size = 128*3
rdm = np.random.RandomState(1)
X = rdm.rand(dataset_size,2)
Y1 = [[int(x1+x2<1)] for (x1,x2) in X]
Y2 = [[int(x1+x2*x2<0.5)] for (x1,x2) in X]
X_train = X[:-2]
Y_train1 = Y1[:-2]
Y_train2 = Y2[:-2]
X_test = X[-2:dataset_size]
Y_test1 = Y1[-2:dataset_size]
Y_test2 = Y2[-2:dataset_size]
#網(wǎng)絡(luò)一
input = Input(shape=(2,))
x = Dense(units=16,activation='relu')(input)
output = Dense(units=1,activation='sigmoid',name='output1')(x)
#網(wǎng)絡(luò)二
input2 = Input(shape=(2,))
x2 = Dense(units=16,activation='relu')(input2)
output2 = Dense(units=1,activation='sigmoid',name='output2')(x2)
#模型合并
model = Model(inputs=[input,input2],outputs=[output,output2])
model.summary()
model.compile(optimizer='rmsprop',loss='binary_crossentropy',loss_weights=[1.0,1.0])
model.fit([X_train,X_train],[Y_train1,Y_train2],batch_size=48,epochs=200)
print('x_test is :\n')
print(X_test)
print('y_test1 is :\n')
print(Y_test1)
print('y_test2 is :\n')
print(Y_test2)
predict = model.predict([X_test,X_test])
print('prediction is : \n')
print(predict[0])
print(predict[1])
補(bǔ)充知識:keras的融合層使用理解
最近開始研究U-net網(wǎng)絡(luò),其中接觸到了融合層的概念,做個(gè)筆記。

上圖為U-net網(wǎng)絡(luò),其中上采樣層(綠色箭頭)需要與下采樣層池化層(紅色箭頭)層進(jìn)行融合,要求每層的圖片大小一致,維度依照融合的方式可以不同,融合之后輸出的圖片相較于沒有融合層的網(wǎng)絡(luò),邊緣處要清晰很多!
這時(shí)候就要用到keras的融合層概念(Keras中文文檔https://keras.io/zh/)
文檔中分別講述了加減乘除的四中融合方式,這種方式要求兩層之間shape必須一致。
重點(diǎn)講述一下Concatenate(拼接)方式
拼接方式默認(rèn)依照最后一維也就是通道來進(jìn)行拼接

如同上圖(128*128*64)與(128*128*128)進(jìn)行Concatenate之后的shape為128*128*192
ps:
中文文檔為老版本,最新版本的keras.layers.merge方法進(jìn)行了整合

上圖為新版本整合之后的方法,具體使用方法一看就懂,不再贅述。
以上這篇在keras下實(shí)現(xiàn)多個(gè)模型的融合方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python利用GDAL模塊實(shí)現(xiàn)讀取柵格數(shù)據(jù)并對指定數(shù)據(jù)加以篩選掩膜
這篇文章主要為大家詳細(xì)介紹了如何基于Python語言中g(shù)dal模塊,對遙感影像數(shù)據(jù)進(jìn)行柵格讀取與計(jì)算,同時(shí)基于QA波段對像元加以篩選、掩膜的操作,需要的可以參考一下2023-02-02
Python的內(nèi)建模塊itertools的使用解析
這篇文章主要介紹了Python的內(nèi)建模塊itertools的使用解析,itertools是python的迭代器模塊,itertools提供的工具相當(dāng)高效且節(jié)省內(nèi)存,Python的內(nèi)建模塊itertools提供了非常有用的用于操作迭代對象的函數(shù),需要的朋友可以參考下2023-09-09
Python通過串口實(shí)現(xiàn)收發(fā)文件
這篇文章主要為大家詳細(xì)介紹了Python如何通過串口實(shí)現(xiàn)收發(fā)文件功能,文中的示例代碼簡潔易懂,具有一定的借鑒價(jià)值,感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-11-11

