【caffe源碼研究】第三章:源碼篇(4) :Solver

一個典型的solver文件如下

# The train/test net protocol buffer definition
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_mode: CPU

Solver通過協調Net的前向推斷計算和反向梯度計算(forward inference and backward gradients),來對參數進行更新,從而達到減小loss的目的。Caffe模型的學習被分爲兩個部分:由Solver進行優化、更新參數,由Net計算出loss和gradient。

caffe 支持的solvers包括:

  • Stochastic Gradient Descent (type: “SGD”),隨機梯度下降
  • AdaDelta (type: “AdaDelta”)
  • Adaptive Gradient (type: “AdaGrad”),自適應梯度
  • Adam (type: “Adam”)
  • Nesterov’s Accelerated Gradient (type: “Nesterov”)
  • RMSprop (type: “RMSProp”)

solver作用有

  • 提供優化日誌支持、創建用於學習的訓練網絡、創建用於評估的測試網絡
  • 通過調用forward / backward迭代地優化,更新權值
  • 週期性地評估測試網絡
  • 通過優化了解model及solver的狀態

每一個Solver都會繼承Solve和Step函數,而每個Solver中獨有的僅僅是ApplyUpdate這個函數裏面執行的內容不一樣,接口是一致的,這也就類似於工廠生產出來的產品一樣功能一樣,細節上有差異。接下里我們看看Solver中的關鍵函數。

核心代碼如下:

/**
 * @brief An interface for classes that perform optimization on Net%s.
 *
 * Requires implementation of ApplyUpdate to compute a parameter update
 * given the current state of the Net parameters.
 */
template <typename Dtype>
class Solver {
 public:
  explicit Solver(const SolverParameter& param,
      const Solver* root_solver = NULL);
  explicit Solver(const string& param_file, const Solver* root_solver = NULL);
  void Init(const SolverParameter& param);
  void InitTrainNet();
  void InitTestNets();
 ...
  // The main entry of the solver function. In default, iter will be zero. Pass
  // in a non-zero iter number to resume training for a pre-trained net.
  virtual void Solve(const char* resume_file = NULL);
  inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
  void Step(int iters);
...

 protected:
  // Make and apply the update value for the current iteration.
  virtual void ApplyUpdate() = 0;
  ...

  SolverParameter param_;
  int iter_;
  int current_step_;
  shared_ptr<Net<Dtype> > net_;
  vector<shared_ptr<Net<Dtype> > > test_nets_;
  vector<Callback*> callbacks_;
  vector<Dtype> losses_;
  Dtype smoothed_loss_;

  // The root solver that holds root nets (actually containing shared layers)
  // in data parallelism
  const Solver* const root_solver_;
...
};

說明:

  • shared_ptr<Net<Dtype>> net_爲訓練網絡的指針,vector<shared_ptr<Net<Dtype>>> test_nets爲測試網絡的指針組,可見測試網絡可以有多個
  • 一般來說訓練網絡跟測試網絡在實現上會有區別,但是絕大部分網絡層是相同的。
  • 不同的模型訓練方法通過重載函數ComputeUpdateValue( )實現計算update參數的核心功能
  • caffe.cpp中的train( )函數訓練模型,在這裏實例化一個Solver對象,初始化後調用了Solver中的Solve( )方法。而這個Solve( )函數主要就是在迭代運行下面這兩個函數。
    • ComputeUpdateValue();
    • net_->Update();

每一次迭代過稱中:

  • 調用Net的前向過程計算出輸出和loss;
  • 調用Net的後向過程計算出梯度(loss對每層的權重w和偏置b求導);
  • 根據Solver方法,利用梯度更新參數;
  • 根據學習率(learning rate),歷史數據和求解方法更新solver的狀態,使權重從初始化狀態逐步更新到最終的學習到的狀態。solvers的運行模式有CPU/GPU兩種模式。

Solver中Solve函數的流程圖如下:

這裏寫圖片描述

Solver類中Step函數流程圖:

這裏寫圖片描述

總結一下Solve執行中的關鍵步驟

Created with Raphaël 2.1.0Solve Step TestAll結束

其中Step步驟

Created with Raphaël 2.1.0Step是否大於最大迭代次數?ForwardBackwardUpdateSmoothedLossApplyUpdate結束yesno

其中
Net::ForwardBackward()函數如下,在Net小節中再詳細介紹。

Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
    Dtype loss;
    Forward(bottom, &loss);
    Backward();
    return loss;
  }

說明:

  • 前向計算。計算網絡損失loss.
  • 反向傳播。計算loss關於網絡權值的偏導.

而不同的Solver子類實現不同的ApplyUpdate函數。例如SGDSolver的函數實現如下

template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
  CHECK(Caffe::root_solver());
  //得到學習率
  Dtype rate = GetLearningRate();
  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
  }
  ClipGradients();
  for (int param_id = 0; param_id < this->net_->learnable_params().size();
       ++param_id) {
    Normalize(param_id);
    Regularize(param_id);
    ComputeUpdateValue(param_id, rate);
  }
  this->net_->Update();
}

優化目標是
這裏寫圖片描述

Normalize是歸一化操作。Normalize核心代碼如下

template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {

  // Scale gradient to counterbalance accumulation.
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();

  caffe_scal(net_params[param_id]->count(), accum_normalization, net_params[param_id]->mutable_cpu_diff());

}

其中caffe_scal 函數:

void caffe_scal<float>(const int N, const float alpha, float *X) {
  cblas_sscal(N, alpha, X, 1);
}

功能:X = alpha*X, N: X中element的個數

其中net_params 就是需要學習更新的參數。

Regularize函數大致類似。L2正則執行的是

losswij=decaywij+losswij

下面看ComputeUpdateValue函數。
計算公式

vij=lrratelosswij+momentumvij

losswij=vij
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  const vector<float>& net_params_lr = this->net_->params_lr();
  // momentum = 0.9 in lenet
  Dtype momentum = this->param_.momentum();
  // local_rate = lr_mult * global_rate
  // lr_mult爲該層學習率乘子,在lenet_train_test.prototxt中設置
  Dtype local_rate = rate * net_params_lr[param_id];

  // Compute the update to history, then copy it to the parameter diff.

  ...
    // axpby means ax_plus_by. i.e., y = ax + by
    // 計算新的權值更新變化值 \delta w,結果保存在歷史權值變化中
    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
              net_params[param_id]->cpu_diff(), momentum,
              history_[param_id]->mutable_cpu_data());

    // 從歷史權值變化中把變化值 \delta w 保存到歷史權值中diff中
    caffe_copy(net_params[param_id]->count(),
        history_[param_id]->cpu_data(),
        net_params[param_id]->mutable_cpu_diff());
   ... 
}

最後一步是執行this->net_->Update();更新參數,計算公式

wij=wij+(1)losswij

這裏寫圖片描述

關鍵代碼


template <typename Dtype>
void Net<Dtype>::Update() {
  for (int i = 0; i < learnable_params_.size(); ++i) {
    learnable_params_[i]->Update();
  }
}

其中,learnable_params_是一個blob的vector,它的update核心如下

caffe_axpy<Dtype>(count_, Dtype(-1),
        static_cast<const Dtype*>(diff_->cpu_data()),
        static_cast<Dtype*>(data_->mutable_cpu_data()));
發佈了319 篇原創文章 · 獲贊 77 · 訪問量 31萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章