計算智能 -- BP神經網絡(2)

本文是對懶惰的gler–Java實現BP神經網絡完成 Iris 數據分類:http://blog.csdn.net/u010858605/article/details/72898178 這篇博客的理解和對一些小問題的改進。原文的相關描述請點擊上面的鏈接即可。

(1)數據選取

由於採用Iris 鳶尾花數據集,該數據集一共有150條記錄,選取 Iris 數據集中的120條數據作爲訓練集(train.txt),剩餘的30條數據作爲測試集(test.txt)。
數據採取編號

(2)java-code的理解

一共包含三個類: BPNN.java 、DataUtil.java 、Test.java
這裏寫圖片描述

(3)代碼的問題

3.1 在數據寫入reslt.txt中的時候,判定的花型全是”Iris-setosa”,與事實不符合。
3.2 如果訓練次數超過MaxTrain則沒有給出判斷的條件

(4)改進

4.1 – BPNN.java

import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;

class BPNN {
    // private static int LAYER = 3; // 三層神經網絡
    private static int NodeNum = 10; // 每層的最多節點數
    private static final int ADJUST = 5; // 隱層節點數調節常數
    private static final int MaxTrain = 2000; // 最大訓練次數
    private static final double ACCU = 0.015; // 每次迭代允許的誤差 iris:0.015
    private double ETA_W = 0.5; // 權值學習效率0.5
    private double ETA_T = 0.5; // 閾值學習效率0.5
    private double accu;

    // 附加動量項
    //private static final double ETA_A = 0.3; // 動量常數0.1
    //private double[][] in_hd_last; // 上一次的權值調整量
    //private double[][] hd_out_last;

    private int in_num;     // 輸入層節點數
    private int hd_num;     // 隱層節點數
    private int out_num;    // 輸入出節點數

    private ArrayList<ArrayList<Double>> list = new ArrayList<>(); // 輸入輸出數據

    private double[][] in_hd_weight;    // BP網絡in-hidden突觸權值
    private double[][] hd_out_weight;   // BP網絡hidden_out突觸權值
    private double[] in_hd_th;          // BP網絡in-hidden閾值
    private double[] hd_out_th;         // BP網絡hidden-out閾值

    private double[][] out;     // 每個神經元的值經S型函數轉化後的輸出值,輸入層就爲原值
    private double[][] delta;       // delta學習規則中的值

    // 獲得網絡三層中神經元最多的數量
    public int GetMaxNum() {
        return Math.max(Math.max(in_num, hd_num), out_num);
    }

    // 設置權值學習率
    public void SetEtaW() {
        ETA_W = 0.5;
    }

    // 設置閾值學習率
    public void SetEtaT() {
        ETA_T = 0.5;
    }

    // BPNN訓練
    public void Train(int in_number, int out_number, ArrayList<ArrayList<Double>> arraylist) throws IOException {

        list = arraylist;
        in_num = in_number;
        out_num = out_number;

        GetNums(in_num, out_num); // 獲取輸入層、隱層、輸出層的節點數
        // SetEtaW(); // 設置學習率
        // SetEtaT();

        InitNetWork(); // 初始化網絡的權值和閾值

        int datanum = list.size();        // 訓練數據的組數
        int createsize = GetMaxNum();     // 比較每一層的節點數,取max
        out = new double[3][createsize];  // 創建輸出數組 out[3][7]

        //訓練次數爲MaxTrain以內,如果訓練次數超過MaxTrain則沒有給出判斷的條件
        for (int iter = 0; iter < MaxTrain; iter++) {

            for (int cnd = 0; cnd < datanum; cnd++) {
                // 第一層輸入節點賦值  out[0][4]
                for (int i = 0; i < in_num; i++) {
                    //list.get(cnd).get(i) 取樣本數據的第cnd組中第i個數據放入到out[0][i]中
                    out[0][i] = list.get(cnd).get(i); // 爲輸入層節點賦值,其輸入與輸出相同
                }
                Forward(); // 前向傳播
                Backward(cnd); // 誤差反向傳播

            }
            System.out.println("This is the " + (iter + 1) + " th trainning NetWork !");
            accu = GetAccu();
            System.out.println("All Samples Accuracy is " + accu);
            if (accu < ACCU)
                break;

        }

    }

    // 獲取輸入層、隱層、輸出層的節點數,in_number、out_number分別爲輸入層節點數和輸出層節點數
    public void GetNums(int in_number, int out_number) {
        in_num = in_number;
        out_num = out_number;
        hd_num = (int) Math.sqrt(in_num + out_num) + ADJUST;
        if (hd_num > NodeNum)
            hd_num = NodeNum; // 隱層節點數不能大於最大節點數
    }

    // 初始化網絡的權值和閾值
    public void InitNetWork() {
        // 初始化上一次權值量,範圍爲-0.5-0.5之間
        //in_hd_last = new double[in_num][hd_num];
        //hd_out_last = new double[hd_num][out_num];

        in_hd_weight = new double[in_num][hd_num];
        for (int i = 0; i < in_num; i++)
            for (int j = 0; j < hd_num; j++) {
                int flag = 1; // 符號標誌位(-1或者1)
                if ((new Random().nextInt(2)) == 1)
                    flag = 1;
                else
                    flag = -1;
                // New Random.nextDouble()的取值範圍: [0,1.0)
                in_hd_weight[i][j] = ( new Random().nextDouble() / 2 ) * flag; // 初始化in-hidden的權值
                //in_hd_last[i][j] = 0;
            }

        hd_out_weight = new double[hd_num][out_num];
        for (int i = 0; i < hd_num; i++)
            for (int j = 0; j < out_num; j++) {
                int flag = 1; // 符號標誌位(-1或者1)
                if ((new Random().nextInt(2)) == 1)
                    flag = 1;
                else
                    flag = -1;
                hd_out_weight[i][j] = (new Random().nextDouble() / 2) * flag; // 初始化hidden-out的權值
                //hd_out_last[i][j] = 0;
            }

        // 閾值均初始化爲0
        // 輸入層不處理數據,只接收數據,所以不設置閾值
        in_hd_th = new double[hd_num];
        for (int k = 0; k < hd_num; k++)
            in_hd_th[k] = 0;

        hd_out_th = new double[out_num];
        for (int k = 0; k < out_num; k++)
            hd_out_th[k] = 0;
    }

   /* // 計算單個樣本的誤差
    public double GetError(int cnd) {
        double ans = 0;
        for (int i = 0; i < out_num; i++)
        {
            System.out.println(out[2][i]);
            ans += 0.5 * (out[2][i] - list.get(cnd).get(in_num + i)) * (out[2][i] - list.get(cnd).get(in_num + i));
        }

        return ans;
    }*/

    // 計算所有樣本的平均精度
    public double GetAccu() {
        double ans = 0;
        int num = list.size();
        for (int i = 0; i < num; i++) {
            int m = in_num;
            for (int j = 0; j < m; j++)
                out[0][j] = list.get(i).get(j);
            Forward();
            int n = out_num;
            for (int k = 0; k < n; k++){
                //定義了輸入與輸出之間的平方誤差
                //System.out.println(list.get(i).get(in_num + k));
                //System.out.println(out[2][k]);
                ans += 0.5 * (list.get(i).get(in_num + k) - out[2][k]) * (list.get(i).get(in_num + k) - out[2][k]);
            }
        }

        return ans / num;
    }

    // 前向傳播
    public void Forward() {
        /**
         * 計算隱層節點的輸出值 
         * v = 求和( 每個輸入層數據 * 每個隱層的權重 ) + 對應  隱層 的閾值 
         * in_hd_weight[4][7]   out[0][4]   in_hd_th[7]
         */
        for (int j = 0; j < hd_num; j++) {
            double v = 0;
            for (int i = 0; i < in_num; i++)
                v += in_hd_weight[i][j] * out[0][i];
            v += in_hd_th[j];
            out[1][j] = Sigmoid(v);
        }

        /**
         * 計算輸出層節點的輸出值
         * v = 求和( 每個隱層輸出數據 * 每個輸出層的權重 ) + 對應 輸出層 的閾值 
         * hd_out_weight[7][3]  out[1][3]   hd_out_th[3]
         */
        for (int j = 0; j < out_num; j++) {
            double v = 0;
            for (int i = 0; i < hd_num; i++)
                v += hd_out_weight[i][j] * out[1][i];
            v += hd_out_th[j];
            out[2][j] = Sigmoid(v);
        }
    }

    // 誤差反向傳播 = 計算權值調整量 + 更新BP神經網絡的權值和閾值
    public void Backward(int cnd) {
        CalcDelta(cnd); // 計算權值調整量
        UpdateNetWork(); // 更新BP神經網絡的權值和閾值
    }

    // 計算delta調整量
    public void CalcDelta(int cnd) {

        int createsize = GetMaxNum(); // 比較創建數組
        delta = new double[3][createsize];

        // 計算輸出層的delta值  cnd ( 0 - 119 )
        for (int i = 0; i < out_num; i++) {
            //System.out.println(list.size());
            delta[2][i] = (list.get(cnd).get(in_num + i) - out[2][i]) * SigmoidDerivative(out[2][i]);
        }

        // 計算隱層的delta值
        for (int i = 0; i < hd_num; i++) {
            double t = 0;
            for (int j = 0; j < out_num; j++)
                t += hd_out_weight[i][j] * delta[2][j];
            delta[1][i] = t * SigmoidDerivative(out[1][i]);
        }
    }

    // 更新BP神經網絡的權值和閾值
    public void UpdateNetWork() {

        // 隱含層和輸出層之間權值和閥值調整
        for (int i = 0; i < hd_num; i++) {
            for (int j = 0; j < out_num; j++) {
                hd_out_weight[i][j] += ETA_W * delta[2][j] * out[1][i]; // 未加權值動量項
                /* 動量項
                 * hd_out_weight[i][j] += (ETA_A * hd_out_last[i][j] + ETA_W
                 * delta[2][j] * out[1][i]); hd_out_last[i][j] = ETA_A *
                 * hd_out_last[i][j] + ETA_W delta[2][j] * out[1][i];
                 */
            }

        }
        for (int i = 0; i < out_num; i++)
            hd_out_th[i] += ETA_T * delta[2][i];

        // 輸入層和隱含層之間權值和閥值調整
        for (int i = 0; i < in_num; i++) {
            for (int j = 0; j < hd_num; j++) {
                in_hd_weight[i][j] += ETA_W * delta[1][j] * out[0][i]; // 未加權值動量項
                /* 動量項
                 * in_hd_weight[i][j] += (ETA_A * in_hd_last[i][j] + ETA_W
                 * delta[1][j] * out[0][i]); in_hd_last[i][j] = ETA_A *
                 * in_hd_last[i][j] + ETA_W delta[1][j] * out[0][i];
                 */
            }
        }
        for (int i = 0; i < hd_num; i++)
            in_hd_th[i] += ETA_T * delta[1][i];
    }

    // 符號函數sign
    public int Sign(double x) {
        if (x > 0)
            return 1;
        else if (x < 0)
            return -1;
        else
            return 0;
    }

    // 返回最大值
    public double Maximum(double x, double y) {
        if (x >= y)
            return x;
        else
            return y;
    }

    // 返回最小值
    public double Minimum(double x, double y) {
        if (x <= y)
            return x;
        else
            return y;
    }

    // log-sigmoid函數
    public double Sigmoid(double x) {
        return (double) (1 / (1 + Math.exp(-x)));
    }

    // log-sigmoid函數的倒數
    public double SigmoidDerivative(double y) {
        return (double) (y * (1 - y));
    }

 /*   // tan-sigmoid函數
    public double TSigmoid(double x) {
        return (double) ((1 - Math.exp(-x)) / (1 + Math.exp(-x)));
    }

    // tan-sigmoid函數的倒數
    public double TSigmoidDerivative(double y) {
        return (double) (1 - (y * y));
    }*/

    // 分類預測函數
    public ArrayList<ArrayList<Double>> ForeCast(
            ArrayList<ArrayList<Double>> arraylist) {

        ArrayList<ArrayList<Double>> alloutlist = new ArrayList<>();
        ArrayList<Double> outlist = new ArrayList<Double>();
        int datanum = arraylist.size();
        for (int cnd = 0; cnd < datanum; cnd++) {
            for (int i = 0; i < in_num; i++)
                out[0][i] = arraylist.get(cnd).get(i); // 爲輸入節點賦值
            Forward();
            for (int i = 0; i < out_num; i++) {
                if (out[2][i] > 0 && out[2][i] < 0.5)
                    out[2][i] = 0;
                else if (out[2][i] > 0.5 && out[2][i] < 1) {
                    out[2][i] = 1;
                }
                outlist.add(out[2][i]);
                //System.out.println( out[2][i] );
            }
            alloutlist.add(outlist);
            outlist = new ArrayList<Double>();
            outlist.clear();
        }
        return alloutlist;
    }

}

4.2 – DataUtil.java

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;

class DataUtil {

    private ArrayList<ArrayList<Double>> alllist = new ArrayList<ArrayList<Double>>(); // 存放所有數據
    private ArrayList<String> outlist = new ArrayList<String>(); // 存放輸出數據,索引對應每個everylist的輸出
    private ArrayList<String> checklist = new ArrayList<String>();  //存放測試集的真實輸出字符串
    private int in_num = 0;
    private int out_num = 0; // 輸入輸出數據的個數
    private int type_num = 0; // 輸出的類型數量
    private double[][] nom_data; //歸一化輸入數據中的最大值和最小值
    private int in_data_num = 0; //提前獲得輸入數據的個數

    // 獲取輸出類型的個數
    public int GetTypeNum() {
        return type_num;
    }

    // 設置輸出類型的個數
    public void SetTypeNum(int type_num) {
        this.type_num = type_num;
    }

    // 獲取輸入數據的個數
    public int GetInNum() {
        return in_num;
    }

    // 獲取輸出數據的個數
    public int GetOutNum() {
        return out_num;
    }

    // 獲取所有數據的數組
    public ArrayList<ArrayList<Double>> GetList() {
        return alllist;
    }

    // 獲取輸出爲字符串形式的數據
    public ArrayList<String> GetOutList() {
        return outlist;
    }

    // 獲取輸出爲字符串形式的數據
    public ArrayList<String> GetCheckList() {
        return checklist;
    }

    //返回歸一化數據所需最大最小值
    public double[][] GetMaxMin(){

        return nom_data;
    }

    // 讀取文件初始化數據
    public void ReadFile( String filepath, String sep, int flag ) throws Exception {

        ArrayList<Double> everylist = new ArrayList<Double>(); // 存放每一組輸入輸出數據
        int readflag = flag; // flag=0,train;flag=1,test

        String encoding = "GBK"; //編碼格式"GBK"
        File file = new File(filepath);

        if (file.isFile() && file.exists()) { // 判斷文件是否存在
            InputStreamReader read = new InputStreamReader(new FileInputStream( file ), encoding);// 考慮到編碼格式
            BufferedReader bufferedReader = new BufferedReader(read);
            String lineTxt = null;

            while ((lineTxt = bufferedReader.readLine()) != null) {
                int in_number = 0;
                //將每一行的數據按','截取字符串
                String splits[] = lineTxt.split(sep); 
                if (readflag == 0) {
                    for (int i = 0; i < splits.length; i++)
                        try {
                            //對數據進行歸一化處理
                            everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1]));
                            in_number++;
                        } catch (Exception e) {
                            //outlist:存放輸出數據的類型
                            if (!outlist.contains(splits[i]))
                                outlist.add(splits[i]); // 存放字符串形式的輸出數據
                            //初始化[-,-,-,-,0.0,0.0,0.0]
                            for (int k = 0; k < type_num; k++) {
                                everylist.add(0.0);
                            }
                            // 0-3:四個屬性   4-6:輸出節點處理,進行one-hot編程 
                            // outlist.indexOf(splits[i]):獲取第幾位的不爲空
                            // everylist 存放着[ 0 - 6 ] 位
                            everylist.set(in_number + outlist.indexOf(splits[i]),1.0);
                        }
                } else if (readflag == 1) {
                    for (int i = 0; i < splits.length; i++)
                        try {
                            everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1]));
                            in_number++;
                        } catch (Exception e) {
                            checklist.add(splits[i]); // 存放字符串形式的輸出數據
                        }
                }
                alllist.add(everylist); // 存放所有數據
                in_num = in_number;
                out_num = type_num;
                everylist = new ArrayList<Double>();
                everylist.clear();
            }
            bufferedReader.close();
        }
    }

    //向文件寫入分類結果
    public void WriteFile(String filepath, ArrayList<ArrayList<Double>> list, int in_number,  ArrayList<String> resultlist) throws IOException{
        File file = new File(filepath);
        FileWriter fw = null;
        BufferedWriter writer = null;
        try {
            fw = new FileWriter(file);
            writer = new BufferedWriter(fw);
            for(int i=0;i<list.size();i++){
                for(int j=0;j<in_number;j++){
                    writer.write(list.get(i).get(j)+",");
                }
                writer.write(resultlist.get(i));
                writer.newLine();
            }
            writer.flush();
        } catch (IOException e) {
            e.printStackTrace();
        }finally{
            writer.close();
            fw.close();
        }
    }


    //學習樣本歸一化,找到輸入樣本數據的最大值和最小值
    public void NormalizeData(String filepath) throws IOException{
        //提前獲得輸入數據的個數   
        GetBeforIn(filepath);
        int flag=1;
        //nom_data存放輸入節點的max和min   in_data_num:4
        nom_data = new double[in_data_num][2];
        String encoding = "GBK";
        File file = new File(filepath);
        if ( file.isFile() && file.exists() ) { // 判斷文件是否存在
            InputStreamReader read = new InputStreamReader( new FileInputStream(file), encoding );// 考慮到編碼格式
            BufferedReader bufferedReader = new BufferedReader(read);
            String lineTxt = null;
            while ((lineTxt = bufferedReader.readLine()) != null) {
                String splits[] = lineTxt.split(",");   // 按','截取字符串
                for (int i = 0; i < splits.length-1; i++){
                    if(flag==1){
                        nom_data[i][0]=Double.valueOf(splits[i]);
                        nom_data[i][1]=Double.valueOf(splits[i]);
                    }
                    else{
                        if(Double.valueOf(splits[i])>nom_data[i][0])
                            nom_data[i][0]=Double.valueOf(splits[i]);
                        if(Double.valueOf(splits[i])<nom_data[i][1])
                            nom_data[i][1]=Double.valueOf(splits[i]);
                    }
                }
                flag=0;
            }
            bufferedReader.close();
        }
    }

    //歸一化前獲得輸入數據的個數
    public void GetBeforIn(String filepath) throws IOException{
        String encoding = "GBK";
        File file = new File(filepath);
        if (file.isFile() && file.exists()) { // 判斷文件是否存在
            InputStreamReader read = new InputStreamReader(new FileInputStream(
                    file), encoding);// 考慮到編碼格式
            //提前獲得輸入數據的個數
            BufferedReader beforeReader = new BufferedReader(read);
            String beforetext = beforeReader.readLine();
            String splits[] = beforetext.split(",");
            in_data_num = splits.length-1;
            beforeReader.close();
        }
    }

    //歸一化公式 -- 用於讀取文件中
    public double Normalize(double x, double max, double min){
        double y = 0.1+0.8*(x-min)/(max-min);
        return y;
    }
}

4.3 – Test.java

import java.util.ArrayList;

public class Test {
    public static void main(String args[]) throws Exception {

        //alllist = 4 + 3 即輸入和輸出
        ArrayList<ArrayList<Double>> alllist = new ArrayList<ArrayList<Double>>(); // 存放所有數據
        ArrayList<String> outlist = new ArrayList<String>();  // 存放分類的字符串
        int in_num = 0, out_num = 0; // 輸入輸出數據的個數

        DataUtil dataUtil = new DataUtil(); // 初始化數據

        dataUtil.NormalizeData("F:\\實訓\\code\\BPNN_three\\data\\train.txt");  //對數據進行歸一化處理

        dataUtil.SetTypeNum(3); // 設置輸出類型的數量
        dataUtil.ReadFile("F:\\實訓\\code\\BPNN_three\\data\\train.txt", ",", 0);

        in_num = dataUtil.GetInNum();   // 獲得輸入數據的個數
        out_num = dataUtil.GetOutNum(); // 獲得輸出數據的個數(個數代表類型個數)
        alllist = dataUtil.GetList();   // 獲得初始化後的數據

        outlist = dataUtil.GetOutList();
        //System.out.println(outlist);
        System.out.print("分類的類型:");
        for(int i =0 ;i<outlist.size();i++)
            System.out.print(outlist.get(i)+"  ");
        System.out.println();
        System.out.println("訓練集的數量:"+alllist.size());

        BPNN bpnn = new BPNN();
        // 訓練
        System.out.println("Train Start!");
        System.out.println(".............");
        bpnn.Train(in_num, out_num, alllist);
        System.out.println("Train End!");

        // 測試
        DataUtil testUtil = new DataUtil();

        testUtil.NormalizeData("F:\\實訓\\code\\BPNN_three\\data\\test.txt");

        testUtil.SetTypeNum(3); // 設置輸出類型的數量
        testUtil.ReadFile("F:\\實訓\\code\\BPNN_three\\data\\test.txt", ",", 1);

        ArrayList<ArrayList<Double>> testList = new ArrayList<ArrayList<Double>>();
        ArrayList<ArrayList<Double>> resultList = new ArrayList<ArrayList<Double>>();
        ArrayList<String> normallist = new ArrayList<String>(); // 存放測試集標準的輸出字符串
        ArrayList<String> resultlist = new ArrayList<String>(); // 存放測試集計算後的輸出字符串

        int right = 0;          // 分類正確的數量
        int type_num = 0;       // 類型的數量
        int all_num = 0;        //測試集的數量
        type_num = outlist.size();

        testList = testUtil.GetList();          // 獲取測試數據
        normallist = testUtil.GetCheckList(); 

        //int errorcount = 0; // 分類錯誤的數量
        resultList = bpnn.ForeCast(testList);   // 測試
        all_num  = resultList.size();

        //resultList:[-,-,-] normallist:[-] outlist:[-,-,-]
        //System.out.println(resultList);
        //System.out.println(normallist);
        //System.out.println(outlist);

        //臨時存放結果
        ArrayList<String> Temp = new ArrayList<String>();
        //resultList=[30][3] 這裏的輸出有問題???解決方式:增加一個臨時存放結果的數組
        for (int i = 0; i < resultList.size(); i++) {
            String checkString = "unknow";
            for (int j = 0; j < type_num; j++) {
                //System.out.println(resultList.get(i).get(j));
                if( resultList.get(i).get(j) == 1.0 ){
                    //System.out.println(outlist.get(j));
                    checkString = outlist.get(j);
                    Temp.add(checkString);
                }
                else{
                    resultlist.add(checkString);
                }
            }

           /* if(checkString.equals("unknow"))
                errorcount++;*/

            //normallist.get(i)爲實際的判定值
            if(checkString.equals(normallist.get(i)))
                right++;
        }
        //System.out.println(Temp);

        testUtil.WriteFile("F:\\實訓\\code\\BPNN_three\\data\\result.txt",testList,in_num,Temp);

        System.out.println("測試集的數量:"+ all_num );
        System.out.println("分類正確的數量:"+ right );
        //System.out.println("分類正確的數量:"+(new Double(right)).intValue());
        System.out.println("算法的分類正確率爲:"+  (new Double( (double) right/all_num )));
        System.out.println("分類結果存儲在:F:\\實訓\\code\\BPNN_three\\data\\result.txt");     

        //bpnn.GetError(1);

    }
}

(5)運行截圖

1

2
…………………………………………….一共30組

(6)參考資料

1.原作者博客:http://blog.csdn.net/u010858605/article/details/72898178
2.數據集下載:http://archive.ics.uci.edu/ml/index.php
3.歸一化處理:http://www.cnblogs.com/chaosimple/p/3227271.html
4.one-hot編程:http://www.cnblogs.com/daguankele/p/6595470.html
5.delta學習:http://blog.csdn.net/u012562273/article/details/56297648
6.機器學習之BP神經網絡(三) : https://zhuanlan.zhihu.com/p/28993795

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章