Tensorflow Keras 中input_shape引發的維度順序衝突問題(NCHW與NHWC)
原文鏈接:Tensorflow Keras 中input_shape引發的維度順序衝突問題(NCHW與NHWC)
以tf.keras.Sequential
構建卷積層爲例:
tf.keras.layers.Conv2D(10, 3, input_shape=(2, 9, 9),padding='same',activation=tf.nn.relu,kernel_initializer='glorot_normal',
bias_initializer='glorot_normal'),
這是一個簡單的卷積層的定義,主要看input_shape參數:
這是用來指定卷積層輸入形狀的參數,由於Keras提供了兩套後端,Theano和Tensorflow,不同的後端使用時對該參數所指代的維度順序dim_ordering會有衝突。
Theano(th):
- NCHW:順序是 [batch, in_channels, in_height, in_width]
Tensorflow(tf):keras默認使用這種方式
- NHWC:順序是 [batch, in_height, in_width, in_channels]
即對於上述input_shape=(2, 9, 9)
來說:我們先忽略batch,2會被解析爲通道數,矩陣大小爲9*9,符合我們預期。而tf會將矩陣大小解析爲2 * 9 ,且最後一位9代表通道數,與預期不符。
解決
法一:
在卷積層定義中加入參數來讓keras在兩種後端之間切換:
data_format='channels_first'
:代表thdata_format='channels_last
':代表tf
但是該法在某些時候不成功會報錯:
或許是cpu電腦導致的,只支持NHWC即tf模式。
只能修改相應文件的配置來使其支持NCHW,參考這裏:https://github.com/balancap/SSD-Tensorflow/issues/226
法二:(推薦)
使用tf.transpose函數進行高維數據的轉置(維度大於2,軸的轉換)
如將上述(2,9,9)轉爲(9,9,2)並且是以2爲通道數,即矩陣爲9*9,而不是像reshape函數簡單的調整維度,若使用reshape函數來轉換,只會得到通道數爲9,矩陣爲9 * 2的數據。
tf.transpose(待轉矩陣,(1,2,0))
解釋:
其中0,1,2…是原矩陣維度從左到右軸的標號,即(2,9,9)中三個維度分別對應標號0,1,2。而調整過後將標號順序變爲1,2,0 即是把表通道數的軸置於最後,這樣轉置後的矩陣就滿足了keras的默認tf後端。即可正常訓練。