轉載自http://www.w2bc.com/Article/34963
因爲工作需要最近一直在琢磨Caffe,純粹新手,寫博客供以後查閱方便,請大神們批評指正!
Caffe中,數據的讀取、運算、存儲都是採用Google Protocol Buffer來進行的,所以首先來較爲詳細的介紹下Protocol Buffer(PB)。
PB是一種輕便、高效的結構化數據存儲格式,可以用於結構化數據串行化,很適合做數據存儲或 RPC 數據交換格式。它可用於通訊協議、數據存儲等領域的語言無關、平臺無關、可擴展的序列化結構數據格式。是一種效率和兼容性都很優秀的二進制數據傳輸格式,目前提供了 C++、Java、Python 三種語言的 API。Caffe採用的是C++和Python的API。
接下來,我用一個簡單的例子來說明一下。
使用PB和 C++ 編寫一個十分簡單的例子程序。該程序由兩部分組成。第一部分被稱爲Writer,第二部分叫做Reader。Writer 負責將一些結構化的數據寫入一個磁盤文件,Reader則負責從該磁盤文件中讀取結構化數據並打印到屏幕上。準備用於演示的結構化數據是HelloWorld,它包含兩個基本數據:
ID,爲一個整數類型的數據;
Str,這是一個字符串。
首先我們需要編寫一個proto文件,定義我們程序中需要處理的結構化數據,Caffe是定義在caffe.proto文件中。在PB的術語中,結構化數據被稱爲 Message。proto文件非常類似java或C語言的數據定義。代碼清單 1 顯示了例子應用中的proto文件內容。
清單 1. proto 文件
package lm;
message helloworld
{
required int32 id = 1; // ID
required string str = 2; // str
optional int32 opt = 3; // optional field
}
一個比較好的習慣是認真對待proto文件的文件名。比如將命名規則定於如下: packageName.MessageName.proto
在上例中,package名字叫做 lm,定義了一個消息helloworld,該消息有三個成員,類型爲int32的id,另一個爲類型爲string的成員str。optional是一個可選的成員,即消息中可以不包含該成員,required表明是必須包含該成員。一般在定義中會出現如下三個字段屬性:
對於required的字段而言,初值是必須要提供的,否則字段的便是未初始化的。 在Debug模式的buffer庫下編譯的話,序列化話的時候可能會失敗,而且在反序列化的時候對於該字段的解析會總是失敗的。所以,對於修飾符爲required的字段,請在序列化的時候務必給予初始化。
對於optional的字段而言,如果未進行初始化,那麼一個默認值將賦予該字段,當然也可以指定默認值。
對於repeated的字段而言,該字段可以重複多個,谷歌提供的這個 addressbook例子便有個很好的該修飾符的應用場景,即每個人可能有多個電話號碼。在高級語言裏面,我們可以通過數組來實現,而在proto定義文件中可以使用repeated來修飾,從而達到相同目的。當然,出現0次也是包含在內的。
寫好proto文件之後就可以用PB編譯器(protoc)將該文件編譯成目標語言了。本例中我們將使用C++。假設proto文件存放在 $SRC_DIR 下面,您也想把生成的文件放在同一個目錄下,則可以使用如下命令:
protoc -I=$SRC_DIR --cpp_out=$DST_DIR $SRC_DIR/addressbook.proto
命令將生成兩個文件:
lm.helloworld.pb.h, 定義了C++ 類的頭文件;
lm.helloworld.pb.cc,C++類的實現文件。
在生成的頭文件中,定義了一個 C++ 類 helloworld,後面的 Writer 和 Reader 將使用這個類來對消息進行操作。諸如對消息的成員進行賦值,將消息序列化等等都有相應的方法。
如前所述,Writer將把一個結構化數據寫入磁盤,以便其他人來讀取。假如我們不使用 PB,其實也有許多的選擇。一個可能的方法是將數據轉換爲字符串,然後將字符串寫入磁盤。轉換爲字符串的方法可以使用 sprintf(),這非常簡單。數字 123 可以變成字符串”123”。這樣做似乎沒有什麼不妥,但是仔細考慮一下就會發現,這樣的做法對寫Reader的那個人的要求比較高,Reader 的作者必須瞭解Writer 的細節。比如”123”可以是單個數字 123,但也可以是三個數字 1、2 和 3等等。這麼說來,我們還必須讓Writer定義一種分隔符一樣的字符,以便Reader可以正確讀取。但分隔符也許還會引起其他的什麼問題。最後我們發現一個簡單的Helloworld 也需要寫許多處理消息格式的代碼。
如果使用 PB,那麼這些細節就可以不需要應用程序來考慮了。使用PB,Writer 的工作很簡單,需要處理的結構化數據由 .proto 文件描述,經過上一節中的編譯過程後,該數據化結構對應了一個 C++ 的類,並定義在 lm.helloworld.pb.h 中。對於本例,類名爲lm::helloworld。
Writer 需要include該頭文件,然後便可以使用這個類了。現在,在Writer代碼中,將要存入磁盤的結構化數據由一個lm::helloworld類的對象表示,它提供了一系列的 get/set 函數用來修改和讀取結構化數據中的數據成員,或者叫field。
當我們需要將該結構化數據保存到磁盤上時,類 lm::helloworld 已經提供相應的方法來把一個複雜的數據變成一個字節序列,我們可以將這個字節序列寫入磁盤。
對於想要讀取這個數據的程序來說,也只需要使用類 lm::helloworld 的相應反序列化方法來將這個字節序列重新轉換會結構化數據。這同我們開始時那個“123”的想法類似,不過PB想的遠遠比我們那個粗糙的字符串轉換要全面,因此,我們可以放心將這類事情交給PB吧。程序清單 2 演示了 Writer 的主要代碼。
清單 2. Writer 的主要代碼
#include "lm.helloworld.pb.h"
…
int main(void)
{
lm::helloworld msg1;
msg1.set_id(101); //設置id
msg1.set_str(“hello”); //設置str
// 向磁盤中寫入數據流fstream
fstream output("./log", ios::out | ios::trunc | ios::binary);
if (!msg1.SerializeToOstream(&output)) {
cerr << "Failed to write msg." << endl;
return -1;
}
return 0;
}
Msg1 是一個helloworld類的對象,set_id()用來設置id的值。SerializeToOstream將對象序列化後寫入一個fstream流。我們可以寫出Reader代碼,程序清單3列出了 reader 的主要代碼。
清單 3. Reader的主要代碼
#include "lm.helloworld.pb.h"
…
void ListMsg(const lm::helloworld & msg) {
cout << msg.id() << endl;
cout << msg.str() << endl;
}
int main(int argc, char* argv[]) {
lm::helloworld msg1;
{
fstream input("./log", ios::in | ios::binary);
if (!msg1.ParseFromIstream(&input)) {
cerr << "Failed to parse address book." << endl;
return -1;
}
}
ListMsg(msg1);
…
}
同樣,Reader 聲明類helloworld的對象msg1,然後利用ParseFromIstream從一個fstream流中讀取信息並反序列化。此後,ListMsg中採用get方法讀取消息的內部信息,並進行打印輸出操作。
運行Writer和Reader的結果如下:
>writer
>reader
101
Hello
Reader 讀取文件 log 中的序列化信息並打印到屏幕上。這個例子本身並無意義,但只要稍加修改就可以將它變成更加有用的程序。比如將磁盤替換爲網絡 socket,那麼就可以實現基於網絡的數據交換任務。而存儲和交換正是PB最有效的應用領域。
到這裏爲止,我們只給出了一個簡單的沒有任何用處的例子。在實際應用中,人們往往需要定義更加複雜的 Message。我們用“複雜”這個詞,不僅僅是指從個數上說有更多的 fields 或者更多類型的 fields,而是指更加複雜的數據結構:嵌套 Message,Caffe.proto文件中定義了大量的嵌套Message。使得Message的表達能力增強很多。代碼清單 4 給出一個嵌套 Message 的例子。
清單 4. 嵌套 Message 的例子
message Person {
required string name = 1;
required int32 id = 2; // Unique ID number for this person.
optional string email = 3;
enum PhoneType {
MOBILE = 0;
HOME = 1;
WORK = 2;
}
message PhoneNumber {
required string number = 1;
optional PhoneType type = 2 [default = HOME];
}
repeated PhoneNumber phone = 4;
}
在 Message Person 中,定義了嵌套消息 PhoneNumber,並用來定義 Person 消息中的 phone 域。這使得人們可以定義更加複雜的數據結構。
以上部分參考網址:http://www.ibm.com/developerworks/cn/linux/l-cn-gpb/
在Caffe中也是類似於上例中的Writer和Reader去讀寫PB數據的。接下來,具體說明下Caffe中是如何存儲Caffemodel的。在Caffe主目錄下的solver.cpp文件中的一段代碼展示了Caffe是如何存儲Caffemodel的,代碼清單5如下:
清單 5. Caffemodel存儲代碼
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
NetParameter net_param; // NetParameter爲網絡參數類
// 爲了中間結果,也會寫入梯度值
net_->ToProto(&net_param, param_.snapshot_diff());
string filename(param_.snapshot_prefix());
string model_filename, snapshot_filename;
const int kBufferSize = 20;
char iter_str_buffer[kBufferSize];
// 每訓練完1次,iter_就加1
snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_ + 1);
filename += iter_str_buffer;
model_filename = filename + ".caffemodel"; //XX_iter_YY.caffemodel
LOG(INFO) << "Snapshotting to " << model_filename;
// 向磁盤寫入網絡參數
WriteProtoToBinaryFile(net_param, model_filename.c_str());
SolverState state;
SnapshotSolverState(&state);
state.set_iter(iter_ + 1); //set
state.set_learned_net(model_filename);
state.set_current_step(current_step_);
snapshot_filename = filename + ".solverstate";
LOG(INFO) << "Snapshotting solver state to " << snapshot_filename;
// 向磁盤寫入網絡state
WriteProtoToBinaryFile(state, snapshot_filename.c_str());
}
在清單5代碼中,我們可以看到,其實Caffemodel存儲的數據也就是網絡參數net_param的PB,Caffe可以保存每一次訓練完成後的網絡參數,我們可以通過XX.prototxt文件來進行參數設置。在這裏的 WriteProtoToBinaryFile函數與之前HelloWorld例子中的Writer函數類似,在這就不在貼出。那麼我們只要弄清楚NetParameter類的組成,也就明白了Caffemodel的具體數據構成。在caffe.proto這個文件中定義了NetParameter類,如代碼清單6所示。
清單6. Caffemodel存儲代碼
message NetParameter {
optional string name = 1; // 網絡名稱
repeated string input = 3; // 網絡輸入input blobs
repeated BlobShape input_shape = 8; // The shape of the input blobs
// 輸入維度blobs,4維(num, channels, height and width)
repeated int32 input_dim = 4;
// 網絡是否強制每層進行反饋操作開關
// 如果設置爲False,則會根據網絡結構和學習率自動確定是否進行反饋操作
optional bool force_backward = 5 [default = false];
// 網絡的state,部分網絡層依賴,部分不依賴,需要看具體網絡
optional NetState state = 6;
// 是否打印debug log
optional bool debug_info = 7 [default = false];
// 網絡層參數,Field Number 爲100,所以網絡層參數在最後
repeated LayerParameter layer = 100;
// 棄用: 用 'layer' 代替
repeated V1LayerParameter layers = 2;
}
// Specifies the shape (dimensions) of a Blob.
message BlobShape {
repeated int64 dim = 1 [packed = true];
}
message BlobProto {
optional BlobShape shape = 7;
repeated float data = 5 [packed = true];
repeated float diff = 6 [packed = true];
optional int32 num = 1 [default = 0];
optional int32 channels = 2 [default = 0];
optional int32 height = 3 [default = 0];
optional int32 width = 4 [default = 0];
}
// The BlobProtoVector is simply a way to pass multiple blobproto instances
around.
message BlobProtoVector {
repeated BlobProto blobs = 1;
}
message NetState {
optional Phase phase = 1 [default = TEST];
optional int32 level = 2 [default = 0];
repeated string stage = 3;
}
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
repeated string bottom = 3; // the name of each bottom blob
repeated string top = 4; // the name of each top blob
// The train/test phase for computation.
optional Phase phase = 10;
// Loss weight值:float
// 每一層爲每一個top blob都分配了一個默認值,通常是0或1
repeated float loss_weight = 5;
// 指定的學習參數
repeated ParamSpec param = 6;
// The blobs containing the numeric parameters of the layer.
repeated BlobProto blobs = 7;
// included/excluded.
repeated NetStateRule include = 8;
repeated NetStateRule exclude = 9;
// Parameters for data pre-processing.
optional TransformationParameter transform_param = 100;
// Parameters shared by loss layers.
optional LossParameter loss_param = 101;
// 各種類型層參數
optional AccuracyParameter accuracy_param = 102;
optional ArgMaxParameter argmax_param = 103;
optional ConcatParameter concat_param = 104;
optional ContrastiveLossParameter contrastive_loss_param = 105;
optional ConvolutionParameter convolution_param = 106;
optional DataParameter data_param = 107;
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
optional ExpParameter exp_param = 111;
optional HDF5DataParameter hdf5_data_param = 112;
optional HDF5OutputParameter hdf5_output_param = 113;
optional HingeLossParameter hinge_loss_param = 114;
optional ImageDataParameter image_data_param = 115;
optional InfogainLossParameter infogain_loss_param = 116;
optional InnerProductParameter inner_product_param = 117;
optional LRNParameter lrn_param = 118;
optional MemoryDataParameter memory_data_param = 119;
optional MVNParameter mvn_param = 120;
optional PoolingParameter pooling_param = 121;
optional PowerParameter power_param = 122;
optional PythonParameter python_param = 130;
optional ReLUParameter relu_param = 123;
optional SigmoidParameter sigmoid_param = 124;
optional SoftmaxParameter softmax_param = 125;
optional SliceParameter slice_param = 126;
optional TanHParameter tanh_param = 127;
optional ThresholdParameter threshold_param = 128;
optional WindowDataParameter window_data_param = 129;
}
那麼接下來的一段代碼來演示如何解析Caffemodel,我解析用的model爲MNIST手寫庫訓練後的model,Lenet_iter_10000.caffemodel。
清單7. Caffemodel解析代碼
#include <stdio.h>
#include <string.h>
#include <fstream>
#include <iostream>
#include "proto/caffe.pb.h"
using namespace std;
using namespace caffe;
int main(int argc, char* argv[])
{
caffe::NetParameter msg;
fstream input("lenet_iter_10000.caffemodel", ios::in | ios::binary);
if (!msg.ParseFromIstream(&input))
{
cerr << "Failed to parse address book." << endl;
return -1;
}
printf("length = %d\n", length);
printf("Repeated Size = %d\n", msg.layer_size());
::google::protobuf::RepeatedPtrField< LayerParameter >* layer = msg.mutable_layer();
::google::protobuf::RepeatedPtrField< LayerParameter >::iterator it = layer->begin();
for (; it != layer->end(); ++it)
{
cout << it->name() << endl;
cout << it->type() << endl;
cout << it->convolution_param().weight_filler().max() << endl;
}
return 0;
}
參考網址:http://www.cnblogs.com/stephen-liu74/archive/2013/01/04/2842533.html