轉《視覺SLAM十四講》ch6/gaussNewton.cpp

轉自《視覺SLAM十四講》ch6/gaussNewton.cpp
自己稍微加了點點理解的備註

#include <Eigen/Core>
#include <Eigen/Dense>
#include <chrono>
#include <iostream>
#include <opencv2/opencv.hpp>

using namespace std;
using namespace Eigen;

int main(int argc, char **argv) {
  double ar = 1.0, br = 2.0, cr = 1.0;   // 真實參數值
  double ae = 2.0, be = -1.0, ce = 5.0;  // 估計參數值,初始值
  int N = 100;                           // 數據點
  double w_sigma = 1.0;                  // 噪聲Sigma值,可調整
  double inv_sigma = 1.0 / w_sigma;
  cv::RNG rng;  // OpenCV隨機數產生器

  vector<double> x_data, y_data;  // 數據
  for (int i = 0; i < N; i++) {
    double x = i / 100.0;
    x_data.push_back(x);
    y_data.push_back(exp(ar * x * x + br * x + cr) +
                     rng.gaussian(w_sigma * w_sigma));
  }

  // 開始Gauss-Newton迭代
  int iterations = 100;           // 迭代次數
  double cost = 0, lastCost = 0;  // 本次迭代的cost和上一次迭代的cost

  chrono::steady_clock::time_point t1 =
      chrono::steady_clock::now();  //獲得當前時間
  for (int iter = 0; iter < iterations; iter++) {
    Matrix3d H = Matrix3d::Zero();  // Hessian = J^T W^{-1} J in Gauss-Newton
    Vector3d b = Vector3d::Zero();  // bias
    cost = 0;

    for (int i = 0; i < N; i++) {
      double xi = x_data[i], yi = y_data[i];  // 第i個數據點
      double error = yi - exp(ae * xi * xi + be * xi + ce);
      Vector3d J;                                          // 雅可比矩陣
      J[0] = -xi * xi * exp(ae * xi * xi + be * xi + ce);  // de/da
      J[1] = -xi * exp(ae * xi * xi + be * xi + ce);       // de/db
      J[2] = -exp(ae * xi * xi + be * xi + ce);            // de/dc

      H += inv_sigma * inv_sigma * J * J.transpose();
      b += -inv_sigma * inv_sigma * error * J;

      cost += error * error;  //用cost目標函數判斷每次迭代計算的結果好不好
    }                         //所有數據點的H,b都加起來

    // 求解線性方程 Hx=b
    Vector3d dx = H.ldlt().solve(b);
    if (isnan(dx[0])) {
      cout << "result is nan!" << endl;
      break;
    }  // H有問題,跳出計算

    if (iter > 0 && cost >= lastCost) {
      cout << "cost: " << cost << ">= last cost: " << lastCost << ", break."
           << endl;
      break;
    }  // cost目標函數判斷,△x足夠小,參數幾乎無變化,跳出迭代

    ae += dx[0];  //不斷更新優化ae,be,ce,使得cost越來越小,
    be += dx[1];  //即△x足夠小,ae等參數變化不大
    ce += dx[2];

    lastCost = cost;

    cout << "total cost: " << cost << ", \t\tupdate: " << dx.transpose()
         << "\t\testimated params: " << ae << "," << be << "," << ce << endl;
  }

  chrono::steady_clock::time_point t2 = chrono::steady_clock::now();
  chrono::duration<double> time_used =
      chrono::duration_cast<chrono::duration<double>>(t2 - t1);
  cout << "solve time cost = " << time_used.count() << " seconds. " << endl;

  cout << "estimated abc = " << ae << ", " << be << ", " << ce << endl;
  return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章