TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseBinaryProto源碼研讀

TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseBinaryProto源碼研讀

parseBinaryProto

/*
解析filename的內容,新建一個BinaryProtoBlob物件後回傳
*/
IBinaryProtoBlob* CaffeParser::parseBinaryProto(const char* fileName)
{
    //定義於TensorRT/parsers/caffe/caffeMacros.h
    CHECK_NULL_RET_NULL(fileName)
    using namespace google::protobuf::io;

    std::ifstream stream(fileName, std::ios::in | std::ios::binary);
    if (!stream)
    {
        //定義於TensorRT/parsers/caffe/caffeMacros.h
        RETURN_AND_LOG_ERROR(nullptr, "Could not open file " + std::string{fileName});
    }

    IstreamInputStream rawInput(&stream);
    CodedInputStream codedInput(&rawInput);
    codedInput.SetTotalBytesLimit(INT_MAX, -1);

	/*
	於TensorRT/parsers/caffe/proto/trtcaffe.proto:
	// Specifies the shape (dimensions) of a Blob.
	message BlobShape {
	  repeated int64 dim = 1 [packed = true];
	}
	
	message BlobProto {
	  optional BlobShape shape = 7;
	  repeated float data = 5 [packed = true];
	  repeated float diff = 6 [packed = true];
	  repeated double double_data = 8 [packed = true];
	  repeated double double_diff = 9 [packed = true];
	  // New raw storage (faster and takes 1/2 of space for FP16)
	  optional Type raw_data_type = 10;
	  optional Type raw_diff_type = 11;
	  optional bytes raw_data = 12 [packed = false];
	  optional bytes raw_diff = 13 [packed = false];
	  // 4D dimensions -- deprecated.  Use "shape" instead.
	  optional int32 num = 1 [default = 0];
	  optional int32 channels = 2 [default = 0];
	  optional int32 height = 3 [default = 0];
	  optional int32 width = 4 [default = 0];
	}
	*/
    trtcaffe::BlobProto blob;
    /*
    從給定的input stream裡解析出protocol buffer,
    並填入blob這個message物件中
    */
    bool ok = blob.ParseFromCodedStream(&codedInput);
    stream.close();

    if (!ok)
    {
        RETURN_AND_LOG_ERROR(nullptr, "parseBinaryProto: Could not parse mean file");
    }

    DimsNCHW dims{1, 1, 1, 1};
    if (blob.has_shape())
    {
        //blob.shape().dim_size():blob維度的數量
        int size = blob.shape().dim_size(), s[4] = {1, 1, 1, 1};
        //由後往前將s中的元素用blob.shape().dim中的元素取代
        for (int i = 4 - size; i < 4; i++)
        {
            /*
            BlobShape::dim是int64型別的陣列,
            所以這裡需要做檢查及型別轉換
            */
            assert(blob.shape().dim(i) < INT32_MAX);
            s[i] = static_cast<int>(blob.shape().dim(i));
        }
        dims = DimsNCHW{s[0], s[1], s[2], s[3]};
    }
    else
    {
        //維度順序:NCHW
        dims = DimsNCHW{blob.num(), blob.channels(), blob.height(), blob.width()};
    }

    const int dataSize = dims.n() * dims.c() * dims.h() * dims.w();
    assert(dataSize > 0);

	/*
	trtcaffe::Type
	定義於TensorRT/parsers/caffe/proto/trtcaffe.proto
	enum Type {
	  DOUBLE = 0;
	  FLOAT = 1;
	  FLOAT16 = 2;
	  INT = 3;  // math not supported
	  UINT = 4;  // math not supported
	}
	*/
	/*
	CaffeWeightFactory::getBlobProtoDataType
	定義於TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp
	獲取blobMsg的資料型別
	*/
    const trtcaffe::Type blobProtoDataType = CaffeWeightFactory::getBlobProtoDataType(blob);
    /*
    回傳一個pair,
    第一個元素為一個指標,
    指向存放轉為blobProtoDataType型別的blob裡的數據
    (raw_data,data或double_data)
    第二個元素為記憶體中的元素個數
    mTmpAllocs中則存放了上述指標
    */
    const auto blobProtoData = CaffeWeightFactory::getBlobProtoData(blob, blobProtoDataType, mTmpAllocs);

	/*
	dataSize:由blob的shape或num,channels,height及width計算而來
	blobProtoData.second:blob裡的數據(raw_data,data或double_data)的實際元素個數
	*/
    if (dataSize != (int) blobProtoData.second)
    {
        std::cout << "CaffeParser::parseBinaryProto: blob dimensions don't match data size!!" << std::endl;
        return nullptr;
    }

    //將已轉換為blobProtoDataType型別的blob中的數據複製到memory這個新申請的記憶體內
    const int dataSizeBytes = dataSize * CaffeWeightFactory::sizeOfCaffeType(blobProtoDataType);
    void* memory = malloc(dataSizeBytes);
    memcpy(memory, blobProtoData.first, dataSizeBytes);
    //新建一個BinaryProtoBlob物件後回傳
    return new BinaryProtoBlob(memory,
                               blobProtoDataType == trtcaffe::FLOAT ? DataType::kFLOAT : DataType::kHALF, dims);

    //unreachable code?
    std::cout << "CaffeParser::parseBinaryProto: couldn't find any data!!" << std::endl;
    return nullptr;
}

google::protobuf

parseBinaryProto中用到了IstreamInputStreamCodedInputStream等類別及ParseFromCodedStream等函數,詳見C++ google protobuf

CHECK_NULL_RET_NULL,RETURN_AND_LOG_ERROR

這兩個macro定義於定義於TensorRT/parsers/caffe/caffeMacros.h,詳見TensorRT/parsers/caffe/caffeMacros.h源碼研讀

return new

parseBinaryProto函數中有:

return new BinaryProtoBlob(memory,
    blobProtoDataType == trtcaffe::FLOAT ? DataType::kFLOAT : DataType::kHALF, dims);

這裡為何要特別使用new這個operator呢?詳見C++ new的使用場景

參考連結

C++ google protobuf

TensorRT/parsers/caffe/caffeMacros.h源碼研讀

C++ new的使用場景

發佈了145 篇原創文章 · 獲贊 9 · 訪問量 7萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章