TensorFlow——Checkpoint為模型添加檢查點的實例
1.檢查點
保存模型并不限于在訓練模型后,在訓練模型之中也需要保存,因為TensorFlow訓練模型時難免會出現(xiàn)中斷的情況,我們自然希望能夠將訓練得到的參數(shù)保存下來,否則下次又要重新訓練。
這種在訓練中保存模型,習慣上稱之為保存檢查點。
2.添加保存點
通過添加檢查點,可以生成載入檢查點文件,并能夠指定生成檢查文件的個數(shù),例如使用saver的另一個參數(shù)——max_to_keep=1,表明最多只保存一個檢查點文件,在保存時使用如下的代碼傳入迭代次數(shù)。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5
plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()
tf.reset_default_graph()
X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)
w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')
z = tf.multiply(X, w) + b
cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
training_epochs = 20
display_step = 2
saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"
if __name__ == '__main__':
with tf.Session() as sess:
sess.run(init)
loss_list = []
for epoch in range(training_epochs):
for (x, y) in zip(train_x, train_y):
sess.run(optimizer, feed_dict={X: x, Y: y})
if epoch % display_step == 0:
loss = sess.run(cost, feed_dict={X: x, Y: y})
loss_list.append(loss)
print('Iter: ', epoch, ' Loss: ', loss)
w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
saver.save(sess, savedir + "linear.cpkt", global_step=epoch)
print(" Finished ")
print("W: ", w_, " b: ", b_, " loss: ", loss)
plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
plt.grid(True)
plt.show()
load_epoch = 10
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))
在上述的代碼中,我們使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)將訓練的參數(shù)傳入檢查點進行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一個文件,這樣在訓練過程中得到的新的模型就會覆蓋以前的模型。
cpkt = tf.train.get_checkpoint_state(savedir) if cpkt and cpkt.model_checkpoint_path: saver.restore(sess2, cpkt.model_checkpoint_path) kpt = tf.train.latest_checkpoint(savedir) saver.restore(sess2, kpt)
上述的兩種方法也可以對checkpoint文件進行加載,tf.train.latest_checkpoint(savedir)為加載最后的檢查點文件。這種方式,我們可以通過保存指定訓練次數(shù)的檢查點,比如保存5的倍數(shù)次保存一下檢查點。
3.簡便保存檢查點
我們還可以用更加簡單的方法進行檢查點的保存,tf.train.MonitoredTrainingSession()函數(shù),該函數(shù)可以直接實現(xiàn)保存載入檢查點模型的文件,與前面的方法不同的是,它是按照訓練時間來保存檢查點的,可以通過指定save_checkpoint_secs參數(shù)的具體秒數(shù),設置多久保存一次檢查點。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5
# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()
tf.reset_default_graph()
X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)
w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')
z = tf.multiply(X, w) + b
cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
training_epochs = 30
display_step = 2
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
saver = tf.train.Saver()
savedir = "check-point/"
if __name__ == '__main__':
with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
sess.run(init)
loss_list = []
for epoch in range(training_epochs):
sess.run(global_step)
for (x, y) in zip(train_x, train_y):
sess.run(optimizer, feed_dict={X: x, Y: y})
if epoch % display_step == 0:
loss = sess.run(cost, feed_dict={X: x, Y: y})
loss_list.append(loss)
print('Iter: ', epoch, ' Loss: ', loss)
w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
sess.run(step)
print(" Finished ")
print("W: ", w_, " b: ", b_, " loss: ", loss)
plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
plt.grid(True)
plt.show()
load_epoch = 10
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
# saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))
# cpkt = tf.train.get_checkpoint_state(savedir)
# if cpkt and cpkt.model_checkpoint_path:
# saver.restore(sess2, cpkt.model_checkpoint_path)
#
kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')
saver.restore(sess2, kpt)
print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))
上述的代碼中,我們設置了沒訓練了5秒中之后,就保存一次檢查點,它默認的保存時間間隔是10分鐘,這種按照時間的保存模式更適合使用大型數(shù)據(jù)集訓練復雜模型的情況,注意在使用上述的方法時,要定義global_step變量,在訓練完一個批次或者一個樣本之后,要將其進行加1的操作,否則將會報錯。

以上這篇TensorFlow——Checkpoint為模型添加檢查點的實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python socket實現(xiàn)的文件下載器功能示例
這篇文章主要介紹了Python socket實現(xiàn)的文件下載器功能,結合實例形式分析了Python使用socket模塊實現(xiàn)的文件下載器客戶端與服務器端相關操作技巧,需要的朋友可以參考下2019-11-11
淺談selenium如何應對網(wǎng)頁內(nèi)容需要鼠標滾動加載的問題
這篇文章主要介紹了淺談selenium如何應對網(wǎng)頁內(nèi)容需要鼠標滾動加載的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-03-03
python 統(tǒng)計數(shù)組中元素出現(xiàn)次數(shù)并進行排序的實例
今天小編就為大家分享一篇python 統(tǒng)計數(shù)組中元素出現(xiàn)次數(shù)并進行排序的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07
Python調(diào)用olmOCR大模型實現(xiàn)提取復雜PDF文件內(nèi)容
olmocr是由Allen人工智能研究所(AI2)開發(fā)的一個開源工具包,旨在高效地將PDF和其他文檔轉換為結構化的純文本,同時保持自然閱讀順序,下面我們來看看如何使用olmOCR大模型實現(xiàn)提取復雜PDF文件內(nèi)容吧2025-03-03
Appium+python自動化之連接模擬器并啟動淘寶APP(超詳解)
這篇文章主要介紹了Appium+python自動化之 連接模擬器并啟動淘寶APP(超詳解)本文以淘寶app為例,通過實例代碼給大家介紹的非常詳細,需要的朋友可以參考下2019-06-06

