Caffe源碼解析

樓燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/

首先看到的是Blob這個類,Blob是作爲Caffe中數據流通的一個基本類,網絡各層之間的數據是通過Blob來傳遞的。這裏整個代碼是非常規範的,基本上條件編譯,命名空間,模板類,各種不太經常看到的關鍵字如exlicit,inline等等。
首先提一下explicit關鍵字的作用是禁止單參數構造函數的隱式轉換,具體含義谷歌即可。還有inline的作用,iniline主要是將代碼進行復制,擴充,會使代碼總量上升,好處就是可以節省調用的開銷,能提高執行效率。

1主要變量

shared_ptr<SyncedMemory> data_;shared_ptr<SyncedMemory> diff_;shared_ptr<SyncedMemory> shape_data_;vector<int> shape_;int count_;int capacity_;

BLob只是一個基本的數據結構,因此內部的變量相對較少,首先是data_指針,指針類型是shared_ptr,屬於boost庫的一個智能指針,這一部分主要用來申請內存存儲data,data主要是正向傳播的時候用的。同理,diff_主要用來存儲偏差,update data,shape_datashape_都是存儲Blob的形狀,一個是老版本一個是新版本。count表示Blob中的元素個數,也就是個數*通道數*高度*寬度,capacity表示當前的元素個數,因爲Blob可能會reshape。

2主要函數

template <typename Dtype>class Blob { public:
  Blob()
       : data_(), diff_(), count_(0), capacity_(0) {}  /// @brief Deprecated; use <code>Blob(const vector<int>& shape)</code>.  explicit Blob(const int num, const int channels, const int height,      const int width);  explicit Blob(const vector<int>& shape);  /// @brief Deprecated; use <code>Reshape(const vector<int>& shape)</code>.  void Reshape(const int num, const int channels, const int height,      const int width);

其中Blob作爲一個最基礎的類,其中構造函數開闢一個內存空間來存儲數據,Reshape函數在Layer中的reshape或者forward操作中來adjust dimension。同時在改變Blob大小時,內存將會被重新分配如果內存大小不夠了,並且額外的內存將不會被釋放。對input的blob進行reshape,如果立馬調用Net::Backward是會出錯的,因爲reshape之後,要麼Net::forward或者Net::Reshape就會被調用來將新的input shape 傳播到高層

Blob類裏面有重載很多個count()函數,主要還是爲了統計Blob的容量(volume),或者是某一片(slice),從某個axis到具體某個axis的shape乘積。

inline int count(int start_axis, int end_axis)

並且Blob的Index是可以從負座標開始讀的,這一點跟Python好像

inline int CanonicalAxisIndex(int axis_index)

對於Blob中的4個基本變量num,channel,height,width可以直接通過shape(0),shape(1),shape(2),shape(3)來訪問。

計算offset

inline int offset(const int n, const int c = 0, const int h = 0, const int w = 0)inline int offset(const vector<int>& indices)

offset計算的方式也支持兩種方式,一種直接指定n,c,h,w或者放到一個vector中進行計算,偏差是根據對應的n,c,h,w,返回的offset是((n * channels() + c) * height() + h) * width() + w

其實裏面稍加留意可以看到有很多的

CHECK_GE
CHECK_LE
CHECK_EQ
....

等等看意思就知道了,肯定是在做比較Geater or Eqal這樣的意思。這其實是GLOG,谷歌的一個日誌庫,Caffe裏面用用了大量這樣的宏,看起來也比較直觀

void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,bool reshape = false);

從一個blob中copy數據 ,通過開關控制是否copy_diff,如果是False則copy data。reshape控制是否需要reshape。好我們接着往下看

inline Dtype data_at(const int n, const int c, const int h, const int w)inline Dtype diff_at(const int n, const int c, const int h, const int w)inline Dtype data_at(const vector<int>& index)inline Dtype diff_at(const vector<int>& index)inline const shared_ptr<SyncedMemory>& data()inline const shared_ptr<SyncedMemory>& diff()

這一部分函數主要通過給定的位置訪問數據,根據位置計算與數據起始的偏差offset,在通過cpu_data*指針獲得地址。下面幾個函數都是獲得

const Dtype* cpu_data() const;void set_cpu_data(Dtype* data);const int* gpu_shape() const;const Dtype* gpu_data() const;const Dtype* cpu_diff() const;const Dtype* gpu_diff() const;Dtype* mutable_cpu_data();Dtype* mutable_gpu_data();Dtype* mutable_cpu_diff();Dtype* mutable_gpu_diff();

可以看到這裏有data和diff兩類數據,而這個diff就是我們所熟知的偏差,前者主要存儲前向傳遞的數據,而後者存儲的是反向傳播中的梯度

void Update();

看到update裏面面調用了

caffe_axpy<float>(const int N, const float alpha, const float* X,float* Y)
{ cblas_saxpy(N, alpha, X, 1, Y, 1); }

這個函數在caffe的util下面的match-functions.cpp裏面,主要是負責了線性代數庫的調用,實現的功能是

Y=alphaX+betaYY=alphaX+betaY


也就是blob裏面的data部分減去diff部分

 

void FromProto(const BlobProto& proto, bool reshape = true);void ToProto(BlobProto* proto, bool write_diff = false) const;

這兩個函數主要是將數據序列化,存儲到BlobProto,這裏說到Proto是谷歌的一個數據序列化的存儲格式,可以實現語言、平臺無關、可擴展的序列化結構數據格式。Caffe裏面數據的存儲都採用這一結構,這裏就不深入展開,具體可以參照這篇文章,對於proto的序列化和反序列都講解的非常詳細http://***/Article/34963

Dtype asum_data() const;//計算data的L1範數Dtype asum_diff() const;//計算diff的L1範數Dtype sumsq_data() const;//計算data的L2範數Dtype sumsq_diff() const;//計算diff的L2範數void scale_data(Dtype scale_factor);//將data部分乘以一個因子void scale_diff(Dtype scale_factor);//將diff部分乘一個因子

這幾個函數是一些零散的功能,一看就懂。

void ShareData(const Blob& other);void ShareData(const Blob& other);

這兩個函數看名字就知道了一個是共享data,一個是共享diff,具體就是將別的blob的data和響應的diff指針給這個Blob,實現數據的共享。同時需要注意的是這個操作會引起這個Blob裏面的SyncedMemory被釋放,因爲shared_ptr指針被用=重置的時候回調用響應的析構器。

bool ShapeEquals(const BlobProto& other);

這函數就不用說了,比較兩個Blob形狀是否相同
好了,基本上Blob的主要參數功能基本就涵蓋在裏面了,以上只是我的拙見,如有紕漏,還望指出,萬分感謝。


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章