python使用tensorflow保存、加載和使用模型的方法
使用Tensorflow進(jìn)行深度學(xué)習(xí)訓(xùn)練的時(shí)候,需要對(duì)訓(xùn)練好的網(wǎng)絡(luò)模型和各種參數(shù)進(jìn)行保存,以便在此基礎(chǔ)上繼續(xù)訓(xùn)練或者使用。介紹這方面的博客有很多,我發(fā)現(xiàn)寫(xiě)的最好的是這一篇官方英文介紹:
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
我對(duì)這篇文章進(jìn)行了整理和匯總。
首先是模型的保存。直接上代碼:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut1_save.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 11:04:25
############################
import tensorflow as tf
# prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2')
b1 = tf.Variable(2.0, name = 'bias1')
feed_dict = {w1:[10,3], w2:[5,5]}
# define a test operation that will be restored
w3 = tf.add(w1, w2) # without name, w3 will not be stored
w4 = tf.multiply(w3, b1, name = "op_to_restore")
#saver = tf.train.Saver()
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print sess.run(w4, feed_dict)
#saver.save(sess, 'my_test_model', global_step = 100)
saver.save(sess, 'my_test_model')
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)
需要說(shuō)明的有以下幾點(diǎn):
1. 創(chuàng)建saver的時(shí)候可以指明要存儲(chǔ)的tensor,如果不指明,就會(huì)全部存下來(lái)。在這里也可以指明最大存儲(chǔ)數(shù)量和checkpoint的記錄時(shí)間。具體細(xì)節(jié)看英文博客。
2. saver.save()函數(shù)里面可以設(shè)定global_step和write_meta_graph,meta存儲(chǔ)的是網(wǎng)絡(luò)結(jié)構(gòu),只在開(kāi)始運(yùn)行程序的時(shí)候存儲(chǔ)一次即可,后續(xù)可以通過(guò)設(shè)置write_meta_graph = False加以限制。
3. 這個(gè)程序執(zhí)行結(jié)束后,會(huì)在程序目錄下生成四個(gè)文件,分別是.meta(存儲(chǔ)網(wǎng)絡(luò)結(jié)構(gòu))、.data和.index(存儲(chǔ)訓(xùn)練好的參數(shù))、checkpoint(記錄最新的模型)。
下面是如何加載已經(jīng)保存的網(wǎng)絡(luò)模型。這里有兩種方法,第一種是saver.restore(sess, 'aaaa.ckpt'),這種方法的本質(zhì)是讀取全部參數(shù),并加載到已經(jīng)定義好的網(wǎng)絡(luò)結(jié)構(gòu)上,因此相當(dāng)于給網(wǎng)絡(luò)的weights和biases賦值并執(zhí)行tf.global_variables_initializer()。這種方法的缺點(diǎn)是使用前必須重寫(xiě)網(wǎng)絡(luò)結(jié)構(gòu),而且網(wǎng)絡(luò)結(jié)構(gòu)要和保存的參數(shù)完全對(duì)上。第二種就比較高端了,直接把網(wǎng)絡(luò)結(jié)構(gòu)加載進(jìn)來(lái)(.meta),上代碼:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut2_import.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:16:38
############################
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my_test_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run('w1:0')
使用加載的模型,輸入新數(shù)據(jù),計(jì)算輸出,還是直接上代碼:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut3_reuse.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:33:35
############################
import tensorflow as tf
sess = tf.Session()
# First, load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
# Second, access and create placeholders variables and create feed_dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name('w1:0')
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:[-1,1], w2:[4,6]}
# Access the op that want to run
op_to_restore = graph.get_tensor_by_name('op_to_restore:0')
print sess.run(op_to_restore, feed_dict) # ouotput: [6. 14.]
在已經(jīng)加載的網(wǎng)絡(luò)后繼續(xù)加入新的網(wǎng)絡(luò)層:
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)
print sess.run(add_on_op,feed_dict)
#This will print 120.
對(duì)加載的網(wǎng)絡(luò)進(jìn)行局部修改和處理(這個(gè)最麻煩,我還沒(méi)搞太明白,后續(xù)會(huì)繼續(xù)補(bǔ)充):
......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
# Now, you run this with fine-tuning data in sess.run()
有了這樣的方法,無(wú)論是自行訓(xùn)練、加載模型繼續(xù)訓(xùn)練、使用經(jīng)典模型還是finetune經(jīng)典模型抑或是加載網(wǎng)絡(luò)跑前項(xiàng),效果都是杠杠的。
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
詳解如何在Python中有效調(diào)用JavaScript
JavaScript和Python都是極為流行的編程語(yǔ)言,并在前端開(kāi)發(fā)和后端開(kāi)發(fā)領(lǐng)域扮演著重要的角色,那么Python如何更好的契合JavaScript呢,下面就跟隨小編一起學(xué)習(xí)一下吧2024-02-02
python 實(shí)現(xiàn)添加標(biāo)簽&打標(biāo)簽的操作
這篇文章主要介紹了python 實(shí)現(xiàn)添加標(biāo)簽&打標(biāo)簽的操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-05-05
Python功能點(diǎn)實(shí)現(xiàn):函數(shù)級(jí)/代碼塊級(jí)計(jì)時(shí)器
今天小編就為大家分享一篇關(guān)于Python功能點(diǎn)實(shí)現(xiàn):函數(shù)級(jí)/代碼塊級(jí)計(jì)時(shí)器,小編覺(jué)得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來(lái)看看吧2019-01-01
五個(gè)Pandas?實(shí)戰(zhàn)案例帶你分析操作數(shù)據(jù)
pandas是基于NumPy的一種工具,該工具是為了解決數(shù)據(jù)分析任務(wù)而創(chuàng)建的。Pandas納入了大量庫(kù)和一些標(biāo)準(zhǔn)的數(shù)據(jù)模型,提供了高效操作大型數(shù)據(jù)集的工具。pandas提供大量快速便捷地處理數(shù)據(jù)的函數(shù)和方法。你很快就會(huì)發(fā)現(xiàn),它是使Python強(qiáng)大而高效的數(shù)據(jù)分析環(huán)境的重要因素之一2022-01-01
對(duì)python 生成拼接xml報(bào)文的示例詳解
今天小編就為大家分享一篇對(duì)python 生成拼接xml報(bào)文的示例詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-12-12
python 讀取txt,json和hdf5文件的實(shí)例
今天小編就為大家分享一篇python 讀取txt,json和hdf5文件的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-06-06
Python實(shí)現(xiàn)修改IE注冊(cè)表功能示例
這篇文章主要介紹了Python實(shí)現(xiàn)修改IE注冊(cè)表功能,結(jié)合完整實(shí)例形式分析了Python操作IE注冊(cè)表項(xiàng)的相關(guān)實(shí)現(xiàn)技巧與注意事項(xiàng),需要的朋友可以參考下2018-05-05
基于python不同開(kāi)根號(hào)的速度對(duì)比分析
這篇文章主要介紹了基于python不同開(kāi)根號(hào)的速度對(duì)比分析,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2021-03-03

