BP神經網絡,BP推導過程,反向傳播算法,誤差反向傳播,梯度下降,權值閾值更新推導,隱含層權重更新公式

一、BP神經網絡的概念

BP神經網絡是一種多層的前饋神經網絡,其主要的特點是:信號是前向傳播的,而誤差是反向傳播的。具體來說,對於如下的只含一個隱層的神經網絡模型:
技術分享
(三層BP神經網絡模型)
BP神經網絡的過程主要分爲兩個階段,第一階段是信號的前向傳播,從輸入層經過隱含層,最後到達輸出層;第二階段是誤差的反向傳播,從輸出層到隱含層,最後到輸入層,依次調節隱含層到輸出層的權重和偏置,輸入層到隱含層的權重和偏置。

二、BP神經網絡的流程

    在知道了BP神經網絡的特點後,我們需要依據信號的前向傳播和誤差的反向傳播來構建整個網絡。

1、網絡的初始化

    假設輸入層的節點個數爲技術分享,隱含層的節點個數爲技術分享,輸出層的節點個數爲技術分享。輸入層到隱含層的權重技術分享,隱含層到輸出層的權重爲技術分享,輸入層到隱含層的偏置爲技術分享,隱含層到輸出層的偏置爲技術分享。學習速率爲技術分享,激勵函數爲技術分享。其中激勵函數爲技術分享取Sigmoid函數。形式爲:
技術分享

2、隱含層的輸出

    如上面的三層BP網絡所示,隱含層的輸出技術分享
技術分享

3、輸出層的輸出

技術分享

4、誤差的計算

    我們取誤差公式爲:
技術分享
其中技術分享爲期望輸出。我們記技術分享,則技術分享可以表示爲
技術分享
以上公式中,技術分享技術分享技術分享

5、權值的更新

    權值的更新公式爲:
技術分享
這裏需要解釋一下公式的由來:
這是誤差反向傳播的過程,我們的目標是使得誤差函數達到最小值,即技術分享,我們使用梯度下降法:
  • 隱含層到輸出層的權重更新
技術分享
則權重的更新公式爲:
技術分享
  • 輸入層到隱含層的權重更新
技術分享
其中
技術分享
 
技術分享
則權重的更新公式爲:
技術分享

6、偏置的更新

    偏置的更新公式爲:
技術分享
  • 隱含層到輸出層的偏置更新
技術分享
則偏置的更新公式爲:
技術分享
  • 輸入層到隱含層的偏置更新
技術分享
其中
技術分享
 
技術分享
則偏置的更新公式爲:
技術分享

7、判斷算法迭代是否結束

    有很多的方法可以判斷算法是否已經收斂,常見的有指定迭代的代數,判斷相鄰的兩次誤差之間的差別是否小於指定的值等等。

三、實驗的仿真

    在本試驗中,我們利用BP神經網絡處理一個四分類問題,最終的分類結果爲:
技術分享

MATLAB代碼

主程序
[plain] view plaincopy技術分享技術分享
 
  1. %% BP的主函數  
  2.   
  3. % 清空  
  4. clear all;  
  5. clc;  
  6.   
  7. % 導入數據  
  8. load data;  
  9.   
  10. %從1到2000間隨機排序  
  11. k=rand(1,2000);  
  12. [m,n]=sort(k);  
  13.   
  14. %輸入輸出數據  
  15. input=data(:,2:25);  
  16. output1 =data(:,1);  
  17.   
  18. %把輸出從1維變成4維  
  19. for i=1:2000  
  20.     switch output1(i)  
  21.         case 1  
  22.             output(i,:)=[1 0 0 0];  
  23.         case 2  
  24.             output(i,:)=[0 1 0 0];  
  25.         case 3  
  26.             output(i,:)=[0 0 1 0];  
  27.         case 4  
  28.             output(i,:)=[0 0 0 1];  
  29.     end  
  30. end  
  31.   
  32. %隨機提取1500個樣本爲訓練樣本,500個樣本爲預測樣本  
  33. trainCharacter=input(n(1:1600),:);  
  34. trainOutput=output(n(1:1600),:);  
  35. testCharacter=input(n(1601:2000),:);  
  36. testOutput=output(n(1601:2000),:);  
  37.   
  38. % 對訓練的特徵進行歸一化  
  39. [trainInput,inputps]=mapminmax(trainCharacter‘);  
  40.   
  41. %% 參數的初始化  
  42.   
  43. % 參數的初始化  
  44. inputNum = 24;%輸入層的節點數  
  45. hiddenNum = 50;%隱含層的節點數  
  46. outputNum = 4;%輸出層的節點數  
  47.   
  48. % 權重和偏置的初始化  
  49. w1 = rands(inputNum,hiddenNum);  
  50. b1 = rands(hiddenNum,1);  
  51. w2 = rands(hiddenNum,outputNum);  
  52. b2 = rands(outputNum,1);  
  53.   
  54. % 學習率  
  55. yita = 0.1;  
  56.   
  57. %% 網絡的訓練  
  58. for r = 1:30  
  59.     E(r) = 0;% 統計誤差  
  60.     for m = 1:1600  
  61.         % 信息的正向流動  
  62.         x = trainInput(:,m);  
  63.         % 隱含層的輸出  
  64.         for j = 1:hiddenNum  
  65.             hidden(j,:) = w1(:,j)‘*x+b1(j,:);  
  66.             hiddenOutput(j,:) = g(hidden(j,:));  
  67.         end  
  68.         % 輸出層的輸出  
  69.         outputOutput = w2‘*hiddenOutput+b2;  
  70.           
  71.         % 計算誤差  
  72.         e = trainOutput(m,:)‘-outputOutput;  
  73.         E(r) = E(r) + sum(abs(e));  
  74.           
  75.         % 修改權重和偏置  
  76.         % 隱含層到輸出層的權重和偏置調整  
  77.         dw2 = hiddenOutput*e‘;  
  78.         db2 = e;  
  79.           
  80.         % 輸入層到隱含層的權重和偏置調整  
  81.         for j = 1:hiddenNum  
  82.             partOne(j) = hiddenOutput(j)*(1-hiddenOutput(j));  
  83.             partTwo(j) = w2(j,:)*e;  
  84.         end  
  85.           
  86.         for i = 1:inputNum  
  87.             for j = 1:hiddenNum  
  88.                 dw1(i,j) = partOne(j)*x(i,:)*partTwo(j);  
  89.                 db1(j,:) = partOne(j)*partTwo(j);  
  90.             end  
  91.         end  
  92.           
  93.         w1 = w1 + yita*dw1;  
  94.         w2 = w2 + yita*dw2;  
  95.         b1 = b1 + yita*db1;  
  96.         b2 = b2 + yita*db2;    
  97.     end  
  98. end  
  99.   
  100. %% 語音特徵信號分類  
  101. testInput=mapminmax(‘apply‘,testCharacter‘,inputps);  
  102.   
  103. for m = 1:400  
  104.     for j = 1:hiddenNum  
  105.         hiddenTest(j,:) = w1(:,j)‘*testInput(:,m)+b1(j,:);  
  106.         hiddenTestOutput(j,:) = g(hiddenTest(j,:));  
  107.     end  
  108.     outputOfTest(:,m) = w2‘*hiddenTestOutput+b2;  
  109. end  
  110.   
  111. %% 結果分析  
  112. %根據網絡輸出找出數據屬於哪類  
  113. for m=1:400  
  114.     output_fore(m)=find(outputOfTest(:,m)==max(outputOfTest(:,m)));  
  115. end  
  116.   
  117. %BP網絡預測誤差  
  118. error=output_fore-output1(n(1601:2000))‘;  
  119.   
  120. k=zeros(1,4);    
  121. %找出判斷錯誤的分類屬於哪一類  
  122. for i=1:400  
  123.     if error(i)~=0  
  124.         [b,c]=max(testOutput(i,:));  
  125.         switch c  
  126.             case 1   
  127.                 k(1)=k(1)+1;  
  128.             case 2   
  129.                 k(2)=k(2)+1;  
  130.             case 3   
  131.                 k(3)=k(3)+1;  
  132.             case 4   
  133.                 k(4)=k(4)+1;  
  134.         end  
  135.     end  
  136. end  
  137.   
  138. %找出每類的個體和  
  139. kk=zeros(1,4);  
  140. for i=1:400  
  141.     [b,c]=max(testOutput(i,:));  
  142.     switch c  
  143.         case 1  
  144.             kk(1)=kk(1)+1;  
  145.         case 2  
  146.             kk(2)=kk(2)+1;  
  147.         case 3  
  148.             kk(3)=kk(3)+1;  
  149.         case 4  
  150.             kk(4)=kk(4)+1;  
  151.     end  
  152. end  
  153.   
  154. %正確率  
  155. rightridio=(kk-k)./kk  

激活函數
[plain] view plaincopy技術分享技術分享
 
  1. %% 激活函數  
  2. function [ y ] = g( x )  
  3.     y = 1./(1+exp(-x));  
  4. end  


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章