【小白學PyTorch】16 TF2讀取圖片的方法

【新聞】:機器學習煉丹術的粉絲的人工智能交流羣已經建立,目前有目標檢測、醫學圖像、NLP等多個學術交流分羣和水羣嘮嗑的總羣,歡迎大家加煉丹兄爲好友,加入煉丹協會。微信:cyx645016617.

參考目錄:

本文的代碼已經上傳公衆號後臺,回覆【PyTorch】獲取。

1 PIL讀取圖片

想要把一個圖片,轉換成RGB3通道的一個張量,我們怎麼做呢?大家第一反應應該是PIL這個庫

from PIL import Image
import numpy as np
image = Image.open('./bug1.jpg')
image.show()

展示的圖片:

然後我們這個image現在是PIL格式的,我們使用numpy.array()來將其轉換成numpy的張量的形式:

image = np.array(image)
print(image.shape)
>>>(326, 312, 3)

可以看到,這個第三維度是3。對於pytorch而言,數據的第一維度應該是樣本數量,第二維度是通道數,第三四是圖像的寬高,因此PIL讀入的圖片,往往需要把通道數的這個維度移動到第二維度上才能對接上pytorch的形式。(transpose方法來實現這個功能,這裏不細說)

2 TF讀取圖片

下面是重點啦,對於tensorflow,tf中自己帶了一個解碼函數,先看一下我的文件目錄:

import tensorflow as tf
images = tf.io.gfile.glob('./*.jpeg')
print(images,type(images))
> ['.\\bug1.jpeg', '.\\bug2.jpeg'] <class 'list'>

可以看出來:

  • 這個tensorflow.io.gfile.glob()是讀取路徑下的所有符合條件的文件,並且把路徑做成一個list返回;
  • 這個功能也可以用glob庫函數實現,我記得是glob.glob()方法;
  • 這裏的bug1和bug2其實是同一張圖片,都是上面的那個小兔子。
image = tf.io.read_file('./bug1.jpeg')
image = tf.image.decode_jpeg(image,channels=3)
print(image.shape,type(image))
> (326, 312, 3) <class 'tensorflow.python.framework.ops.EagerTensor'>

需要注意的是:

  • tf.io.read_file()這個得到的返回值是二進制格式,所以需要下面的tf.image.decode_jpeg進行一個解碼;
  • decode_jpeg的第一個參數就是讀取的二進制文件,然後channels是輸出的圖片的通道數,3就是RPB三個通道,如果是1的話,就是灰度圖片,ratio是圖片大小的一個縮小比例,默認是1,可以是2和4,一會看一下ratio=2的情況;
  • 這個image的type是一個tensorflow特別的Tensor的形式,而不是pytorch的那種tensor的形式了。
image = tf.io.read_file('./bug1.jpeg')
image = tf.image.decode_jpeg(image,channels=1,ratio=2)
print(image.shape,type(image))
> (163, 156, 1) <class 'tensorflow.python.framework.ops.EagerTensor'>

寬高都變成了原來的一半,然後通道數是1,都和預想的一樣。使用decode_jpeg等解碼函數得到的結果,是uint8的類型的,簡單地說就是整數,0到255範圍的。在對圖片進行操作的時候,我們需要將其標準化到0到1區間的,因此需要將其轉換成float32類型的。所以對上述代碼進行補充:

image = tf.io.read_file('./bug1.jpeg')
image = tf.image.decode_jpeg(image,channels=1,ratio=2)
print(image.shape,type(image))
image = tf.image.resize(image,[256,256]) # 統一圖片大小
image = tf.cast(image,tf.float32) # 轉換類型
image = image/255 # 歸一化
print(image)

從結果來看,數據類型已經改變:

3 TF構建數據集

下面是dataset更正式的寫法,關於TF2的問題,不要百度!百度到的都是TF1的解答,看的我暈死了,TF的API的結構真是不太友好。。。

def read_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3, ratio=1)
    image = tf.image.resize(image, [256, 256])  # 統一圖片大小
    image = tf.cast(image, tf.float32)  # 轉換類型
    image = image / 255  # 歸一化
    return image
images = tf.io.gfile.glob('./*.jpeg')
dataset = tf.data.Dataset.from_tensor_slices(images)
AUTOTUNE = tf.data.experimental.AUTOTUNE
dataset = dataset.map(read_image,num_parallel_calls=AUTOTUNE)
dataset = dataset.shuffle(1).batch(1)
for a in dataset.take(2):
    print(a.shape)

代碼中需要注意的是:

  • glob獲取一個文件的list,本次就兩個文件名字,一個bug1.jpeg,一個bug2.jpeg;
  • tf.data.Dataset.from_tensor_slices()返回的就是一個tensorflow的dataset類型,可以簡單理解爲一個可迭代的list,並且有很多其他方法;
  • dataset.map就是用實現定義好的函數,對處理dataset中每一個元素,在上面代碼中是把路徑的字符串變成該路徑讀取的圖片張量,對圖片的預處理應該也在這部分進行吧;
  • dataset.shuffle就是亂序,.batch()就是把dataset中的元素組裝batch;
  • 在獲取dataset中的元素的時候,TF1中有什麼迭代器的定義啊,什麼iter,但是TF2不用這些,直接.take(num)就行了,這個num就是從dataset中取出來的batch的數量,也就是循環的次數吧。
  • AUTOTUNE = tf.data.experimental.AUTOTUNE 就是根據你的cpu的情況,自動判斷多線程的數量。
    上面代碼的輸出結果爲:
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章