TensorRT/parsers/caffe/caffeParser/caffeParser.cpp入口函數源碼研讀

前言

由於篇幅有限,筆者將 TensorRT/parsers/caffe/caffeParser/caffeParser.cpp 中的內容拆分成 TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parse源碼研讀TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseBinaryProto源碼研讀TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseNormalizeParam源碼研讀 等幾篇,本篇為第一篇。

TensorRT/parsers/caffe/caffeParser/caffeParser.cpp

/*
 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <iostream>

#include "caffeMacros.h"
#include "caffeParser.h"
#include "opParsers.h"
#include "parserUtils.h"
#include "readProto.h"
#include "binaryProtoBlob.h"
#include "google/protobuf/text_format.h"
#include "half.h"
#include "NvInferPluginUtils.h"

using namespace nvinfer1; //定義於多處,此處用的是include/NvInferPluginUtils.h
using namespace nvcaffeparser1; //定義於include/NvCaffeParser.h

CaffeParser::~CaffeParser()
{
    /*
    於TensorRT/parsers/caffe/caffeParser/caffeParser.h
    std::vector<void*> mTmpAllocs;
    */
    for (auto v : mTmpAllocs)
    {
        free(v);
    }
    /*
    於TensorRT/parsers/caffe/caffeParser/caffeParser.h
    std::vector<nvinfer1::IPluginV2*> mNewPlugins;
    */
    for (auto p : mNewPlugins)
    {
        if (p)
        {
            /*
            nvinfer1::IPluginV2::destroy
            宣告於TensorRT/include/NvInferRuntimeCommon.h
            virtual void destroy() TRTNOEXCEPT = 0;
            Destroy the plugin object. This will be called when the network, builder or engine is destroyed.
            */
            p->destroy();
        }
    }
    /*
    在parse(INetworkDefinition&, DataType, bool)中:
    mBlobNameToTensor = new (BlobNameToTensor);
    */
    delete mBlobNameToTensor;
}

/*
讀取buffer中的數據,
設定好mModel及mDeploy後,
填滿network並回傳mBlobNameToTensor這個字典,
如果建構過程都正常,最後會回傳mBlobNameToTensor,否則回傳nullptr
*/
const IBlobNameToTensor* CaffeParser::parseBuffers(const char* deployBuffer,
                                                   std::size_t deployLength,
                                                   const char* modelBuffer,
                                                   std::size_t modelLength,
                                                   INetworkDefinition& network,
                                                   DataType weightType)
{
    //設定CaffeParser成員變數std::shared_ptr<trtcaffe::NetParameter> mDeploy
    //宣告時是shared_ptr,到了這裡變成unique_ptr?
    mDeploy = std::unique_ptr<trtcaffe::NetParameter>(new trtcaffe::NetParameter);
    google::protobuf::io::ArrayInputStream deployStream(deployBuffer, deployLength);
    //從給定的ZeroCopyInputStream裡讀取並解析文字格式的protocol message,存到給定的Message物件當中
    if (!google::protobuf::TextFormat::Parse(&deployStream, mDeploy.get()))
    {
        RETURN_AND_LOG_ERROR(nullptr, "Could not parse deploy file");
    }

    //解析modelBuffer,設定CaffeParser類別的成員變數std::shared_ptr<trtcaffe::NetParameter> mModel
    if (modelBuffer)
    {
       //宣告時是shared_ptr,到了這裡變成unique_ptr?
        mModel = std::unique_ptr<trtcaffe::NetParameter>(new trtcaffe::NetParameter);
        google::protobuf::io::ArrayInputStream modelStream(modelBuffer, modelLength);
        google::protobuf::io::CodedInputStream codedModelStream(&modelStream);
        codedModelStream.SetTotalBytesLimit(modelLength, -1);

        if (!mModel->ParseFromCodedStream(&codedModelStream))
        {
            RETURN_AND_LOG_ERROR(nullptr, "Could not parse model file");
        }
    }
    
    /*
    設定好mModel及mDeploy後呼叫parse,
    用於填滿network並回傳mBlobNameToTensor這個字典,
    如果建構過程都正常,最後會回傳mBlobNameToTensor,否則回傳nullptr
    */
    return parse(network, weightType, modelBuffer != nullptr);
}

/*
讀取deploy檔及model檔中的數據,
設定好mModel及mDeploy後,
填滿network並回傳mBlobNameToTensor這個字典,
如果建構過程都正常,最後會回傳mBlobNameToTensor,否則回傳nullptr
*/
const IBlobNameToTensor* CaffeParser::parse(const char* deployFile,
                                            const char* modelFile,
                                            INetworkDefinition& network,
                                            DataType weightType)
{
    //因為是macro函數,所以結尾不加分號
    //檢查deployFile是否為nullptr,如果是,則回傳nullptr
    CHECK_NULL_RET_NULL(deployFile)

    // this is used to deal with dropout layers which have different input and output
    /*
    trtcaffe::NetParameter來自:
    TensorRT/parsers/caffe/proto/trtcaffe.proto裡的
    trtcaffe package下的NetParameter message
    */
    /*
    CaffeParser的成員變數
    std::shared_ptr<trtcaffe::NetParameter> mModel;
	為何到這裡變成unique_ptr?
    */
    mModel = std::unique_ptr<trtcaffe::NetParameter>(new trtcaffe::NetParameter);
	/*
	readBinaryProto 
	定義於TensorRT/parsers/caffe/caffeParser/readProto.h
	*/
    //將modelFile裡的權重讀到mModel.get()指標所指向的trtcaffe::NetParameter裡
    if (modelFile && !readBinaryProto(mModel.get(), modelFile, mProtobufBufferSize))
    {
        RETURN_AND_LOG_ERROR(nullptr, "Could not parse model file");
    }

	/*
	CaffeParser的成員變數
	std::shared_ptr<trtcaffe::NetParameter> mDeploy;
	為何到這裡變成unique_ptr?
	*/
	/*
	readTextProto 
	定義於TensorRT/parsers/caffe/caffeParser/readProto.h
	*/
	//將deployFile裡的模型架構讀到mDeploy.get()指標所指向的trtcaffe::NetParameter裡
    mDeploy = std::unique_ptr<trtcaffe::NetParameter>(new trtcaffe::NetParameter);
    if (!readTextProto(mDeploy.get(), deployFile))
    {
        RETURN_AND_LOG_ERROR(nullptr, "Could not parse deploy file");
    }

    /*
    設定好mModel及mDeploy後呼叫parse,
    用於填滿network並回傳mBlobNameToTensor這個字典,
    如果建構過程都正常,最後會回傳mBlobNameToTensor,否則回傳nullptr
    */
    return parse(network, weightType, modelFile != nullptr);
}

google::protobuf

用到了 google::protobuf::io::ArrayInputStreamgoogle::protobuf::io::CodedInputStream 等類別及 google::protobuf::TextFormat::ParseMessageLite::ParseFromCodedStream 等函數,詳見C++ google protobuf

RETURN_AND_LOG_ERROR,CHECK_NULL_RET_NULL

用到了RETURN_AND_LOG_ERRORCHECK_NULL_RET_NULL兩個函數,詳見TensorRT/parsers/caffe/caffeMacros.h源碼研讀

參考連結

C++ google protobuf

TensorRT/parsers/caffe/caffeMacros.h源碼研讀

TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parse源碼研讀

TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseBinaryProto源碼研讀

TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseNormalizeParam源碼研讀

發佈了145 篇原創文章 · 獲贊 9 · 訪問量 7萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章