將keras的h5模型轉(zhuǎn)換為tensorflow的pb模型操作
背景:目前keras框架使用簡(jiǎn)單,很容易上手,深得廣大算法工程師的喜愛,但是當(dāng)部署到客戶端時(shí),可能會(huì)出現(xiàn)各種各樣的bug,甚至不支持使用keras,本文來解決的是將keras的h5模型轉(zhuǎn)換為客戶端常用的tensorflow的pb模型并使用tensorflow加載pb模型。
h5_to_pb.py
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
#路徑參數(shù)
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#轉(zhuǎn)換函數(shù)
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
if osp.exists(output_dir) == False:
os.mkdir(output_dir)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i],out_prefix + str(i + 1))
sess = K.get_session()
from tensorflow.python.framework import graph_util,graph_io
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
if log_tensorboard:
from tensorflow.python.tools import import_pb_to_tensorboard
import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#輸出路徑
output_dir = osp.join(os.getcwd(),"trans_model")
#加載模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')
將轉(zhuǎn)換成的pb模型進(jìn)行加載
load_pb.py
import tensorflow as tf
from tensorflow.python.platform import gfile
def load_pb(pb_file_path):
sess = tf.Session()
with gfile.FastGFile(pb_file_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
print(sess.run('b:0'))
#輸入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
#輸出
op = sess.graph.get_tensor_by_name('op_to_store:0')
#預(yù)測(cè)結(jié)果
ret = sess.run(op, {input_x: 3, input_y: 4})
print(ret)
補(bǔ)充知識(shí):h5模型轉(zhuǎn)化為pb模型,代碼及排坑
我是在實(shí)際工程中要用到tensorflow訓(xùn)練的pb模型,但是訓(xùn)練的代碼是用keras寫的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函數(shù)。
附上h5_to_pb.py(python3)
#*-coding:utf-8-*
"""
將keras的.h5的模型文件,轉(zhuǎn)換成TensorFlow的pb文件
"""
# ==========================================================
from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential
def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
""".h5模型文件轉(zhuǎn)換成pb模型文件
Argument:
h5_model: str
.h5模型文件
output_dir: str
pb模型文件保存路徑
model_name: str
pb模型文件名稱
out_prefix: str
根據(jù)訓(xùn)練,需要修改
log_tensorboard: bool
是否生成日志文件
Return:
pb模型文件
"""
if os.path.exists(output_dir) == False:
os.mkdir(output_dir)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i], out_prefix + str(i + 1))
sess = backend.get_session()
from tensorflow.python.framework import graph_util, graph_io
# 寫入pb模型文件
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
# 輸出日志文件
if log_tensorboard:
from tensorflow.python.tools import import_pb_to_tensorboard
import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)
if __name__ == '__main__':
# .h模型文件路徑參數(shù)
input_path = 'D:/CSP'
weight_file = 'xingren.h5'
weight_file_path = os.path.join(input_path, weight_file)
output_graph_name = weight_file[:-3] + '.pb'
# pb模型文件輸出輸出路徑
output_dir = osp.join(os.getcwd(),"trans_model")
#model.save(xingren.h5)
# 加載模型
#h5_model = Sequential()
h5_model = load_model(weight_file_path)
#h5_model.save(weight_file_path)
#h5_model.save('xingren.h5')
h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
print ('Finished')
在運(yùn)行的時(shí)候遇到了下面問題:

原因:我們訓(xùn)練模型的時(shí)候用save_weights函數(shù)保存模型,但是這個(gè)函數(shù)只保存了權(quán)重文件,并沒有又保存模型的參數(shù)。要把save_weights改為save。
下邊是兩個(gè)函數(shù)介紹:
save()保存的模型結(jié)果,它既保持了模型的圖結(jié)構(gòu),又保存了模型的參數(shù)。
save_weights()保存的模型結(jié)果,它只保存了模型的參數(shù),但并沒有保存模型的圖結(jié)構(gòu)
以上這篇將keras的h5模型轉(zhuǎn)換為tensorflow的pb模型操作就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
django settings.py配置文件的詳細(xì)介紹
本文主要介紹了django settings.py配置文件的詳細(xì)介紹,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2022-04-04
jupyter notebook運(yùn)行代碼沒反應(yīng)且in[ ]沒有*
本文主要介紹了jupyter notebook運(yùn)行代碼沒反應(yīng)且in[ ]沒有*,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-03-03
Python算法練習(xí)之二分查找算法的實(shí)現(xiàn)
二分查找也稱折半查找(Binary Search),它是一種效率較高的查找方法。本文將介紹python如何實(shí)現(xiàn)二分查找算法,幫助大家更好的理解和使用python,感興趣的朋友可以了解下2022-06-06
Python實(shí)現(xiàn)打磚塊小游戲代碼實(shí)例
這篇文章主要介紹了Python打磚塊小游戲,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-05-05
Python爬蟲實(shí)戰(zhàn)之12306搶票開源
今天小編就為大家分享一篇關(guān)于Python爬蟲實(shí)戰(zhàn)之12306搶票開源,小編覺得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來看看吧2019-01-01
Python?sklearn庫(kù)中的隨機(jī)森林模型詳解
本文主要說明?Python?的?sklearn?庫(kù)中的隨機(jī)森林模型的常用接口、屬性以及參數(shù)調(diào)優(yōu)說明,需要讀者或多或少了解過sklearn庫(kù)和一些基本的機(jī)器學(xué)習(xí)知識(shí)2023-08-08
Python NumPy中的隨機(jī)數(shù)及ufuncs函數(shù)使用示例詳解
這篇文章主要介紹了Python NumPy中的隨機(jī)數(shù)及ufuncs函數(shù)使用,ufunc函數(shù)是NumPy中的一種通用函數(shù),它可以對(duì)數(shù)組中的每個(gè)元素進(jìn)行操作,而不需要使用循環(huán)語(yǔ)句,文中通過示例代碼介紹的非常詳細(xì),需要的朋友們下面隨著小編來一起學(xué)習(xí)吧2023-05-05

