tensorflow模型文件(ckpt)轉(zhuǎn)pb文件的方法(不知道輸出節(jié)點(diǎn)名)
網(wǎng)上關(guān)于tensorflow模型文件ckpt格式轉(zhuǎn)pb文件的帖子很多,本人幾乎嘗試了所有方法,最后終于成功了,現(xiàn)總結(jié)如下。方法無(wú)外乎下面兩種:
- 使用tensorflow.python.tools.freeze_graph.freeze_graph
- 使用graph_util.convert_variables_to_constants
1、tensorflow模型的文件解讀
使用tensorflow訓(xùn)練好的模型會(huì)自動(dòng)保存為四個(gè)文件,如下

checkpoint:記錄近幾次訓(xùn)練好的模型結(jié)果(名稱)。
xxx.data-00000-of-00001: 模型的所有變量的值(weights, biases, placeholders,gradients, hyper-parameters etc),也就是模型訓(xùn)練好參數(shù)和其他值。
xxx.index :模型的元數(shù)據(jù),二進(jìn)制或者其他格式,不可直接查看 。是一個(gè)不可變得字符串表,每一個(gè)鍵都是張量的名稱,它的值是一個(gè)序列化的BundleEntryProto。 每個(gè)BundleEntryProto描述張量的元數(shù)據(jù):“數(shù)據(jù)”文件中的哪個(gè)文件包含張量的內(nèi)容,該文件的偏移量,校驗(yàn)和一些輔助數(shù)據(jù)等。
xxx.meta:模型的meta數(shù)據(jù) ,二進(jìn)制或者其他格式,不可直接查看,保存了TensorFlow計(jì)算圖的結(jié)構(gòu)信息,通俗地講就是神經(jīng)網(wǎng)絡(luò)的網(wǎng)絡(luò)結(jié)構(gòu)。
2、最常見(jiàn)的ckpt轉(zhuǎn)pb文件的方法
2、ckpt轉(zhuǎn)pb文件(freeze_graph.freeze_graph)
此種方法嘗試成功,雖然不知道輸出節(jié)點(diǎn)名,但是只要模型代碼還在就可以操作,直接上代碼。
import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from model import network # network是你們自己定義的模型結(jié)構(gòu)(代碼結(jié)構(gòu))
# egs:
# def network(input):
# return tf.layers.softmax(input)
model_path = "model.ckpt-0000" #設(shè)置model的路徑,因新版tensorflow會(huì)生成三個(gè)文件,只需寫到數(shù)字前
def main():
tf.reset_default_graph()
# 設(shè)置輸入網(wǎng)絡(luò)的數(shù)據(jù)維度,根據(jù)訓(xùn)練時(shí)的模型輸入數(shù)據(jù)的維度自行修改
input_node = tf.placeholder(tf.float32, shape=(None, None, 200))
output_node = network(input_node) # 神經(jīng)網(wǎng)絡(luò)的輸出
# 設(shè)置輸出數(shù)據(jù)類型(特別注意,這里必須要跟輸出網(wǎng)絡(luò)參數(shù)的數(shù)據(jù)格式保持一致,不然會(huì)導(dǎo)致模型預(yù)測(cè) 精度或者預(yù)測(cè)能力的丟失)以及重新定義輸出節(jié)點(diǎn)的名字(這樣在后面保存pb文件以及之后使用pb文件時(shí)直接使用重新定義的節(jié)點(diǎn)名字即可)
flow = tf.cast(output_node , tf.float16, 'the_outputs')
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, model_path)
#保存模型圖(結(jié)構(gòu)),為一個(gè)json文件
tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')
#將模型參數(shù)與模型圖結(jié)合,并保存為pb文件
freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'the_outputs','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")
print("done")
if __name__ == '__main__':
main()
2、ckpt轉(zhuǎn)pb文件(graph_util.convert_variables_to_constants)
沒(méi)有成功,因?yàn)椴恢垒敵龉?jié)點(diǎn)的名字,使用該方法保存后的pb文件只有幾十k,無(wú)法使用,寫在這里主要是為了總結(jié)。直接上代碼,代碼里面沒(méi)有的庫(kù)(函數(shù)),按提示自行import。
def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路徑
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt文件狀態(tài)是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路徑
# 指定輸出的節(jié)點(diǎn)名稱,該節(jié)點(diǎn)名稱必須是原模型中存在的節(jié)點(diǎn)
output_node_names = "InceptionV3/Logits/SpatialSqueeze"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 獲得默認(rèn)的圖
input_graph_def = graph.as_graph_def() # 返回一個(gè)序列化的圖代表當(dāng)前的圖
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢復(fù)圖并得到數(shù)據(jù)
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定
sess=sess,
input_graph_def=input_graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多個(gè)輸出節(jié)點(diǎn),以逗號(hào)隔開(kāi)
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化輸出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到當(dāng)前圖有幾個(gè)操作節(jié)點(diǎn)
# for op in graph.get_operations():
# print(op.name, op.values())
if __name__ == '__main__':
# 輸入ckpt模型路徑
input_checkpoint='models/model.ckpt-10000'
# 輸出pb模型的路徑
out_pb_path="models/pb/frozen_model.pb"
# 調(diào)用freeze_graph將ckpt轉(zhuǎn)為pb
freeze_graph(input_checkpoint,out_pb_path)
參考鏈接:
http://www.dhdzp.com/article/185209.htm
http://www.dhdzp.com/article/185206.htm
到此這篇關(guān)于tensorflow模型文件(ckpt)轉(zhuǎn)pb文件(不知道輸出節(jié)點(diǎn)名)的文章就介紹到這了,更多相關(guān)tensorflow ckpt轉(zhuǎn)pb文件內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
利用Python實(shí)現(xiàn)K-Means聚類的方法實(shí)例(案例:用戶分類)
k-means是發(fā)現(xiàn)給定數(shù)據(jù)集的k個(gè)簇的算法,也就是將數(shù)據(jù)集聚合為k類的算法,下面這篇文章主要給大家介紹了關(guān)于利用Python實(shí)現(xiàn)K-Means聚類的相關(guān)資料,需要的朋友可以參考下2022-05-05
Django中實(shí)現(xiàn)點(diǎn)擊圖片鏈接強(qiáng)制直接下載的方法
這篇文章主要介紹了Django中實(shí)現(xiàn)點(diǎn)擊圖片鏈接強(qiáng)制直接下載的方法,涉及Python操作圖片的相關(guān)技巧,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2015-05-05
python高效過(guò)濾出文件夾下指定文件名結(jié)尾的文件實(shí)例
今天小編就為大家分享一篇python高效過(guò)濾出文件夾下指定文件名結(jié)尾的文件實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-10-10
使用python 將圖片復(fù)制到系統(tǒng)剪貼中
今天小編就為大家分享一篇使用python 將圖片復(fù)制到系統(tǒng)剪貼中,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-12-12
Python中導(dǎo)入csv數(shù)據(jù)文件的詳細(xì)示例教程
Python中的csv模塊是一種用于讀取和寫入csv文件的模塊,csv可以用于將數(shù)據(jù)從文件或者其他來(lái)源導(dǎo)入到Python中進(jìn)行分析和處理,在這篇文章中,我們將全面介紹Python中如何導(dǎo)入csv文件,并將從多個(gè)方面進(jìn)行詳細(xì)探討,感興趣的朋友一起看看吧2024-03-03
手把手教你使用Django + Vue.js 快速構(gòu)建項(xiàng)目
本篇將基于Django + Vue.js,手把手教大家快速的實(shí)現(xiàn)一個(gè)前后端分離的Web項(xiàng)目。文中通過(guò)示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-08-08

