keras導(dǎo)入weights方式
keras源碼engine中toplogy.py定義了加載權(quán)重的函數(shù):
load_weights(self, filepath, by_name=False)
其中默認(rèn)by_name為False,這時(shí)候加載權(quán)重按照網(wǎng)絡(luò)拓?fù)浣Y(jié)構(gòu)加載,適合直接使用keras中自帶的網(wǎng)絡(luò)模型,如VGG16
VGG19/resnet50等,源碼描述如下:
If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.
若將by_name改為True則加載權(quán)重按照layer的name進(jìn)行,layer的name相同時(shí)加載權(quán)重,適合用于改變了
模型的相關(guān)結(jié)構(gòu)或增加了節(jié)點(diǎn)但利用了原網(wǎng)絡(luò)的主體結(jié)構(gòu)情況下使用,源碼描述如下:
If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.
在進(jìn)行邊緣檢測時(shí),利用VGG網(wǎng)絡(luò)的主體結(jié)構(gòu),網(wǎng)絡(luò)中增加反卷積層,這時(shí)加載權(quán)重應(yīng)該使用
model.load_weights(filepath,by_name=True)
補(bǔ)充知識(shí):Keras下實(shí)現(xiàn)mnist手寫數(shù)字
之前一直在用tensorflow,被同學(xué)推薦來用keras了,把之前文檔中的mnist手寫數(shù)字?jǐn)?shù)據(jù)集拿來練手,
代碼如下。
import struct
import numpy as np
import os
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
def load_mnist(path, kind):
labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
return images, labels
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
#define models
model = Sequential()
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax'))
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0]
print('Training accuracy: %.2f%%' % (train_acc * 100))
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0]
print('Test accuracy: %.2f%%' % (test_acc * 100))
訓(xùn)練結(jié)果如下:
Epoch 45/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323 Epoch 46/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358 Epoch 47/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347 Epoch 48/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350 Epoch 49/50 42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359 Epoch 50/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346 Training accuracy: 94.11% Test accuracy: 93.61%
以上這篇keras導(dǎo)入weights方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
如何將自己的python庫打包成wheel文件并上傳到pypi
這篇文章主要介紹了如何將自己的python庫打包成wheel文件并上傳到pypi,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04
python實(shí)現(xiàn)把二維列表變?yōu)橐痪S列表的方法分析
這篇文章主要介紹了python實(shí)現(xiàn)把二維列表變?yōu)橐痪S列表的方法,結(jié)合實(shí)例形式總結(jié)分析了Python列表推導(dǎo)式、嵌套、循環(huán)等相關(guān)操作技巧,需要的朋友可以參考下2019-10-10
Python+Pygame實(shí)戰(zhàn)之俄羅斯方塊游戲的實(shí)現(xiàn)
俄羅斯方塊,作為是一款家喻戶曉的游戲,陪伴70、80甚至90后,度過無憂的兒時(shí)歲月,它上手簡單能自由組合、拼接技巧也很多。本文就來用Python中的Pygame模塊實(shí)現(xiàn)這一經(jīng)典游戲,需要的可以參考一下2022-12-12

