TensorRT INT8量化原理以及如何編寫校準器類進行校準

上一篇博客中介紹了從Pytorch模型到ONNX中間格式文件再到TensorRT推理引擎的整個過程,其中在進行INT8格式的轉換時,需要額外的工作來做,這篇博客就針對INT8轉換的具體過程及如何準備校準集、編寫校準器進行詳細介紹。

同時,使用TensorRT進行INT8量化的過程也分享到了GitHub,歡迎大家參考。

目錄

1、INT8量化過程

2、編寫校準器,並進行INT8量化


1、INT8量化過程

衆所周知,一個訓練好的深度學習模型,其數據包含了權重(weights)和偏移(biases)兩部分,在其進行前向推理(forward)時,中間會根據權重和偏移產生激活值(activation)。

關於INT8的量化原理,這篇回答非常詳盡,我這裏就簡單說結論了:

  • TensorRT在進行INT8量化時,對權重直接使用了最大值量化,對偏移直接忽略,對前向計算中的激活值的量化是重點;
  • 對激活值進行INT8量化採用飽和量化:因爲激活值通常分佈不均勻,直接使用非飽和量化會使得量化後的值都擠在一個很小的範圍從而浪費了INT8範圍內的其他空間,也就是說沒有充分利用INT8(-128~+127)的值域;而進行飽和量化後,使得映射後的-128~+127範圍內分佈相對均勻,這相當於去掉了一些不重要的因素,保留了主要成分。
圖1. 直接忽略bias

圖1告訴我們,直接忽略bias就完事了,這是官方給出的實驗結論。 

圖1. 非飽和量化(左圖)和飽和量化(右圖)

圖2告訴我們權重沒必要使用飽和映射,因爲沒啥提高,而激活值使用飽和映射能調高性能,這好理解,因爲權重通常分別較爲均勻直接最大值非飽和映射和費勁力氣找閾值再進行飽和映射,其量化後的分佈很可能是極其相似的,而激活值分佈不均,尋找一個合適的閾值進行飽和映射就顯得比較重要了;並展示了直接使用最大值量化到INT8和選擇一個合適的閾值後飽和地量化到INT的區別,可以看出:右圖的關鍵在於選擇一個合適的閾值T,來對原來的分佈進行一個截取,將-T~+T之間的值映射到-128~+127,而>T和<-T的值則忽略掉。

如何尋找這個閾值T就成了INT量化的關鍵

圖3. 各模型激活值分佈

圖3可以看出,不同模型的激活值分佈差異很大,這就需要進行動態的量化,也即針對每一個模型,尋找一個對它來說最合適的T。

於是,NVIDIA就選擇了KL散度也即相對熵來對量化前後的激活值分佈進行評價,來找出使得量化後INT8分佈相對於原來的FP32分佈信息損失最小的那個閾值。如圖4所示:

圖4. 相對熵:KL散度

 於是,整個的量化過程就給了出來,如圖5所示:

圖5. INT8量化校準過程

意思就是:

  • 先在一個校準數據集上跑一遍原FP32的模型;
  • 然後,對每一層都收集激活值的直方圖,並生成在不同閾值下的飽和量化分佈;
  • 最後,找出使得KL散度最小的那個閾值T,即爲所求。

這個過程同時也告訴了我們,要做INT8量化,需要準備哪些東西——原來的未量化的模型(廢話,沒有原模型拿什麼量化!)、一個校準數據集進行量化過程的校準器。如圖6所示:

圖6. TensorRT的INT8工作流程

圖6可以看出,校準過程我們是不用參與的,全部都由TensorRT內部完成,但是,我們需要告訴校準器如何獲取一個batch的數據,也就是說,我們需要重寫校準器類中的一些方法。下面,我們就開始介紹如何繼承原校準器類並重寫其中的部分方法,來獲取我們自己的數據集來校準我們自己的模型。

2、編寫校準器,並進行INT8量化

我們需要繼承父類——trt.IInt8EntropyCalibrator2,並重寫他的一些方法:get_batch_size, get_batch, read_calibration_cache, write_calibration_cache。

這些方法分別是:獲取batch大小、獲取一個batch的數據、將校準集寫入緩存、從緩存讀出校準集。前兩個是必須的,不然校準器不知道用什麼數據來校準,後兩個方法可以忽略,但當你需要多次嘗試時,後兩個方法將很有用,它們會大大減少數據讀取的時間!

下面給出我寫的一個讀取我自己的數據集的校準器示例,完整工程可參考GitHub

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

import os
import numpy as np
from PIL import Image

import torchvision.transforms as transforms

class CenterNetEntropyCalibrator(trt.IInt8EntropyCalibrator2):

    def __init__(self, args, files_path='/home/user/Downloads/datasets/train_val_files/val.txt'):
        trt.IInt8EntropyCalibrator2.__init__(self)

        self.cache_file = 'CenterNet.cache'

        self.batch_size = args.batch_siz
        self.Channel = args.channel
        self.Height = args.height
        self.Width = args.width
        self.transform = transforms.Compose([
            transforms.Resize([self.Height, self.Width]),  # [h,w]
            transforms.ToTensor(),
        ])

        self._txt_file = open(files_path, 'r')
        self._lines = self._txt_file.readlines()
        np.random.shuffle(self._lines)
        self.imgs = [os.path.join('/home/user/Downloads/datasets/train_val_files/images',
                                  line.rstrip() + '.jpg') for line in self._lines]
        self.batch_idx = 0
        self.max_batch_idx = len(self.imgs)//self.batch_size
        self.data_size = trt.volume([self.batch_size, self.Channel,self.Height, self.Width]) * trt.float32.itemsize
        self.device_input = cuda.mem_alloc(self.data_size)

    def next_batch(self):
        if self.batch_idx < self.max_batch_idx:
            batch_files = self.imgs[self.batch_idx * self.batch_size:\
                                    (self.batch_idx + 1) * self.batch_size]
            batch_imgs = np.zeros((self.batch_size, self.Channel, self.Height, self.Width),
                                  dtype=np.float32)
            for i, f in enumerate(batch_files):
                img = Image.open(f)
                img = self.transform(img).numpy()
                assert (img.nbytes == self.data_size/self.batch_size), 'not valid img!'+f
                batch_imgs[i] = img
            self.batch_idx += 1
            print("batch:[{}/{}]".format(self.batch_idx, self.max_batch_idx))
            return np.ascontiguousarray(batch_imgs)
        else:
            return np.array([])

    def get_batch_size(self):
        return self.batch_size

    def get_batch(self, names, p_str=None):
        try:
            batch_imgs = self.next_batch()
            if batch_imgs.size == 0 or batch_imgs.size != self.batch_size*self.Channel*self.Height*self.Width:
                return None
            cuda.memcpy_htod(self.device_input, batch_imgs.astype(np.float32))
            return [int(self.device_input)]
        except:
            return None

    def read_calibration_cache(self):
        # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

上述代碼中,你需要改動的並不多,只需要根據你的數據集存放路徑及格式,讀取一個batch即可。需要注意的是,讀取的一個batch數據,數據類型是np.ndarray,shape爲[batch_size, C, H, W],也即[batch大小, 通道, 高, 寬]。

OK,現在編寫好了校準器,那麼如何進行INT量化呢?這一步,上一篇博客已經介紹過了,這裏就不多說了,僅給出示例代碼,直接看也很清晰,解釋可以看上篇博客

def ONNX2TRT(args, calib=None):
    ''' convert onnx to tensorrt engine, use mode of ['fp32', 'fp16', 'int8']
    :return: trt engine
    '''

    assert args.mode.lower() in ['fp32', 'fp16', 'int8'], "mode should be in ['fp32', 'fp16', 'int8']"

    G_LOGGER = trt.Logger(trt.Logger.WARNING)
    with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, \
            trt.OnnxParser(network, G_LOGGER) as parser:

        builder.max_batch_size = args.batch_size
        builder.max_workspace_size = 1 << 30
        if args.mode.lower() == 'int8':
            assert (builder.platform_has_fast_int8 == True), "not support int8"
            builder.int8_mode = True
            builder.int8_calibrator = calib
        elif args.mode.lower() == 'fp16':
            assert (builder.platform_has_fast_fp16 == True), "not support fp16"
            builder.fp16_mode = True

        print('Loading ONNX file from path {}...'.format(args.onnx_file_path))
        with open(args.onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            parser.parse(model.read())
        print('Completed parsing of ONNX file')

        print('Building an engine from file {}; this may take a while...'.format(args.onnx_file_path))
        engine = builder.build_cuda_engine(network)
        print("Created engine success! ")

        # 保存計劃文件
        print('Saving TRT engine file to path {}...'.format(args.engine_file_path))
        with open(args.engine_file_path, "wb") as f:
            f.write(engine.serialize())
        print('Engine file has already saved to {}!'.format(args.engine_file_path))
        return engine

 

參考:

NVIDIA TensorRT量化介紹PPT——8-bit Inference with TensorRT:

http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf

知乎上@章小龍的Int8量化-介紹:

https://zhuanlan.zhihu.com/p/58182172

我整理的INT8量化GitHub工程:

https://github.com/qq995431104/Pytorch2TensorRT.git

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章