TensorRT/parsers/caffe/caffeParser/caffeParser.cpp入口函數源碼研讀
前言
由於篇幅有限,筆者將 TensorRT/parsers/caffe/caffeParser/caffeParser.cpp
中的內容拆分成 TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parse源碼研讀, TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseBinaryProto源碼研讀, TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseNormalizeParam源碼研讀 等幾篇,本篇為第一篇。
TensorRT/parsers/caffe/caffeParser/caffeParser.cpp
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "caffeMacros.h"
#include "caffeParser.h"
#include "opParsers.h"
#include "parserUtils.h"
#include "readProto.h"
#include "binaryProtoBlob.h"
#include "google/protobuf/text_format.h"
#include "half.h"
#include "NvInferPluginUtils.h"
using namespace nvinfer1; //定義於多處,此處用的是include/NvInferPluginUtils.h
using namespace nvcaffeparser1; //定義於include/NvCaffeParser.h
CaffeParser::~CaffeParser()
{
/*
於TensorRT/parsers/caffe/caffeParser/caffeParser.h
std::vector<void*> mTmpAllocs;
*/
for (auto v : mTmpAllocs)
{
free(v);
}
/*
於TensorRT/parsers/caffe/caffeParser/caffeParser.h
std::vector<nvinfer1::IPluginV2*> mNewPlugins;
*/
for (auto p : mNewPlugins)
{
if (p)
{
/*
nvinfer1::IPluginV2::destroy
宣告於TensorRT/include/NvInferRuntimeCommon.h
virtual void destroy() TRTNOEXCEPT = 0;
Destroy the plugin object. This will be called when the network, builder or engine is destroyed.
*/
p->destroy();
}
}
/*
在parse(INetworkDefinition&, DataType, bool)中:
mBlobNameToTensor = new (BlobNameToTensor);
*/
delete mBlobNameToTensor;
}
/*
讀取buffer中的數據,
設定好mModel及mDeploy後,
填滿network並回傳mBlobNameToTensor這個字典,
如果建構過程都正常,最後會回傳mBlobNameToTensor,否則回傳nullptr
*/
const IBlobNameToTensor* CaffeParser::parseBuffers(const char* deployBuffer,
std::size_t deployLength,
const char* modelBuffer,
std::size_t modelLength,
INetworkDefinition& network,
DataType weightType)
{
//設定CaffeParser成員變數std::shared_ptr<trtcaffe::NetParameter> mDeploy
//宣告時是shared_ptr,到了這裡變成unique_ptr?
mDeploy = std::unique_ptr<trtcaffe::NetParameter>(new trtcaffe::NetParameter);
google::protobuf::io::ArrayInputStream deployStream(deployBuffer, deployLength);
//從給定的ZeroCopyInputStream裡讀取並解析文字格式的protocol message,存到給定的Message物件當中
if (!google::protobuf::TextFormat::Parse(&deployStream, mDeploy.get()))
{
RETURN_AND_LOG_ERROR(nullptr, "Could not parse deploy file");
}
//解析modelBuffer,設定CaffeParser類別的成員變數std::shared_ptr<trtcaffe::NetParameter> mModel
if (modelBuffer)
{
//宣告時是shared_ptr,到了這裡變成unique_ptr?
mModel = std::unique_ptr<trtcaffe::NetParameter>(new trtcaffe::NetParameter);
google::protobuf::io::ArrayInputStream modelStream(modelBuffer, modelLength);
google::protobuf::io::CodedInputStream codedModelStream(&modelStream);
codedModelStream.SetTotalBytesLimit(modelLength, -1);
if (!mModel->ParseFromCodedStream(&codedModelStream))
{
RETURN_AND_LOG_ERROR(nullptr, "Could not parse model file");
}
}
/*
設定好mModel及mDeploy後呼叫parse,
用於填滿network並回傳mBlobNameToTensor這個字典,
如果建構過程都正常,最後會回傳mBlobNameToTensor,否則回傳nullptr
*/
return parse(network, weightType, modelBuffer != nullptr);
}
/*
讀取deploy檔及model檔中的數據,
設定好mModel及mDeploy後,
填滿network並回傳mBlobNameToTensor這個字典,
如果建構過程都正常,最後會回傳mBlobNameToTensor,否則回傳nullptr
*/
const IBlobNameToTensor* CaffeParser::parse(const char* deployFile,
const char* modelFile,
INetworkDefinition& network,
DataType weightType)
{
//因為是macro函數,所以結尾不加分號
//檢查deployFile是否為nullptr,如果是,則回傳nullptr
CHECK_NULL_RET_NULL(deployFile)
// this is used to deal with dropout layers which have different input and output
/*
trtcaffe::NetParameter來自:
TensorRT/parsers/caffe/proto/trtcaffe.proto裡的
trtcaffe package下的NetParameter message
*/
/*
CaffeParser的成員變數
std::shared_ptr<trtcaffe::NetParameter> mModel;
為何到這裡變成unique_ptr?
*/
mModel = std::unique_ptr<trtcaffe::NetParameter>(new trtcaffe::NetParameter);
/*
readBinaryProto
定義於TensorRT/parsers/caffe/caffeParser/readProto.h
*/
//將modelFile裡的權重讀到mModel.get()指標所指向的trtcaffe::NetParameter裡
if (modelFile && !readBinaryProto(mModel.get(), modelFile, mProtobufBufferSize))
{
RETURN_AND_LOG_ERROR(nullptr, "Could not parse model file");
}
/*
CaffeParser的成員變數
std::shared_ptr<trtcaffe::NetParameter> mDeploy;
為何到這裡變成unique_ptr?
*/
/*
readTextProto
定義於TensorRT/parsers/caffe/caffeParser/readProto.h
*/
//將deployFile裡的模型架構讀到mDeploy.get()指標所指向的trtcaffe::NetParameter裡
mDeploy = std::unique_ptr<trtcaffe::NetParameter>(new trtcaffe::NetParameter);
if (!readTextProto(mDeploy.get(), deployFile))
{
RETURN_AND_LOG_ERROR(nullptr, "Could not parse deploy file");
}
/*
設定好mModel及mDeploy後呼叫parse,
用於填滿network並回傳mBlobNameToTensor這個字典,
如果建構過程都正常,最後會回傳mBlobNameToTensor,否則回傳nullptr
*/
return parse(network, weightType, modelFile != nullptr);
}
google::protobuf
用到了 google::protobuf::io::ArrayInputStream
, google::protobuf::io::CodedInputStream
等類別及 google::protobuf::TextFormat::Parse
, MessageLite::ParseFromCodedStream
等函數,詳見C++ google protobuf。
RETURN_AND_LOG_ERROR,CHECK_NULL_RET_NULL
用到了RETURN_AND_LOG_ERROR
及CHECK_NULL_RET_NULL
兩個函數,詳見TensorRT/parsers/caffe/caffeMacros.h源碼研讀。
參考連結
TensorRT/parsers/caffe/caffeMacros.h源碼研讀
TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parse源碼研讀
TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseBinaryProto源碼研讀
TensorRT/parsers/caffe/caffeParser/caffeParser.cpp - parseNormalizeParam源碼研讀