Caffemodel解析

轉載自http://www.w2bc.com/Article/34963

因爲工作需要最近一直在琢磨Caffe,純粹新手,寫博客供以後查閱方便,請大神們批評指正!

Caffe中,數據的讀取、運算、存儲都是採用Google Protocol Buffer來進行的,所以首先來較爲詳細的介紹下Protocol Buffer(PB)。

PB是一種輕便、高效的結構化數據存儲格式,可以用於結構化數據串行化,很適合做數據存儲或 RPC 數據交換格式。它可用於通訊協議、數據存儲等領域的語言無關、平臺無關、可擴展的序列化結構數據格式。是一種效率和兼容性都很優秀的二進制數據傳輸格式,目前提供了 C++、Java、Python 三種語言的 API。Caffe採用的是C++和Python的API。

接下來,我用一個簡單的例子來說明一下。

使用PB和 C++ 編寫一個十分簡單的例子程序。該程序由兩部分組成。第一部分被稱爲Writer,第二部分叫做Reader。Writer 負責將一些結構化的數據寫入一個磁盤文件,Reader則負責從該磁盤文件中讀取結構化數據並打印到屏幕上。準備用於演示的結構化數據是HelloWorld,它包含兩個基本數據:

ID,爲一個整數類型的數據;

Str,這是一個字符串。

首先我們需要編寫一個proto文件,定義我們程序中需要處理的結構化數據,Caffe是定義在caffe.proto文件中。在PB的術語中,結構化數據被稱爲 Message。proto文件非常類似java或C語言的數據定義。代碼清單 1 顯示了例子應用中的proto文件內容。

清單 1. proto 文件
package lm; 

message helloworld 

 { 

    required int32     id = 1;   // ID    

    required string    str = 2;  // str 

    optional int32     opt = 3;  // optional field 

 }

一個比較好的習慣是認真對待proto文件的文件名。比如將命名規則定於如下: packageName.MessageName.proto

在上例中,package名字叫做 lm,定義了一個消息helloworld,該消息有三個成員,類型爲int32的id,另一個爲類型爲string的成員str。optional是一個可選的成員,即消息中可以不包含該成員,required表明是必須包含該成員。一般在定義中會出現如下三個字段屬性:

對於required的字段而言,初值是必須要提供的,否則字段的便是未初始化的。 在Debug模式的buffer庫下編譯的話,序列化話的時候可能會失敗,而且在反序列化的時候對於該字段的解析會總是失敗的。所以,對於修飾符爲required的字段,請在序列化的時候務必給予初始化。

對於optional的字段而言,如果未進行初始化,那麼一個默認值將賦予該字段,當然也可以指定默認值。

對於repeated的字段而言,該字段可以重複多個,谷歌提供的這個 addressbook例子便有個很好的該修飾符的應用場景,即每個人可能有多個電話號碼。在高級語言裏面,我們可以通過數組來實現,而在proto定義文件中可以使用repeated來修飾,從而達到相同目的。當然,出現0次也是包含在內的。

寫好proto文件之後就可以用PB編譯器(protoc)將該文件編譯成目標語言了。本例中我們將使用C++。假設proto文件存放在 $SRC_DIR 下面,您也想把生成的文件放在同一個目錄下,則可以使用如下命令:

protoc -I=$SRC_DIR --cpp_out=$DST_DIR $SRC_DIR/addressbook.proto

命令將生成兩個文件:

lm.helloworld.pb.h, 定義了C++ 類的頭文件;

lm.helloworld.pb.cc,C++類的實現文件。

在生成的頭文件中,定義了一個 C++ 類 helloworld,後面的 Writer 和 Reader 將使用這個類來對消息進行操作。諸如對消息的成員進行賦值,將消息序列化等等都有相應的方法。

如前所述,Writer將把一個結構化數據寫入磁盤,以便其他人來讀取。假如我們不使用 PB,其實也有許多的選擇。一個可能的方法是將數據轉換爲字符串,然後將字符串寫入磁盤。轉換爲字符串的方法可以使用 sprintf(),這非常簡單。數字 123 可以變成字符串”123”。這樣做似乎沒有什麼不妥,但是仔細考慮一下就會發現,這樣的做法對寫Reader的那個人的要求比較高,Reader 的作者必須瞭解Writer 的細節。比如”123”可以是單個數字 123,但也可以是三個數字 1、2 和 3等等。這麼說來,我們還必須讓Writer定義一種分隔符一樣的字符,以便Reader可以正確讀取。但分隔符也許還會引起其他的什麼問題。最後我們發現一個簡單的Helloworld 也需要寫許多處理消息格式的代碼。

如果使用 PB,那麼這些細節就可以不需要應用程序來考慮了。使用PB,Writer 的工作很簡單,需要處理的結構化數據由 .proto 文件描述,經過上一節中的編譯過程後,該數據化結構對應了一個 C++ 的類,並定義在 lm.helloworld.pb.h 中。對於本例,類名爲lm::helloworld。

Writer 需要include該頭文件,然後便可以使用這個類了。現在,在Writer代碼中,將要存入磁盤的結構化數據由一個lm::helloworld類的對象表示,它提供了一系列的 get/set 函數用來修改和讀取結構化數據中的數據成員,或者叫field。

當我們需要將該結構化數據保存到磁盤上時,類 lm::helloworld 已經提供相應的方法來把一個複雜的數據變成一個字節序列,我們可以將這個字節序列寫入磁盤。

對於想要讀取這個數據的程序來說,也只需要使用類 lm::helloworld 的相應反序列化方法來將這個字節序列重新轉換會結構化數據。這同我們開始時那個“123”的想法類似,不過PB想的遠遠比我們那個粗糙的字符串轉換要全面,因此,我們可以放心將這類事情交給PB吧。程序清單 2 演示了 Writer 的主要代碼。

清單 2. Writer 的主要代碼
 #include "lm.helloworld.pb.h"

…

 int main(void) 

 { 

  lm::helloworld msg1; 

  msg1.set_id(101);          //設置id

  msg1.set_str(“hello”);   //設置str

  // 向磁盤中寫入數據流fstream 

  fstream output("./log", ios::out | ios::trunc | ios::binary);  

  if (!msg1.SerializeToOstream(&output)) { 
 
      cerr << "Failed to write msg." << endl; 
 
      return -1; 

  }         

  return 0; 

 }

Msg1 是一個helloworld類的對象,set_id()用來設置id的值。SerializeToOstream將對象序列化後寫入一個fstream流。我們可以寫出Reader代碼,程序清單3列出了 reader 的主要代碼。

清單 3. Reader的主要代碼
#include "lm.helloworld.pb.h" 

…
 
 void ListMsg(const lm::helloworld & msg) { 

  cout << msg.id() << endl; 

  cout << msg.str() << endl; 

 } 

 int main(int argc, char* argv[]) { 

  lm::helloworld msg1; 

  { 

    fstream input("./log", ios::in | ios::binary); 

    if (!msg1.ParseFromIstream(&input)) { 

      cerr << "Failed to parse address book." << endl; 

      return -1; 

    } 
 
  } 
 
  ListMsg(msg1); 
 
  … 

 }

同樣,Reader 聲明類helloworld的對象msg1,然後利用ParseFromIstream從一個fstream流中讀取信息並反序列化。此後,ListMsg中採用get方法讀取消息的內部信息,並進行打印輸出操作。

運行Writer和Reader的結果如下:

 >writer 
 >reader 
 101 
 Hello

Reader 讀取文件 log 中的序列化信息並打印到屏幕上。這個例子本身並無意義,但只要稍加修改就可以將它變成更加有用的程序。比如將磁盤替換爲網絡 socket,那麼就可以實現基於網絡的數據交換任務。而存儲和交換正是PB最有效的應用領域。

到這裏爲止,我們只給出了一個簡單的沒有任何用處的例子。在實際應用中,人們往往需要定義更加複雜的 Message。我們用“複雜”這個詞,不僅僅是指從個數上說有更多的 fields 或者更多類型的 fields,而是指更加複雜的數據結構:嵌套 Message,Caffe.proto文件中定義了大量的嵌套Message。使得Message的表達能力增強很多。代碼清單 4 給出一個嵌套 Message 的例子。

清單 4. 嵌套 Message 的例子
 message Person {
  required string name = 1;
  required int32 id = 2;        // Unique ID number for this person.
  optional string email = 3;
  enum PhoneType {
    MOBILE = 0;
    HOME = 1;
    WORK = 2;
  }
 
  message PhoneNumber {
    required string number = 1;
    optional PhoneType type = 2 [default = HOME];
  }
  repeated PhoneNumber phone = 4;
 }

在 Message Person 中,定義了嵌套消息 PhoneNumber,並用來定義 Person 消息中的 phone 域。這使得人們可以定義更加複雜的數據結構。

以上部分參考網址:http://www.ibm.com/developerworks/cn/linux/l-cn-gpb/

在Caffe中也是類似於上例中的Writer和Reader去讀寫PB數據的。接下來,具體說明下Caffe中是如何存儲Caffemodel的。在Caffe主目錄下的solver.cpp文件中的一段代碼展示了Caffe是如何存儲Caffemodel的,代碼清單5如下:

清單 5. Caffemodel存儲代碼
template <typename Dtype>

void Solver<Dtype>::Snapshot() {

  NetParameter net_param;    // NetParameter爲網絡參數類

  // 爲了中間結果,也會寫入梯度值
 
  net_->ToProto(&net_param, param_.snapshot_diff());

  string filename(param_.snapshot_prefix());

  string model_filename, snapshot_filename;

  const int kBufferSize = 20;

  char iter_str_buffer[kBufferSize];

  // 每訓練完1次,iter_就加1 

snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_ + 1);

  filename += iter_str_buffer;
 
  model_filename = filename + ".caffemodel"; //XX_iter_YY.caffemodel
 
  LOG(INFO) << "Snapshotting to " << model_filename;

  // 向磁盤寫入網絡參數
 
  WriteProtoToBinaryFile(net_param, model_filename.c_str());

  SolverState state;
 
  SnapshotSolverState(&state);
 
  state.set_iter(iter_ + 1);    //set

  state.set_learned_net(model_filename);
 
  state.set_current_step(current_step_);
 
  snapshot_filename = filename + ".solverstate";
 
  LOG(INFO) << "Snapshotting solver state to " << snapshot_filename;
 
  // 向磁盤寫入網絡state
 
  WriteProtoToBinaryFile(state, snapshot_filename.c_str());
 
}

在清單5代碼中,我們可以看到,其實Caffemodel存儲的數據也就是網絡參數net_param的PB,Caffe可以保存每一次訓練完成後的網絡參數,我們可以通過XX.prototxt文件來進行參數設置。在這裏的 WriteProtoToBinaryFile函數與之前HelloWorld例子中的Writer函數類似,在這就不在貼出。那麼我們只要弄清楚NetParameter類的組成,也就明白了Caffemodel的具體數據構成。在caffe.proto這個文件中定義了NetParameter類,如代碼清單6所示。

清單6. Caffemodel存儲代碼
 message NetParameter {
 
   optional string name = 1;   // 網絡名稱
 
   repeated string input = 3;  // 網絡輸入input blobs
 
   repeated BlobShape input_shape = 8; // The shape of the input blobs
  
   // 輸入維度blobs,4維(num, channels, height and width)

  repeated int32 input_dim = 4;
 
   // 網絡是否強制每層進行反饋操作開關

  // 如果設置爲False,則會根據網絡結構和學習率自動確定是否進行反饋操作
 
   optional bool force_backward = 5 [default = false];
  
 // 網絡的state,部分網絡層依賴,部分不依賴,需要看具體網絡
 
   optional NetState state = 6;
 
   // 是否打印debug log
 
   optional bool debug_info = 7 [default = false];
 
   // 網絡層參數,Field Number 爲100,所以網絡層參數在最後
 
   repeated LayerParameter layer = 100; 
 
   // 棄用: 用 'layer' 代替
 
   repeated V1LayerParameter layers = 2;
 
 }
 
 // Specifies the shape (dimensions) of a Blob.
 
 message BlobShape {
 
   repeated int64 dim = 1 [packed = true];
 
 }
 
 message BlobProto {
 
   optional BlobShape shape = 7;
 
   repeated float data = 5 [packed = true];
 
   repeated float diff = 6 [packed = true];
 
   optional int32 num = 1 [default = 0];
 
   optional int32 channels = 2 [default = 0];
 
   optional int32 height = 3 [default = 0];
 
   optional int32 width = 4 [default = 0];
 
 }
 
  
 
 // The BlobProtoVector is simply a way to pass multiple blobproto instances
 
 around.
 
 message BlobProtoVector {
 
   repeated BlobProto blobs = 1;
 
 }
 
 message NetState {
 
   optional Phase phase = 1 [default = TEST];
 
   optional int32 level = 2 [default = 0];
 
   repeated string stage = 3;
 
 }
 
 message LayerParameter {
 
   optional string name = 1;   // the layer name

   optional string type = 2;   // the layer type
 
   repeated string bottom = 3; // the name of each bottom blob
 
   repeated string top = 4;    // the name of each top blob
 
   // The train/test phase for computation.
 
   optional Phase phase = 10;
 
   // Loss weight值:float
 
   // 每一層爲每一個top blob都分配了一個默認值,通常是0或1
 
   repeated float loss_weight = 5;
 
   // 指定的學習參數
 
   repeated ParamSpec param = 6;
 
   // The blobs containing the numeric parameters of the layer.
 
   repeated BlobProto blobs = 7;
 
   // included/excluded.
 
   repeated NetStateRule include = 8;
 
   repeated NetStateRule exclude = 9;
 
   // Parameters for data pre-processing.
 
   optional TransformationParameter transform_param = 100;
 
   // Parameters shared by loss layers.
 
   optional LossParameter loss_param = 101;
 
   // 各種類型層參數
 
   optional AccuracyParameter accuracy_param = 102;
 
   optional ArgMaxParameter argmax_param = 103;
 
   optional ConcatParameter concat_param = 104;
 
   optional ContrastiveLossParameter contrastive_loss_param = 105;
 
   optional ConvolutionParameter convolution_param = 106;
 
   optional DataParameter data_param = 107;
 
   optional DropoutParameter dropout_param = 108;
 
   optional DummyDataParameter dummy_data_param = 109;
 
   optional EltwiseParameter eltwise_param = 110;
 
   optional ExpParameter exp_param = 111;
 
   optional HDF5DataParameter hdf5_data_param = 112;
 
   optional HDF5OutputParameter hdf5_output_param = 113;
 
   optional HingeLossParameter hinge_loss_param = 114;

   optional ImageDataParameter image_data_param = 115;
 
   optional InfogainLossParameter infogain_loss_param = 116;
 
   optional InnerProductParameter inner_product_param = 117;
 
   optional LRNParameter lrn_param = 118;
 
   optional MemoryDataParameter memory_data_param = 119;
 
   optional MVNParameter mvn_param = 120;
 
   optional PoolingParameter pooling_param = 121;
 
   optional PowerParameter power_param = 122;
 
   optional PythonParameter python_param = 130;
 
   optional ReLUParameter relu_param = 123;
 
   optional SigmoidParameter sigmoid_param = 124;
 
   optional SoftmaxParameter softmax_param = 125;
 
   optional SliceParameter slice_param = 126;
 
   optional TanHParameter tanh_param = 127;
 
   optional ThresholdParameter threshold_param = 128;
 
   optional WindowDataParameter window_data_param = 129;
 
 }

那麼接下來的一段代碼來演示如何解析Caffemodel,我解析用的model爲MNIST手寫庫訓練後的model,Lenet_iter_10000.caffemodel。

清單7. Caffemodel解析代碼
 #include <stdio.h>
 #include <string.h>
 #include <fstream>
 #include <iostream>
 #include "proto/caffe.pb.h"

 using namespace std;
 using namespace caffe;

 int main(int argc, char* argv[]) 
 { 
 
  caffe::NetParameter msg; 

  fstream input("lenet_iter_10000.caffemodel", ios::in | ios::binary); 
  if (!msg.ParseFromIstream(&input)) 
  { 
    cerr << "Failed to parse address book." << endl; 
    return -1; 
  } 
  printf("length = %d\n", length);
  printf("Repeated Size = %d\n", msg.layer_size());

  ::google::protobuf::RepeatedPtrField< LayerParameter >* layer = msg.mutable_layer();
  ::google::protobuf::RepeatedPtrField< LayerParameter >::iterator it = layer->begin();
  for (; it != layer->end(); ++it)
  {
    cout << it->name() << endl;
    cout << it->type() << endl;
    cout << it->convolution_param().weight_filler().max() << endl;
  } 

  return 0;
 }
參考網址:http://www.cnblogs.com/stephen-liu74/archive/2013/01/04/2842533.html


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章