keras實(shí)現(xiàn)theano和tensorflow訓(xùn)練的模型相互轉(zhuǎn)換
我就廢話不多說了,大家還是直接看代碼吧~
</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">
# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
def th2tf( model):
import tensorflow as tf
ops = []
for layer in model.layers:
if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
original_w = K.get_value(layer.W)
converted_w = convert_kernel(original_w)
ops.append(tf.assign(layer.W, converted_w).op)
K.get_session().run(ops)
return model
def tf2th(model):
for layer in model.layers:
if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
original_w = K.get_value(layer.W)
converted_w = convert_kernel(original_w)
K.set_value(layer.W, converted_w)
return model
def conv_layer_converted(tf_weights, th_weights, m = 0):
"""
:param tf_weights:
:param th_weights:
:param m: 0-tf2th, 1-th2tf
:return:
"""
if m == 0: # tf2th
tc = keras_text_classifier(weights_path=tf_weights)
model = tc.loadmodel()
model = tf2th(model)
model.save_weights(th_weights)
elif m == 1: # th2tf
tc = keras_text_classifier(weights_path=th_weights)
model = tc.loadmodel()
model = th2tf(model)
model.save_weights(tf_weights)
else:
print("0-tf2th, 1-th2tf")
return
if __name__ == '__main__':
if len(sys.argv) < 4:
print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
sys.exit(0)
tf_weights = sys.argv[1]
th_weights = sys.argv[2]
m = int(sys.argv[3])
conv_layer_converted(tf_weights, th_weights, m)
補(bǔ)充知識(shí):keras學(xué)習(xí)之修改底層為TensorFlow還是theano
我們知道,keras的底層是TensorFlow或者theano
要知道我們是用的哪個(gè)為底層,只需要import keras即可顯示
修改方法:
打開

修改

以上這篇keras實(shí)現(xiàn)theano和tensorflow訓(xùn)練的模型相互轉(zhuǎn)換就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python3 sleep 延時(shí)秒 毫秒實(shí)例
這篇文章主要介紹了python3 sleep 延時(shí)秒 毫秒實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-05-05
Queue隊(duì)列中join()與task_done()的關(guān)系及說明
Python抓取通過Ajax加載數(shù)據(jù)的示例
python實(shí)現(xiàn)大戰(zhàn)外星人小游戲?qū)嵗a
python正則實(shí)現(xiàn)計(jì)算器功能

