算法#03--詳解最小二乘法原理和代碼

最小二乘法原理

最小二乘法的目標:求誤差的最小平方和,對應有兩種:線性和非線性。線性最小二乘的解是closed-form(如下文),而非線性最小二乘沒有closed-form,通常用迭代法求解(如高斯牛頓迭代法,本文不作介紹)。

【首先得到線性方程組】

1.概念

最小二乘法(又稱最小平方法)是一種數學優化技術。它通過最小化誤差的平方和尋找數據的最佳函數匹配。

利用最小二乘法可以簡便地求得未知的數據,並使得這些求得的數據與實際數據之間誤差的平方和爲最小。

最小二乘法還可用於曲線擬合。

2.原理

函數原型:

這裏寫圖片描述

已知:

(x0,y0),(x1,y1)…(xi,yi)…(xn,yn)個點,n>=k。

偏差平方和:

這裏寫圖片描述

偏差平方和最小值可以通過使偏導數等於零得到:

這裏寫圖片描述

簡化左邊等式有:

這裏寫圖片描述

寫成矩陣形式:公式①

這裏寫圖片描述

將這個範德蒙得矩陣化簡後可得到:公式②

這裏寫圖片描述

也就是說X*A=Y,那麼A = (X’*X)-1*X’*Y,便得到了係數矩陣A,同時,我們也就得到了擬合曲線。

高斯消元法

【然後解線性方程組,即公式①】

1.概念

數學上,高斯消元法(或譯:高斯消去法)(英語:Gaussian Elimination),是線性代數中的一個算法,可用來爲線性方程組求解,求出矩陣的秩,以及求出可逆方陣的逆矩陣。當用於一個矩陣時,高斯消元法會產生出一個“行梯陣式”。

2.原理

這裏寫圖片描述

這裏寫圖片描述

這裏寫圖片描述

這裏寫圖片描述

這裏寫圖片描述

這裏寫圖片描述

這裏寫圖片描述

3.僞代碼

這個算法和上面談到的有點不同,它由絕對值最大的部分開始做起,這樣可以改善算法的穩定性。本算法由左至右地計算,每作出以下三個步驟,才跳到下一列和下一行:

  • 定出i列的絕對值最大的一個非0的數,將第i行的值與該行交換,使得該行擁有該列的最大值;
  • 將i列的數字除以該數,使得i列i行的數成爲1;
  • 第(i+1)行以下(包括第(j+1)行)所有元素都轉化爲0。

所有步驟完成後,這個矩陣會變成一個行梯矩陣,再用代入法就可以求解該方程組。

 i = 1
 j = 1
 while (i ≤ m and j ≤ n) do
   Find pivot in column j, starting in row i    // 從第i行開始,找出第j列中的最大值(i、j值應保持不變)  
   maxi = i
   for k = i+1 to m do
     if abs(A[k,j]) > abs(A[maxi,j]) then
       maxi = k   // 使用交換法找出最大值(絕對值最大)
     end if
   end for
   if A[maxi,j] ≠ 0 then  // 判定找到的絕對值最大值是否爲零:若不爲零就進行以下操作;若爲零則說明該列第(i+1)行以下(包括第(i+1)行)均爲零,不需要再處理,直接跳轉至第(j+1)列第(i+1)行
     swap rows i and maxi, but do not change the value of i   // 將第i行與找到的最大值所在行做交換,保持i值不變(i值記錄了本次操作的起始行)
     Now A[i,j] will contain the old value of A[maxi,j].
     divide each entry in row i by A[i,j]    // 將交換後的第i行歸一化(第i行所有元素分別除以A[i,j])
     Now A[i,j] will have the value 1.
     for u = i+1 to m do    // 第j列中,第(i+1)行以下(包括第(i+1)行)所有元素都減去A[i,j],直到第j列的i+1行以後元素均為零
       subtract A[u,j] * row i from row u
       Now A[u,j] will be 0, since A[u,j] - A[i,j] * A[u,j] = A[u,j] - 1 * A[u,j] = 0.
     end for
     i = i + 1   
   end if
   j = j + 1  // 第j列中,第(i+1)行以下(包括第(i+1)行)所有元素均爲零。移至第(j+1)列,從第(i+1)行開始重複上述步驟。
 end while

代碼

public class CurveFitting {
     ///<summary>
    ///最小二乘法擬合二元多次曲線
    ///例如y=ax+b
    ///其中MultiLine將返回a,b兩個參數。
    ///a對應MultiLine[1]
    ///b對應MultiLine[0]
    ///</summary>
    ///<param name="arrX">已知點的x座標集合</param>
    ///<param name="arrY">已知點的y座標集合</param>
    ///<param name="length">已知點的個數</param>
    ///<param name="dimension">方程的最高次數</param>
    public static double[] MultiLine(double[] arrX, double[] arrY, int length, int dimension) {
        int n = dimension + 1;                 //dimension次方程需要求 dimension+1個 係數              
        double[][] Guass = new double[n][n + 1];      
        for (int i = 0; i < n; i++){ //求矩陣公式①
            int j;
            for (j = 0; j < n; j++){
                Guass[i][j] = SumArr(arrX, j + i, length);//公式①等號左邊第一個矩陣,即Ax=b中的A
            }
            Guass[i][j] = SumArr(arrX, i, arrY, 1, length);//公式①等號右邊的矩陣,即Ax=b中的b
        }        

        return ComputGauss(Guass, n);//高斯消元法
    }

    //求數組的元素的n次方的和,即矩陣A中的元素
    private static double SumArr(double[] arr, int n, int length) {
        double s = 0;
        for (int i = 0; i < length; i++){
            if (arr[i] != 0 || n != 0){
                s = s + Math.pow(arr[i], n);
            }
            else{
                s = s + 1;
            }
        }
        return s;
    }

    //求數組的元素的n次方的和,即矩陣b中的元素
    private static double SumArr(double[] arr1, int n1, double[] arr2, int n2, int length) {
        double s = 0;
        for (int i = 0; i < length; i++)
        {
            if ((arr1[i] != 0 || n1 != 0) && (arr2[i] != 0 || n2 != 0))
                s = s + Math.pow(arr1[i], n1) * Math.pow(arr2[i], n2);
            else
                s = s + 1;
        }
        return s;        
    }

    //高斯消元法解線性方程組
    private static double[] ComputGauss(double[][] Guass, int n) {
        int i, j;
        int k, m;
        double temp;
        double max;
        double s;
        double[] x = new double[n];

        for (i = 0; i < n; i++) {
            x[i] = 0.0;//初始化
        }

        for (j = 0; j < n; j++) {
            max = 0;
            k = j;

            // 從第i行開始,找出第j列中的最大值(i、j值應保持不變)  
            for (i = j; i < n; i++) {
                if (Math.abs(Guass[i][j]) > max){
                    max = Guass[i][j];// 使用交換法找出最大值(絕對值最大)
                    k = i;
                }
            }

            if (k != j) {
                //將第j行與找到的最大值所在行做交換,保持i值不變(j值記錄了本次操作的起始行)
                for (m = j; m < n + 1; m++) {
                    temp = Guass[j][m];
                    Guass[j][m] = Guass[k][m];
                    Guass[k][m] = temp;
                }
            }

            if (max == 0) {
                // "此線性方程爲奇異線性方程" 
                return x;
            }

            // 第m列中,第(j+1)行以下(包括第(j+1)行)所有元素都減去Guass[j][m] * s / (Guass[j][j])
            //直到第m列的i+1行以後元素均爲零
            for (i = j + 1; i < n; i++) {
                s = Guass[i][j];                
                for (m = j; m < n + 1; m++) {
                    Guass[i][m] = Guass[i][m] - Guass[j][m] * s / (Guass[j][j]);                 
                }
            }
        }//結束for (j=0;j<n;j++)

        //回代過程(見公式4.1.5)
        for (i = n - 1; i >= 0; i--) {
            s = 0;
            for (j = i + 1; j < n; j++) {
                s = s + Guass[i][j] * x[j];
            }
            x[i] = (Guass[i][n] - s) / Guass[i][i];
        }

        return x;
    }//返回值是函數的係數

    public static void main(String[] args)  {
        double[] x = {0, 1, 2, 3, 4, 5, 6, 7};
        double[] y = {0, 1, 4, 9, 16, 25, 36, 49};
        double[] a = MultiLine(x, y, 8, 2);

        for(int i =0; i <a.length;i++){
            System.out.println(a[i]);
        }
    }  
}

輸出:

0.708333333333342
-0.37500000000000583
1.0416666666666674

取整就得到y=x^2。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章