Tensorflow實(shí)現(xiàn)在訓(xùn)練好的模型上進(jìn)行測(cè)試
Tensorflow可以使用訓(xùn)練好的模型對(duì)新的數(shù)據(jù)進(jìn)行測(cè)試,有兩種方法:第一種方法是調(diào)用模型和訓(xùn)練在同一個(gè)py文件中,中情況比較簡(jiǎn)單;第二種是訓(xùn)練過(guò)程和調(diào)用模型過(guò)程分別在兩個(gè)py文件中。本文將講解第二種方法。
模型的保存
tensorflow提供可保存訓(xùn)練模型的接口,使用起來(lái)也不是很難,直接上代碼講解:
#網(wǎng)絡(luò)結(jié)構(gòu)
w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)
x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#損失函數(shù)與優(yōu)化函數(shù)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
saver.save(sess,"save/model.ckpt")
train_step.run({x: train_x, y_: train_y})
以上代碼就完成了模型的保存,值得注意的是下面這行代碼
tf.add_to_collection('network-output', y)
這行代碼保存了神經(jīng)網(wǎng)絡(luò)的輸出,這個(gè)在后面使用導(dǎo)入模型過(guò)程中起到關(guān)鍵作用。
模型的導(dǎo)入
模型訓(xùn)練并保存后就可以導(dǎo)入來(lái)評(píng)估模型在測(cè)試集上的表現(xiàn),網(wǎng)上很多文章只用簡(jiǎn)單的四則運(yùn)算來(lái)做例子,讓人看的頭大。還是先上代碼:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./model.ckpt.meta')
saver.restore(sess, './model.ckpt')# .data文件
pred = tf.get_collection('network-output')[0]
graph = tf.get_default_graph()
x = graph.get_operation_by_name('x').outputs[0]
y_ = graph.get_operation_by_name('y_').outputs[0]
y = sess.run(pred, feed_dict={x: test_x, y_: test_y})
講解一下關(guān)鍵的代碼,首先是pred = tf.get_collection('pred_network')[0],這行代碼獲得訓(xùn)練過(guò)程中網(wǎng)絡(luò)輸出的“接口”,簡(jiǎn)單理解就是,通過(guò)tf.get_collection() 這個(gè)方法獲取了整個(gè)網(wǎng)絡(luò)結(jié)構(gòu)。獲得網(wǎng)絡(luò)結(jié)構(gòu)后我們就需要喂它對(duì)應(yīng)的數(shù)據(jù)y = sess.run(pred, feed_dict={x: test_x, y_: test_y}) 在訓(xùn)練過(guò)程中我們的輸入是
x = tf.placeholder(tf.float32, [None, in_units], name='x') y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
因此導(dǎo)入模型后所需的輸入也要與之對(duì)應(yīng)可使用以下代碼獲得:
x = graph.get_operation_by_name('x').outputs[0]
y_ = graph.get_operation_by_name('y_').outputs[0]
使用模型的最后一步就是輸入測(cè)試集,然后按照訓(xùn)練好的網(wǎng)絡(luò)進(jìn)行評(píng)估
sess.run(pred, feed_dict={x: test_x, y_: test_y})
理解下這行代碼,sess.run() 的函數(shù)原型為
run(fetches, feed_dict=None, options=None, run_metadata=None)
Tensorflow對(duì) feed_dict 執(zhí)行fetches操作,因此在導(dǎo)入模型后的運(yùn)算就是,按照訓(xùn)練的網(wǎng)絡(luò)計(jì)算測(cè)試輸入的數(shù)據(jù)。
以上這篇Tensorflow實(shí)現(xiàn)在訓(xùn)練好的模型上進(jìn)行測(cè)試就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Django異步任務(wù)線程池實(shí)現(xiàn)原理
這篇文章主要介紹了Django異步任務(wù)線程池實(shí)現(xiàn)原理,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-12-12
Python yield生成器和return對(duì)比代碼實(shí)例
這篇文章主要介紹了Python yield生成器和return對(duì)比代碼實(shí)例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-04-04
Python 3.8正式發(fā)布,來(lái)嘗鮮這些新特性吧
今天 Python3.8 發(fā)布啦,它是 Python2 終結(jié)前最后一個(gè)大版本,我們一起看看這個(gè)版本都添加了那些新功能和特性2019-10-10
Python實(shí)現(xiàn)ElGamal加密算法的示例代碼
ElGamal加密算法是一個(gè)基于迪菲-赫爾曼密鑰交換的非對(duì)稱加密算法。這篇文章通過(guò)示例代碼給大家介紹Python實(shí)現(xiàn)ElGamal加密算法的相關(guān)知識(shí),感興趣的朋友一起看看吧2020-06-06
Centos7 Python3下安裝scrapy的詳細(xì)步驟
這篇文章主要介紹了Centos7 Python3下安裝scrapy的詳細(xì)步驟,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-03-03
Django websocket原理及功能實(shí)現(xiàn)代碼
這篇文章主要介紹了Django websocket原理及功能實(shí)現(xiàn)代碼,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-11-11
Python制作簡(jiǎn)單的網(wǎng)頁(yè)爬蟲
自己寫的一個(gè)爬蟲,模仿了python核心編程書里的程序,有詳細(xì)的注釋。 是我一個(gè)理解學(xué)習(xí)的過(guò)程吧。 有需要的小伙伴可以參考下2015-11-11

