keras的siamese(孿生網(wǎng)絡)實現(xiàn)案例
更新時間:2020年06月12日 14:20:25 作者:李上花開
這篇文章主要介紹了keras的siamese(孿生網(wǎng)絡)實現(xiàn)案例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
代碼位于keras的官方樣例,并做了微量修改和大量學習?。
最終效果:


import keras
import numpy as np
import matplotlib.pyplot as plt
import random
from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K
num_classes = 10
epochs = 20
def euclidean_distance(vects):
x, y = vects
sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
return K.sqrt(K.maximum(sum_square, K.epsilon()))
def eucl_dist_output_shape(shapes):
shape1, shape2 = shapes
return (shape1[0], 1)
def contrastive_loss(y_true, y_pred):
'''Contrastive loss from Hadsell-et-al.'06
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
'''
margin = 1
sqaure_pred = K.square(y_pred)
margin_square = K.square(K.maximum(margin - y_pred, 0))
return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)
def create_pairs(x, digit_indices):
'''Positive and negative pair creation.
Alternates between positive and negative pairs.
'''
pairs = []
labels = []
n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
for d in range(num_classes):
for i in range(n):
z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
pairs += [[x[z1], x[z2]]]
inc = random.randrange(1, num_classes)
dn = (d + inc) % num_classes
z1, z2 = digit_indices[d][i], digit_indices[dn][i]
pairs += [[x[z1], x[z2]]]
labels += [1, 0]
return np.array(pairs), np.array(labels)
def create_base_network(input_shape):
'''Base network to be shared (eq. to feature extraction).
'''
input = Input(shape=input_shape)
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dropout(0.1)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.1)(x)
x = Dense(128, activation='relu')(x)
return Model(input, x)
def compute_accuracy(y_true, y_pred): # numpy上的操作
'''Compute classification accuracy with a fixed threshold on distances.
'''
pred = y_pred.ravel() < 0.5
return np.mean(pred == y_true)
def accuracy(y_true, y_pred): # Tensor上的操作
'''Compute classification accuracy with a fixed threshold on distances.
'''
return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
def plot_train_history(history, train_metrics, val_metrics):
plt.plot(history.history.get(train_metrics), '-o')
plt.plot(history.history.get(val_metrics), '-o')
plt.ylabel(train_metrics)
plt.xlabel('Epochs')
plt.legend(['train', 'validation'])
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]
# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)
digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)
# network definition
base_network = create_base_network(input_shape)
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
distance = Lambda(euclidean_distance,
output_shape=eucl_dist_output_shape)([processed_a, processed_b])
model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()
# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
batch_size=128,
epochs=epochs,verbose=2,
validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()
# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)
print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))
以上這篇keras的siamese(孿生網(wǎng)絡)實現(xiàn)案例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python+selenium 獲取瀏覽器窗口坐標、句柄的方法
今天小編就為大家分享一篇Python+selenium 獲取瀏覽器窗口坐標、句柄的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-10-10
linux環(huán)境打包python工程為可執(zhí)行程序的過程
本次需求,在ubuntu上面開發(fā)的python代碼程序需要打包成一個可執(zhí)行程序然后交付給甲方,因為不能直接給源碼給甲方,所以尋找方法將python開發(fā)的源碼打包成一個可執(zhí)行程序,本次在ubuntu上打包python源碼的方法和在window上打包的有點類似,感興趣的朋友跟隨小編一起看看吧2024-01-01
Python實現(xiàn)打印詳細報錯日志,獲取報錯信息位置行數(shù)
這篇文章主要介紹了Python實現(xiàn)打印詳細報錯日志,獲取報錯信息位置行數(shù)方式,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-08-08
一篇文章帶你搞定Ubuntu中打開Pycharm總是卡頓崩潰
這篇文章主要介紹了一篇文章帶你搞定Ubuntu中打開Pycharm總是卡頓崩潰,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-11-11

