【caffe源碼研究】第四章:完整案例源碼篇(1) :LeNetSolver初始化

在訓練lenet的train_lenet.sh中內容爲:

./build/tools/caffe train –solver=examples/mnist/lenet_solver.prototxt

由此可知,訓練網咯模型是由tools/caffe.cpp生成的工具caffe在模式train下完成的。
初始化過程總的來說,從main()、train()中創建Solver,在Solver中創建Net,在Net中創建Layer.

1. 程序入口

找到caffe.cpp的main函數中,通過GetBrewFunction(caffe::string(argv[1]))()調用執行train()函數。
train中,通過參數-examples/mnist/lenet_solver.prototxt把solver參數讀入solver_param中。
隨後註冊並定義solver的指針

  shared_ptr<caffe::Solver<float> > 
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param))

調用solver的Solver()方法。多個GPU涉及到GPU間帶異步處理問題.

if (gpus.size() > 1) {
    caffe::P2PSync<float> sync(solver, NULL, solver->param());
    sync.run(gpus);
} else {
    LOG(INFO) << "Starting Optimization";
    solver->Solve();
}

2. Solver的創建

在1中,Solver的指針solver是通過SolverRegistry::CreateSolver創建的,CreateSolver函數中值得注意帶是return registry[type](param)

  // Get a solver using a SolverParameter.
  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
    const string& type = param.type();
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
        << " (known types: " << SolverTypeListString() << ")";
    return registry[type](param);
  }

其中:

registry是一個map<string,Creator>: typedef std::map<string, Creator> CreatorRegistry
其中Creator是一個函數指針類型:typedef Solver<Dtype>* (*Creator)(const SolverParameter&) ``registry[type]爲一個函數指針變量,在Lenet5中,此處具體的值爲caffe::Creator_SGDSolver<float>(caffe::SolverParameter const&)
其中Creator_SGDSolver在以下宏中定義, REGISTER_SOLVER_CLASS(SGD)

該宏完全展開得到的內容爲:

template <typename Dtype>                                                    \
  Solver<Dtype>* Creator_SGDSolver(                                       \
      const SolverParameter& param)                                            \
  {                                                                            \
    return new SGDSolver<Dtype>(param);                                     \
  }                                                                            \
  static SolverRegisterer<float> g_creator_f_SGD("SGD", Creator_SGDSolver<float>);    \
  static SolverRegisterer<double> g_creator_d_SGD("SGD", Creator_SGDSolver<double>)

從上可以看出,registry[type](param)中實際上調用了SGDSolver帶構造方法,事實上,網絡是在SGDSolver的構造方法中初始化的。

SGDSolver的定義如下:

template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
 public:
  explicit SGDSolver(const SolverParameter& param)
      : Solver<Dtype>(param) { PreSolve(); }
  explicit SGDSolver(const string& param_file)
      : Solver<Dtype>(param_file) { PreSolve(); }
......

SGDSolver繼承與Solver<Dtype>,因而new SGDSolver<Dtype>(param)將執行Solver<Dtype>的構造函數,然後調用自身構造函數。整個網絡帶初始化即在這裏面完成.

3. Solver::Solve()函數

在這個函數裏面,程序執行完網絡的完整訓練過程。
核心代碼如下:

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {

  Step(param_.max_iter() - iter_);
  //..
    Snapshot();
  //..

  // some additional display 
  // ...
}

說明:

值得關注的代碼是Step(),在該函數中,值得了param_.max_iter()輪迭代(10000)
Snapshot()中序列化model到文件.

4. Solver::Step()函數

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {

  //10000輪迭代
  while (iter_ < stop_iter) {

    // 每隔500輪進行一次測試
    if (param_.test_interval() && iter_ % param_.test_interval() == 0
        && (iter_ > 0 || param_.test_initialization())
        && Caffe::root_solver()) {
      // 測試網絡,實際是執行前向傳播計算loss
      TestAll();
    }

    // accumulate the loss and gradient
    Dtype loss = 0;
    for (int i = 0; i < param_.iter_size(); ++i) {
      // 執行前向反向傳播,前向計算損失loss,並計算loss關於權值的偏導
      loss += net_->ForwardBackward(bottom_vec);
    }

    // 平滑loss,計算結果用於輸出調試等
    loss /= param_.iter_size();
    // average the loss across iterations for smoothed reporting
    UpdateSmoothedLoss(loss, start_iter, average_loss);

    // 通過反向傳播計算的偏導更新權值
    ApplyUpdate();

  }
}

(1). Solver::TestAll()函數

在TestAll()中,調用Test(test_net_id)對每個測試網絡test_net(不是訓練網絡train_net)進行測試。在Lenet中,只有一個測試網絡,所以只調用一次Test(0)進行測試。
Test()函數裏面做了兩件事:

  • 前向計算網絡,得到網絡損失
  • 通過測試網絡的第11層accuracy層,與第12層loss層結果統計accuracy與loss信息。

(2). Net::ForwardBackward()函數

Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
    Dtype loss;
    Forward(bottom, &loss);
    Backward();
    return loss;
  }

說明:

  • 前向計算。計算網絡損失loss,參考 (Caffe,LeNet)前向計算(五)
  • 反向傳播。計算loss關於網絡權值的偏導,參考 (Caffe,LeNet)反向傳播(六)

(3). Solver::ApplyUpdate()函數

根據反向傳播階段計算的loss關於網絡權值的偏導,使用配置的學習策略,更新網絡權值從而完成本輪學習。

5. 訓練完畢

至此,網絡訓練優化完成。在第3部分solve()函數中,最後對訓練網絡與測試網絡再執行一輪額外的前行計算求得loss,以進行測試。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章