Pytorch模型部署 - Libtorch(crnn模型部署)

Pytorch模型部署 - Libtorch

簡介

libtorch是facebook提供的一套C++推理接口庫,便於工業級別部署和性能優化。

配置

  • cmake 3.0
  • libtorch-1.14(cpu)
  • opencv-4.1.1

安裝:

libtoch+opencv聯合編譯,這裏採用libtorch-1.4(cpu)+opencv4.1.

  • 可能出現的問題

    • ibtoch,opencv聯合編譯項目時,報錯Undefined reference to cv::imread(std::string const&, int).
    • 解決方案:
      • 在相同編譯環境下,重新編譯libtorch和opencv源碼.(未測試…)
      • 在opencv的CMakeList.txt中加上add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)重新編譯opencv.(測試通過)
  • libtorch安裝:解壓下載包就可以,在代碼編譯時指定庫的路徑即可。

  • opencv安裝: 下載源碼 https://opencv.org/releases/

    unzip opencv-4.1.1.zip
    cd opencv-4.1.1
    # vim CMakeList.txt 如果出現上面問題,在這裏添加上述命令,重新編譯安裝
    mkdir build && cd build
    cmake -D CMAKE_BUILD_TYPE=RELEASE -D OPENCV_GENERATE_PKGCONFIG=ON -D CMAKE_INSTALL_PREFIX=/usr/local ..
    make -j4
    sudo make intall
    

    ls /usr/local/lib查看安裝好的opencv庫.

案例:libtorch部署crnn-英文識別模型.

crnn: 文本識別模型,常用於OCR.

Step1: 模型轉換

將pytorch訓練好的crnn模型轉換爲libtorch能夠讀取的模型.

#covertion.py
import torch
import torchvison

model = CRNN(32, 1, len(keys.alphabetEnglish) + 1, 256, 1).cpu()

state_dict = torch.load(
    model_path, map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '')  # remove `module.`
    new_state_dict[name] = v
# # # load params
model.load_state_dict(new_state_dict)

# convert pth-model to pt-model
example = torch.rand(1, 1, 32, 512)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("src/crnn.pt")

代碼過長,github附完整代碼。github: crnn_libtorch

Step2: 模型部署

利用libtoch+opencv實現對文字條的識別.

//crnnDeploy.h
#include <torch/torch.h>
#include <torch/script.h>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>

#include <iostream>
#include <cassert>
#include <vector>

#ifndef CRNN_H
#define CRNN_H

class Crnn{
    public:
        Crnn(std::string& modelFile, std::string& keyFile);
        torch::Tensor loadImg(std::string& imgFile, bool isbath=false);
        void infer(torch::Tensor& input);
    private:
        torch::jit::script::Module m_module;
        std::vector<std::string> m_keys;
        std::vector<std::string> readKeys(const std::string& keyFile);
        torch::jit::script::Module loadModule(const std::string& modelFile);
};

#endif//CRNN_H
/*
@author
date: 2020-03-17
Introduce:
    Deploy crnn model with libtorch.
*/

#include "CrnnDeploy.h"
#include <thread>
#include <sys/time.h>

//construtor
Crnn::Crnn(std::string& modelFile, std::string& keyFile){
    this->m_module = this->loadModule(modelFile);
    this->m_keys = this->readKeys(keyFile);
}


torch::Tensor Crnn::loadImg(std::string& imgFile, bool isbath){
	cv::Mat input = cv::imread(imgFile, 0);
	if(!input.data){
		printf("Error: not image data, imgFile input wrong!!");
	}
	int resize_h = int(input.cols * 32 / input.rows);
	cv::resize(input, input, cv::Size(resize_h, 32));
    torch::Tensor imgTensor;
    if(isbath){
        imgTensor = torch::from_blob(input.data, {32, resize_h, 1}, torch::kByte);
	    imgTensor = imgTensor.permute({2,0,1});
    }else
    {
        imgTensor = torch::from_blob(input.data, {1,32, resize_h, 1}, torch::kByte);
        imgTensor = imgTensor.permute({0,3,1,2});
    }
	imgTensor = imgTensor.toType(torch::kFloat);
	imgTensor = imgTensor.div(255);
	imgTensor = imgTensor.sub(0.5);
	imgTensor = imgTensor.div(0.5);
    return imgTensor;
}

void Crnn::infer(torch::Tensor& input){
    torch::Tensor output = this->m_module.forward({input}).toTensor();
    std::vector<int> predChars;
    int numImgs = output.sizes()[1];
    if(numImgs == 1){
        for(uint i=0; i<output.sizes()[0]; i++){
            auto maxRes = output[i].max(1, true);
            int maxIdx = std::get<1>(maxRes).item<float>();
            predChars.push_back(maxIdx);
        }
        // 字符轉錄處理
        std::string realChars="";
        for(uint i=0; i<predChars.size(); i++){
            if(predChars[i] != 0){
                if(!(i>0 && predChars[i-1]==predChars[i])){
                    realChars += this->m_keys[predChars[i]];
                }
            }
        }
        std::cout << realChars << std::endl;
    }else
    {
        std::vector<std::string> realCharLists;
        std::vector<std::vector<int>> predictCharLists;

        for (int i=0; i<output.sizes()[1]; i++){
            std::vector<int> temp;
            for(int j=0; j<output.sizes()[0]; j++){
                auto max_result = (output[j][i]).max(0, true);
                int max_index = std::get<1>(max_result).item<float>();//predict value
                temp.push_back(max_index);
            }
            predictCharLists.push_back(temp);
        }

        for(auto vec : predictCharLists){
            std::string text = "";
            for(uint i=0; i<vec.size(); i++){
                if(vec[i] != 0){
                    if(!(i>0 && vec[i-1]==vec[i])){
                        text += this->m_keys[vec[i]];
                    }
                }
            }
            realCharLists.push_back(text);
        }
        for(auto t : realCharLists){
            std::cout << t << std::endl;
        }
    }

}

std::vector<std::string> Crnn::readKeys(const std::string& keyFile){
    std::ifstream in(keyFile);
	std::ostringstream tmp;
	tmp << in.rdbuf();
	std::string keys = tmp.str();

    std::vector<std::string> words;
    words.push_back(" ");//函數過濾掉了第一個空格,這裏加上
    int len = keys.length();
    int i = 0;
    
    while (i < len) {
      assert ((keys[i] & 0xF8) <= 0xF0);
      int next = 1;
      if ((keys[i] & 0x80) == 0x00) {
      } else if ((keys[i] & 0xE0) == 0xC0) {
        next = 2;
      } else if ((keys[i] & 0xF0) == 0xE0) {
        next = 3;
      } else if ((keys[i] & 0xF8) == 0xF0) {
        next = 4;
      }
      words.push_back(keys.substr(i, next));
      i += next;
    } 
    return words;
}

torch::jit::script::Module Crnn::loadModule(const std::string& modelFile){
    torch::jit::script::Module module;
    try{
         module = torch::jit::load(modelFile);
    }catch(const c10::Error& e){
        std::cerr << "error loadding the model !!!\n";
    }
    return module;
}


long getCurrentTime(void){
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return tv.tv_sec * 1000 + tv.tv_usec/1000;
}

int main(int argc, const char* argv[]){

    if(argc<4){
        printf("Error use CrnnDeploy: loss input param !!! \n");
        return -1;
    }
    std::string modelFile = argv[1];
    std::string keyFile = argv[2];
    std::string imgFile = argv[3];

    long t1 = getCurrentTime();
    Crnn* crnn = new Crnn(modelFile,keyFile);
    torch::Tensor input = crnn->loadImg(imgFile);
    crnn->infer(input);
    delete crnn;
    long t2 = getCurrentTime();

    printf("ocr time : %ld ms \n", (t2-t1));
    return 0;
}

完整代碼和測試模型:
github: crnn_libtorch

獲取代碼: git clone https://github.com/chenyangMl/crnn_libtorch.git

參考

  • opencv installtion: https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html
  • libtorch : https://pytorch.org/tutorials/advanced/cpp_frontend.html
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章