TensorFlow的權(quán)值更新方法
一. MovingAverage權(quán)值滑動(dòng)平均更新
1.1 示例代碼:
def create_target_q_network(self,state_dim,action_dim,net):
state_input = tf.placeholder("float",[None,state_dim])
action_input = tf.placeholder("float",[None,action_dim])
ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
target_update = ema.apply(net)
target_net = [ema.average(x) for x in net]
layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])
return state_input,action_input,q_value_output,target_update
def update_target(self):
self.sess.run(self.target_update)
其中,TAU=0.001,net是原始網(wǎng)絡(luò)(該示例代碼來(lái)自DDPG算法,經(jīng)過(guò)滑動(dòng)更新后的target_net是目標(biāo)網(wǎng)絡(luò) )
第一句 tf.train.ExponentialMovingAverage,創(chuàng)建一個(gè)權(quán)值滑動(dòng)平均的實(shí)例;
第二句 apply創(chuàng)建所訓(xùn)練模型參數(shù)的一個(gè)復(fù)制品(shadow_variable),并對(duì)這個(gè)復(fù)制品增加一個(gè)保留權(quán)值滑動(dòng)平均的op,函數(shù)average()或average_name()可以用來(lái)獲取最終這個(gè)復(fù)制品(平滑后)的值的。
更新公式為:
shadow_variable = decay * shadow_variable + (1 - decay) * variable
在上述代碼段中,target_net是shadow_variable,net是variable
1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)
var_list必須是Variable或Tensor形式的列表。這個(gè)方法對(duì)var_list中所有元素創(chuàng)建一個(gè)復(fù)制,當(dāng)其是Variable類(lèi)型時(shí),shadow_variable被初始化為variable的初值,當(dāng)其是Tensor類(lèi)型時(shí),初始化為0,無(wú)偏。
函數(shù)返回一個(gè)進(jìn)行權(quán)值平滑的op,因此更新目標(biāo)網(wǎng)絡(luò)時(shí)單獨(dú)run這個(gè)函數(shù)就行。
1.3 tf.train.ExponentialMovingAverage.average(var)
用于獲取var的滑動(dòng)平均結(jié)果。
二. tf.train.Optimizer更新網(wǎng)絡(luò)權(quán)值
2.1 tf.train.Optimizer
tf.train.Optimizer允許網(wǎng)絡(luò)通過(guò)minimize()損失函數(shù)自動(dòng)進(jìn)行權(quán)值更新,此時(shí)tf.train.Optimizer.minimize()做了兩件事:計(jì)算梯度,并把梯度自動(dòng)更新到權(quán)值上。
此外,tensorflow也允許用戶(hù)自己計(jì)算梯度,并做處理后應(yīng)用給權(quán)值進(jìn)行更新,此時(shí)分為以下三個(gè)步驟:
1.利用tf.train.Optimizer.compute_gradients計(jì)算梯度
2.對(duì)梯度進(jìn)行自定義處理
3.利用tf.train.Optimizer.apply_gradients更新權(quán)值
tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)
返回一個(gè)(梯度,權(quán)值)的列表對(duì)。
tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)
返回一個(gè)更新權(quán)值的op,因此可以用它的返回值ret進(jìn)行sess.run(ret)
2.2 其它
此外,tensorflow還提供了其它計(jì)算梯度的方法:
• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)
該函數(shù)計(jì)算ys在xs方向上的梯度,需要注意與train.compute_gradients所不同的地方是,該函數(shù)返回一組dydx dydx的列表,而不是梯度-權(quán)值對(duì)。
其中,gate_gradients是在ys方向上的初始梯度,個(gè)人理解可以看做是偏微分鏈?zhǔn)角髮?dǎo)中所需要的。
• tf.stop_gradient(input, name=None)
該函數(shù)告知整個(gè)graph圖中,對(duì)input不進(jìn)行梯度計(jì)算,將其偽裝成一個(gè)constant常量。比如,可以用在類(lèi)似于DQN算法中的目標(biāo)函數(shù):
cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|
可以事先聲明
y=tf.stop_gradient(r+Q next r+Qnext)
以上這篇TensorFlow的權(quán)值更新方法就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
videocapture庫(kù)制作python視頻高速傳輸程序
python視頻高速傳輸程序,大家參考使用吧2013-12-12
Pytorch數(shù)據(jù)拼接與拆分操作實(shí)現(xiàn)圖解
這篇文章主要介紹了Pytorch數(shù)據(jù)拼接與拆分操作實(shí)現(xiàn)圖解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-04-04
基于python實(shí)現(xiàn)cdn日志文件導(dǎo)入mysql進(jìn)行分析
這篇文章主要介紹了基于python實(shí)現(xiàn)cdn日志文件導(dǎo)入mysql進(jìn)行分析,本文以阿里云CDN日志作為輔助查詢(xún)數(shù)據(jù)展開(kāi)主題內(nèi)容,其它云平臺(tái)大同小異,需要的小伙伴可以參考一下2022-05-05
Python實(shí)現(xiàn)二分查找與bisect模塊詳解
二分查找又叫折半查找,二分查找應(yīng)該屬于減治技術(shù)的成功應(yīng)用。python標(biāo)準(zhǔn)庫(kù)中還有一個(gè)灰常給力的模塊,那就是bisect。這個(gè)庫(kù)接受有序的序列,內(nèi)部實(shí)現(xiàn)就是二分。下面這篇文章就詳細(xì)介紹了Python如何實(shí)現(xiàn)二分查找與bisect模塊,需要的朋友可以參考借鑒,下面來(lái)一起看看吧。2017-01-01
基于Python實(shí)現(xiàn)身份證信息識(shí)別功能
身份證是用于證明個(gè)人身份和身份信息的官方證件,在現(xiàn)代社會(huì)中,身份證被廣泛應(yīng)用于各種場(chǎng)景,如就業(yè)、教育、醫(yī)療、金融等,它包含了個(gè)人的基本信息,本文給大家介紹了如何基于Python實(shí)現(xiàn)身份證信息識(shí)別功能,感興趣的朋友可以參考下2024-01-01
python使用yield壓平嵌套字典的超簡(jiǎn)單方法
這篇文章主要給大家介紹了關(guān)于python使用yield壓平嵌套字典的超簡(jiǎn)單方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者使用python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-11-11
使用IPython或Spyder將省略號(hào)表示的內(nèi)容完整輸出
這篇文章主要介紹了使用IPython或Spyder將省略號(hào)表示的內(nèi)容完整輸出,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-04-04

