TensorFlow中CNN的兩種padding方式“SAME”和“VALID”

原文鏈接:http://blog.csdn.net/wuzqchom/article/details/74785643

在用tensorflow寫CNN的時候,調用卷積核api的時候,會有填padding方式的參數,找到源碼中的函數定義如下(max pooling也是一樣):

def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,
           data_format=None, name=None)


源碼中對於padding參數的說明如下:

padding: A `string` from: `"SAME", "VALID"`.
      The type of padding algorithm to use.


源碼中說明padding可以用SAME和VALID兩種方式,但是對於這兩種方式具體是什麼並沒有多加說明。
這裏用Stack Overflow中的一份代碼來簡單說明一下,代碼如下:i

import tensorflow as tf

x = tf.constant([[1., 2., 3.],
                 [4., 5., 6.]])
x = tf.reshape(x, [1, 2, 3, 1])  # give a shape accepted by tf.nn.max_pool

valid_pad = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID')
same_pad = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME')

print(valid_pad.get_shape())
print(same_pad.get_shape())


最後輸出的結果爲:

(1, 1, 1, 1)
(1, 1, 2, 1)


可以看出SAME的填充方式是比VALID的填充方式多了一列。
讓我們來看看變量xxx是一個2×32\times32×3的矩陣,max pooling窗口爲2×22\times22×2,兩個維度的步長strides=2strides=2strides=2。
第一次由於窗口可以覆蓋,橙色區域做max pooling,沒什麼問題,如下:

                                                                  
接下來就是SAME和VALID的區別所在:由於步長爲2,當向右滑動兩步之後,VALID方式發現餘下的窗口不到2×22\times22×2所以直接將第三列捨棄,而SAME方式並不會把多出的一列丟棄,但是隻有一列了不夠2×22\times22×2怎麼辦?填充!

                                                                
如上圖所示,SAME會增加第四列以保證可以達到2×2,但爲了不影響原始信息,一般以0來填充。這就不難理解不同的padding方式輸出的形狀會有所不同了。

當CNN用於文本中時,一般卷積層設置卷積核的大小爲n×k,其中k爲輸入向量的維度(即[n,k,input_channel_num,output_channel_num]),這時候我們就需要選擇“VALID”填充方式,這時候窗口僅僅是沿着一個維度掃描而不是兩個維度。可以理解爲統計語言模型當中的N-gram。我們設計網絡結構時需要設置輸入輸出的shape,源碼nn_ops.py中的convolution函數和pool函數給出的計算公式如下:

 If padding == "SAME":
      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])

    If padding == "VALID":
      output_spatial_shape[i] =
        ceil((input_spatial_shape[i] -
              (spatial_filter_shape[i]-1) * dilation_rate[i])
              / strides[i]).


dilation_rate爲一個可選的參數,默認爲1,這裏我們先不管。
整理一下,對於VALID,輸出的形狀計算如下:

                                                  

對於SAME,輸出的形狀計算如下:

​    

其中,W爲輸入的size,F爲filter的size,S爲步長,⌈ ⌉爲向上取整符號。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章