在Keras中使用回調函數可以通過在模型訓練時傳入回調函數的列表來實現。回調函數是在訓練過程中的特定時刻被調用的函數,可以用來實現一些功能,比如保存模型、動態調整學習率、可視化訓練過程等。
以下是一個簡單的示例,展示了如何在Keras中使用回調函數:
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import ModelCheckpoint
# 創建一個簡單的Sequential模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=100))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 編譯模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 定義一個回調函數,用來保存模型的權重
checkpoint = ModelCheckpoint(filepath='weights.{epoch:02d}-{val_loss:.2f}.hdf5',
monitor='val_loss', save_best_only=True)
# 模型訓練,并傳入回調函數的列表
model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_val, y_val), callbacks=[checkpoint])
在上面的示例中,我們定義了一個ModelCheckpoint回調函數,用來保存模型的權重。在模型訓練時,我們將這個回調函數傳入callbacks參數中,這樣在每個epoch結束時,如果驗證集的損失值有改善,就會保存模型的權重。
除了ModelCheckpoint回調函數,Keras還提供了許多其他內置的回調函數,比如EarlyStopping、TensorBoard等,可以根據具體的需求選擇合適的回調函數來使用。