Keras模型轉(zhuǎn)成tensorflow的.pb操作
Keras的.h5模型轉(zhuǎn)成tensorflow的.pb格式模型,方便后期的前端部署。直接上代碼
from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
print('input is :', model.input.name)
print ('output is:', model.output.name)
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))
補(bǔ)充知識(shí):keras h5 model 轉(zhuǎn)換為tflite
在移動(dòng)端的模型,若選擇tensorflow或者keras最基本的就是生成tflite文件,以本文記錄一次轉(zhuǎn)換過(guò)程。
環(huán)境
tensorflow 1.12.0
python 3.6.5
h5 model saved by `model.save('tf.h5')`
直接轉(zhuǎn)換
`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5` output `TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`
先轉(zhuǎn)成pb再轉(zhuǎn)tflite
``` git clone git@github.com:amir-abdi/keras_to_tensorflow.git cd keras_to_tensorflow python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb tflite_convert \ --output_file=tf.tflite \ --graph_def_file=tf.pb \ --input_arrays=convolution2d_1_input \ --output_arrays=dense_3/BiasAdd \ --input_shape=1,3,448,448 ```
參數(shù)說(shuō)明,input_arrays和output_arrays是model的起始輸入變量名和結(jié)束變量名,input_shape是和input_arrays對(duì)應(yīng)
官網(wǎng)是說(shuō)需要用到tenorboard來(lái)查看,一個(gè)比較trick的方法
先執(zhí)行上面的命令,會(huì)報(bào)convolution2d_1_input找不到,在堆棧里面有convert_saved_model.py文件,get_tensors_from_tensor_names()這個(gè)方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 這個(gè)變量下面,再執(zhí)行一遍,會(huì)打印出所有tensor的名字,再根據(jù)自己的模型很容易就能判斷出實(shí)際的name。
以上這篇Keras模型轉(zhuǎn)成tensorflow的.pb操作就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python values()與itervalues()的用法詳解
今天小編就為大家分享一篇Python values()與itervalues()的用法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-11-11
pytorch中的numel函數(shù)用法說(shuō)明
這篇文章主要介紹了pytorch中的numel函數(shù)用法說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2021-05-05
在pycharm中輸入import torch報(bào)錯(cuò)如何解決
這篇文章主要介紹了在pycharm中輸入import torch報(bào)錯(cuò)如何解決問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-01-01
解決py2exe打包后,總是多顯示一個(gè)DOS黑色窗口的問(wèn)題
今天小編就為大家分享一篇解決py2exe打包后,總是多顯示一個(gè)DOS黑色窗口的問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-06-06
Python version 2.7 required, which was not found in the regi
這篇文章主要介紹了安裝PIL庫(kù)時(shí)提示錯(cuò)誤Python version 2.7 required, which was not found in the registry問(wèn)題的解決方法,需要的朋友可以參考下2014-08-08
Python爬取視頻時(shí)長(zhǎng)場(chǎng)景實(shí)踐示例
這篇文章主要為大家介紹了Python獲取視頻時(shí)長(zhǎng)場(chǎng)景實(shí)踐示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-07-07

