tensorflow 2.0模式下訓(xùn)練的模型轉(zhuǎn)成 tf1.x 版本的pb模型實(shí)例
升級(jí)到tf 2.0后, 訓(xùn)練的模型想轉(zhuǎn)成1.x版本的.pb模型, 但之前提供的通過ckpt轉(zhuǎn)pb模型的方法都不可用(因?yàn)楸4娴腸kpt不再有.meta)文件, 嘗試了好久, 終于找到了一個(gè)方法可以迂回轉(zhuǎn)到1.x版本的pb模型.
Note: 本方法首先有些要求需要滿足:
可以拿的到模型的網(wǎng)絡(luò)結(jié)構(gòu)定義源碼
網(wǎng)絡(luò)結(jié)構(gòu)里面的所有操作都是通過tf.keras完成的, 不能出現(xiàn)類似tf.nn 的tensorflow自己的操作符
tf2.0下保存的模型是.h5格式的,并且僅保存了weights, 即通過model.save_weights保存的模型.
在tf1.x的環(huán)境下, 將tf2.0保存的weights轉(zhuǎn)為pb模型:
如果在tf2.0下保存的模型符合上述的三個(gè)定義, 那么這個(gè).h5文件在1.x環(huán)境下其實(shí)是可以直接用的, 因?yàn)槎际峭ㄟ^tf.keras高級(jí)封裝了,2.0版本和1.x版本不存在特別大的區(qū)別,我自己的模型是可以直接用的.
import tensorflow as tf
import os
from nets.efficientNet import *
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# 這個(gè)代碼網(wǎng)上說需要加上, 如果模型里有dropout , bn層的話, 我測(cè)試過加不加結(jié)果都一樣, 保險(xiǎn)起見還是加上吧
tf.keras.backend.set_learning_phase(0)
# 首先是定義你的模型, 這個(gè)需要和tf2.0下一毛一樣
inputs = tf.keras.Input(shape=(224, 224, 3), name='modelInput')
outputs = yourModel(inputs, training=False)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.load_weights('save_weights.h5')
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
# Graph -> GraphDef ProtoBuf
input_graph_def = graph.as_graph_def(add_shapes=True)
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
frozen_graph = freeze_session(tf.keras.backend.get_session(), output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, "model", "tf_model.pb", as_text=False)
運(yùn)行成功后, 會(huì)在當(dāng)前目錄下生成一個(gè)model文件夾, 里面有生成的tf_model.pb文件, 至此, 我們就完成了將tf2.0下訓(xùn)練的模型轉(zhuǎn)到tf1.x下的pb模型, 這樣,就可以用這個(gè)pb模型做其它推理或者轉(zhuǎn)tvm ncnn等模型轉(zhuǎn)換工作.
這個(gè)轉(zhuǎn)換的重點(diǎn)就是通過keras這個(gè)中間商來完成, 所以我們定義的模型就必須要滿足這個(gè)中間商定義的條件
補(bǔ)充知識(shí):tensorflow2.0降級(jí)及如何從別的版本升到2.0
代碼實(shí)踐《tensorflow實(shí)戰(zhàn)GOOGLE深度學(xué)習(xí)框架》時(shí),由于本機(jī)安裝的tensorflow為2.0版本與配套書籍代碼1.4的API不兼容,只得將tensorflow降級(jí)為1.4.0版本使用,降級(jí)方法如下
1 pip uninstall tensorflow

2 pip install tensorflow==1.14.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

驗(yàn)證
import tensorflow as tf
print(tf.version)

二 從別的版本升級(jí)到2.0
自動(dòng)卸載與其相關(guān)包
pip uninstall tensorflow
安裝某版本
pip install --no-cache-dir tensorflow==x.xx (此處填寫2.0)

驗(yàn)證

以上這篇tensorflow 2.0模式下訓(xùn)練的模型轉(zhuǎn)成 tf1.x 版本的pb模型實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
pandas數(shù)據(jù)處理清洗實(shí)現(xiàn)中文地址拆分案例
因?yàn)楹罄m(xù)數(shù)據(jù)分析工作需要用到地理維度進(jìn)行分析,所以需要把login_place字段進(jìn)行拆分成:國家、省份、地區(qū)。感興趣的可以了解一下2021-06-06
Linux下將Python的Django項(xiàng)目部署到Apache服務(wù)器
這篇文章主要介紹了Python的Django項(xiàng)目部署到Apache服務(wù)器上的要點(diǎn)總結(jié),文中針對(duì)的是wsgi連接方式,需要的朋友可以參考下2015-12-12
python網(wǎng)絡(luò)編程之讀取網(wǎng)站根目錄實(shí)例
這篇文章主要介紹了python網(wǎng)絡(luò)編程之讀取網(wǎng)站根目錄實(shí)例,以quux.org站根目錄為例進(jìn)行了實(shí)例分析,代碼簡(jiǎn)單易懂,需要的朋友可以參考下2014-09-09
python獲取響應(yīng)某個(gè)字段值的3種實(shí)現(xiàn)方法
這篇文章主要介紹了python獲取響應(yīng)某個(gè)字段值的3種實(shí)現(xiàn)方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-04-04
Python容器使用的5個(gè)技巧和2個(gè)誤區(qū)總結(jié)
在本篇文章里小編給大家整理的是關(guān)于Python容器使用的5個(gè)技巧和2個(gè)誤區(qū)的相關(guān)知識(shí)點(diǎn)內(nèi)容,需要的朋友們學(xué)習(xí)下。2019-09-09
使用python flask框架開發(fā)圖片上傳接口的案例詳解
剛領(lǐng)導(dǎo)安排任務(wù),需求是這樣的開發(fā)一個(gè)支持多格式圖片上傳的接口,并且將圖片壓縮,支持在線預(yù)覽圖片,下面小編分享下使用python flask框架開發(fā)圖片上傳接口的案例詳解,感興趣的朋友一起看看吧2022-04-04
Pandas在數(shù)據(jù)分析和機(jī)器學(xué)習(xí)中的應(yīng)用及優(yōu)勢(shì)
Pandas是Python中用于數(shù)據(jù)處理和數(shù)據(jù)分析的庫,它提供了靈活的數(shù)據(jù)結(jié)構(gòu)和數(shù)據(jù)操作工具,包括Series和DataFrame等。Pandas還支持大量數(shù)據(jù)操作和數(shù)據(jù)分析功能,包括數(shù)據(jù)清洗、轉(zhuǎn)換、篩選、聚合、透視表、時(shí)間序列分析等2023-04-04
python 爬取古詩文存入mysql數(shù)據(jù)庫的方法
這篇文章主要介紹了python 爬取古詩文存入mysql數(shù)據(jù)庫的方法,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-01-01

