Caffe工廠模式解析

Caffe有五個基本組件,分別是Blob,Solver,Net,Layer和Proto,其中Solver和Layer使用了工廠模式,下面以Slover爲例說明下。
Solver的工廠模式在註冊和調用的過程中體現,所以在說明工廠模式之前,我們首先要弄明白Solver在Caffe內部是如何被使用的。

Solver註冊機制

什麼是Solver註冊

我們都知道Layer和Slover是需要被註冊的,而所謂的註冊就是把這個類型的Slover(比如SDGSlover)找個地方記錄下來,好告訴後面的過程,有這個Slover了,需要的話可以來這裏調用。
這就和在CSDN註冊會員一樣,我們成功註冊爲會員,“用戶名”和“密碼”就被記錄下來了,然後可以進一步的完善信息,寫博客等等,這些都是我們這個賬戶裏面的內容了。下一次登錄的時候,我們需要使用“用戶名”來匹配,登錄我們的賬戶,而密碼只是一個安全措施。
Caffe中Slover有SGDSlover,AdaGradSolver,AdaDeltaSolver,AdamSolver,NesterovSolver,RMSPropSolver這六種,註冊的代碼在它們各自的源文件中,比如SGDSlover的註冊在sgd_solver.cpp的最下面:

REGISTER_SOLVER_CLASS(SGD);

SGD的就是solver.proto中type對應的字符串。
下面我們就從這行代碼開始,往前追蹤SGDSlover的註冊。

Solver如何被註冊

在這裏插入圖片描述

solver_factory.hpp中可以找到REGISTER_SOLVER_CLASS的定義,它是一個宏

    #define REGISTER_SOLVER_CLASS(type)                                            \
      template <typename Dtype>                                                    \
      Solver<Dtype>* Creator_##type##Solver(                                       \
          const SolverParameter& param)                                            \
      {                                                                            \
        return new type##Solver<Dtype>(param);                                     \
      }                                                                            \
      REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)

define 裏的 ##是一個連接符號,用於把參數連在一起 。而type其實就是SGD,編譯的時候這個宏會被替換,並將type換成SGD
,所以實際上這個宏就是完成了。

  template <typename Dtype>                                                    
  Solver<Dtype>* Creator_SGDSolver(const SolverParameter& param){                                                                            
    return new SGDSolver<Dtype>(param);                                     
  }                                                                            
  REGISTER_SOLVER_CREATOR(SGD, Creator_SGDSolver)

它定義了一個函數Creator_SGDSolver(),參數爲SolverParameter&類型的引用,返回值爲SGDSolver<Dtype>(param)

最後又調用了另一個宏REGISTER_SOLVER_CREATOR

#define REGISTER_SOLVER_CREATOR(type, creator)                                 \
  static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \
  static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \

還是想上面那樣替換它:

  static SolverRegisterer<float> g_creator_f_SGD("SGD", Creator_SGDSolver<float>);    
  static SolverRegisterer<double> g_creator_d_SGD("SGD", Creator_SGDSolver<double>);  

最後的目的就是要實例化SolverRegisterer類的兩個對象。SolverRegisterer是一個模板類,所以在實例化時候有SolverRegisterer<float>SolverRegisterer<double>,以支持兩種Slove的數據類型,分別對應float和double。
實例化時會調用SolverRegisterer類的構造函數,通過SolverRegisterer類定義,發現構造函數裏面調用了AddCreator()方法。

template <typename Dtype>
class SolverRegisterer {
 public:
  SolverRegisterer(const string& type,
      Solver<Dtype>* (*creator)(const SolverParameter&)) {
    // LOG(INFO) << "Registering solver type: " << type;
    SolverRegistry<Dtype>::AddCreator(type, creator);
  }
};

AddCreator()方法是另一個類SolverRegistry的成員,我們暫時只看SolverRegistry類下面這些成員就夠了,細節的地方做了註釋。

// LayerRegistry:註冊類,主要實現兩個方法,AddCreator()和CreateSolver(),下面代碼只有AddCreator()
template <typename Dtype>
class SolverRegistry {
 public:
  //定義名爲Creator的函數指針類型,參數爲SolverParameter&類型的引用,返回值爲一個Solver類型的指針
  typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
  //將一個map類型定義一個別名,叫做CreatorRegistry
  //map將“字符串-函數指針”行成映射
  typedef std::map<string, Creator> CreatorRegistry;

 // Registry()靜態函數,只創建一個map實例,僅第一次調用時會new,其它直接return
 //創建的map其實就是solver的內部註冊表
  static CreatorRegistry& Registry() {
    static CreatorRegistry* g_registry_ = new CreatorRegistry();
    return *g_registry_;
  }

  // Adds a creator.
  // AddCreator函數用來向Registry列表中添加一組<type, creator>
  static void AddCreator(const string& type, Creator creator) {
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 0)
        << "Solver type " << type << " already registered.";
    // 向map中加入一個映射
    registry[type] = creator;
  }
};

所以,當我們看到了 registry[type] = creator;這一行代碼時,也就找到了slover的註冊到底在做什麼,他其實就是在往registry變量裏添加一組映射,registry是靜態的,它只有一個,就是slover的註冊表;一組映射是CreatorRegistry,它實際是一個map,建立映射的兩個值分別stringCreator,string不用說,他就是像“SGD”,“Adam”,“AdaDelta”這樣的一個字符串,關鍵是和它建立映射的東西:Creator
Creator是一個函數指針,這個指針可以指向的函數要以SolverParameter&類型的引用作爲參數,並且返回值爲一個Solver類型的指針,Caffe裏面那個函數是這個樣子呢?就是在宏裏定義的那個函數:Creator_SGDSolver()
最終,SGDSlover的註冊是將字符串"SGD"和指向函數Creator_SGDSolver()的指針成對存儲到registry變量裏面。

Solver的調用

在這裏插入圖片描述
說完了註冊的部分,下面說明下調用,也就是程序的運行過程。
caffe的程序入庫在caffe.cpp的main()函數中,比如執行train的時候,調用了SolverRegistry類的CreateSolver()函數:

  shared_ptr<caffe::Solver<float> >
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

此時的Dtype已經指定爲了float類型,solver_param是從slover.proto裏面解析出來的。
CreateSolver()也在SolverRegistry類中定義:

template <typename Dtype>
class SolverRegistry {
 public:
  // Get a solver using a SolverParameter.
  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
    const string& type = param.type();
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
        << " (known types: " << SolverTypeListString() << ")";
    return registry[type](param);
  }
}

它實現了registry[type](param)的操作,實際上就是AddCreator()反過來的過程,一個是取,一個是存。同樣在"SGD"的時候,取出來的就應該是上面提到的Creator_SGDSolver(),而Creator_SGDSolver()的返回值是SGDSolver<Dtype>(param)
這個SGDSolver<Dtype>(param)就在sgd_solvers.hpp中定義,就是SGDSolver的構造函數:

/**
 * @brief Optimizes the parameters of a Net using
 *        stochastic gradient descent (SGD) with momentum.
 */
template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
 public:
  explicit SGDSolver(const SolverParameter& param)
      : Solver<Dtype>(param) { PreSolve(); }
  explicit SGDSolver(const string& param_file)
      : Solver<Dtype>(param_file) { PreSolve(); }
  virtual inline const char* type() const { return "SGD"; }

  const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
}

通過main()中的調用,Dtype指定爲了float。

Solver註冊發生在什麼時候

通過上面的分析,我們知道了所謂的註冊就是往map裏面存入,調用就是從map取出來,那就會有一個問題,註冊是在什麼時候發生的?
因爲registry就是個靜態變量,它的生命週期的開始一定在程序運行起來之後,但是程序運行起來就要從入口執行train了,這就要求在這之前registry裏就要完成註冊了,我們加個斷點調試一下。
在這裏插入圖片描述
一個斷點打在程序的入口處:
在這裏插入圖片描述
一個斷點打在註冊的地方:
在這裏插入圖片描述
啓動調試之後,先斷到了註冊的地方:
在這裏插入圖片描述
此時的type是"AdaDelta",因爲還沒有存入,所以registy的size=0,再走一步的話:
在這裏插入圖片描述

type變成了"AdaGrad",因爲已經存入了"AdaDelta",所以registy的size=1。
於是可以得到一個結論是,註冊的過程是在進入main函數之前完成。

此外,還可以用代碼圖的當時看下,首先改一下斷點的位置到:
在這裏插入圖片描述
開始執行調試,直到代碼執行到main中,生成代碼圖,就像下面這樣:
在這裏插入圖片描述

Solver的工廠模式

最後就是Solver的工廠模式了,上面的說明包含了工廠模式思想,下面我們工廠模式的角度再說明下。
Caffe中Slover的工廠模式是一種簡單工廠模式,只有一個工廠,負責生產多種產品。在solver_factory.hppSolverRegistry類定義了一個工廠,前面提到的註冊,是在完善工廠中選擇的邏輯,在很多簡單工廠的例子中,這個邏輯可以靠switch,case來實現,只是在caffe中它變成了一個“字符串”-“函數指針”的映射。
上面提到的調用的過程,就是工廠生產產品的過程,還拿SDG的例子:

  shared_ptr<caffe::Solver<float> >
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

儘管solver_param參數的不同,但是都調用工廠中的方法CreateSolver(),最終將生產的過程交給了產品的子類去實現,產品的子類實現就在各個優化器對應的源碼中。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章