python手寫神經網絡之權重初始化——梯度消失、表達消失

基於《深度學習入門——基於Python的理論與實現》第六章,但是書上只有一段基礎的展示代碼,和一些刻板結論(xx激活用xx優化),沒有太多過程分析,所以自己進行了擴展與實驗,加入了激活前後的對比,和不同激活函數不同係數等等的對比。

 

 

關於權重初始化的理解:

在一個默認的情況下(w高斯分佈初始化,sigmoid激活,x是一個1000*100的高斯分佈的隨機matrix),會產生這種兩頭高的分佈,這個分佈的主要缺點是,在sigmoid的輸出0和輸出1的區間,梯度近乎於零,梯度消失,無法訓練。下面討論的是解決它的一個方法,vanilla的方法就是w*=weight_init_std,其中std=0.1或者0.01。(x不變)

圖一:std=1(上下分別爲激活前後,注意看激活前的橫軸量級和激活後的分佈關係)

圖二:std=0.1(注意看激活前的橫軸量級和激活後的分佈關係)

 

但是,當std=0.01時,w的scale太小了(下圖),在0.1以內,趨零,梯度區間倒是很好,沒有梯度消失問題,但是輸出全集中在0.5這其實也是一個問題!之前沒注意到的一點,其實這個初始化也是一個trade-off,除了有梯度消失問題(兩頭高),還有“表達消失”問題(中間高),所謂表達消失,就是因爲權重都是0或者接近0,導致所有神經元輸出雷同甚至相同,喪失對稱性和表達能力(100個神經元幹1個神經元的事)。理想的分佈應該是不那麼集中的。

 

之前只是聽過全0初始化網絡會對稱(注意:w=0不是神經元死亡),對稱的壞處是喪失神經網絡的表達能力。但是沒注意到這個解決梯度消失的初始化方案(降低初始化的方差)的極端就是表達消失和神經網絡對稱。反應到分佈,就是從極端的兩頭高,變成極端的中間高(但是這很取決於激活函數,relu例外),這兩張情形其實是一個trade-off,都要避免纔是好的。

梯度消失和表達消失受w的影響的變化,是會產生對立關係的,卻又不是直接的因果對立關係(抽象就抽象在這,你又不能說它沒關係,w是能影響激活前的值,從而影響到激活後的值得分佈的,這樣間接的因果關係)。(這裏邊有很多巧合,x軸對稱激活函數中(tanh與sigmoid)好像是如此對立,relu好像又不同,relu的輸入如果因爲w而集中到0附近,relu的輸出也集中在0附近,那麼relu既會梯度消失,也會表達消失

 

“神經元死亡”這個概念也容易有誤解,神經元死亡是神經元輸出永久爲零,但是輸出其實是f(w*x)輸出的值,它受多方面因素共同影響,不一定w=0輸出就等於0(全0給sigmoid激活,輸出是0.5)。可能略微抽象的是,既然w不是0,怎麼就永久死亡了?爲什麼某一個特定的“死亡組合”就能永遠生效呢?其實,“死亡”這個詞,一個重要前提就是,“在當前訓練集下”,只有在特定的x分佈下,這個特定的w才能導致神經元死亡,但是這已經足夠致命了,因爲你沒有無窮多的訓練集,針對你的訓練集死亡,就等於不再work了。

 

那麼既然是一個trade-off,想兩頭都不得罪怎麼辦?肯定需要一個“動態”方案了,Xavier就是這種,他根據輸入神經元和輸出神經元的數量進行權重的初始化標準化工作(簡化版的也可能只根據輸入,比如本例)。直覺上也很好理解,有多少輸入神經元和輸出神經元,就對應多大量級的數據量,那麼這個操作就根據數據量動態去操作。太詳細的不說了。

另外,“He初始化”是什麼?ReLU只有半軸輸出,所以自然要在Xavier的係數中乘以2(根號下2/n),這樣簡單粗暴理解“He初始化”就好了

(其實有了各種BNs之後,這些trick多少有些派不上用場的感覺,但是從學習角度,研究一下還是有必要的)

 

 

簡單發幾組結果圖示:

經驗:不要光看圖形,可能很多情況圖的形狀真的差不多,要看座標軸的數字!!

比較繞的一點是:這裏打印的是每一層的激活值,不是w,而激活值到底是什麼,除了w相關,和激活函數也有關。所以下面的圖主要展示的是激活後的值,也就是整體分佈的變化。

分佈取向的不同:sigmoid和tanh是兩頭梯度消失(飽和,梯度0)。尤其是tanh,分佈在0附近梯度正好(但是w),但是relu不同,relu的0分佈代表梯度消失,relu的負軸無梯度,分佈趨0代表梯度消失。

 

sigmoid的普通初始化(*0.01)與Xavier對比(單圖內上下圖分別爲激活前後):

 

 

tanh的普通(0.01)和Xavier(得益於tanh的0均值化,效果要比sigmoid好):

ReLU的普通與Xavier與“He初始化”:(std=0.01橫軸非常小,分佈接近0,Xavier好很多,但是He初始化更好,注意橫軸座標)

圖一:純1.0的高斯分佈:輸出集中在0附近,梯度消失,表達也不好,雙輸,這是relu不同於sigmoid的點。(因爲w方差太大了,所以每一層的激活前分佈的scale還是很大的,只是激活後的分佈不好,那麼如果w方差小,輸出前後會怎樣?)

 

這個圖很難解釋,因爲我用了錯誤的hist range,看起來分佈集中到0了?其實不是,只是分佈到更廣的範圍了(因爲0.1~1.0佔比更少了,所以變得更“矮”了),所以說不算很向0集中。

 

 

圖二,std=0.1,分佈好了很多,激活前scale也更集中了,激活後也更分散了關鍵是怎麼解釋這個變化關係,爲什麼w的scale小了,激活後的值分佈更均勻?首先,w的scale小了,那麼激活前的值,也就是zi = a(i-1)*wi更集中在0附近,然後激活後呢?解釋不了了,所以我需要改一下打印,加上weights,修正range(三行圖的關係,w、z=w*x、a=f(z),x是前一列的a)

 

隨着網絡加深,第二個分佈區間從一兩千,變成了一萬多,反而越來越多?怎麼就導致activations的分佈越來均勻?也不是,因爲這裏只看了0~1,其實看總量(下圖),這個區間在增加,其實是從其他的分佈區間剝削來的所以這是符合預期的結果。分佈越來越集中,所以其實也是很差的表現!!!只是比std=1的情況集中一些。但是能說的上叫做好嗎?也說不上(但是直覺上,std=1的情況太離散了,表達不太連續,可能不太好,這個情況的主要問題是,scale過大,波動過大的W,先不說每次訓練會不會結果都很不同,這些極值之間的參數變化可能也很劇烈。另一方面,W方差過大,過擬合也會嚴重),這兩個只是輸出的數值量級不同,從梯度消失和表達消失上看,可能半斤八兩。

 整體的正負分佈都區別不大,只是w的量級不同,

 

 

 

圖三,std=0.01,又變得糟糕起來,太小了。這意味着什麼,5-layer的activations也就是輸出,輸出如果都是0.00000x,那麼雖然經過softmax之後也許會大一點,然後和one-hot比一下,就是說,前向傳播本身還能有東西(但是極小的值,肯定也伴隨着很多精度損失,試想,如果10個類,每個類的輸出都是1e-10,假如float精度只有1e-10,那麼輕微波動,概率就巨幅的變動,可見,也不是個好現象)。但是反向傳播呢?除了前向傳播本身波動就很大之外,另一個問題也會出現——梯度消失,大規模的零值,(在ReLU中零值等於)梯度消失。除了零以外的值呢?途中5-layer還剩下一小點,但是量級已經很小了,隨着層數加深,剩下的輸出越來越少,最終輸出全都變成零,那麼反向傳播的梯度也就全是零。梯度消失就越嚴重。

關於飽和和梯度消失問題,ReLU的特殊性:它必須要輸出完全成爲零的時候纔是零梯度的狀態,纔是飽和狀態,從不飽和到飽和是一個階躍的過程,所以只能從數量級分析,量變引起質變。而sigmoid和tanh都是一個漸變的過程,相對來說更直觀更容易理解,他們的不飽和到飽和是一個漸變的過程,直觀解釋就是分佈越偏向兩邊飽和區,梯度越微小,最後反向相乘,導致梯度下溢消失。)

 

圖四:Xavier初始化,乍一看,和std=0.1不相上下(也受限於例子,如果網絡更深的話,效果會不同),但是如果仔細看最後一層分佈第二高的bin的縱座標,Xavier優化接近2萬,std=0.1是1萬出頭,而第一高的bin,他的縱座標反而降了一些,其實是優於std=0.1的,所以光看形狀很難說明問題,重點要看刻度和座標!

(0,1)區間 

圖五:He初始化,橫軸,明顯(0,1)分佈比Xavier還要均勻一些,範圍更大一些。

 

(0,1)區間 

總的來說,He初始化0,1區間更均勻穩定,分佈的區域也大一些。

 

更多的實驗結果就不發了,改參數,觀察,就可以。

(好險,改了又改,差點就不能自恰了,好在仔細扣了座標,各種分析,座標真的重要,這個是你簡單看一些教程和slides不會注意到的東西,多動手,多觀察)

 

代碼實現:

兩個參數:用來對比普通權重初始化和Xavier或者“He初始化”的區別,還有具體看哪個激活函數的分佈

plain_weight_init = True
activation = 'relu'

代碼根目錄(這本書的其他實踐和其他博客的相關代碼都在這個根目錄下):https://github.com/huqinwei/python_deep_learning_introduction/chap06_weight_init_activation_histogram.py

代碼已經有變動,不更新到這了,這段代碼當做一個簡化版。

#這個還是要好好練一下的,手寫,各種激活的打印和監視。

# from book_dir.common.multi_layer_net import MultiLayerNet#沒用上!!!!!!
import numpy as np
import matplotlib.pyplot as plt
import math

def sigmoid(x):
    return 1.0 / (1 + np.exp(-x))
def tanh(x):#自己平移sigmoid做的tanh和庫裏的tanh,和網上標準公式實現的tanh,是否有區別,
    return (2.0 / (1 + np.exp(-2*x))) - 1
def tanh2(x):#公式推導,其實是等價的
    y=(math.e**(x)-math.e**(-x))/(math.e**(x)+math.e**(-x))
    return y
def relu(x):
    return np.maximum(0,x)

x = np.random.randn(1000,100)
node_num = 100
hidden_layer_size = 5
activations = {}
before_activations = {}
hist_demo = False
plain_weight_init = True
activation = 'relu'

for i in range(hidden_layer_size):
    if i != 0:#有一層沒算
        x = activations[i-1]#那麼第一層呢,直接用原定義,注意這是裸寫的網絡,不是類!

    w = np.random.randn(node_num, node_num)
    if plain_weight_init:
        weight_init_std = 0.01#1,0.1,0.01,0.001#量級越小,後邊就越消失
        w = w * weight_init_std
        
    else:
        if activation == 'tanh':
            w = w / np.sqrt(node_num)
        elif activation == 'sigmoid':
            w = w / np.sqrt(node_num)
        elif activation == 'relu':#作爲對比,relu最好也用weight_init_std=0.01跑一次
            w = w / np.sqrt(node_num) * np.sqrt(2)#可以註釋掉根號2對比,差距明顯,這個很好解釋,因爲relu只有正半軸


    z = np.dot(x,w)
    if activation == 'tanh':
        a = tanh(z)
    elif activation == 'sigmoid':
        a = sigmoid(z)
    elif activation == 'relu':
        a = relu(z)
    else:
        a = sigmoid(z)
    activations[i] = a
    before_activations[i] = z

layer_nums = len(activations)
for i,z in before_activations.items():
    plt.subplot(2,layer_nums,i+1)#i從0起,plot從1起
    plt.title(str(i+1) + "-layer")
    zf = z.flatten()

    if activation == 'tanh':
        plt.hist(zf,30,range=(-1,1))#用tanh可以看到分佈更好看一些,鐘形
    elif activation == 'sigmoid':
        plt.hist(zf,30,range=(0,1))
    elif activation == 'relu':
        plt.hist(zf,range=(-1,1))



for i,a in activations.items():
    plt.subplot(2,layer_nums,i+1 + layer_nums)#i從0起,plot從1起
    plt.title(str(i+1) + "-layer")
    af = a.flatten()

    if activation == 'tanh':
        plt.hist(af,30,range=(-1,1))#用tanh可以看到分佈更好看一些,鐘形
    elif activation == 'sigmoid':
        plt.hist(af,30,range=(0,1))
    elif activation == 'relu':
        plt.hist(af,30)#根據hist的功能,0右側是大於0,那麼relu不可能0左側沒東西吧?0不要面子的嗎????記錯了,左閉右開,那麼其實是[0,0.x],其實是包含0的,也包含極小的非零值


plt.show()








 

 

 

 

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章