Tensorflow實(shí)現(xiàn)部分參數(shù)梯度更新操作
在深度學(xué)習(xí)中,遷移學(xué)習(xí)經(jīng)常被使用,在大數(shù)據(jù)集上預(yù)訓(xùn)練的模型遷移到特定的任務(wù),往往需要保持模型參數(shù)不變,而微調(diào)與任務(wù)相關(guān)的模型層。
本文主要介紹,使用tensorflow部分更新模型參數(shù)的方法。
1. 根據(jù)Variable scope剔除需要固定參數(shù)的變量
def get_variable_via_scope(scope_lst):
vars = []
for sc in scope_lst:
sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
vars.extend(sc_variable)
return vars
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
no_change_vars = get_variable_via_scope(no_change_scope)
for v in no_change_vars:
trainable_vars.remove(v)
grads, _ = tf.gradients(loss, trainable_vars)
optimizer = tf.train.AdamOptimizer(lr)
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)
2. 使用tf.stop_gradient()函數(shù)
在建立Graph過(guò)程中使用該函數(shù),非常簡(jiǎn)潔地避免了使用scope獲取參數(shù)
3. 一個(gè)矩陣中部分行或列參數(shù)更新
如果一個(gè)矩陣,只有部分行或列需要更新參數(shù),其它保持不變,該場(chǎng)景很常見(jiàn),例如word embedding中,一些預(yù)定義的領(lǐng)域相關(guān)詞保持不變(使用領(lǐng)域相關(guān)word embedding初始化),而另一些通用詞變化。
import tensorflow as tf import numpy as np def entry_stop_gradients(target, mask): mask_h = tf.abs(mask-1) return tf.stop_gradient(mask_h * target) + mask * target mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1]) mask_h = np.abs(mask-1) emb = tf.constant(np.ones([10, 5])) matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1)) parm = np.random.randn(5, 1) t_parm = tf.constant(parm) loss = tf.reduce_sum(tf.matmul(matrix, t_parm)) grad1 = tf.gradients(loss, emb) grad2 = tf.gradients(loss, matrix) print matrix with tf.Session() as sess: print sess.run(loss) print sess.run([grad1, grad2])
以上這篇Tensorflow實(shí)現(xiàn)部分參數(shù)梯度更新操作就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- tensorflow 實(shí)現(xiàn)自定義梯度反向傳播代碼
- 有關(guān)Tensorflow梯度下降常用的優(yōu)化方法分享
- TensorFlow梯度求解tf.gradients實(shí)例
- 基于TensorFlow中自定義梯度的2種方式
- tensorflow 查看梯度方式
- tensorflow求導(dǎo)和梯度計(jì)算實(shí)例
- Tensorflow的梯度異步更新示例
- 在Tensorflow中實(shí)現(xiàn)梯度下降法更新參數(shù)值
- 運(yùn)用TensorFlow進(jìn)行簡(jiǎn)單實(shí)現(xiàn)線性回歸、梯度下降示例
- Tensorflow 卷積的梯度反向傳播過(guò)程
相關(guān)文章
Python提示[Errno 32]Broken pipe導(dǎo)致線程crash錯(cuò)誤解決方法
這篇文章主要介紹了Python提示[Errno 32]Broken pipe導(dǎo)致線程crash錯(cuò)誤解決方法,是ThreadingHTTPServer實(shí)現(xiàn)http服務(wù)中經(jīng)常會(huì)遇到的問(wèn)題,需要的朋友可以參考下2014-11-11
Python利用treap實(shí)現(xiàn)雙索引的方法
所遍歷的元素一定是遞增(小堆)或是遞減(大堆)關(guān)系,但是我們無(wú)法得知左子樹(shù)與右子樹(shù)兩部分節(jié)點(diǎn)的排序關(guān)系。本文就來(lái)講講算法和數(shù)據(jù)結(jié)構(gòu)共同滿足一組特性,感興趣的小伙伴請(qǐng)參考下面文章的內(nèi)容2021-09-09
python不到50行代碼完成了多張excel合并的實(shí)現(xiàn)示例
這篇文章主要介紹了python不到50行代碼完成了多張excel合并的實(shí)現(xiàn)示例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-05-05
Python-Seaborn熱圖繪制的實(shí)現(xiàn)方法
這篇文章主要介紹了Python-Seaborn熱圖繪制的實(shí)現(xiàn)方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07
深入解析Python編程中super關(guān)鍵字的用法
Python的子類調(diào)用父類成員時(shí)可以用到super關(guān)鍵字,初始化時(shí)需要注意super()和__init__()的區(qū)別,下面我們就來(lái)深入解析Python編程中super關(guān)鍵字的用法:2016-06-06

