TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseBinaryProto源碼研讀
parseBinaryProto
/*
解析filename的內容,新建一個BinaryProtoBlob物件後回傳
*/
IBinaryProtoBlob* CaffeParser::parseBinaryProto(const char* fileName)
{
//定義於TensorRT/parsers/caffe/caffeMacros.h
CHECK_NULL_RET_NULL(fileName)
using namespace google::protobuf::io;
std::ifstream stream(fileName, std::ios::in | std::ios::binary);
if (!stream)
{
//定義於TensorRT/parsers/caffe/caffeMacros.h
RETURN_AND_LOG_ERROR(nullptr, "Could not open file " + std::string{fileName});
}
IstreamInputStream rawInput(&stream);
CodedInputStream codedInput(&rawInput);
codedInput.SetTotalBytesLimit(INT_MAX, -1);
/*
於TensorRT/parsers/caffe/proto/trtcaffe.proto:
// Specifies the shape (dimensions) of a Blob.
message BlobShape {
repeated int64 dim = 1 [packed = true];
}
message BlobProto {
optional BlobShape shape = 7;
repeated float data = 5 [packed = true];
repeated float diff = 6 [packed = true];
repeated double double_data = 8 [packed = true];
repeated double double_diff = 9 [packed = true];
// New raw storage (faster and takes 1/2 of space for FP16)
optional Type raw_data_type = 10;
optional Type raw_diff_type = 11;
optional bytes raw_data = 12 [packed = false];
optional bytes raw_diff = 13 [packed = false];
// 4D dimensions -- deprecated. Use "shape" instead.
optional int32 num = 1 [default = 0];
optional int32 channels = 2 [default = 0];
optional int32 height = 3 [default = 0];
optional int32 width = 4 [default = 0];
}
*/
trtcaffe::BlobProto blob;
/*
從給定的input stream裡解析出protocol buffer,
並填入blob這個message物件中
*/
bool ok = blob.ParseFromCodedStream(&codedInput);
stream.close();
if (!ok)
{
RETURN_AND_LOG_ERROR(nullptr, "parseBinaryProto: Could not parse mean file");
}
DimsNCHW dims{1, 1, 1, 1};
if (blob.has_shape())
{
//blob.shape().dim_size():blob維度的數量
int size = blob.shape().dim_size(), s[4] = {1, 1, 1, 1};
//由後往前將s中的元素用blob.shape().dim中的元素取代
for (int i = 4 - size; i < 4; i++)
{
/*
BlobShape::dim是int64型別的陣列,
所以這裡需要做檢查及型別轉換
*/
assert(blob.shape().dim(i) < INT32_MAX);
s[i] = static_cast<int>(blob.shape().dim(i));
}
dims = DimsNCHW{s[0], s[1], s[2], s[3]};
}
else
{
//維度順序:NCHW
dims = DimsNCHW{blob.num(), blob.channels(), blob.height(), blob.width()};
}
const int dataSize = dims.n() * dims.c() * dims.h() * dims.w();
assert(dataSize > 0);
/*
trtcaffe::Type
定義於TensorRT/parsers/caffe/proto/trtcaffe.proto
enum Type {
DOUBLE = 0;
FLOAT = 1;
FLOAT16 = 2;
INT = 3; // math not supported
UINT = 4; // math not supported
}
*/
/*
CaffeWeightFactory::getBlobProtoDataType
定義於TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp
獲取blobMsg的資料型別
*/
const trtcaffe::Type blobProtoDataType = CaffeWeightFactory::getBlobProtoDataType(blob);
/*
回傳一個pair,
第一個元素為一個指標,
指向存放轉為blobProtoDataType型別的blob裡的數據
(raw_data,data或double_data)
第二個元素為記憶體中的元素個數
mTmpAllocs中則存放了上述指標
*/
const auto blobProtoData = CaffeWeightFactory::getBlobProtoData(blob, blobProtoDataType, mTmpAllocs);
/*
dataSize:由blob的shape或num,channels,height及width計算而來
blobProtoData.second:blob裡的數據(raw_data,data或double_data)的實際元素個數
*/
if (dataSize != (int) blobProtoData.second)
{
std::cout << "CaffeParser::parseBinaryProto: blob dimensions don't match data size!!" << std::endl;
return nullptr;
}
//將已轉換為blobProtoDataType型別的blob中的數據複製到memory這個新申請的記憶體內
const int dataSizeBytes = dataSize * CaffeWeightFactory::sizeOfCaffeType(blobProtoDataType);
void* memory = malloc(dataSizeBytes);
memcpy(memory, blobProtoData.first, dataSizeBytes);
//新建一個BinaryProtoBlob物件後回傳
return new BinaryProtoBlob(memory,
blobProtoDataType == trtcaffe::FLOAT ? DataType::kFLOAT : DataType::kHALF, dims);
//unreachable code?
std::cout << "CaffeParser::parseBinaryProto: couldn't find any data!!" << std::endl;
return nullptr;
}
google::protobuf
parseBinaryProto
中用到了IstreamInputStream
,CodedInputStream
等類別及ParseFromCodedStream
等函數,詳見C++ google protobuf。
CHECK_NULL_RET_NULL,RETURN_AND_LOG_ERROR
這兩個macro定義於定義於TensorRT/parsers/caffe/caffeMacros.h
,詳見TensorRT/parsers/caffe/caffeMacros.h源碼研讀。
return new
parseBinaryProto
函數中有:
return new BinaryProtoBlob(memory,
blobProtoDataType == trtcaffe::FLOAT ? DataType::kFLOAT : DataType::kHALF, dims);
這裡為何要特別使用new
這個operator呢?詳見C++ new的使用場景。