include/caffe/solver_factory.hpp中的solverRegistry類和SolverRegisterer 類
/**
* @brief A solver factory that allows one to register solvers, similar to
* layer factory. During runtime, registered solvers could be called by passing
* a SolverParameter protobuffer to the CreateSolver function:
*
* SolverRegistry<Dtype>::CreateSolver(param);
*
* There are two ways to register a solver. Assuming that we have a solver like:
*
* template <typename Dtype>
* class MyAwesomeSolver : public Solver<Dtype> {
* // your implementations
* };
*
* and its type is its C++ class name, but without the "Solver" at the end
* ("MyAwesomeSolver" -> "MyAwesome").
*
* If the solver is going to be created simply by its constructor, in your C++
* file, add the following line:
*
* REGISTER_SOLVER_CLASS(MyAwesome);
*
* Or, if the solver is going to be created by another creator function, in the
* format of:
*
* template <typename Dtype>
* Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
* // your implementation
* }
*
* then you can register the creator function instead, like
*
* REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
*
* Note that each solver type should only be registered once.
*/
#ifndef CAFFE_SOLVER_FACTORY_H_
#define CAFFE_SOLVER_FACTORY_H_
#include <map>
#include <string>
#include <vector>
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
namespace caffe {
template <typename Dtype>
class Solver;
template <typename Dtype>
class SolverRegistry {
public:
typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
typedef std::map<string, Creator> CreatorRegistry;
//所有成員函數都是靜態的,通過類名調用
static CreatorRegistry& Registry() {
static CreatorRegistry* g_registry_ = new CreatorRegistry();//g_registry是指向CreatorRegistry
這個map類型的指針,然後直接返回,因爲這個變量是static的,所以即使多次調用這個函數,也只會定義一個g_registry,而且在其他地方修改這個map裏的內容,是存儲在這個map中的。事實上各個Solver的register的過程正是往g_registry指向的那個map裏添加以Solver的type爲key,對應的Creator函數指針爲value的內容。
return *g_registry_;
}
// 添加一個creator指針
static void AddCreator(const string& type, Creator creator) {
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 0)
<< "Solver type " << type << " already registered.";
registry[type] = creator;//如果沒有註冊就添加到registor靜態指針指向的map中
}
// Get a solver using a SolverParameter.
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
const string& type = param.type();//先定義了一個string類型的變量表示Solver的類型
CreatorRegistry& registry = Registry();//通過調用Registry()函數,Registry()中創建CreatorRegistry類的對象,定義了一個key類型爲string,value類型爲Creator
的map:registry.其中Creator
是一個solver函數指針類型,指向的函數的參數爲SolverParameter
類型
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type//果是一個已經register過的Solver類型,那麼registry.count(type)
應該爲1
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);//返回registry中type對應的creator對象,並調用這個creator函數,將creator返回的Solver<Dtype>*
返回
}
static vector<string> SolverTypeList() {
CreatorRegistry& registry = Registry();
vector<string> solver_types;
for (typename CreatorRegistry::iterator iter = registry.begin();
iter != registry.end(); ++iter) {
solver_types.push_back(iter->first);
}
return solver_types;
}
private:
// Solver registry should never be instantiated - everything is done with its
// static variables.
SolverRegistry() {} //構造函數時私有,所以沒辦法創造該類的變量.直接用類名調用
static string SolverTypeListString() {
vector<string> solver_types = SolverTypeList();
string solver_types_str;
for (vector<string>::iterator iter = solver_types.begin();
iter != solver_types.end(); ++iter) {
if (iter != solver_types.begin()) {
solver_types_str += ", ";
}
solver_types_str += *iter;
}
return solver_types_str;
}
};
template <typename Dtype>
class SolverRegisterer {
public:
SolverRegisterer(const string& type,
Solver<Dtype>* (*creator)(const SolverParameter&)) {
// LOG(INFO) << "Registering solver type: " << type;
SolverRegistry<Dtype>::AddCreator(type, creator);
}
};
在sgd_solver.cpp(SGD Solver對應的cpp文件)末尾使用了REGISTER_SOLVER_CLASS
這個宏,這個宏會定義一個名爲Creator_SGDSolver
的函數,這個函數即爲Creator
類型的指針指向的函數,在這個函數中調用了SGDSolver
的構造函數,並將構造的這個變量得到的指針返回,這也就是Creator類型函數的作用:構造一個對應類型的Solver對象,將其指針返回。然後在這個宏裏又調用了REGISTER_SOLVER_CREATOR
這個宏,這裏分別定義了SolverRegisterer
這個模板類的float和double類型的static變量,這會去調用各自的構造函數,而在SolverRegisterer
的構造函數中調用了之前提到的SolverRegistry
類的AddCreator
函數,這個函數就是將剛纔定義的Creator_SGDSolver
這個函數的指針存到g_registry指向的map裏面。類似地,所有的Solver對應的cpp文件的末尾都調用了這個宏來完成註冊,在所有的Solver都註冊之後,我們就可以通過之前描述的方式,通過g_registry得到對應的Creator函數的指針,並通過調用這個Creator函數來構造對應的Solver。
template <typename Dtype>
class SolverRegisterer {
public:
SolverRegisterer(const string& type,
Solver<Dtype>* (*creator)(const SolverParameter&)) {
// LOG(INFO) << "Registering solver type: " << type;
SolverRegistry<Dtype>::AddCreator(type, creator);
}
};
//
分別定義了SolverRegisterer這個模板類的float和double類型的static變量,這會去調用各自的構造函數,而在SolverRegisterer的構造函數中調用了之前提到的SolverRegistry類的
AddCreator函數,這個函數就是將剛纔定義的Creator_SGDSolver這個函數的指針存到g_registry指向的map裏面。
#define REGISTER_SOLVER_CREATOR(type, creator) \
static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
這個宏會定義一個名爲Creator_×××Solver的函數,這個函數即爲Creator類型的指針指向的函數,在這個函數中調用了×××Solver的構造函數,並將構造的這個變量得到的指針返回,這也就是Creator
類型函數的作用:構造一個對應類型的Solver對象,將其指針返回。然後在這個宏裏又調用了REGISTER_SOLVER_CREATOR這個宏
#define REGISTER_SOLVER_CLASS(type) \
template <typename Dtype> \
Solver<Dtype>* Creator_##type##Solver( \
const SolverParameter& param) \
{ \
return new type##Solver<Dtype>(param); \
} \
REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
} // namespace caffe
#endif // CAFFE_SOLVER_FACTORY_H_
include/caffe/solver.hpp
#ifndef CAFFE_SOLVER_HPP_
#define CAFFE_SOLVER_HPP_
#include <boost/function.hpp>
#include <string>
#include <vector>
#include "caffe/net.hpp"
#include "caffe/solver_factory.hpp"
#include "caffe/util/benchmark.hpp"
/* |
* (1)solver_factory的register和create不同類型Solver的機制, |
* (2)通過signal_handler來獲取系統信號,並根據用戶或默認的設置進行相應的處理, |
* (3)Solver::Solve函數的具體實現的分析, |
* (4)SGDSolver::ApplyUpdate函數的具體實現。前面三個部分都屬於基類的, |
* 最後一個是SGDSolver這個子類的,如果用戶想要實現自己的Solver類, |
* 也應該類似地去繼承基類,並實現自己的ApplyUpdate函數,在代碼的末尾通過 |
* register宏完成註冊,便可以被成功的調用。 |
namespace caffe {
/**
按Ctrl-C時,會保存當前訓練時的模型
如果還在訓練終端不小心被關閉時,可以接着上次繼續訓練
*/
namespace SolverAction {
enum Enum {
NONE = 0, // Take no special action.
STOP = 1, // Stop training. snapshot_after_train controls whether a
// snapshot is created.
SNAPSHOT = 2 // Take a snapshot, and keep training.
};
}
/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
typedef boost::function<SolverAction::Enum()> ActionCallback;
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
class Solver {
public:
explicit Solver(const SolverParameter& param);
explicit Solver(const string& param_file);
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
// Client of the Solver optionally may call this in order to set the function
// that the solver uses to see what action it should take (e.g. snapshot or
// exit training early).
void SetActionFunction(ActionCallback func);
SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
// The Restore method simply dispatches to one of the
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
void Restore(const char* resume_file);
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
virtual ~Solver() {}
inline const SolverParameter& param() const { return param_; }
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
return test_nets_;
}
int iter() const { return iter_; }
// Invoked at specific points during an iteration
class Callback {
protected:
virtual void on_start() = 0;
virtual void on_gradients_ready() = 0;
template <typename T>
friend class Solver;
};
const vector<Callback*>& callbacks() const { return callbacks_; }
void add_callback(Callback* value) {
callbacks_.push_back(value);
}
void CheckSnapshotWritePermissions();
/**
* @brief Returns the solver type.
*/
virtual inline const char* type() const { return ""; }
protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
string SnapshotFilename(const string extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
// The test routine
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(const string& model_filename) = 0;
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
SolverParameter param_;
int iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
vector<Callback*> callbacks_;
vector<Dtype> losses_;
Dtype smoothed_loss_;
// A function that can be set by a client of the Solver to provide indication
// that it wants a snapshot saved and/or to exit early.
ActionCallback action_request_function_;
// True iff a request to stop early was received.
bool requested_early_exit_;
// Timing information, handy to tune e.g. nbr of GPUs
Timer iteration_timer_;
float iterations_last_;
DISABLE_COPY_AND_ASSIGN(Solver);
};
} // namespace caffe
#endif // CAFFE_SOLVER_HPP_
src/caffe/solver.cpp
#include <cstdio>
#include <string>
#include <vector>
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"
namespace caffe {
//確定solver層的適用方式
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
action_request_function_ = func;
}
template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
if (action_request_function_) {
// If the external request function has been set, call it.
return action_request_function_();
}
return SolverAction::NONE;
}
//設計好需要優化的對象,以及用於學習的訓練網絡和用於評估的測試網絡
//構造函數:初始化net,調用init(),有兩種調用參數的方式
//1.使用SolverParamter類型的param
template <typename Dtype>Solver<Dtype>::Solver(const SolverParameter& param)
: net_(), callbacks_(), requested_early_exit_(false) {
Init(param);
}
//2.使用string類型的param_file
template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file)
: net_(), callbacks_(), requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
Init(param);
}
template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
<< std::endl << param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
CheckSnapshotWritePermissions();
if (param_.random_seed() >= 0) { //設置隨機種子
Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank());
}
// Scaffolding code
InitTrainNet();//初始化訓練網絡,net指向這個空間
if (Caffe::root_solver()) {
InitTestNets();//初始化測試網絡,net指向這個空間
LOG(INFO) << "Solver scaffolding done.";
}
iter_ = 0;
current_step_ = 0;
}
//初始化訓練網絡
template <typename Dtype>void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
param_.has_train_net() + param_.has_train_net_param();//訓練網絡數量
const string& field_names = "net, net_param, train_net, train_net_param";//區域名字
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
<< "using one of these fields: " << field_names;
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
<< "one of these fields specifying a train_net: " << field_names;//訓練網絡數量超過,報錯
NetParameter net_param;//網絡參數
if (param_.has_train_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in train_net_param.";
net_param.CopyFrom(param_.train_net_param());//從訓練網絡中複製參數
} else if (param_.has_train_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from train_net file: " << param_.train_net();
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);//從訓練網絡中讀取參數
}
if (param_.has_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in net_param.";
net_param.CopyFrom(param_.net_param());//從測試網絡中複製參數
}
if (param_.has_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from net file: " << param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);//從測試網絡中讀取參數
}
// Set the correct NetState. We start with the solver defaults (lowest
// precedence); then, merge in any NetState specified by the net_param itself;
// finally, merge in any NetState specified by the train_state (highest
// precedence).
NetState net_state;
net_state.set_phase(TRAIN);
net_state.MergeFrom(net_param.state());
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);//設置solver的初始化參數,混合網絡參數本身的網絡狀態
net_.reset(new Net<Dtype>(net_param));//調用模板類的構造函數,進行net初始化
}
//初始化測試網絡
template <typename Dtype>void Solver<Dtype>::InitTestNets() {
CHECK(Caffe::root_solver());
const bool has_net_param = param_.has_net_param();
const bool has_net_file = param_.has_net();
const int num_generic_nets = has_net_param + has_net_file;//同類網絡數量
CHECK_LE(num_generic_nets, 1)
<< "Both net_param and net_file may not be specified.";
const int num_test_net_params = param_.test_net_param_size();
const int num_test_net_files = param_.test_net_size();
const int num_test_nets = num_test_net_params + num_test_net_files;//測試網絡數量
if (num_generic_nets) {
CHECK_GE(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
} else {
CHECK_EQ(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
}
// If we have a generic net (specified by net or net_param, rather than
// test_net or test_net_param), we may have an unlimited number of actual
// test networks -- the actual number is given by the number of remaining
// test_iters after any test nets specified by test_net_param and/or test_net
// are evaluated.
const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
const int num_test_net_instances = num_test_nets + num_generic_net_instances;
if (param_.test_state_size()) {
CHECK_EQ(param_.test_state_size(), num_test_net_instances)
<< "test_state must be unspecified or specified once per test net.";
}
if (num_test_net_instances) {
CHECK_GT(param_.test_interval(), 0);
}
int test_net_id = 0;
vector<string> sources(num_test_net_instances);
vector<NetParameter> net_params(num_test_net_instances);
for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
sources[test_net_id] = "test_net_param";//對網絡參數進行標記
net_params[test_net_id].CopyFrom(param_.test_net_param(i));//複製網絡參數
}
for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
sources[test_net_id] = "test_net file: " + param_.test_net(i);//對網絡參數進行標記
ReadNetParamsFromTextFileOrDie(param_.test_net(i),//複製網絡參數
&net_params[test_net_id]);
}
const int remaining_test_nets = param_.test_iter_size() - test_net_id;
if (has_net_param) {
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net_param";net_params[test_net_id].CopyFrom(param_.net_param());
}
}
if (has_net_file) {
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net file: " + param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
}
}
test_nets_.resize(num_test_net_instances);
for (int i = 0; i < num_test_net_instances; ++i) {
// Set the correct NetState. We start with the solver defaults (lowest
// precedence); then, merge in any NetState specified by the net_param
// itself; finally, merge in any NetState specified by the test_state
// (highest precedence).
NetState net_state;
net_state.set_phase(TEST);
net_state.MergeFrom(net_params[i].state());
if (param_.test_state_size()) {
net_state.MergeFrom(param_.test_state(i));
}
net_params[i].mutable_state()->CopyFrom(net_state);
LOG(INFO)
<< "Creating test net (#" << i << ") specified by " << sources[i];
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
test_nets_[i]->set_debug_info(param_.debug_info());
}
}
//step()函數
template <typename Dtype>void Solver<Dtype>::Step(int iters) {
const int start_iter = iter_;//設置開始的迭代次數(如果是從之前的snapshot恢復的,那麼開始的迭代次數是snapshot結束時的迭代次數)
const int stop_iter = iter_ + iters;//設置結束的迭代次數
int average_loss = this->param_.average_loss();//輸出的loss是以前的average_loss次的loss平均值,在solver.prototxt裏設置,默認爲1
losses_.clear();
smoothed_loss_ = 0;
iteration_timer_.Start();
//迭代
while (iter_ < stop_iter) {
// 清空上一次所有參數的梯度
net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())) {//判斷是否需要測試
if (Caffe::root_solver()) {
TestAll();
}
if (requested_early_exit_) {
// 判斷是都需要提前結束
break;
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == 0;
net_->set_debug_info(display && param_.debug_info());//輸出信息
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();//每次迭代loss求均值
//計算要輸出的smooth_loss
UpdateSmoothedLoss(loss, start_iter, average_loss);
float lapse = iteration_timer_.Seconds();
float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1);
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< " (" << per_s << " iter/s, " << lapse << "s/"
<< param_.display() << " iters), loss = " << smoothed_loss_;
iteration_timer_.Start();
iterations_last_ = iter_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();//輸出blob付給result
int score_index = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];//輸出名字
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];//輸出loss
for (int k = 0; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
// Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
++iter_;//迭代次數加1
SolverAction::Enum request = GetRequestedAction();
// Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();//存儲snapshot
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}
//solve函數
template <typename Dtype>void Solver<Dtype>::Solve(const char* resume_file) {
CHECK(Caffe::root_solver());
//檢查當前是否是root_solver(多gpu模式下,只有root_soler才運行這一部分代碼)
LOG(INFO) << "Solving " << net_->name();LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
//輸出學習率更新策略
// Initialize to false every time we start solving.
requested_early_exit_ = false;
//初始化爲FALSE,表示沒有要求在優化結束前退出
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}
//如果resume_file指針不爲空,則需要從存儲的路徑裏讀取之前的訓練狀態
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_;
Step(param_.max_iter() - iter_);
//調用step函數,執行實際的逐步迭代過程
// If we haven't already, save a snapshot after optimization, unless// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
//迭代結束或者遇到系統信號提前結束後,判斷是否需要在訓練結束後snapshot,這個可以在solver.prototxt裏設置
LOG(INFO) << "Optimization stopped early.";
return;
}
//如果在step函數中遇到了提前結束的信號,且我們的處理方式是stop,那麼requested_early_exit_會被修改爲TRUE,所以進入函數內部迭代提前結束,輸出信息
//優化完成以後,運行一個額外的訓練和測試過程展示訓練測試的loss或者輸出if (param_.display() && iter_ % param_.display() == 0) {
int average_loss = this->param_.average_loss();
Dtype loss;
net_->Forward(&loss);
UpdateSmoothedLoss(loss, start_iter, average_loss);
//判斷是否需要輸出最後的loss
LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
TestAll();
}
//判斷是否需要最後的測試
LOG(INFO) << "Optimization Done.";}
template <typename Dtype>
void Solver<Dtype>::TestAll() {//對test_net全部進行測試
for (int test_net_id = 0;
test_net_id < test_nets_.size() && !requested_early_exit_;
++test_net_id) {
Test(test_net_id);
}
}
template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Iteration " << iter_
<< ", Testing net (#" << test_net_id << ")";
CHECK_NOTNULL(test_nets_[test_net_id].get())->
ShareTrainedLayersWith(net_.get());
vector<Dtype> test_score;
vector<int> test_score_output_id;
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverAction::Enum request = GetRequestedAction();
// Check to see if stoppage of testing/training has been requested.
//對於網路不斷檢測請求狀態,如果在訓練或測試終端請求發出後,隨時執行保存快照
while (request != SolverAction::NONE) {if (SolverAction::SNAPSHOT == request) {
Snapshot();
} else if (SolverAction::STOP == request) {
requested_early_exit_ = true;
}
request = GetRequestedAction();
}
if (requested_early_exit_) {
// break out of test loop.
break;
}
Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(&iter_loss);
if (param_.test_compute_loss()) {
loss += iter_loss;
}
if (i == 0) {
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
//第一次測試時,取每一個輸出層的blob result_vec = result[j]->cpu_data();把每一個blob的數據(降爲一維)存入一個vector test_score
for (int k = 0; k < result[j]->count(); ++k) {test_score.push_back(result_vec[k]);
test_score_output_id.push_back(j);
}
}
} else {
int idx = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
//不是第一次測試,把輸出層的對應位置的blob值累加
test_score[idx++] += result_vec[k];}
}
}
}
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);//求出平均loss值並輸出
LOG(INFO) << "Test loss: " << loss;
}
for (int i = 0; i < test_score.size(); ++i) {
const int output_blob_index =
test_net->output_blob_indices()[test_score_output_id[i]];
const string& output_name = test_net->blob_names()[output_blob_index];
const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
ostringstream loss_msg_stream;
const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * mean_score << " loss)";
}
LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
<< mean_score << loss_msg_stream.str();
}
}
template <typename Dtype>
void Solver<Dtype>::Snapshot() {//選擇合適的方式保存快照
CHECK(Caffe::root_solver());
string model_filename;
switch (param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
model_filename = SnapshotToBinaryProto();
break;
case caffe::SolverParameter_SnapshotFormat_HDF5:
model_filename = SnapshotToHDF5();
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
SnapshotSolverState(model_filename);
}
template <typename Dtype>
void Solver<Dtype>::CheckSnapshotWritePermissions() {//檢查是否允許保存快照
if (Caffe::root_solver() && param_.snapshot()) {
CHECK(param_.has_snapshot_prefix())
<< "In solver params, snapshot is specified but snapshot_prefix is not";
string probe_filename = SnapshotFilename(".tempfile");
std::ofstream probe_ofs(probe_filename.c_str());
if (probe_ofs.good()) {
probe_ofs.close();
std::remove(probe_filename.c_str());
} else {
LOG(FATAL) << "Cannot write to snapshot prefix '"
<< param_.snapshot_prefix() << "'. Make sure "
<< "that the directory exists and is writeable.";
}
}
}
template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string extension) {//生成快照文件名
return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
+ extension;
}
template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {
string model_filename = SnapshotFilename(".caffemodel");
LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
NetParameter net_param;
net_->ToProto(&net_param, param_.snapshot_diff());
WriteProtoToBinaryFile(net_param, model_filename);
return model_filename;
}
template <typename Dtype>
string Solver<Dtype>::SnapshotToHDF5() {
string model_filename = SnapshotFilename(".caffemodel.h5");
LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
net_->ToHDF5(model_filename, param_.snapshot_diff());
return model_filename;
}
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
string state_filename(state_file);
if (state_filename.size() >= 3 &&
state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
RestoreSolverStateFromHDF5(state_filename);
} else {
RestoreSolverStateFromBinaryProto(state_filename);
}
}
template <typename Dtype>
void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
int average_loss) {
if (losses_.size() < average_loss) {
losses_.push_back(loss);
int size = losses_.size();
smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
} else {
int idx = (iter_ - start_iter) % average_loss;
smoothed_loss_ += (loss - losses_[idx]) / average_loss;
losses_[idx] = loss;
}
}
INSTANTIATE_CLASS(Solver);
} // namespace caffe
src/caffe/solvers/sgd_solver.cpp
#include <string>
#include <vector>
#include "caffe/sgd_solvers.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"
namespace caffe {
// Return the current learning rate. The currently implemented learning rate
// policies are as follows:
// - fixed: always return base_lr.
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
template <typename Dtype>
Dtype SGDSolver<Dtype>::GetLearningRate() {
Dtype rate;
const string& lr_policy = this->param_.lr_policy();
if (lr_policy == "fixed") {
rate = this->param_.base_lr();
} else if (lr_policy == "step") {
this->current_step_ = this->iter_ / this->param_.stepsize();
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "exp") {
rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
} else if (lr_policy == "inv") {
rate = this->param_.base_lr() *
pow(Dtype(1) + this->param_.gamma() * this->iter_,
- this->param_.power());
} else if (lr_policy == "multistep") {
if (this->current_step_ < this->param_.stepvalue_size() &&
this->iter_ >= this->param_.stepvalue(this->current_step_)) {
this->current_step_++;
LOG(INFO) << "MultiStep Status: Iteration " <<
this->iter_ << ", step = " << this->current_step_;
}
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "poly") {
rate = this->param_.base_lr() * pow(Dtype(1.) -
(Dtype(this->iter_) / Dtype(this->param_.max_iter())),
this->param_.power());
} else if (lr_policy == "sigmoid") {
rate = this->param_.base_lr() * (Dtype(1.) /
(Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
Dtype(this->param_.stepsize())))));
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
return rate;
}
template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
// Initialize the history
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
history_.clear();
update_.clear();
temp_.clear();
for (int i = 0; i < net_params.size(); ++i) {
const vector<int>& shape = net_params[i]->shape();
history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
}
}
template <typename Dtype>
void SGDSolver<Dtype>::ClipGradients() {
const Dtype clip_gradients = this->param_.clip_gradients();
if (clip_gradients < 0) { return; }
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
Dtype sumsq_diff = 0;
for (int i = 0; i < net_params.size(); ++i) {
sumsq_diff += net_params[i]->sumsq_diff();
}
const Dtype l2norm_diff = std::sqrt(sumsq_diff);
if (l2norm_diff > clip_gradients) {
Dtype scale_factor = clip_gradients / l2norm_diff;
LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
<< l2norm_diff << " > " << clip_gradients << ") "
<< "by scale factor " << scale_factor;
for (int i = 0; i < net_params.size(); ++i) {
net_params[i]->scale_diff(scale_factor);
}
}
}
//:ApplyUpdate()
template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
Dtype rate = GetLearningRate();//根據設置的學習率改變策略,計算當前迭代的學習率
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {//判斷是否需要輸出當前的學習率
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << this->iter_
<< ", lr = " << rate;
}
ClipGradients();//避免梯度爆炸,如果梯度的二範數超過某個數值則進行scale操作將梯度減小
//對所有可更新的網絡參數進行操作
for (int param_id = 0; param_id < this->net_->learnable_params().size();
++param_id) {
Normalize(param_id);//將第param_id個參數除以iter_size,這一步的作用是保證實際的batch_size=iter_size*設置的batch_size
Regularize(param_id);//將正則化部分的梯度降到每個參數的梯度中
ComputeUpdateValue(param_id, rate);//計算sgd算法的梯度
}
this->net_->Update();//調用網絡更新所有參數
}
//Normalize
template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {
if (this->param_.iter_size() == 1) { return; }/如果iter_size等於1,不用操作,直接返回
// Scale gradient to counterbalance accumulation.
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();//通過net返回所有的可學習參數,是vector類型
const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();//要乘以的係數等於1/iter_size
switch (Caffe::mode()) {
case Caffe::CPU: {
caffe_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_cpu_diff());
break;
//caffe_scal函數在src/caffe/util/math_functions.cpp中。是blas的scale函數的一個封裝。
//第一個參數是數據的個數,第二個參數是乘以的係數,第三個參數是數據的指針
}
case Caffe::GPU: {
#ifndef CPU_ONLY
caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
.//Regularize
template <typename Dtype>
void SGDSolver<Dtype>::Regularize(int param_id) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();//獲取所有可以學習的參數
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();//獲取所有參數對應的權重衰減
Dtype weight_decay = this->param_.weight_decay();//模型整體的權重衰減數值
string regularization_type = this->param_.regularization_type();//獲取正則化類型
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];//實際的衰減等於整體模型的衰減乘以具體每個參數的數值
switch (Caffe::mode()) {
case Caffe::CPU: {
if (local_decay) {
if (regularization_type == "L2") {
// L2的梯度是diff_=weight_decay*data_+diff_
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());//caffe_axpy函數是計算y=a*x+y即diff_=weight_delay*data+diff_第一個參數是數據的個數,第二個是a,第三個是data指針,第四個是y指針.
} else if (regularization_type == "L1") {
//L1的梯度是diff_=diff_+sign(data)
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
temp_[param_id]->mutable_cpu_data());//temp_=sign(data)
caffe_axpy(net_params[param_id]->count(),
local_decay,
temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());//將temp_添加到diff_中,diff_=weight_decay*temp_+diff_
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
#ifndef CPU_ONLY
template <typename Dtype>
void sgd_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum,
Dtype local_rate);
#endif
//ComputeUpdateValue
template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();//獲取所有可更新的參數
const vector<float>& net_params_lr = this->net_->params_lr();//獲取所有參數對應的學習率
Dtype momentum = this->param_.momentum();//獲取動量數值
Dtype local_rate = rate * net_params_lr[param_id];//實際的學習率等於全局的學習率乘以每個參數各自的學習率
// Compute the update to history, then copy it to the parameter diff.
switch (Caffe::mode()) {
case Caffe::CPU: {
//history存儲了上一次的梯度
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());//history_=lr*diff_+momentum*history
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());//把當前的梯度拷貝給參數blob的diff_
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
sgd_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
history_[param_id]->mutable_gpu_data(),
momentum, local_rate);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
switch (this->param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
SnapshotSolverStateToBinaryProto(model_filename);
break;
case caffe::SolverParameter_SnapshotFormat_HDF5:
SnapshotSolverStateToHDF5(model_filename);
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
}
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
const string& model_filename) {
SolverState state;
state.set_iter(this->iter_);
state.set_learned_net(model_filename);
state.set_current_step(this->current_step_);
state.clear_history();
for (int i = 0; i < history_.size(); ++i) {
// Add history
BlobProto* history_blob = state.add_history();
history_[i]->ToProto(history_blob);
}
string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
LOG(INFO)
<< "Snapshotting solver state to binary proto file " << snapshot_filename;
WriteProtoToBinaryFile(state, snapshot_filename.c_str());
}
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
const string& model_filename) {
string snapshot_filename =
Solver<Dtype>::SnapshotFilename(".solverstate.h5");
LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC,
H5P_DEFAULT, H5P_DEFAULT);
CHECK_GE(file_hid, 0)
<< "Couldn't open " << snapshot_filename << " to save solver state.";
hdf5_save_int(file_hid, "iter", this->iter_);
hdf5_save_string(file_hid, "learned_net", model_filename);
hdf5_save_int(file_hid, "current_step", this->current_step_);
hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT,
H5P_DEFAULT);
CHECK_GE(history_hid, 0)
<< "Error saving solver state to " << snapshot_filename << ".";
for (int i = 0; i < history_.size(); ++i) {
ostringstream oss;
oss << i;
hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]);
}
H5Gclose(history_hid);
H5Fclose(file_hid);
}
template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
const string& state_file) {
SolverState state;
ReadProtoFromBinaryFile(state_file, &state);
this->iter_ = state.iter();
if (state.has_learned_net()) {
NetParameter net_param;
ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
this->net_->CopyTrainedLayersFrom(net_param);
}
this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
LOG(INFO) << "SGDSolver: restoring history";
for (int i = 0; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
}
template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
this->iter_ = hdf5_load_int(file_hid, "iter");
if (H5LTfind_dataset(file_hid, "learned_net")) {
string learned_net = hdf5_load_string(file_hid, "learned_net");
this->net_->CopyTrainedLayersFrom(learned_net);
}
this->current_step_ = hdf5_load_int(file_hid, "current_step");
hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
int state_history_size = hdf5_get_num_links(history_hid);
CHECK_EQ(state_history_size, history_.size())
<< "Incorrect length of history blobs.";
for (int i = 0; i < history_.size(); ++i) {
ostringstream oss;
oss << i;
hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
kMaxBlobAxes, history_[i].get());
}
H5Gclose(history_hid);
H5Fclose(file_hid);
}
INSTANTIATE_CLASS(SGDSolver);
REGISTER_SOLVER_CLASS(SGD);//在代碼最後,調用宏完成註冊
} // namespace caffe