torchvision.Transform.ToTensor()將圖片正確輸入網絡

ToTensor()描述如下:
在這裏插入圖片描述
它會將圖片從NHWC轉換爲NCHW且變爲tensor,並且通過除以255將圖片歸一化到(0,1)。
注意,通道的順序與你讀取圖片所用的工具有關:
PIL: (R,G,G)
cv2:(B,G,R)
例子如下:

import torch
from PIL import Image
import cv2

from torchvision import transforms
import numpy


img_PIL = Image.open("000001.jpg")
img_cv2 = cv2.imread("000001.jpg")
print(img_PIL.size)       ##  WH
print(img_cv2.shape)      ##  HWC

img_PIL_np = numpy.array(img_PIL) #轉爲numpy後,變爲CHW
print(img_PIL_np.shape)   ##    HWC


tran = transforms.ToTensor()### 注意用這種寫法
img_PIL_tensor = tran(img_PIL)
img_cv2_tensor = tran(img_cv2)

print(img_PIL_tensor.size())  #CHW (RGB)
print(img_cv2_tensor.size())  #CHW (BGR)

輸出結果:
(409, 687)
(687, 409, 3)
(687, 409, 3)
torch.Size([3, 687, 409])
torch.Size([3, 687, 409])

注意:
當使用PIL.Image.open()打開圖片後,如果要使用img.shape函數,需要先將image形式轉換成array數組
torchvision.Transform.ToTensor()這個工具在數據處理時還是非常方便的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章