python簡單實現 反向傳播算法

1 一些鋪墊

1、本文所使用例子來自於《一文弄懂神經網絡中的反向傳播法——BackPropagation

I1,I2是輸入層,h1,h2是隱含層,o1,o2是輸出層,b1,b2是偏置。

其中,輸入數據 i1=0.05,i2=0.10;

輸出數據 o1=0.01,o2=0.99;

初始權重 w1=0.15,w2=0.20,w3=0.25,w4=0.30; w5=0.40,w6=0.45,w7=0.50,w8=0.55

目標:給出輸入數據i1,i2(0.05和0.10),使輸出儘可能與原始輸出o1,o2(0.01和0.99)接近。

2、本文所使用的反向傳播算法的公式,都是來自吳恩達的神經網絡與深度學習視頻

2 前向傳播與反向傳播

設代價函數爲

J(A,Y)=-1/m*(Y^Tlog(A)+(1-Y)^Tlog(1-A))

其中A爲預測值,Y爲輸出值。

由上面兩層神經網絡的的圖片可以得到正向傳播的步驟如下(上標代表所在神經元的層數):

根據正向傳播可以得到反向傳播的各值爲:

\large dZ^{[2]}= A^{[2]}-Y

dW^{[2]}= \frac{1}{m}dZ^{[2]}A^{{[1]}T}

db^{[2]}= \frac{1}{m}np.sum(dZ^{[2]},axis=1,keepdims=True)

dZ^{[1]}=W^{{[2]}T}dZ^{[2]}*\sigma'(Z^{[1]})

dW^{[1]}=\frac{1}{m}dZ^{[1]}X^T

db^{[1]}=\frac{1}{m}np.sum(dZ^{[1]},axis=1,keepdims=True)

3 python代碼實現

import numpy as np
from matplotlib import pyplot as plt
# 設置中文可以顯示
from pylab import mpl
mpl.rcParams['font.sans-serif']=['SimHei']
# 變量的定義
# 權重和偏置
W1=np.array([[0.15,0.20],[0.25,0.30]])
b1=0.35  # 因爲有廣播機制,所以不用寫成向量形式
W2=np.array([[0.40,0.45],[0.50,0.55]])
b2=0.6
# 輸入
X=np.array([[0.05],[0.10]])
# 輸出
Y=np.array([[0.01],[0.99]])
# 學習率
alpha=0.5
# 迭代次數
Count=10000
# 樣本個數
m=2

def sigmoid(x):
    return 1/(1+np.exp(-x))

# 神經網絡的訓練
def NNtraining(W1,X,b1,W2,b2,Y,m,alpha):
    Count=1000
    J = np.zeros((Count, 1))
    for i in range(Count):
        # 前向傳播
        Z1 = np.dot(W1, X) + b1
        A1 = sigmoid(Z1)
        Z2 = np.dot(W2, A1) + b2
        A2 = sigmoid(Z2)
        # 代價函數計算
        J[i]=-1/m*(np.dot(Y.T,np.log(A2)+np.dot((1-Y).T,np.log(1-A2))))
        # 反向傳播
        dZ2=A2-Y
        dW2=1/m*np.dot(dZ2,A1.T)
        db2=1/m*np.sum(dZ2,axis=1,keepdims=True)
        dZ1=np.dot(W2.T,dZ2)*np.dot(Z1.T,1-Z1)
        dW1=1/m*np.dot(dZ1,X.T)
        db1=1/m*np.sum(dZ1,axis=1,keepdims=True)
        # 梯度下降
        W1=W1-alpha*dW1
        b1=b1-alpha*db1
        W2=W1-alpha*dW2
        b2=b2-alpha*db2
        print("result:",A2[0],A2[1])
    return J


if __name__ == '__main__':
    # W1,X,b1,W2,b2,Y,m,alpha=InitNN()
    J=NNtraining(W1,X,b1,W2,b2,Y,m,alpha)
    fig=plt.figure(1)
    plt.plot(J)
    plt.title(u'代價函數隨迭代次數的變化')
    plt.xlabel(u'迭代次數')
    plt.ylabel(u'代價函數的值')
    plt.show()

4 結果展示

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章