pytorch 數據處理

# -*- coding: utf-8 -*-
import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from tqdm import tqdm
import numpy as np
from torch import nn

from models import FSRCNN, Discriminator
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr
import PIL.Image as Image
import matplotlib.pyplot as plt

# print('pid:{}   GPU:{}'.format(os.getpid(), os.environ['CUDA_VISIBLE_DEVICES']))
def convert_rgb_to_ycbcr(img, dim_order='hwc'):
    if dim_order == 'hwc':
        y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
        cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
        cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
    else:
        y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
        cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
        cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
    return np.array([y, cb, cr]).transpose([1, 2, 0])
def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.

class MyDataset(Dataset):
    def __init__(self, data_dir, crop_size, scale=3):
        """
        rmb面額分類任務的Dataset
        :param data_dir: str, 數據集所在路徑
        :param transform: torch.transform,數據預處理
        """
        self.label_name = {"1": 0, "100": 1}   # 初始化部分
        self.data_info = self.get_img_info(data_dir)  # data_info存儲所有圖片路徑和標籤,在DataLoader中通過index讀取樣本
        self.crop_size = crop_size
        self.scale = scale

    def __getitem__(self, index):  # 函數功能是根據index索引去返回圖片img以及標籤label
        path_img = self.data_info[index]
        img = Image.open(path_img).convert("RGB")     # 0~255
        left = np.random.randint(0, img.width - self.crop_size)
        top = np.random.randint(0, img.height - self.crop_size)
        
        img = img.crop((left, top, left + self.crop_size, top + self.crop_size))
        hr_width = (img.width // self.scale) * self.scale
        hr_height = (img.height // self.scale) * self.scale
        hr = img.resize((hr_width, hr_height), resample=Image.BICUBIC)
        lr = img.resize((hr.width // self.scale, hr_height // self.scale), resample=Image.BICUBIC)
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        hr = convert_rgb_to_y(hr)
        lr = convert_rgb_to_y(lr)

        # return np.expand_dims(hr, axis=0), np.expand_dims(lr, axis=0)
        return hr/255., lr/255.

    def __len__(self):   # 函數功能是用來查看數據的長度,也就是樣本的數量
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):   # 函數功能是用來獲取數據的路徑以及標籤
        data_info = list()
        print(data_dir)
        for root, dirs, files in os.walk(data_dir):
            print(files)
            # 遍歷類別
            for file in files:
                if file.endswith('.png') or file.endswith('.PNG'):
                    img_names = os.path.join(root, file)
                    data_info.append(img_names)

        return data_info    # 有了data_info,就可以返回上面的__getitem__()函數中的self.data_info[index],根據index索取圖片和標籤


if __name__ == '__main__':
    
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    BATCH_SIZE = 2
    
    split_dir = os.path.join("./DIV2K_train_HR")
    train_dir = os.path.join(split_dir, "train")
    print(train_dir)
    
    # train_transform = transforms.Compose([
    # transforms.RandomCrop(128, padding=4),
    # transforms.ToTensor(),
    # transforms.Normalize(norm_mean, norm_std),
    # ])   # Resize的功能是縮放,RandomCrop的功能是裁剪,ToTensor的功能是把圖片變爲張量
    train_data = MyDataset(data_dir=train_dir, crop_size=600)  # data_dir是數據的路徑,transform是數據預處理
    # 構建DataLoder
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)  # shuffle=True,每一個epoch中樣本都是亂序的

    
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.unsqueeze(1) # [2, 498, 498] => [2, 1, 498, 498]
        print('inputs.size():',inputs.size())
        
        #顯示圖片的第一種方法
        # 方法1:Image.show()
        # transforms.ToPILImage()中有一句
        # npimg = np.transpose(pic.numpy(), (1, 2, 0))
        # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一維
        
        # 方法2:plt.imshow(ndarray)
        img = inputs[0][0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一維
        img = img.numpy() # FloatTensor轉爲ndarray
        img = np.transpose(img, (0,1)) # 把channel那一維放到最後
        # 顯示圖片
        plt.imshow(img)
        plt.show()
        
        print(inputs[0])


#https://blog.csdn.net/qq_37388085/article/details/102663166?utm_medium=distribute.pc_relevant.none-task-blog-baidujs-2
#https://www.pianshen.com/article/3908277840/   Pytorch torchvision.transforms小結

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章