在做語義分割的時候會經過讀取圖像的步驟,根據 TensorFlow 官方教程 我使用了 tf.data.Dataset
這個 API。
根據官方讀取圖像的例子,一開始我的代碼如下:
def load_image(filename, resized_shape):
'''
:param filename: 圖像文件名
:param resized_shape: 縮放後圖像大小
'''
image = tf.read_file(filename)
image = tf.image.decode_png(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize_images(
image, size=resized_shape, method=tf.image.ResizeMethod.AREA)
return image
因爲訓練圖像是 png
格式,因此使用了 decode_png
。
這裏還使用了 resize_images
函數,因爲在進行小範圍測試時將圖像縮小可以加快訓練速度。
至此都是沒有問題的,但是最近處理數據集的時候遇到了 jpg
格式的訓練圖像,之前看到 TensorFlow 有 decode_image
這個函數,好像可以自動判定圖像格式然後 decode。
但是使用了之後報錯了,這個錯誤是在 resize_images
的時候發生的:
ValueError: 'images' contains no shape.
根據 decode_image
的官方文檔:
Returns:
Tensor with type uint8 with shape [height, width, num_channels] for BMP
, JPEG, and PNG images and shape [num_frames, height, width, 3] for GIF images.
返回的 Tensor 是有形狀的,但是從調試中可以看到 shape 是 unknown 的,所以返回應該是沒有形狀,google 了一下也沒有發現能說清這個問題的,因此這個函數暫時用不了了。
不過我發現了一種解決辦法:
def load_image(filename, resized_shape):
'''
:param filename: 圖像文件名
:param resized_shape: 縮放後圖像大小
'''
image = tf.read_file(filename)
image = tf.cond(
tf.image.is_jpeg(image),
lambda: tf.image.decode_jpeg(image),
lambda: tf.image.decode_png(image))
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize_images(
image, size=resized_shape, method=tf.image.ResizeMethod.AREA)
return image
就是使用 tf.cond()
函數,通過 tf.image.is_jpeg(image)
判斷圖像是不是 jpg
格式,如果是,就執行 decode_jpeg
,如果不是,就執行 decode_png
。因爲在語義分割中,大部分訓練圖像都是 jpg
或 png
,很少會有其他格式的圖像,因此用一個條件就夠了。
其實 decode_image
函數裏面就是使用的 tf.cond
來判斷的,判斷之後 decode,然後再 convert_image_dtype
,至於爲什麼返回沒有 shape 我也不清楚。
還有一種方法是使用 tf.Tensor.set_shape ,我的代碼裏不方便使用這個方法, 所以就沒有嘗試。