Tensorflow之Saver的用法詳解
Saver的用法
1. Saver的背景介紹
我們經(jīng)常在訓(xùn)練完一個(gè)模型之后希望保存訓(xùn)練的結(jié)果,這些結(jié)果指的是模型的參數(shù),以便下次迭代的訓(xùn)練或者用作測(cè)試。Tensorflow針對(duì)這一需求提供了Saver類。
Saver類提供了向checkpoints文件保存和從checkpoints文件中恢復(fù)變量的相關(guān)方法。Checkpoints文件是一個(gè)二進(jìn)制文件,它把變量名映射到對(duì)應(yīng)的tensor值 。
只要提供一個(gè)計(jì)數(shù)器,當(dāng)計(jì)數(shù)器觸發(fā)時(shí),Saver類可以自動(dòng)的生成checkpoint文件。這讓我們可以在訓(xùn)練過(guò)程中保存多個(gè)中間結(jié)果。例如,我們可以保存每一步訓(xùn)練的結(jié)果。
為了避免填滿整個(gè)磁盤,Saver可以自動(dòng)的管理Checkpoints文件。例如,我們可以指定保存最近的N個(gè)Checkpoints文件。
2. Saver的實(shí)例
下面以一個(gè)例子來(lái)講述如何使用Saver類
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))
- isTrain:用來(lái)區(qū)分訓(xùn)練階段和測(cè)試階段,True表示訓(xùn)練,F(xiàn)alse表示測(cè)試
- train_steps:表示訓(xùn)練的次數(shù),例子中使用100
- checkpoint_steps:表示訓(xùn)練多少次保存一下checkpoints,例子中使用50
- checkpoint_dir:表示checkpoints文件的保存路徑,例子中使用當(dāng)前路徑
2.1 訓(xùn)練階段
使用Saver.save()方法保存模型:
- sess:表示當(dāng)前會(huì)話,當(dāng)前會(huì)話記錄了當(dāng)前的變量值
- checkpoint_dir + 'model.ckpt':表示存儲(chǔ)的文件名
- global_step:表示當(dāng)前是第幾步
訓(xùn)練完成后,當(dāng)前目錄底下會(huì)多出5個(gè)文件。

打開(kāi)名為“checkpoint”的文件,可以看到保存記錄,和最新的模型存儲(chǔ)位置。

2.1測(cè)試階段
測(cè)試階段使用saver.restore()方法恢復(fù)變量:
sess:表示當(dāng)前會(huì)話,之前保存的結(jié)果將被加載入這個(gè)會(huì)話
ckpt.model_checkpoint_path:表示模型存儲(chǔ)的位置,不需要提供模型的名字,它會(huì)去查看checkpoint文件,看看最新的是誰(shuí),叫做什么。
運(yùn)行結(jié)果如下圖所示,加載了之前訓(xùn)練的參數(shù)w和b的結(jié)果

以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
numpy多級(jí)排序lexsort函數(shù)的使用
本文主要介紹了numpy多級(jí)排序lexsort函數(shù)的使用,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03
python自動(dòng)獲取微信公眾號(hào)最新文章的實(shí)現(xiàn)代碼
這篇文章主要介紹了python自動(dòng)獲取微信公眾號(hào)最新文章,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-07-07
Python一句代碼實(shí)現(xiàn)找出所有水仙花數(shù)的方法
今天小編就為大家分享一篇Python一句代碼實(shí)現(xiàn)找出所有水仙花數(shù)的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-11-11
Django應(yīng)用程序中如何發(fā)送電子郵件詳解
我們常常會(huì)用到一些發(fā)送郵件的功能,比如有人提交了應(yīng)聘的表單,可以向HR的郵箱發(fā)郵件,這樣,HR不看網(wǎng)站就可以知道有人在網(wǎng)站上提交了應(yīng)聘信息。下面這篇文章就介紹了在Django應(yīng)用程序中如何發(fā)送電子郵件的相關(guān)資料,需要的朋友可以參考借鑒。2017-02-02
python制作小說(shuō)爬蟲(chóng)實(shí)錄
本文給大家介紹的是作者所寫(xiě)的第一個(gè)爬蟲(chóng)程序的全過(guò)程,從構(gòu)思到思路到程序的編寫(xiě),非常的細(xì)致,有需要的小伙伴可以參考下2017-08-08
python使用xlrd實(shí)現(xiàn)檢索excel中某列含有指定字符串記錄的方法
這篇文章主要介紹了python使用xlrd實(shí)現(xiàn)檢索excel中某列含有指定字符串記錄的方法,涉及Python使用xlrd模塊檢索Excel的技巧,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2015-05-05
Python?ArcPy實(shí)現(xiàn)批量拼接長(zhǎng)時(shí)間序列柵格圖像
這篇文章主要介紹了如何基于Python中ArcPy模塊,對(duì)大量不同時(shí)相的柵格遙感影像按照其成像時(shí)間依次執(zhí)行批量拼接的方法,感興趣的可以了解一下2023-03-03
深入淺析Python 函數(shù)注解與匿名函數(shù)
這篇文章主要介紹了Python 函數(shù)注解與匿名函數(shù)的相關(guān)知識(shí),本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-02-02

