解決Keras TensorFlow 混編中 trainable=False設(shè)置無效問題
這是最近碰到一個(gè)問題,先描述下問題:
首先我有一個(gè)訓(xùn)練好的模型(例如vgg16),我要對(duì)這個(gè)模型進(jìn)行一些改變,例如添加一層全連接層,用于種種原因,我只能用TensorFlow來進(jìn)行模型優(yōu)化,tf的優(yōu)化器,默認(rèn)情況下對(duì)所有tf.trainable_variables()進(jìn)行權(quán)值更新,問題就出在這,明明將vgg16的模型設(shè)置為trainable=False,但是tf的優(yōu)化器仍然對(duì)vgg16做權(quán)值更新
以上就是問題描述,經(jīng)過谷歌百度等等,終于找到了解決辦法,下面我們一點(diǎn)一點(diǎn)的來復(fù)原整個(gè)問題。
trainable=False 無效
首先,我們導(dǎo)入訓(xùn)練好的模型vgg16,對(duì)其設(shè)置成trainable=False
from keras.applications import VGG16 import tensorflow as tf from keras import layers
# 導(dǎo)入模型 base_mode = VGG16(include_top=False) # 查看可訓(xùn)練的變量 tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>, <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>, <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>, <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>, <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>, <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>, <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>, <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>, <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>, <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>, <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>, <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>, <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>, <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>, <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>, <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>, <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]
# 設(shè)置 trainable=False # base_mode.trainable = False似乎也是可以的 for layer in base_mode.layers: layer.trainable = False
設(shè)置好trainable=False后,再次查看可訓(xùn)練的變量,發(fā)現(xiàn)并沒有變化,也就是說設(shè)置無效
# 再次查看可訓(xùn)練的變量
tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>, <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>, <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>, <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>, <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>, <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>, <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>, <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>, <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>, <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>, <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>, <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>, <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>, <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>, <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>, <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>, <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>, <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]
解決的辦法
解決的辦法就是在導(dǎo)入模型的時(shí)候建立一個(gè)variable_scope,將需要訓(xùn)練的變量放在另一個(gè)variable_scope,然后通過tf.get_collection獲取需要訓(xùn)練的變量,最后通過tf的優(yōu)化器中var_list指定需要訓(xùn)練的變量
from keras import models
with tf.variable_scope('base_model'):
base_model = VGG16(include_top=False, input_shape=(224,224,3))
with tf.variable_scope('xxx'):
model = models.Sequential()
model.add(base_model)
model.add(layers.Flatten())
model.add(layers.Dense(10))
# 獲取需要訓(xùn)練的變量 trainable_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'xxx') trainable_var
[<tf.Variable 'xxx_2/dense_1/kernel:0' shape=(25088, 10) dtype=float32_ref>,
<tf.Variable 'xxx_2/dense_1/bias:0' shape=(10,) dtype=float32_ref>]
# 定義tf優(yōu)化器進(jìn)行訓(xùn)練,這里假設(shè)有一個(gè)loss loss = model.output / 2; # 隨便定義的,方便演示 train_step = tf.train.AdamOptimizer().minimize(loss, var_list=trainable_var)
總結(jié)
在keras與TensorFlow混編中,keras中設(shè)置trainable=False對(duì)于TensorFlow而言并不起作用
解決的辦法就是通過variable_scope對(duì)變量進(jìn)行區(qū)分,在通過tf.get_collection來獲取需要訓(xùn)練的變量,最后通過tf優(yōu)化器中var_list指定訓(xùn)練
以上這篇解決Keras TensorFlow 混編中 trainable=False設(shè)置無效問題就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
對(duì)Python的Django框架中的項(xiàng)目進(jìn)行單元測(cè)試的方法
這篇文章主要介紹了對(duì)Python的Django框架中的項(xiàng)目進(jìn)行單元測(cè)試的方法,使用Django中的tests.py模塊可以輕松地檢測(cè)出一些常見錯(cuò)誤,需要的朋友可以參考下2016-04-04
分布式全文檢索引擎ElasticSearch原理及使用實(shí)例
這篇文章主要介紹了分布式全文檢索引擎ElasticSearch原理及使用實(shí)例,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-11-11
最小公倍數(shù)Python實(shí)現(xiàn)的方法例子
這篇文章介紹了兩種計(jì)算最小公倍數(shù)的方法:輾轉(zhuǎn)相除法(歐幾里德法)和相減法(更相減損法),這兩種方法分別基于求最大公約數(shù)的不同原理,需要的朋友可以參考下2024-11-11
Python多進(jìn)程與多線程的使用場(chǎng)景詳解
這篇文章主要給大家介紹了關(guān)于Python多進(jìn)程與多線程使用場(chǎng)景的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-03-03
Python中將圖像轉(zhuǎn)換為PDF的方法實(shí)現(xiàn)
本文主要介紹了Python中將圖像轉(zhuǎn)換為PDF的方法實(shí)現(xiàn),主要使用img2pdf和PyPDF2軟件包,具有一定的參考價(jià)值,感興趣的可以了解一下2023-08-08

