Learning Tensorflow(2)--- 深度神經網絡

在構建深度神經網絡時候,不只需要構建網絡的主體部分,還需要定義損失函數,網絡優化操作,以及訓練過程, loss及精確度輸出。

損失函數

         神經網絡模型的效果以及優化的目標是通過損失函數來定義的。

交叉熵

         交叉熵用於評判兩個概率分佈之間的距離。是分類問題中使用比較廣的一種損失函數。

         給定兩個概率分佈p和q,則通過q來表示p的交叉熵爲

Hp,q=-xp(x)logq(x)

那麼如何得到一個分類結果的概率分佈呢?

Softmax迴歸

Softmax迴歸將神經網絡的輸出變成一個概率分佈。

這個新的輸出可以理解爲一個樣例爲不同類別的概率分別是多大。

例如:

某個樣例的正確答案爲(1,0,0),softmax迴歸之後的預測答案爲(0.5,0.4,0.1);

則這個預測和正確答案之間的交叉熵爲0.3(依據上述公式)

在tensorflow中,損失函數loss的一個實現樣例爲:

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_, 1), logits=y)
cross_entropy_mean = tf.reduce_mean(cross_entropy) loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))

 

神經網絡優化

學習率設置

         學習率用於控制參數更新的速度。

         學習率過大時,參數將無法收斂,學習率過小時,參數收斂速度變得極慢。

爲了找到合適的學習率,tensorflow採用一種更加靈活的學習率設置方法---指數衰減法,該方法先使用較大的學習率來快速得到一個比較優的解,然後隨着迭代的繼續逐步減小學習率。

Decayed_learning_rate = \ learning_rate * decay_rate ^(global_step / decay_setps)

其中:

  • Decayed_learning_rate:每一輪使用的學習率
  • learning_rate:初始學習率
  • decay_rate:衰減速度
  • decay_setps:完整的使用一遍訓練數據所需要的迭代輪數,一般爲樣本總數/batch_size
  • global_step:全局的迭代輪數,用於迭代輪數的計數。指數衰減的學習率是伴隨global_step的變化而衰減的。
learning_rate = tf.train.exponential_decay(
    LEARNING_RATE_BASE,
    global_step,
    mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
    staircase=True)

當staircase=true時,global_step/decay_steps會被轉化成整數,學習率會成爲一個階梯函數。

當staircase=true時,學習率是一個連續函數。

 

滑動平均模型

         滑動平均模型可以是模型在測試數據上更健壯。在採用隨機梯度下降算法訓練神經網絡時,使用滑動平均模型在很多應用中都可以再一定程度上提高最終模型在測試數據上的表現。

Tf.train.ExpoenentialMovingAverage

在初始化滑動平均模型時,需要提供一個衰減率decay,用於模型更新的速度。該函數會對每一個變量維護一個影子變量,每次運行變量更新時,影子變量的值會更新爲:

Shadow_variable = decay * shadow_variable + (1-decay) * variable
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())

 

Loss 及 Accuracy輸出

y爲標準輸出,y_爲預測輸出

correction_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))

feed_dict爲每一輪batch的輸入數據和label
 

train_accuracy = accuracy.eval(feed_dict={x:reshaped_xs,y_:ys})
print("test accuracy %g"%train_accuracy)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章