keras回調(diào)函數(shù)的使用
回調(diào)函數(shù)
- 回調(diào)函數(shù)是一個對象(實現(xiàn)了特定方法的類實例),它在調(diào)用fit()時被傳入模型,并在訓練過程中的不同時間點被模型調(diào)用
- 可以訪問關于模型狀態(tài)與模型性能的所有可用數(shù)據(jù)
- 模型檢查點(model checkpointing):在訓練過程中的不同時間點保存模型的當前狀態(tài)。
- 提前終止(early stopping):如果驗證損失不再改善,則中斷訓練(當然,同時保存在訓練過程中的最佳模型)。
- 在訓練過程中動態(tài)調(diào)節(jié)某些參數(shù)值:比如調(diào)節(jié)優(yōu)化器的學習率。
- 在訓練過程中記錄訓練指標和驗證指標,或者將模型學到的表示可視化(這些表示在不斷更新):fit()進度條實際上就是一個回調(diào)函數(shù)。
fit()方法中使用callbacks參數(shù)
# 這里有兩個callback函數(shù):早停和模型檢查點
callbacks_list=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",#監(jiān)控指標
patience=2 #兩輪內(nèi)不再改善中斷訓練
),
keras.callbacks.ModelCheckpoint(
filepath="checkpoint_path",
monitor="val_loss",
save_best_only=True
)
]
#模型獲取
model=get_minist_model()
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
model.fit(train_images,train_labels,
epochs=10,callbacks=callbacks_list, #該參數(shù)使用回調(diào)函數(shù)
validation_data=(val_images,val_labels))
test_metrics=model.evaluate(test_images,test_labels)#計算模型在新數(shù)據(jù)上的損失和指標
predictions=model.predict(test_images)#計算模型在新數(shù)據(jù)上的分類概率

模型的保存和加載
#也可以在訓練完成后手動保存模型,只需調(diào)用model.save('my_checkpoint_path')。
#重新加載模型
model_new=keras.models.load_model("checkpoint_path.keras")
通過對Callback類子類化來創(chuàng)建自定義回調(diào)函數(shù)
on_epoch_begin(epoch, logs) ←----在每輪開始時被調(diào)用
on_epoch_end(epoch, logs) ←----在每輪結(jié)束時被調(diào)用
on_batch_begin(batch, logs) ←----在處理每個批量之前被調(diào)用
on_batch_end(batch, logs) ←----在處理每個批量之后被調(diào)用
on_train_begin(logs) ←----在訓練開始時被調(diào)用
on_train_end(logs ←----在訓練結(jié)束時被調(diào)用
from matplotlib import pyplot as plt
# 實現(xiàn)記錄每一輪中每個batch訓練后的損失,并為每個epoch繪制一個圖
class LossHistory(keras.callbacks.Callback):
def on_train_begin(self, logs):
self.per_batch_losses = []
def on_batch_end(self, batch, logs):
self.per_batch_losses.append(logs.get("loss"))
def on_epoch_end(self, epoch, logs):
plt.clf()
plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
label="Training loss for each batch")
plt.xlabel(f"Batch (epoch {epoch})")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"plot_at_epoch_{epoch}")
self.per_batch_losses = [] #清空,方便下一輪的技術model = get_mnist_model()
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
model.fit(train_images, train_labels,
epochs=10,
callbacks=[LossHistory()],
validation_data=(val_images, val_labels))

【其他】模型的定義 和 數(shù)據(jù)加載
def get_minist_model():
inputs=keras.Input(shape=(28*28,))
features=layers.Dense(512,activation="relu")(inputs)
features=layers.Dropout(0.5)(features)
outputs=layers.Dense(10,activation="softmax")(features)
model=keras.Model(inputs,outputs)
return model
#datset
from tensorflow.keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
train_images=train_images.reshape((60000,28*28)).astype("float32")/255
test_images=test_images.reshape((10000,28*28)).astype("float32")/255
train_images,val_images=train_images[10000:],train_images[:10000]
train_labels,val_labels=train_labels[10000:],train_labels[:10000]
到此這篇關于keras回調(diào)函數(shù)的使用的文章就介紹到這了,更多相關keras回調(diào)函數(shù)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python使用函數(shù)默認值實現(xiàn)函數(shù)靜態(tài)變量的方法
這篇文章主要介紹了Python使用函數(shù)默認值實現(xiàn)函數(shù)靜態(tài)變量的方法,是很實用的功能,需要的朋友可以參考下2014-08-08
Python內(nèi)置函數(shù)memoryview()的實現(xiàn)示例
本文主要介紹了Python內(nèi)置函數(shù)memoryview()的實現(xiàn)示例,它允許你在不復制其內(nèi)容的情況下操作同一個數(shù)組的不同切片,具有一定的參考價值,感興趣的可以了解一下2024-05-05
Python OOP類中的幾種函數(shù)或方法總結(jié)
今天小編就為大家分享一篇關于Python OOP類中的幾種函數(shù)或方法總結(jié),小編覺得內(nèi)容挺不錯的,現(xiàn)在分享給大家,具有很好的參考價值,需要的朋友一起跟隨小編來看看吧2019-02-02
python matplotlib imshow熱圖坐標替換/映射實例
這篇文章主要介紹了python matplotlib imshow熱圖坐標替換/映射實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-03-03
pytorch中節(jié)約顯卡內(nèi)存的方法和技巧
顯存不足是很多人感到頭疼的問題,畢竟能擁有大量顯存的實驗室還是少數(shù),而現(xiàn)在的模型已經(jīng)越跑越大,模型參數(shù)量和數(shù)據(jù)集也越來越大,所以這篇文章給大家總結(jié)了一些pytorch中節(jié)約顯卡內(nèi)存的方法和技巧,需要的朋友可以參考下2023-11-11

