Java實現線性迴歸模型算法

今天跟大家一起學習機器學習比較簡單的一個算法,也就是線性迴歸算法。

讓我們通過一個例子開始:這個例子就是預測住房價格的,我們要使用一個數據集,數據集包含一個地方的住房價格,這裏我們要根據不同房屋尺寸所售出的價格,畫出我們的數據集,比方說,如果你朋友的房子是1250平方尺大小,你要告訴他這個房子能賣多少錢。那麼,你可以做的一件事就是構建一個模型,也許是條直線,從這個數據模型上來看,也許你可以告訴你的朋友,他能以大約220000(美元)左右的價格賣掉這個房子。這就是監督學習算法的一個例子。

它被稱作監督學習是因爲對於每個數據來說,我們給出了“正確的答案”,即告訴我們:根據我們的數據來說,房子實際的價格是多少,而且,更具體來說,這是一個迴歸問題。迴歸一詞指的是,我們根據之前的數據預測出一個準確的輸出值,對於這個例子就是價格。同時,還有另外一種最常見的監督學洗方式,叫做分類問題,當我們想要預測離散的輸出值,例如,我們正在薛兆癌症腫瘤,並想要確定腫瘤是良性還是惡性的,這就是0/1離散輸出的問題。更近一步說,在監督學習中我們有一個數據集,這個數據集被稱作是訓練集。

下面就是實現一元線性迴歸模型的Java版本的代碼,其中繪製數據集,和繪製迴歸模型使用的是JfreeChart,核心代碼如下:

package cn.rocket.ml;


import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import cn.rocket.data.DataSet;
import cn.rocket.utils.ScatterPlot;

public class LinearRegression {

    private double theta0 = 0.0 ;  //截距
    private double theta1 = 0.0 ;  //斜率
    private double alpha = 0.01 ;  //學習速率

    private int max_itea = 20000 ; //最大迭代步數

    private DataSet dataSet = new DataSet() ;

    public  LinearRegression() throws IOException{
        dataSet.loadDataFromTxt("datas/house_price.txt", ",",1);
    }


    public double predict(double x){
        return theta0+theta1*x ;
    }

    public double calc_error(double x, double y) {
        return predict(x)-y;
    }



    public void gradientDescient(){
        double sum0 =0.0 ;
        double sum1 =0.0 ;

        for(int i = 0 ; i < dataSet.getSize() ;i++) {
            sum0 += calc_error(dataSet.getDatas().get(i).get(0), dataSet.getLabels().get(i)) ;
            sum1 += calc_error(dataSet.getDatas().get(i).get(0), dataSet.getLabels().get(i))*dataSet.getDatas().get(i).get(0) ;
        }

        this.theta0 = theta0 - alpha*sum0/dataSet.getSize() ; 
        this.theta1 = theta1 - alpha*sum1/dataSet.getSize() ; 

    }

    public void lineGre() {
        int itea = 0 ;
        while( itea< max_itea){
            //System.out.println(error_rate);
            System.out.println("The current step is :"+itea);
            System.out.println("theta0 "+theta0);
            System.out.println("theta1 "+theta1);
            System.out.println();
            gradientDescient();
            itea ++ ;
        }
    } ;

    public static void main(String[] args) throws IOException {
        LinearRegression linearRegression = new LinearRegression() ;
        linearRegression.lineGre();
        List<Double> list = new ArrayList<Double>() ;

        for(int i = 0 ; i < linearRegression.dataSet.getSize() ;i++) {
            list.add(linearRegression.dataSet.getDatas().get(i).get(0));
        }


        ScatterPlot.data("Datas", list, linearRegression.dataSet.getLabels(),linearRegression.theta0,linearRegression.theta1);

    }

}

這段代碼值得我們注意的問題有很多,一個是學習步長alpha的設置,如果設置的太大最後結果會無法收斂,但是如果設置的太小訓練會非常緩慢。

下面是結果,我們可以看到,散點圖是訓練數據,紅色的直線表示我們訓練出來的一元線性模型,我們可以看出該模型能對訓練數據做一個較好的線性擬合。

一元線性迴歸模型

該項目的項目源碼我已經放在GitHub上。項目地址:https://github.com/ShengPengYu/MachineLearning

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章