Pytorch:將圖像tensor數據用Opencv顯示

*Pytorch:將圖像tensor數據用Opencv顯示

首先導入相關庫:*

import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2

利用PIL中的Image打開一張圖片

image2=Image.open('pikachu.jpg')

這裏print看一下image2的圖像數據類型,這裏可以直接調用image2.show()直接顯示:
在這裏插入圖片描述

print(image2)
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=814x982 at 0x1E988BA74A8>

將image2轉化爲tensor數據(爲什麼轉化爲tensor,當然是爲了方便計算)

transform2=transforms.Compose([transforms.ToTensor()])
tensor2=transform2(image2)
print('tensor2:',tensor2)#打印看一下tensor的數據
print(tensor2.dtype)#torch.float32
print(tensor2.shape)#返回tensor2_shape torch.Size([3, 982, 814])->3通道982*814的RGB圖像
tensor2: tensor([[[0.5647, 0.5686, 0.5686,  ..., 0.5725, 0.5725, 0.5686],
         [0.5647, 0.5686, 0.5725,  ..., 0.5725, 0.5725, 0.5686],
         [0.5686, 0.5725, 0.5725,  ..., 0.5725, 0.5725, 0.5686],
         ...,
         [0.3176, 0.3216, 0.3216,  ..., 0.3098, 0.3098, 0.3098],
         [0.3176, 0.3216, 0.3216,  ..., 0.3098, 0.3098, 0.3098],
         [0.3176, 0.3216, 0.3216,  ..., 0.3137, 0.3137, 0.3137]],

        [[0.7412, 0.7451, 0.7451,  ..., 0.7490, 0.7490, 0.7451],
         [0.7412, 0.7451, 0.7490,  ..., 0.7490, 0.7490, 0.7451],
         [0.7451, 0.7490, 0.7490,  ..., 0.7490, 0.7490, 0.7451],
         ...,
         [0.5529, 0.5569, 0.5569,  ..., 0.5451, 0.5451, 0.5451],
         [0.5529, 0.5569, 0.5569,  ..., 0.5451, 0.5451, 0.5451],
         [0.5529, 0.5569, 0.5569,  ..., 0.5490, 0.5490, 0.5490]],

        [[0.9059, 0.9098, 0.9098,  ..., 0.9098, 0.9098, 0.9059],
         [0.9059, 0.9098, 0.9137,  ..., 0.9098, 0.9098, 0.9059],
         [0.9098, 0.9137, 0.9137,  ..., 0.9098, 0.9098, 0.9059],
         ...,
         [0.8275, 0.8314, 0.8314,  ..., 0.8275, 0.8275, 0.8275],
         [0.8275, 0.8314, 0.8314,  ..., 0.8275, 0.8275, 0.8275],
         [0.8275, 0.8314, 0.8314,  ..., 0.8314, 0.8314, 0.8314]]])

要將tensor圖像數據轉爲opencv支持的圖像數據,首先要了解opencv所支持的圖像數據:

image3=cv2.imread('pokeman/pikachu/00000000.jpg')
print(image3)
[[[231 189 144]
  [232 190 145]
  [232 190 145]
  ...
  [232 191 146]
  [232 191 146]
  [231 190 145]]]
print(image3.shape)
(982, 814, 3)
print(type(image3))
<class 'numpy.ndarray'>
print(image3.dtype)
uint8

所以我們知道opencv支持的圖像數據時numpy格式,數據類型爲uint8,而且像素值分佈在[0,255]之間。 但是從上面的tensor數據可以看出,像素值並不是分佈在[0,255],且數據類型爲float32,所以需要做一下normalize和數據變換,將圖像數據擴展到[0,255]。還有一點不同的是tensor(3,982, 814)、numpy(982, 814, 3)存儲的數據維度順序不同。

array1=tensor2.numpy()#將tensor數據轉爲numpy數據
maxValue=array1.max()
array1=array1*255/maxValue#normalize,將圖像數據擴展到[0,255]
mat=np.uint8(array1)#float32-->uint8
print('mat_shape:',mat.shape)#mat_shape: (3, 982, 814)
mat=mat.transpose(1,2,0)#mat_shape: (982, 814,3)
cv2.imshow("img",mat)
cv2.waitKey()

在這裏插入圖片描述
這是由於opencv中的顏色通道順序是BGR而PIL、torch裏面的圖像顏色通道是RGB,利用cvtColor對顏色通道進行轉換


mat=cv2.cvtColor(mat,cv2.COLOR_BGR2RGB)
cv2.imshow("img",mat)
cv2.waitKey()

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章