將TensorFlow的模型網(wǎng)絡(luò)導(dǎo)出為單個(gè)文件的方法
有時(shí)候,我們需要將TensorFlow的模型導(dǎo)出為單個(gè)文件(同時(shí)包含模型架構(gòu)定義與權(quán)重),方便在其他地方使用(如在c++中部署網(wǎng)絡(luò))。利用tf.train.write_graph()默認(rèn)情況下只導(dǎo)出了網(wǎng)絡(luò)的定義(沒(méi)有權(quán)重),而利用tf.train.Saver().save()導(dǎo)出的文件graph_def與權(quán)重是分離的,因此需要采用別的方法。
我們知道,graph_def文件中沒(méi)有包含網(wǎng)絡(luò)中的Variable值(通常情況存儲(chǔ)了權(quán)重),但是卻包含了constant值,所以如果我們能把Variable轉(zhuǎn)換為constant,即可達(dá)到使用一個(gè)文件同時(shí)存儲(chǔ)網(wǎng)絡(luò)架構(gòu)與權(quán)重的目標(biāo)。
我們可以采用以下方式凍結(jié)權(quán)重并保存網(wǎng)絡(luò):
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants # 構(gòu)造網(wǎng)絡(luò) a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') # 一定要給輸出tensor取一個(gè)名字?。? output = tf.add(a, b, name='out') # 轉(zhuǎn)換Variable為constant,并將網(wǎng)絡(luò)寫(xiě)入到文件 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 這里需要填入輸出tensor的名字 graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
當(dāng)恢復(fù)網(wǎng)絡(luò)時(shí),可以使用如下方式:
import tensorflow as tf
with tf.Session() as sess:
with open('./graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def, return_elements=['out:0'])
print(sess.run(output))
輸出結(jié)果為:
[array([[ 7.],
[ 8.]], dtype=float32)]
可以看到之前的權(quán)重確實(shí)保存了下來(lái)!!
問(wèn)題來(lái)了,我們的網(wǎng)絡(luò)需要能有一個(gè)輸入自定義數(shù)據(jù)的接口?。〔蝗贿@玩意有什么用。。別急,當(dāng)然有辦法。
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') input_tensor = tf.placeholder(tf.float32, name='input') output = tf.add((a+b), input_tensor, name='out') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
用上述代碼重新保存網(wǎng)絡(luò)至graph.pb,這次我們有了一個(gè)輸入placeholder,下面來(lái)看看怎么恢復(fù)網(wǎng)絡(luò)并輸入自定義數(shù)據(jù)。
import tensorflow as tf
with tf.Session() as sess:
with open('./graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a')
print(sess.run(output))
輸出結(jié)果為:
[array([[ 11.],
[ 12.]], dtype=float32)]
可以看到結(jié)果沒(méi)有問(wèn)題,當(dāng)然在input_map那里可以替換為新的自定義的placeholder,如下所示:
import tensorflow as tf
new_input = tf.placeholder(tf.float32, shape=())
with tf.Session() as sess:
with open('./graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a')
print(sess.run(output, feed_dict={new_input:4}))
看看輸出,同樣沒(méi)有問(wèn)題。
[array([[ 11.],
[ 12.]], dtype=float32)]
另外需要說(shuō)明的一點(diǎn)是,在利用tf.train.write_graph寫(xiě)網(wǎng)絡(luò)架構(gòu)的時(shí)候,如果令as_text=True了,則在導(dǎo)入網(wǎng)絡(luò)的時(shí)候,需要做一點(diǎn)小修改。
import tensorflow as tf
from google.protobuf import text_format
with tf.Session() as sess:
# 不使用'rb'模式
with open('./graph.pb', 'r') as f:
graph_def = tf.GraphDef()
# 不使用graph_def.ParseFromString(f.read())
text_format.Merge(f.read(), graph_def)
output = tf.import_graph_def(graph_def, return_elements=['out:0'])
print(sess.run(output))
參考資料
Is there an example on how to generate protobuf files holding trained Tensorflow graphs
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
- TensorFlow模型保存/載入的兩種方法
- python使用tensorflow保存、加載和使用模型的方法
- 淺談Tensorflow模型的保存與恢復(fù)加載
- TensorFlow模型保存和提取的方法
- 利用TensorFlow訓(xùn)練簡(jiǎn)單的二分類(lèi)神經(jīng)網(wǎng)絡(luò)模型的方法
- TensorFlow入門(mén)使用 tf.train.Saver()保存模型
- TensorFlow 模型載入方法匯總(小結(jié))
- TensorFlow實(shí)現(xiàn)Softmax回歸模型
- TensorFlow實(shí)現(xiàn)MLP多層感知機(jī)模型
- TensorFlow實(shí)現(xiàn)模型評(píng)估
相關(guān)文章
如何將numpy二維數(shù)組中的np.nan值替換為指定的值
這篇文章主要介紹了將numpy二維數(shù)組中的np.nan值替換為指定的值操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-05-05
Python編寫(xiě)一個(gè)趣味問(wèn)答小游戲
隨著六一兒童節(jié)的到來(lái),我們可以為孩子們編寫(xiě)一個(gè)有趣的小游戲,讓他們?cè)谟螒蛑袑W(xué)習(xí)有關(guān)六一兒童節(jié)的知識(shí)。本文將介紹如何用Python編寫(xiě)一個(gè)六一兒童節(jié)問(wèn)答小游戲及趣味比賽,需要的可以參考一下2023-06-06
Python文件基本操作open函數(shù)應(yīng)用與示例詳解
這篇文章主要為大家介紹了Python文件基本操作open函數(shù)應(yīng)用與示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-12-12
Pytorch中關(guān)于nn.Conv2d()參數(shù)的使用
這篇文章主要介紹了Pytorch中關(guān)于nn.Conv2d()參數(shù)的使用,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-06-06
使用python實(shí)現(xiàn)下載我們想聽(tīng)的歌曲,速度超快
這篇文章主要介紹了使用python實(shí)現(xiàn)下載我們想聽(tīng)的歌曲,速度超快,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-07-07

