Keras 實現 動態調整學習率 保存最佳模型

深度學習的訓練過程中,經常需要動態調整學習率同時保存最佳模型,本教程使用 Keras 框架,通過其自設的回調函數,實現所需

Keras使用手冊網址:Keras Documentation(英文版)Keras 中文文檔(中文版)
中文版的使用手冊部分內容無法查看,但可在英文版中查詢所有
Keras使用手冊可輸入查詢,但查詢函數時,需要在函數前添加 ` (英文格式的反引號,鍵盤上的波浪線所在按鍵,其下方表示的符號),同時函數輸入不需要添加 ()

動態調整學習率

本教程參考 Keras學習率調整 實現,該篇博客還提出了兩種調整學習率的方法,博主僅使用了其中的一種,也是最常用的方法

# 動態調整學習率lr
def scheduler(epoch):
	lr = K.get_value(model.optimizer.lr)
    # // 取整除 - 向下取接近商的整數
    # ** 冪 - 返回x的y次冪
    K.set_value(model.optimizer.lr, lr * (0.1 ** (epoch // lr_epochs)))
    return K.get_value(model.optimizer.lr)

lr_new = LearningRateScheduler(scheduler)
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, callbacks=[lr_new], validation_data=(x_test, y_test))

其中,對於 LearningRateScheduler 的描述,中文文檔給出瞭如下的回答
在這裏插入圖片描述
scheduler 函數並沒有具體的描述,但可在其內部,通過 K.get_value(model.optimizer.lr) 獲取當前學習率,用 K.set_value(model.optimizer.lr, lr * (0.1 ** (epoch // lr_epochs))) 修改學習率,而後將其加入LearningRateScheduler

在模型訓練的函數中需要添加回調 callbacks=[lr_new],以實現學習率的動態調整

至此,在模型的訓練過程中,將會以每一epoch的週期自動修改學習率

保存最佳模型

該部分涉及 ModelCheckpoint 的使用,其中文描述如下
在這裏插入圖片描述
代碼實現如下,其中的具體參數上圖已作詳細說明,就不再贅述,但同樣是使用了回調函數,所以在模型訓練的函數中要添加 callbacks=[checkpoint](若涉及準確率,需要在模型設置時,添加 metrics=[‘accuracy’] 參數)

model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# 保存最佳模型
filepath = 'weights.{epoch:02d}-{val_loss:.2f}.hdf5'
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, callbacks=[checkpoint], validation_data=(x_test, y_test))

需要注意一下,文檔中在使用 ModelCheckpoint 時,其參數 monitor 使用了 val_acc,但實際運行時該參數命名存在問題,需要修改爲 val_accuracy

代碼

最後將博主使用的完整代碼附上,以MNIST數據集爲例(沒有優化),其中 def scheduler(epoch) 函數嵌套在了 def main() 函數中,算是Python的一個使用的便捷技巧

from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras import optimizers

import keras.backend as K
from keras.callbacks import LearningRateScheduler, ModelCheckpoint


def main(batch_size=100, test_batch_size=100, lr=1e-3, momentum=0.9, decay=1e-5, epochs=10, lr_epochs=20):
    # 載入數據
    (x_train, y_train), (x_test, y_test) = mnist.load_data()  # 自動從網絡上下載
    # x_shape: (60000, 28, 28)
    # y_shape: (60000,)

    # 數據格式轉換,歸一
    # -1自動轉換合適的數列
    # (60000, 28, 28) -> (60000, 784)
    x_train = x_train.reshape(x_train.shape[0], -1) / 255.0
    x_test = x_test.reshape(x_test.shape[0], -1) / 255.0
    # 轉換one_hot格式,num_classes種類,數字有10個
    y_train = np_utils.to_categorical(y_train, num_classes=10)
    y_test = np_utils.to_categorical(y_test, num_classes=10)

    # 創建模型,輸入784個神經元,輸出10個神經元
    # bias_initializer偏置值初始化
    model = Sequential([Dense(units=10, input_dim=784, bias_initializer='one', activation='softmax')])

    # 定義優化器
    # sgd = optimizers.SGD(lr=lr, momentum=momentum, decay=decay, nesterov=True)
    adam = optimizers.Adam(lr=lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=decay, amsgrad=False)

    # metrices=['accuracy']準確率
    # categorical_crossentropy交叉熵
    model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

    # 動態調整學習率lr
    def scheduler(epoch):
        lr = K.get_value(model.optimizer.lr)
        # // 取整除 - 向下取接近商的整數
        # ** 冪 - 返回x的y次冪
        K.set_value(model.optimizer.lr, lr * (0.1 ** (epoch // lr_epochs)))
        return K.get_value(model.optimizer.lr)

    lr_new = LearningRateScheduler(scheduler)

    # 保存最佳模型
    filepath = 'weights.{epoch:02d}-{val_loss:.2f}.hdf5'
    checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')

    # epochs迭代週期,圖片全部訓練一次爲一週期
    history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, callbacks=[lr_new, checkpoint], validation_data=(x_test, y_test))

    # 評估模型
    loss, accuracy = model.evaluate(x_test, y_test, batch_size=test_batch_size)

    print('\nFinally Test loss:', loss, '\taccuracy:', accuracy)


if __name__ == '__main__':
    main()

該代碼運行結果如下,僅展示部分內容

在這裏插入圖片描述
保存的最佳模型文件列表如下
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章