Tensorflow 使用pb文件保存(恢復(fù))模型計算圖和參數(shù)實例詳解
一、保存:
graph_util.convert_variables_to_constants 可以把當(dāng)前session的計算圖串行化成一個字節(jié)流(二進(jìn)制),這個函數(shù)包含三個參數(shù):參數(shù)1:當(dāng)前活動的session,它含有各變量
參數(shù)2:GraphDef 對象,它描述了計算網(wǎng)絡(luò)
參數(shù)3:Graph圖中需要輸出的節(jié)點的名稱的列表
返回值:精簡版的GraphDef 對象,包含了原始輸入GraphDef和session的網(wǎng)絡(luò)和變量信息,它的成員函數(shù)SerializeToString()可以把這些信息串行化為字節(jié)流,然后寫入文件里:
constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] ) with open( pbName, mode='wb') as f: f.write(constant_graph.SerializeToString())
需要指出的是,如果原始張量(包含在參數(shù)1和參數(shù)2中的組成部分)不參與參數(shù)3指定的輸出節(jié)點列表所指定的張量計算的話,這些張量將不會存在返回的GraphDef對象里,也不會被串行化寫入pb文件。
二、恢復(fù):
恢復(fù)時,創(chuàng)建一個GraphDef,然后從上述的文件里加載進(jìn)來,接著輸入到當(dāng)前的session:
graph0 = tf.GraphDef()
with open( pbName, mode='rb') as f:
graph0.ParseFromString( f.read() )
tf.import_graph_def( graph0 , name = '' )
三、代碼:
import tensorflow as tf
from tensorflow.python.framework import graph_util
pbName = 'graphA.pb'
def graphCreate() :
with tf.Session() as sess :
var1 = tf.placeholder ( tf.int32 , name='var1' )
var2 = tf.Variable( 20 , name='var2' )#實參name='var2'指定了操作名,該操作返回的張量名是在
#'var2'后面:0 ,即var2:0 是返回的張量名,也就是說變量
# var2的名稱是'var2:0'
var3 = tf.Variable( 30 , name='var3' )
var4 = tf.Variable( 40 , name='var4' )
var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
sum = tf.Variable( 4, name='sum' )
sum = tf.add ( var1 , var2, name = 'var1_var2' )
sum = tf.add( sum , var3 , name='sum_var3' )
sumOps = tf.add( sum , var4 , name='sum_operation' )
oper = tf.get_default_graph().get_operations()
with open( 'operation.csv','wt' ) as f:
s = 'name,type,output\n'
f.write( s )
for o in oper:
s = o.name
s += ','+ o.type
inp = o.inputs
oup = o.outputs
for iip in inp :
s #s += ','+ str(iip)
for iop in oup :
s += ',' + str(iop)
s += '\n'
f.write( s )
for var in tf.global_variables():
print('variable=> ' , var.name) #張量是tf.Variable/tf.Add之類操作的結(jié)果,
#張量的名字使用操作名加:0來表示
init = tf.global_variables_initializer()
sess.run( init )
sess.run( var4op )
print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())
def graphGet() :
print("start get:" )
with tf.Graph().as_default():
graph0 = tf.GraphDef()
with open( pbName, mode='rb') as f:
graph0.ParseFromString( f.read() )
tf.import_graph_def( graph0 , name = '' )
with tf.Session() as sess :
init = tf.global_variables_initializer()
sess.run(init)
v1 = sess.graph.get_tensor_by_name('var1:0' )
v2 = sess.graph.get_tensor_by_name('var2:0' )
v3 = sess.graph.get_tensor_by_name('var3:0' )
v4 = sess.graph.get_tensor_by_name('var4:0' )
sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
print('sumTensor is : ' , sumTensor )
print( sess.run( sumTensor , feed_dict={v1:1} ) )
graphCreate()
graphGet()
四、保存pb函數(shù)代碼里的操作名稱/類型/返回的張量:
| operation name | operation type | output | ||
| var1 | Placeholder | Tensor("var1:0" | dtype=int32) | |
| var2/initial_value | Const | Tensor("var2/initial_value:0" | shape=() | dtype=int32) |
| var2 | VariableV2 | Tensor("var2:0" | shape=() | dtype=int32_ref) |
| var2/Assign | Assign | Tensor("var2/Assign:0" | shape=() | dtype=int32_ref) |
| var2/read | Identity | Tensor("var2/read:0" | shape=() | dtype=int32) |
| var3/initial_value | Const | Tensor("var3/initial_value:0" | shape=() | dtype=int32) |
| var3 | VariableV2 | Tensor("var3:0" | shape=() | dtype=int32_ref) |
| var3/Assign | Assign | Tensor("var3/Assign:0" | shape=() | dtype=int32_ref) |
| var3/read | Identity | Tensor("var3/read:0" | shape=() | dtype=int32) |
| var4/initial_value | Const | Tensor("var4/initial_value:0" | shape=() | dtype=int32) |
| var4 | VariableV2 | Tensor("var4:0" | shape=() | dtype=int32_ref) |
| var4/Assign | Assign | Tensor("var4/Assign:0" | shape=() | dtype=int32_ref) |
| var4/read | Identity | Tensor("var4/read:0" | shape=() | dtype=int32) |
| var4op1/value | Const | Tensor("var4op1/value:0" | shape=() | dtype=int32) |
| var4op1 | Assign | Tensor("var4op1:0" | shape=() | dtype=int32_ref) |
| sum/initial_value | Const | Tensor("sum/initial_value:0" | shape=() | dtype=int32) |
| sum | VariableV2 | Tensor("sum:0" | shape=() | dtype=int32_ref) |
| sum/Assign | Assign | Tensor("sum/Assign:0" | shape=() | dtype=int32_ref) |
| sum/read | Identity | Tensor("sum/read:0" | shape=() | dtype=int32) |
| var1_var2 | Add | Tensor("var1_var2:0" | dtype=int32) | |
| sum_var3 | Add | Tensor("sum_var3:0" | dtype=int32) | |
| sum_operation | Add | Tensor("sum_operation:0" | dtype=int32) |
以上這篇Tensorflow 使用pb文件保存(恢復(fù))模型計算圖和參數(shù)實例詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python按照行來讀取txt文件全部內(nèi)容(去除空行處理掉\t,\n后以列表方式返回)
這篇文章主要介紹了python按照行來讀取txt文件全部內(nèi)容 ,去除空行,處理掉\t,\n后,以列表方式返回,本文通過實例代碼給大家介紹的非常詳細(xì),需要的朋友可以參考下2023-06-06
Python基于多線程操作數(shù)據(jù)庫相關(guān)問題分析
這篇文章主要介紹了Python基于多線程操作數(shù)據(jù)庫相關(guān)問題,結(jié)合實例形式分析了Python使用數(shù)據(jù)庫連接池并發(fā)操作數(shù)據(jù)庫避免超時、連接丟失相關(guān)實現(xiàn)技巧,需要的朋友可以參考下2018-07-07
使用 PyTorch-BigGraph 構(gòu)建和部署大規(guī)模圖嵌入的完整步驟
本文深入探討了使用 PyTorch-BigGraph (PBG) 構(gòu)建和部署大規(guī)模圖嵌入的完整流程,涵蓋了從環(huán)境設(shè)置、數(shù)據(jù)準(zhǔn)備、模型配置與訓(xùn)練,到高級優(yōu)化技術(shù)、評估指標(biāo)、部署策略以及實際案例研究等各個方面,感興趣的朋友跟隨小編一起看看吧2024-11-11
Python+PyQt5+MySQL實現(xiàn)天氣管理系統(tǒng)
這篇文章主要為大家詳細(xì)介紹了Python+PyQt5+MySQL實現(xiàn)天氣管理系統(tǒng),文中示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下2020-06-06
python GUI庫圖形界面開發(fā)之PyQt5 UI主線程與耗時線程分離詳細(xì)方法實例
這篇文章主要介紹了python GUI庫圖形界面開發(fā)之PyQt5 UI主線程與耗時線程分離詳細(xì)方法實例,需要的朋友可以參考下2020-02-02
Python創(chuàng)建普通菜單示例【基于win32ui模塊】
這篇文章主要介紹了Python創(chuàng)建普通菜單,結(jié)合實例形式分析了Python基于win32ui模塊創(chuàng)建普通菜單及添加菜單項的相關(guān)操作技巧,并附帶說明了win32ui模塊的安裝命令,需要的朋友可以參考下2018-05-05

