keras 自定義loss層+接受輸入實(shí)例
loss函數(shù)如何接受輸入值
keras封裝的比較厲害,官網(wǎng)給的例子寫(xiě)的云里霧里,
在stackoverflow找到了答案
You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).
def custom_loss_wrapper(input_tensor): def custom_loss(y_true, y_pred): return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor) return custom_loss
input_tensor = Input(shape=(10,)) hidden = Dense(100, activation='relu')(input_tensor) out = Dense(1, activation='sigmoid')(hidden) model = Model(input_tensor, out) model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')
You can verify that input_tensor and the loss value will change as different X is passed to the model.
X = np.random.rand(1000, 10) y = np.random.randint(2, size=1000) model.test_on_batch(X, y) # => 1.1974642 X *= 1000 model.test_on_batch(X, y) # => 511.15466
fit_generator
fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.
Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)
### generator yield [inputX_1,inputX_2],y ### model model = Model(inputs=[inputX_1,inputX_2],outputs=...)
補(bǔ)充知識(shí):keras中自定義 loss損失函數(shù)和修改不同樣本的loss權(quán)重(樣本權(quán)重、類(lèi)別權(quán)重)
首先辨析一下概念:
1. loss是整體網(wǎng)絡(luò)進(jìn)行優(yōu)化的目標(biāo), 是需要參與到優(yōu)化運(yùn)算,更新權(quán)值W的過(guò)程的
2. metric只是作為評(píng)價(jià)網(wǎng)絡(luò)表現(xiàn)的一種“指標(biāo)”, 比如accuracy,是為了直觀(guān)地了解算法的效果,充當(dāng)view的作用,并不參與到優(yōu)化過(guò)程
一、keras自定義損失函數(shù)
在keras中實(shí)現(xiàn)自定義loss, 可以有兩種方式,一種自定義 loss function, 例如:
# 方式一 def vae_loss(x, x_decoded_mean): xent_loss = objectives.binary_crossentropy(x, x_decoded_mean) kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1) return xent_loss + kl_loss vae.compile(optimizer='rmsprop', loss=vae_loss)
或者通過(guò)自定義一個(gè)keras的層(layer)來(lái)達(dá)到目的, 作為model的最后一層,最后令model.compile中的loss=None:
# 方式二 # Custom loss layer class CustomVariationalLayer(Layer): def __init__(self, **kwargs): self.is_placeholder = True super(CustomVariationalLayer, self).__init__(**kwargs) def vae_loss(self, x, x_decoded_mean_squash): x = K.flatten(x) x_decoded_mean_squash = K.flatten(x_decoded_mean_squash) xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash) kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) return K.mean(xent_loss + kl_loss) def call(self, inputs): x = inputs[0] x_decoded_mean_squash = inputs[1] loss = self.vae_loss(x, x_decoded_mean_squash) self.add_loss(loss, inputs=inputs) # We don't use this output. return x y = CustomVariationalLayer()([x, x_decoded_mean_squash]) vae = Model(x, y) vae.compile(optimizer='rmsprop', loss=None)
在keras中自定義metric非常簡(jiǎn)單,需要用y_pred和y_true作為自定義metric函數(shù)的輸入?yún)?shù) 點(diǎn)擊查看metric的設(shè)置
注意事項(xiàng):
1. keras中定義loss,返回的是batch_size長(zhǎng)度的tensor, 而不是像tensorflow中那樣是一個(gè)scalar
2. 為了能夠?qū)⒆远x的loss保存到model, 以及可以之后能夠順利load model, 需要把自定義的loss拷貝到keras.losses.py 源代碼文件下,否則運(yùn)行時(shí)找不到相關(guān)信息,keras會(huì)報(bào)錯(cuò)
有時(shí)需要不同的sample的loss施加不同的權(quán)重,這時(shí)需要用到sample_weight,例如
discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)
二、keras中的樣本權(quán)重
# Import
import numpy as np
from sklearn.utils import class_weight
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy'])
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
np.unique(y_train),
y_train)
# Add the class weights to the training
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)
Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].
以上這篇keras 自定義loss層+接受輸入實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Django實(shí)現(xiàn)從數(shù)據(jù)庫(kù)中獲取到的數(shù)據(jù)轉(zhuǎn)換為dict
這篇文章主要介紹了Django實(shí)現(xiàn)從數(shù)據(jù)庫(kù)中獲取到的數(shù)據(jù)轉(zhuǎn)換為dict,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-03-03
對(duì)pyqt5多線(xiàn)程正確的開(kāi)啟姿勢(shì)詳解
今天小編就為大家分享一篇對(duì)pyqt5多線(xiàn)程正確的開(kāi)啟姿勢(shì)詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-06-06
Django跨域請(qǐng)求原理及實(shí)現(xiàn)代碼
這篇文章主要介紹了Django跨域請(qǐng)求原理及實(shí)現(xiàn)代碼,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-11-11
Python數(shù)據(jù)分析之缺失值檢測(cè)與處理詳解
在實(shí)際的數(shù)據(jù)處理中,缺失值是普遍存在的,如何使用 Python 檢測(cè)和處理缺失值,就是本文要講的主要內(nèi)容。感興趣的同學(xué)可以關(guān)注一下2021-12-12
詳解如何使用Python處理INI、YAML和JSON配置文件
在軟件開(kāi)發(fā)中,配置文件是存儲(chǔ)程序配置信息的常見(jiàn)方式,INI、YAML和JSON是常用的配置文件格式,各自有著特定的結(jié)構(gòu)和用途,Python擁有豐富的庫(kù)和模塊,本文將重點(diǎn)探討如何使用Python處理這三種格式的配置文件,需要的朋友可以參考下2023-12-12
Python干貨實(shí)戰(zhàn)之八音符醬小游戲全過(guò)程詳解
讀萬(wàn)卷書(shū)不如行萬(wàn)里路,只學(xué)書(shū)上的理論是遠(yuǎn)遠(yuǎn)不夠的,只有在實(shí)戰(zhàn)中才能獲得能力的提升,本篇文章手把手帶你用Python實(shí)現(xiàn)一個(gè)八音符醬小游戲,大家可以在過(guò)程中查缺補(bǔ)漏,提升水平2021-10-10
基于Python開(kāi)發(fā)云主機(jī)類(lèi)型管理腳本分享
這篇文章主要為大家詳細(xì)介紹了如何基于Python開(kāi)發(fā)一個(gè)云主機(jī)類(lèi)型管理腳本,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-02-02
Python中的省略號(hào)(Ellipsis)賦值方式詳解
在Python編程中,省略號(hào)(...)是一種特殊對(duì)象,主要用作函數(shù)占位、未實(shí)現(xiàn)的方法示例和NumPy數(shù)組處理,本文通過(guò)示例詳細(xì)解釋了省略號(hào)的賦值方式及其在不同編程場(chǎng)景下的應(yīng)用,幫助提升Python編程技巧2024-10-10
在Python程序中進(jìn)行文件讀取和寫(xiě)入操作的教程
這篇文章主要介紹了在Python程序中進(jìn)行文件讀取和寫(xiě)入操作的教程,是Python學(xué)習(xí)當(dāng)中的基礎(chǔ)知識(shí),需要的朋友可以參考下2015-04-04

