在訓練lenet的train_lenet.sh中內容爲:
./build/tools/caffe train –solver=examples/mnist/lenet_solver.prototxt
由此可知,訓練網咯模型是由tools/caffe.cpp
生成的工具caffe在模式train下完成的。
初始化過程總的來說,從main()、train()中創建Solver,在Solver中創建Net,在Net中創建Layer.
1. 程序入口
找到caffe.cpp的main函數中,通過GetBrewFunction(caffe::string(argv[1]))()
調用執行train()函數。
train中,通過參數-examples/mnist/lenet_solver.prototxt
把solver參數讀入solver_param中。
隨後註冊並定義solver的指針
shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param))
調用solver的Solver()方法。多個GPU涉及到GPU間帶異步處理問題.
if (gpus.size() > 1) {
caffe::P2PSync<float> sync(solver, NULL, solver->param());
sync.run(gpus);
} else {
LOG(INFO) << "Starting Optimization";
solver->Solve();
}
2. Solver的創建
在1中,Solver的指針solver是通過SolverRegistry::CreateSolver
創建的,CreateSolver
函數中值得注意帶是return registry[type](param)
// Get a solver using a SolverParameter.
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
const string& type = param.type();
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);
}
其中:
registry是一個map<string,Creator>: typedef std::map<string, Creator> CreatorRegistry
其中Creator是一個函數指針類型:typedef Solver<Dtype>* (*Creator)(const SolverParameter&) ``registry[type]
爲一個函數指針變量,在Lenet5中,此處具體的值爲caffe::Creator_SGDSolver<float>(caffe::SolverParameter const&)
其中Creator_SGDSolver
在以下宏中定義, REGISTER_SOLVER_CLASS(SGD)
該宏完全展開得到的內容爲:
template <typename Dtype> \
Solver<Dtype>* Creator_SGDSolver( \
const SolverParameter& param) \
{ \
return new SGDSolver<Dtype>(param); \
} \
static SolverRegisterer<float> g_creator_f_SGD("SGD", Creator_SGDSolver<float>); \
static SolverRegisterer<double> g_creator_d_SGD("SGD", Creator_SGDSolver<double>)
從上可以看出,registry[type](param)
中實際上調用了SGDSolver帶構造方法,事實上,網絡是在SGDSolver的構造方法中初始化的。
SGDSolver的定義如下:
template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) { PreSolve(); }
......
SGDSolver繼承與Solver<Dtype>
,因而new SGDSolver<Dtype>(param)
將執行Solver<Dtype>
的構造函數,然後調用自身構造函數。整個網絡帶初始化即在這裏面完成.
3. Solver::Solve()函數
在這個函數裏面,程序執行完網絡的完整訓練過程。
核心代碼如下:
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Step(param_.max_iter() - iter_);
//..
Snapshot();
//..
// some additional display
// ...
}
說明:
值得關注的代碼是Step(),在該函數中,值得了param_.max_iter()
輪迭代(10000)
在Snapshot()
中序列化model到文件.
4. Solver::Step()函數
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
//10000輪迭代
while (iter_ < stop_iter) {
// 每隔500輪進行一次測試
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
// 測試網絡,實際是執行前向傳播計算loss
TestAll();
}
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
// 執行前向反向傳播,前向計算損失loss,並計算loss關於權值的偏導
loss += net_->ForwardBackward(bottom_vec);
}
// 平滑loss,計算結果用於輸出調試等
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
// 通過反向傳播計算的偏導更新權值
ApplyUpdate();
}
}
(1). Solver::TestAll()函數
在TestAll()中,調用Test(test_net_id)
對每個測試網絡test_net
(不是訓練網絡train_net
)進行測試。在Lenet中,只有一個測試網絡,所以只調用一次Test(0)進行測試。
Test()函數裏面做了兩件事:
- 前向計算網絡,得到網絡損失
- 通過測試網絡的第11層accuracy層,與第12層loss層結果統計accuracy與loss信息。
(2). Net::ForwardBackward()函數
Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
Dtype loss;
Forward(bottom, &loss);
Backward();
return loss;
}
說明:
- 前向計算。計算網絡損失loss,參考 (Caffe,LeNet)前向計算(五)
- 反向傳播。計算loss關於網絡權值的偏導,參考 (Caffe,LeNet)反向傳播(六)
(3). Solver::ApplyUpdate()函數
根據反向傳播階段計算的loss關於網絡權值的偏導,使用配置的學習策略,更新網絡權值從而完成本輪學習。
5. 訓練完畢
至此,網絡訓練優化完成。在第3部分solve()函數中,最後對訓練網絡與測試網絡再執行一輪額外的前行計算求得loss,以進行測試。