tensorflow模型保存、加載之變量重命名實(shí)例
話(huà)不多說(shuō),干就完了。
變量重命名的用處?
簡(jiǎn)單定義:簡(jiǎn)單來(lái)說(shuō)就是將模型A中的參數(shù)parameter_A賦給模型B中的parameter_B
使用場(chǎng)景:當(dāng)需要使用已經(jīng)訓(xùn)練好的模型參數(shù),尤其是使用別人訓(xùn)練好的模型參數(shù)時(shí),往往別人模型中的參數(shù)命名方式與自己當(dāng)前的命名方式不同,所以在加載模型參數(shù)時(shí)需要對(duì)參數(shù)進(jìn)行重命名,使得代碼更簡(jiǎn)潔易懂。
實(shí)現(xiàn)方法:
1)、模型保存
import os
import tensorflow as tf
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
mean=0.0,
stddev=0.1),
dtype=tf.float32,
name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
dtype=tf.float32,
name="biases")
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
dtype=tf.float32,
name="weights_2")
# saver checkpoint
if os.path.exists("checkpoints") is False:
os.makedirs("checkpoints")
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = [tf.global_variables_initializer()]
sess.run(init_op)
saver.save(sess=sess, save_path="checkpoints/variable.ckpt")
2)、模型加載(變量名稱(chēng)保持不變)
import tensorflow as tf
from matplotlib import pyplot as plt
import os
current_path = os.path.dirname(os.path.abspath(__file__))
def restore_variable(sess):
# need not initilize variable, but need to define the same variable like checkpoint
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
mean=0.0,
stddev=0.1),
dtype=tf.float32,
name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
dtype=tf.float32,
name="biases")
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
dtype=tf.float32,
name="weights_2")
saver = tf.train.Saver()
ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
saver.restore(sess=sess, save_path=ckpt_path)
weights_val, weights_2_val = sess.run(
[
tf.reshape(weights, shape=[2048]),
tf.reshape(weights_2, shape=[2048])
]
)
plt.subplot(1, 2, 1)
plt.scatter([i for i in range(len(weights_val))], weights_val)
plt.subplot(1, 2, 2)
plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
plt.show()
if __name__ == '__main__':
with tf.Session() as sess:
restore_variable(sess)
3)、模型加載(變量重命名)
import tensorflow as tf
from matplotlib import pyplot as plt
import os
current_path = os.path.dirname(os.path.abspath(__file__))
def restore_variable_renamed(sess):
conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
mean=0.0,
stddev=0.1),
dtype=tf.float32,
name="conv1_w")
conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
dtype=tf.float32,
name="conv1_b")
conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
dtype=tf.float32,
name="conv2_w")
# variable named 'weights' in ckpt assigned to current variable conv1_w
# variable named 'biases' in ckpt assigned to current variable conv1_b
# variable named 'weights_2' in ckpt assigned to current variable conv2_w
saver = tf.train.Saver({
"weights": conv1_w,
"biases": conv1_b,
"weights_2": conv2_w
})
ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
saver.restore(sess=sess, save_path=ckpt_path)
conv1_w__val, conv2_w__val = sess.run(
[
tf.reshape(conv1_w, shape=[2048]),
tf.reshape(conv2_w, shape=[2048])
]
)
plt.subplot(1, 2, 1)
plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
plt.subplot(1, 2, 2)
plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
plt.show()
if __name__ == '__main__':
with tf.Session() as sess:
restore_variable_renamed(sess)
總結(jié):
# 之前模型中叫 'weights'的變量賦值給當(dāng)前的conv1_w變量
# 之前模型中叫 'biases' 的變量賦值給當(dāng)前的conv1_b變量
# 之前模型中叫 'weights_2'的變量賦值給當(dāng)前的conv2_w變量
saver = tf.train.Saver({
"weights": conv1_w,
"biases": conv1_b,
"weights_2": conv2_w
})
以上這篇tensorflow模型保存、加載之變量重命名實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Django在Win7下的安裝及創(chuàng)建項(xiàng)目hello word簡(jiǎn)明教程
這篇文章主要介紹了Django在Win7下的安裝及創(chuàng)建項(xiàng)目hello word,需要的朋友可以參考下2014-07-07
Pytorch的torch.utils.data中Dataset以及DataLoader示例詳解
torch.utils.data?是?PyTorch?提供的一個(gè)模塊,用于處理和加載數(shù)據(jù),該模塊提供了一系列工具類(lèi)和函數(shù),用于創(chuàng)建、操作和批量加載數(shù)據(jù)集,這篇文章主要介紹了Pytorch的torch.utils.data中Dataset以及DataLoader等詳解,需要的朋友可以參考下2023-08-08
Python math庫(kù) ln(x)運(yùn)算的實(shí)現(xiàn)及原理
這篇文章主要介紹了Python math庫(kù) ln(x)運(yùn)算的實(shí)現(xiàn)及原理,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07
解決pycharm無(wú)法識(shí)別本地site-packages的問(wèn)題
今天小編就為大家分享一篇解決pycharm無(wú)法識(shí)別本地site-packages的問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-10-10
Python標(biāo)準(zhǔn)庫(kù)之Sys模塊使用詳解
這篇文章主要介紹了Python標(biāo)準(zhǔn)庫(kù)之Sys模塊使用詳解,本文講解了使用sys模塊獲得腳本的參數(shù)、處理模塊、使用sys模塊操作模塊搜索路徑、使用sys模塊查找內(nèi)建模塊、使用sys模塊查找已導(dǎo)入的模塊等使用案例,需要的朋友可以參考下2015-05-05
關(guān)于tf.matmul() 和tf.multiply() 的區(qū)別說(shuō)明
這篇文章主要介紹了關(guān)于tf.matmul() 和tf.multiply() 的區(qū)別說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-06-06

