TensorRT實戰(一) 如何搭建Batch Normalization層
2020/01/15 添加附加題 BN1D與hswish的實現
文章目錄
本文使用TensorRT Python API進行搭建,C++ API搭建方法異曲同工,使用的BN layer是PyTorch版本的,Caffe、TF應該相差不大。當前TRT版本爲6.0.1.5,更低、更高版本應該都會支持,本TRT文檔見鏈接。
PyTorch的Batch Normalization
瞅一眼PyTorch提供的BN層的定義,位於torch.nn.BatchNorm2d
,公式已經在註釋中說明,或者直接看文檔也行:
簡單地,是batch的均值,是batch的方差,爲了防止除0,對應batch學習得到的權重,就是偏置。在PyTorch中相對應的,對於任意一個bn層,它會有如下的結構:
weights = torch.load(your_model_dict_state_path)
bn_gamma = weights['bn.weight'].numpy() # bn gamma
bn_beta = weights['bn.bias'].numpy() # bn beta
bn_mean = weights['bn.running_mean'].numpy() # bn mean
bn_var = weights['bn.running_var'].numpy() # bn var sqrt
上面的weights
可以由torch.load()
得到,而bn
就是你自己定義的BN層。
TRT API實現
既然已經知道了BN的公式,那就按照公式實現就可以了。這裏因爲輸入x
是卷積後的結果,一般是個4維矩陣,BN層中的乘法是對4維矩陣按通道數進行矩陣乘法,因此需要使用TRT API提供的IScaleLayer
。官方文檔中提到,使用IElementWiseLayer
構建,這樣做太複雜,不推薦。
IScaleLayer
的文檔見鏈接,它提供操作,並且有三種模式,我們需要的就是trt.ScaleMode.CHANNEL
。代碼如下:
import tensorrt as trt
weights = torch.load(your_model_dict_state_path)
bn_gamma = weights['bn.weight'].numpy() # bn gamma
bn_beta = weights['bn.bias'].numpy() # bn beta
bn_mean = weights['bn.running_mean'].numpy() # bn mean
bn_var = weights['bn.running_var'].numpy() # bn var sqrt
eps = 1e-05
bn_var = np.sqrt(bn_var + eps)
bn_scale = bn_gamma / bn_var
bn_shift = - bn_mean / bn_var * bn_gamma + bn_beta
bn = network.add_scale(input=last_layer.get_output(0), mode=trt.ScaleMode.CHANNEL, shift=bn_shift, scale=bn_scale)
此處,power
未規定則默認爲1
。
fused Batch Normalization
進一步,實際上卷積層和BN層在推理過程中是可以融合在一起的,簡單來講,卷積層的過程爲:
這裏的替換掉BN公式的就可以得到:
當然這裏也是矩陣操作。就是新的,就是新的了。
代碼如下:
import tensorrt as trt
weights = torch.load(your_model_dict_state_path)
conv_w = weights['conv.weight'].numpy() # conv weight
conv_b = weights['conv.bias'].numpy() # conv bias
bn_gamma = weights['bn.weight'].numpy() # bn gamma
bn_beta = weights['bn.bias'].numpy() # bn beta
bn_mean = weights['bn.running_mean'].numpy() # bn mean
bn_var = weights['bn.running_var'].numpy() # bn var sqrt
eps = 1e-05
bn_var = np.sqrt(bn_var + eps)
fused_conv_w = conv_w * (bn_gamma / bn_var).reshape([conv_w.shape[0], 1, 1, 1])
fused_conv_b = (conv_b - bn_mean) / bn_var * bn_gamma + bn_beta
fused_conv = network.add_convolution(input=last_layer.get_output(0), num_output_maps=your_conv_out, kernel_shape=(your_conv_kernel, your_conv_kernel), kernel=fused_conv_w, bias=fused_conv_b)
fused_conv.padding = (your_conv_pad, your_conv_pad)
fused_conv.stride = (your_conv_stride, your_conv_stride)
其中,conv
是需要融合的卷積層,fused_conv
是與bn
融合後的卷積層,你需要規定fused_conv
與conv
擁有相同的參數(padding, stride, kernel_shape, num_output_maps)。
參考資料
[1] TensorRT API文檔: https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/index.html
[2] TensorRT文檔: https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html
[3] PyTorch文檔: https://pytorch.org/docs/stable/
[4] Pytorch中的Batch Normalization操作: https://www.cnblogs.com/yongjieShi/p/9332655.html
附加題
BatchNorm1d的TRT實現
同樣地,參考BatchNorm2d的實現方法,這裏需要添加一個tensorrt.IShuffleLayer
將1D的tensor轉成2D,再在2D進行BN,最後轉回1D,這裏你需要規定輸入tensor的大小,因爲TRT在shuffle的時候需要知道該參數。大概的實現代碼如下所示:
import tensorrt as trt
weights = torch.load(your_model_dict_state_path)
bn_gamma = weights['bn.weight'].numpy() # bn gamma
bn_beta = weights['bn.bias'].numpy() # bn beta
bn_mean = weights['bn.running_mean'].numpy() # bn mean
bn_var = weights['bn.running_var'].numpy() # bn var sqrt
eps = 1e-05
bn_var = np.sqrt(bn_var + eps)
bn_scale = bn_gamma / bn_var
bn_shift = - bn_mean / bn_var * bn_gamma + bn_beta
# reshape to 2D
shuffle = network.add_shuffle(last_layer.get_output(0))
shuffle.reshape_dims = (your_input_shape, your_input_shape, 1)
# do bn1d
bn = network.add_scale(input=shuffle.get_output(0), mode=trt.ScaleMode.CHANNEL, shift=bn_shift, scale=bn_scale)
# reshape to 1D
shuffle = network.add_shuffle(bn.get_output(0))
shuffle.reshape_dims = (your_input_shape, your_input_shape, 1)
hswish的TRT實現
參考PyTorch的hswish
的實現:
class hswish(nn.Module):
def forward(self, x):
out = x * F.relu6(x + 3, inplace=True) / 6
return out
那麼relu6
又是怎麼實現的呢,參考relu6
的公式:
因此我們可以得到如下TRT的實現代碼:
import tensorrt as trt
# x + 3
shape = (1, ) * len(your_input_shape)
tensor = 3.0 * torch.ones(shape, dtype=trt.float32).cpu().numpy()
trt_3 = network.add_constant(shape, tensor)
tmp = network.add_elementwise(last_layer.get_output(0), trt_3.get_output(0), trt.ElementWiseOperation.SUM)
# relu6(x + 3)
relu = network.add_activation(input=tmp.get_output(0), type=trt.ActivationType.RELU)
shape = (1, ) * len(your_input_shape)
tensor = 6.0 * torch.ones(shape, dtype=trt.float32).cpu().numpy()
trt_6 = network.add_constant(shape, tensor)
relu_6 = network.add_elementwise(relu.get_output(0), trt_6.get_output(0), trt.ElementWiseOperation.MIN)
# x * relu6(x + 3)
tmp = network.add_elementwise(last_layer.get_output(0), tmp.get_output(0), trt.ElementWiseOperation.PROD)
# x * relu6(x + 3) / 6
out = network.add_elementwise(tmp.get_output(0), trt_6.get_output(0), trt.ElementWiseOperation.DIV)