對(duì)比分析BN和dropout在預(yù)測(cè)和訓(xùn)練時(shí)區(qū)別
Batch Normalization和Dropout是深度學(xué)習(xí)模型中常用的結(jié)構(gòu)。
但BN和dropout在訓(xùn)練和測(cè)試時(shí)使用卻不相同。
Batch Normalization
BN在訓(xùn)練時(shí)是在每個(gè)batch上計(jì)算均值和方差來進(jìn)行歸一化,每個(gè)batch的樣本量都不大,所以每次計(jì)算出來的均值和方差就存在差異。預(yù)測(cè)時(shí)一般傳入一個(gè)樣本,所以不存在歸一化,其次哪怕是預(yù)測(cè)一個(gè)batch,但batch計(jì)算出來的均值和方差是偏離總體樣本的,所以通常是通過滑動(dòng)平均結(jié)合訓(xùn)練時(shí)所有batch的均值和方差來得到一個(gè)總體均值和方差。
以tensorflow代碼實(shí)現(xiàn)為例:
def bn_layer(self, inputs, training, name='bn', moving_decay=0.9, eps=1e-5):
# 獲取輸入維度并判斷是否匹配卷積層(4)或者全連接層(2)
shape = inputs.shape
param_shape = shape[-1]
with tf.variable_scope(name):
# 聲明BN中唯一需要學(xué)習(xí)的兩個(gè)參數(shù),y=gamma*x+beta
gamma = tf.get_variable('gamma', param_shape, initializer=tf.constant_initializer(1))
beta = tf.get_variable('beat', param_shape, initializer=tf.constant_initializer(0))
# 計(jì)算當(dāng)前整個(gè)batch的均值與方差
axes = list(range(len(shape)-1))
batch_mean, batch_var = tf.nn.moments(inputs , axes, name='moments')
# 采用滑動(dòng)平均更新均值與方差
ema = tf.train.ExponentialMovingAverage(moving_decay, name="ema")
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
# 訓(xùn)練時(shí),更新均值與方差,測(cè)試時(shí)使用之前最后一次保存的均值與方差
mean, var = tf.cond(tf.equal(training,True), mean_var_with_update,
lambda:(ema.average(batch_mean), ema.average(batch_var)))
# 最后執(zhí)行batch normalization
return tf.nn.batch_normalization(inputs ,mean, var, beta, gamma, eps)training參數(shù)可以通過tf.placeholder傳入,這樣就可以控制訓(xùn)練和預(yù)測(cè)時(shí)training的值。
self.training = tf.placeholder(tf.bool, name="training")
Dropout
Dropout在訓(xùn)練時(shí)會(huì)隨機(jī)丟棄一些神經(jīng)元,這樣會(huì)導(dǎo)致輸出的結(jié)果變小。而預(yù)測(cè)時(shí)往往關(guān)閉dropout,保證預(yù)測(cè)結(jié)果的一致性(不關(guān)閉dropout可能同一個(gè)輸入會(huì)得到不同的輸出,不過輸出會(huì)服從某一分布。另外有些情況下可以不關(guān)閉dropout,比如文本生成下,不關(guān)閉會(huì)增大輸出的多樣性)。
為了對(duì)齊Dropout訓(xùn)練和預(yù)測(cè)的結(jié)果,通常有兩種做法,假設(shè)dropout rate = 0.2。一種是訓(xùn)練時(shí)不做處理,預(yù)測(cè)時(shí)輸出乘以(1 - dropout rate)。另一種是訓(xùn)練時(shí)留下的神經(jīng)元除以(1 - dropout rate),預(yù)測(cè)時(shí)不做處理。以tensorflow為例。
x = tf.nn.dropout(x, self.keep_prob)
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
tf.nn.dropout就是采用了第二種做法,訓(xùn)練時(shí)除以(1 - dropout rate),源碼如下:
binary_tensor = math_ops.floor(random_tensor) ret = math_ops.div(x, keep_prob) * binary_tensor if not context.executing_eagerly(): ret.set_shape(x.get_shape()) return ret
binary_tensor就是一個(gè)mask tensor,即里面的值由0或1組成。keep_prob = 1 - dropout rate。
以上就是對(duì)比分析BN和dropout在預(yù)測(cè)和訓(xùn)練時(shí)區(qū)別的詳細(xì)內(nèi)容,更多關(guān)于BN與dropout預(yù)測(cè)訓(xùn)練對(duì)比的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python中的XML庫(kù)4Suite Server的介紹
這篇文章主要介紹了Python中的XML庫(kù)4Suite Server,來自于IBM官方網(wǎng)站,需要的朋友可以參考下2015-04-04
簡(jiǎn)單實(shí)現(xiàn)Python爬取網(wǎng)絡(luò)圖片
這篇文章主要教大家如何簡(jiǎn)單實(shí)現(xiàn)Python爬取網(wǎng)絡(luò)圖片,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-04-04
WxPython開發(fā)之列表數(shù)據(jù)的自定義打印處理
這篇文章主要為大家詳細(xì)介紹了如何利用WxPython內(nèi)置的打印數(shù)據(jù)組件實(shí)現(xiàn)列表數(shù)據(jù)的自定義打印處理,以及對(duì)記錄進(jìn)行分頁(yè)等常規(guī)操作,需要的可以參考下2025-03-03
淺談Django QuerySet對(duì)象(模型.objects)的常用方法
這篇文章主要介紹了淺談Django QuerySet對(duì)象(模型.objects)的常用方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-03-03
Python學(xué)習(xí)筆記之字典,元組,布爾類型和讀寫文件
這篇文章主要為大家詳細(xì)介紹了Python的字典,元組,布爾類型和讀寫文件,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來幫助2022-02-02

