TensorFlow筆記--學習率衰減

在"求解二次函數最小值對應的值"例子中,我們已經直觀看到TensorFlow如何求解;今天我們來討論一下“學習率衰減”。

1. Why?

假設我們在使用梯度下降算法訓練模型,讓學習率 α 爲固定值,如果 α 太小影響訓練速度;如果 α 過大,那麼在訓練後期損失值不會精確的收斂而是在最小值附近擺動。也許我們會說:我選擇一個合適的 α 就可以解決這個問題了。聽起來很對,但是實際行動起來卻困難重重,這時我們引入Learning rate decay

2. What?

學習率衰減:爲了加快學習算法,學習率隨時間慢慢減少,我們將之稱爲學習率衰減;下面我們來看看它是如何做到的。

3. How?

3.1 指數衰減公式

decayed_learning_rate=learning_ratedecay_rateglobal_stepdecay_stepsdecayed\_learning\_rate=learning\_rate*decay\_rate^{{global\_step} \over {decay\_steps}}

參數:
learning_rate:標量float32float64``Tensor或Python數字。初始學習率。
global_step:標量int32int64``Tensor或Python數字。用於衰減計算的全局步驟。一定不要是負數。
decay_steps:標量int32int64``Tensor或Python數字。必須是正數。
decay_rate:標量float32float64``Tensor或Python數字。衰減率。
staircase::布爾。如果 “True” 以不連續的間隔衰減學習率。
name:String。操作的可選名稱。默認爲’ExponentialDecay’。

返回
learning_rate相同類型的標量Tensor。學習率衰減。

3.2 例子

# coding:utf-8
# 設損失函數 loss=(w+1)^2,令w初始值爲10
# 反向傳播就是求最優 w,即求最小loss對應的w值
# 使用指數衰減的學習率,在迭代初期得到較高的下降速度,在較小的迭代輪數下取得更好的收斂

import tensorflow as tf

LEARNINF_RATE_BASE = 0.1 #最初學習率
LEARNING_RATE_DECAY = 0.99 #學習率衰減率
LEARNING_RATE_STEP = 1 # 訓練多少輪BATCH_SIZE後,更新一次學習率,一般設置:總樣本數/BATCH_SIZE   

# 運行了幾輪BATCH_SIZE 的計數器,初始值爲0,設爲不被訓練
global_step = tf.Variable(0, trainable=False)
# 定義指數下降學習率
learning_rate = tf.train.exponential_decay(LEARNINF_RATE_BASE, 
                                           global_step,
                                          LEARNING_RATE_STEP,
                                          LEARNING_RATE_DECAY,
                                          staircase=True)
# 定義待優化參數 w 初值爲 10
w = tf.Variable(tf.constant(10, dtype=tf.float32))
# 定義損失函數loss
loss = tf.square(w+1)
# 定義反向傳播方法
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

# 生成會話,訓練100輪
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    for i in range(100):
        sess.run(train_step)
        learning_rate_val = sess.run(learning_rate)
        global_step_val = sess.run(global_step)
        w_val = sess.run(w)
        loss_val = sess.run(loss)
        print("After %s steps: global_step is %f, w is %f, learning rate is %f, loss is %f."%(i, global_step_val,w_val, learning_rate_val, loss_val))

運行結果:
After 0 steps: global_step is 1.000000, w is 7.800000, learning rate is 0.099000, loss is 77.440002.
After 1 steps: global_step is 2.000000, w is 6.057600, learning rate is 0.098010, loss is 49.809719.
After 2 steps: global_step is 3.000000, w is 4.674169, learning rate is 0.097030, loss is 32.196194.
……
After 55 steps: global_step is 56.000000, w is -0.999063, learning rate is 0.056960, loss is 0.000001.
After 56 steps: global_step is 57.000000, w is -0.999170, learning rate is 0.056391, loss is 0.000001.
After 57 steps: global_step is 58.000000, w is -0.999263, learning rate is 0.055827, loss is 0.000001.
After 58 steps: global_step is 59.000000, w is -0.999346, learning rate is 0.055268, loss is 0.000000.
After 59 steps: global_step is 60.000000, w is -0.999418, learning rate is 0.054716, loss is 0.000000.
……
After 97 steps: global_step is 98.000000, w is -0.999985, learning rate is 0.037346, loss is 0.000000.
After 98 steps: global_step is 99.000000, w is -0.999986, learning rate is 0.036973, loss is 0.000000.
After 99 steps: global_step is 100.000000, w is -0.999987, learning rate is 0.036603, loss is 0.000000.

4.小結

就這樣我們完成了指數學習率衰減的學習;如果你在在看一些文獻中遇到離散下降的學習率或手動衰減,其實他們都是爲了讓模型更好,我們所要做的只要修改一下learning_rate就ok啦。今天就到這裏,下次我們再一起看一下指數滑動平均。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章