淺談keras中Dropout在預(yù)測過程中是否仍要起作用
因?yàn)樾枰貙懹?xùn)練好的keras模型,雖然只具備預(yù)測功能,但是發(fā)現(xiàn)還是有很多坑要趟過。其中Dropout這個(gè)坑,我記憶猶新。
一開始,我以為預(yù)測時(shí)要保持和訓(xùn)練時(shí)完全一樣的網(wǎng)絡(luò)結(jié)構(gòu),也就是預(yù)測時(shí)用的網(wǎng)絡(luò)也是有丟棄的網(wǎng)絡(luò)節(jié)點(diǎn),但是這樣想就掉進(jìn)了一個(gè)大坑!因?yàn)闊o法通過已經(jīng)訓(xùn)練好的模型,來獲取其訓(xùn)練時(shí)隨機(jī)丟棄的網(wǎng)絡(luò)節(jié)點(diǎn)是那些,這本身就根本不可能。
更重要的是:我發(fā)現(xiàn)每一個(gè)迭代周期丟棄的神經(jīng)元也不完全一樣。
假若迭代500次,網(wǎng)絡(luò)共有1000個(gè)神經(jīng)元, 在第n(1<= n <500)個(gè)迭代周期內(nèi),從1000個(gè)神經(jīng)元里隨機(jī)丟棄了200個(gè)神經(jīng)元,在n+1個(gè)迭代周期內(nèi),會(huì)在這1000個(gè)神經(jīng)元里(不是在剩余得800個(gè))重新隨機(jī)丟棄200個(gè)神經(jīng)元。
訓(xùn)練過程中,使用Dropout,其實(shí)就是對(duì)部分權(quán)重和偏置在某次迭代訓(xùn)練過程中,不參與計(jì)算和更新而已,并不是不再使用這些權(quán)重和偏置了(預(yù)測時(shí),會(huì)使用全部的神經(jīng)元,包括使用訓(xùn)練時(shí)丟棄的神經(jīng)元)。
也就是說在預(yù)測過程中完全沒有Dropout什么事了,他只是在訓(xùn)練時(shí)有用,特別是針對(duì)訓(xùn)練集比較小時(shí)防止過擬合非常有用。
補(bǔ)充知識(shí):TensorFlow直接使用ckpt模型predict不用restore
我就廢話不多說了,大家還是直接看代碼吧~
# -*- coding: utf-8 -*-
# from util import *
import cv2
import numpy as np
import tensorflow as tf
# from tensorflow.python.framework import graph_util
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
image_path = './8760.pgm'
input_checkpoint = './model/xu_spatial_model_1340.ckpt'
sess = tf.Session()
saver = tf.train.import_meta_graph(input_checkpoint + '.meta')
saver.restore(sess, input_checkpoint)
# input:0作為輸入圖像,keep_prob:0作為dropout的參數(shù),測試時(shí)值為1,is_training:0訓(xùn)練參數(shù)
input_image_tensor = sess.graph.get_tensor_by_name("coef_input:0")
is_training = sess.graph.get_tensor_by_name('is_training:0')
batch_size = sess.graph.get_tensor_by_name('batch_size:0')
# 定義輸出的張量名稱
output_tensor_name = sess.graph.get_tensor_by_name("xuNet/logits:0") # xuNet/Logits/logits
image = cv2.imread(image_path, 0)
# 讀取測試圖片
out = sess.run(output_tensor_name, feed_dict={input_image_tensor: np.reshape(image, (1, 512, 512, 1)),
is_training: False,
batch_size: 1})
print(out)
ckpt模型中的所有節(jié)點(diǎn)名稱,可以這樣查看
[n.name for n in tf.get_default_graph().as_graph_def().node]
以上這篇淺談keras中Dropout在預(yù)測過程中是否仍要起作用就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
pytest參數(shù)化:@pytest.mark.parametrize詳解
pytest.mark.parametrize裝飾器能夠?qū)y試函數(shù)進(jìn)行參數(shù)化處理,使得一個(gè)測試函數(shù)可以用多組數(shù)據(jù)執(zhí)行多次,這有助于檢查不同輸入下的期望輸出是否匹配,提高測試的效率和覆蓋率,裝飾器可以應(yīng)用于函數(shù)、模塊或類,支持多個(gè)裝飾器組合使用,增強(qiáng)測試的靈活性和綜合性2024-10-10
python使用Apriori算法進(jìn)行關(guān)聯(lián)性解析
這篇文章主要為大家分享了python使用Apriori算法進(jìn)行關(guān)聯(lián)性的解析,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2017-12-12
Python獲取系統(tǒng)所有進(jìn)程PID及進(jìn)程名稱的方法示例
這篇文章主要介紹了Python獲取系統(tǒng)所有進(jìn)程PID及進(jìn)程名稱的方法,涉及Python使用psutil對(duì)系統(tǒng)進(jìn)程進(jìn)行操作的相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2018-05-05
詳解Django框架中用戶的登錄和退出的實(shí)現(xiàn)
這篇文章主要介紹了詳解Django框架中用戶的登錄和退出的實(shí)現(xiàn),Django是重多Python人氣框架中最為知名的一個(gè),需要的朋友可以參考下2015-07-07
python使用os.listdir和os.walk獲得文件的路徑的方法
本篇文章主要介紹了python使用os.listdir和os.walk獲得文件的路徑的方法,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-12-12
python中數(shù)組array和列表list的基本用法及區(qū)別解析
大家都知道數(shù)組array是同類型數(shù)據(jù)的有限集合,列表list是一系列按特定順序排列的元素組成,可以將任何數(shù)據(jù)放入列表,且其中元素之間沒有任何關(guān)系,本文介紹python中數(shù)組array和列表list的基本用法及區(qū)別,感興趣的朋友一起看看吧2022-05-05

