在使用tensorflow訓練CNN時,通常會通過將輸入設置成一個placeholder,如下所示:
image_dims = [224, 224, 3]
inputs = tf.placeholder(tf.float32, [None] + image_dims, name='input_images')
其中None所在位置,表示訓練網絡時,一個迭代需要用到多少個樣本,也是就batch_size,這一維可以不用明確指定,主要是方便後續測試採用不同的數目如[1, 224, 224, 3](一個樣本)。
但此時如果在網絡定義中使用了tf.reshape函數的話(全卷積層中用到),就很容易報錯,如下代碼
resh0 = tf.reshape(h0, [h0.shape[0], -1])
此時報錯TypeError: Expected binary or unicode string, got -1
需要將上述代碼改爲
resh0 = tf.reshape(h0, [-1, h0.get_shape().as_list()[1] * h0.get_shape().as_list()[2] *
h0.get_shape().as_list()[3]])
其中需要使用.as_list()將獲取到的shape轉換成list才行。
Reference
stackoverflow:Tensorflow reshape on convolution output gives TypeError