在Keras中,回調函數是在訓練過程中的特定時間點調用的函數,用于監控模型的性能、調整學習率、保存模型等操作。使用回調函數可以在訓練過程中實時監控模型的性能,并根據需要進行一些操作。
要使用回調函數,首先需要定義一個回調函數的類,并實現對應的方法。Keras已經提供了一些內置的回調函數,比如ModelCheckpoint用于保存模型,EarlyStopping用于提前停止訓練等。
然后,在訓練模型時,通過callbacks參數將定義的回調函數傳遞給fit方法,如下所示:
from keras.callbacks import ModelCheckpoint
# 定義回調函數
checkpoint = ModelCheckpoint(filepath='model.h5', monitor='val_loss', save_best_only=True)
# 訓練模型
model.fit(X_train, y_train, validation_data=(X_val, y_val), callbacks=[checkpoint])
在上面的例子中,ModelCheckpoint回調函數會在每個epoch結束時監測驗證集上的損失值,并保存性能最好的模型到model.h5文件中。
除了內置的回調函數,還可以自定義回調函數。通過繼承keras.callbacks.Callback類,并重寫對應的方法來實現自定義的回調函數。
總之,回調函數是在訓練過程中非常有用的工具,可以幫助我們監控模型的性能,調整參數,保存模型等操作。