Keras自定義實(shí)現(xiàn)帶masking的meanpooling層方式
Keras確實(shí)是一大神器,代碼可以寫得非常簡潔,但是最近在寫LSTM和DeepFM的時(shí)候,遇到了一個(gè)問題:樣本的長度不一樣。對不定長序列的一種預(yù)處理方法是,首先對數(shù)據(jù)進(jìn)行padding補(bǔ)0,然后引入keras的Masking層,它能自動對0值進(jìn)行過濾。
問題在于keras的某些層不支持Masking層處理過的輸入數(shù)據(jù),例如Flatten、AveragePooling1D等等,而其中meanpooling是我需要的一個(gè)運(yùn)算。例如LSTM對每一個(gè)序列的輸出長度都等于該序列的長度,那么均值運(yùn)算就只應(yīng)該除以序列長度,而不是padding后的最長長度。
例如下面這個(gè) 3x4 大小的張量,經(jīng)過補(bǔ)零padding的。我希望做axis=1的meanpooling,則第一行應(yīng)該是 (10+20)/2,第二行應(yīng)該是 (10+20+30)/3,第三行應(yīng)該是 (10+20+30+40)/4。

Keras如何自定義層
在 Keras2.0 版本中(如果你使用的是舊版本請更新),自定義一個(gè)層的方法參考這里。具體地,你只要實(shí)現(xiàn)三個(gè)方法即可。
build(input_shape) : 這是你定義層參數(shù)的地方。這個(gè)方法必須設(shè)self.built = True,可以通過調(diào)用super([Layer], self).build()完成。如果這個(gè)層沒有需要訓(xùn)練的參數(shù),可以不定義。
call(x) : 這里是編寫層的功能邏輯的地方。你只需要關(guān)注傳入call的第一個(gè)參數(shù):輸入張量,除非你希望你的層支持masking。
compute_output_shape(input_shape) : 如果你的層更改了輸入張量的形狀,你應(yīng)該在這里定義形狀變化的邏輯,這讓Keras能夠自動推斷各層的形狀。
下面是一個(gè)簡單的例子:
from keras import backend as K from keras.engine.topology import Layer import numpy as np class MyLayer(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super(MyLayer, self).__init__(**kwargs) def build(self, input_shape): # Create a trainable weight variable for this layer. self.kernel = self.add_weight(name='kernel', shape=(input_shape[1], self.output_dim), initializer='uniform', trainable=True) super(MyLayer, self).build(input_shape) # Be sure to call this somewhere! def call(self, x): return K.dot(x, self.kernel) def compute_output_shape(self, input_shape): return (input_shape[0], self.output_dim)
Keras自定義層如何允許masking
觀察了一些支持masking的層,發(fā)現(xiàn)他們對masking的支持體現(xiàn)在兩方面。
在 __init__ 方法中設(shè)置 supports_masking=True。
實(shí)現(xiàn)一個(gè)compute_mask方法,用于將mask傳到下一層。
部分層會在call中調(diào)用傳入的mask。
自定義實(shí)現(xiàn)帶masking的meanpooling
假設(shè)輸入是3d的。首先,在__init__方法中設(shè)置self.supports_masking = True,然后在call中實(shí)現(xiàn)相應(yīng)的計(jì)算。
from keras import backend as K from keras.engine.topology import Layer import tensorflow as tf class MyMeanPool(Layer): def __init__(self, axis, **kwargs): self.supports_masking = True self.axis = axis super(MyMeanPool, self).__init__(**kwargs) def compute_mask(self, input, input_mask=None): # need not to pass the mask to next layers return None def call(self, x, mask=None): if mask is not None: mask = K.repeat(mask, x.shape[-1]) mask = tf.transpose(mask, [0,2,1]) mask = K.cast(mask, K.floatx()) x = x * mask return K.sum(x, axis=self.axis) / K.sum(mask, axis=self.axis) else: return K.mean(x, axis=self.axis) def compute_output_shape(self, input_shape): output_shape = [] for i in range(len(input_shape)): if i!=self.axis: output_shape.append(input_shape[i]) return tuple(output_shape)
使用舉例:
from keras.layers import Input, Masking from keras.models import Model from MyMeanPooling import MyMeanPool data = [[[10,10],[0, 0 ],[0, 0 ],[0, 0 ]], [[10,10],[20,20],[0, 0 ],[0, 0 ]], [[10,10],[20,20],[30,30],[0, 0 ]], [[10,10],[20,20],[30,30],[40,40]]] A = Input(shape=[4,2]) # None * 4 * 2 mA = Masking()(A) out = MyMeanPool(axis=1)(mA) model = Model(inputs=[A], outputs=[out]) print model.summary() print model.predict(data)
結(jié)果如下,每一行對應(yīng)一個(gè)樣本的結(jié)果,例如第一個(gè)樣本只有第一個(gè)時(shí)刻有值,輸出結(jié)果是[10. 10. ],是正確的。
[[10. 10.] [15. 15.] [20. 20.] [25. 25.]]
在DeepFM中,每個(gè)樣本都是由ID構(gòu)成的,多值field往往會導(dǎo)致樣本長度不一的情況,例如interest這樣的field,同一個(gè)樣本可能在該field中有多項(xiàng)取值,畢竟每個(gè)人的興趣點(diǎn)不止一項(xiàng)。
采取padding的方法將每個(gè)field的特征補(bǔ)長到最長的長度,則數(shù)據(jù)尺寸是 [batch_size, max_timestep],經(jīng)過Embedding為每個(gè)樣本的每個(gè)特征ID配一個(gè)latent vector,數(shù)據(jù)尺寸將變?yōu)?[batch_size, max_timestep,latent_dim]。
我們希望每一個(gè)field的Embedding之后的尺寸為[batch_size, latent_dim],然后進(jìn)行concat操作橫向拼接,所以這里就可以使用自定義的MeanPool層了。希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
學(xué)習(xí)createTrackbar的使用方法及步驟
這篇文章主要為大家介紹了學(xué)習(xí)createTrackbar的使用方法及步驟,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步2021-10-10
PyCharm+PySpark遠(yuǎn)程調(diào)試的環(huán)境配置的方法
今天小編就為大家分享一篇PyCharm+PySpark遠(yuǎn)程調(diào)試的環(huán)境配置的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-11-11
python使用正則表達(dá)式分析網(wǎng)頁中的圖片并進(jìn)行替換的方法
這篇文章主要介紹了python使用正則表達(dá)式分析網(wǎng)頁中的圖片并進(jìn)行替換的方法,涉及Python使用正則表達(dá)式的技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-03-03
python通過pil將圖片轉(zhuǎn)換成黑白效果的方法
這篇文章主要介紹了python通過pil將圖片轉(zhuǎn)換成黑白效果的方法,實(shí)例分析了Python中pil庫的使用技巧,需要的朋友可以參考下2015-03-03
使用python數(shù)據(jù)清洗代碼實(shí)例
這篇文章主要介紹了使用python數(shù)據(jù)清洗代碼實(shí)例,分享一下近期用python做數(shù)據(jù)清洗匯總的相關(guān)代碼,這里我們用到的python包有pandas、numpy、os等,需要的朋友可以參考下2023-07-07
python pandas 時(shí)間日期的處理實(shí)現(xiàn)
這篇文章主要介紹了python pandas 時(shí)間日期的處理實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07
pytorch 圖像中的數(shù)據(jù)預(yù)處理和批標(biāo)準(zhǔn)化實(shí)例
今天小編就為大家分享一篇pytorch 圖像中的數(shù)據(jù)預(yù)處理和批標(biāo)準(zhǔn)化實(shí)例,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01

