Tensorflow Keras 中input_shape引發的維度順序衝突問題(NCHW與NHWC)

Tensorflow Keras 中input_shape引發的維度順序衝突問題(NCHW與NHWC)

原文鏈接:Tensorflow Keras 中input_shape引發的維度順序衝突問題(NCHW與NHWC)

tf.keras.Sequential構建卷積層爲例:

tf.keras.layers.Conv2D(10, 3, input_shape=(2, 9, 9),padding='same',activation=tf.nn.relu,kernel_initializer='glorot_normal',
        bias_initializer='glorot_normal'),

這是一個簡單的卷積層的定義,主要看input_shape參數:

這是用來指定卷積層輸入形狀的參數,由於Keras提供了兩套後端,Theano和Tensorflow,不同的後端使用時對該參數所指代的維度順序dim_ordering會有衝突。

  • Theano(th):

    • NCHW:順序是 [batch, in_channels, in_height, in_width]
  • Tensorflow(tf):keras默認使用這種方式

    • NHWC:順序是 [batch, in_height, in_width, in_channels]

即對於上述input_shape=(2, 9, 9)來說:我們先忽略batch,2會被解析爲通道數,矩陣大小爲9*9,符合我們預期。而tf會將矩陣大小解析爲2 * 9 ,且最後一位9代表通道數,與預期不符。

解決

法一:

在卷積層定義中加入參數來讓keras在兩種後端之間切換:

  • data_format='channels_first':代表th
  • data_format='channels_last':代表tf

但是該法在某些時候不成功會報錯:

或許是cpu電腦導致的,只支持NHWC即tf模式。

只能修改相應文件的配置來使其支持NCHW,參考這裏:https://github.com/balancap/SSD-Tensorflow/issues/226


法二:(推薦)

使用tf.transpose函數進行高維數據的轉置(維度大於2,軸的轉換)

如將上述(2,9,9)轉爲(9,9,2)並且是以2爲通道數,即矩陣爲9*9,而不是像reshape函數簡單的調整維度,若使用reshape函數來轉換,只會得到通道數爲9,矩陣爲9 * 2的數據。

tf.transpose(待轉矩陣,(1,2,0))

解釋:

​ 其中0,1,2…是原矩陣維度從左到右軸的標號,即(2,9,9)中三個維度分別對應標號0,1,2。而調整過後將標號順序變爲1,2,0 即是把表通道數的軸置於最後,這樣轉置後的矩陣就滿足了keras的默認tf後端。即可正常訓練。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章