使用pytorch的dataload方式计算自己的图片数据集的均值和标准差

 

网上看到一个使用opencv读取图片然后计算数据集的均值和标准差的,但是那个读取图片后把图片的每个值append到一个列表,要是数据集大的话内存真的会爆掉的啊,所以借助网上另一个使用pytorch的数据读取方式来计算的,原文https://www.cnblogs.com/wanghui-garcia/p/11448460.html 这篇是分别计算了训练集、测试集和验证集数据的均值和标准差并将均值和标准差保存到了一个文件中,我不需要那样子,我只需要计算我总数据集的均值标准差并输出就好了,所以做了一点修改。

首先说一下我的文件夹格式,没有分训练集测试集啥的,就是一个文件夹下面分类别放

'/home/jfw/tomato/'这个路径下有个tomatodata文件夹,这个文件夹下面有10个文件夹分别是10类,朋友们可以对应下面代码适合下自己数据集。

我输入网络层时需要resize到224*224,所以transforms里先resize到这个尺寸,然后转成tensor,最后来计算所有数据的均值标准差,注释掉的不用看,有兴趣可以看下上面原博主写的

 

# coding:utf-8
import os
import numpy as np
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms

# from options import options
import pickle

"""
    在训练前先运行该函数获得数据的均值和标准差
"""


class Dataloader():
    def __init__(self, dataroot):
        # 训练,验证,测试数据集文件夹名
        # self.dataroot = dataroot
        self.dirs = ['tomatodata']

        self.means = [0, 0, 0]
        self.stdevs = [0, 0, 0]

        self.transform = transforms.Compose([transforms.Resize((224,224)),
                                             transforms.ToTensor(),  # 数据值从[0,255]范围转为[0,1],相当于除以255操作
                                             # transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
                                             ])

        # 因为这里使用的是ImageFolder,按文件夹给数据分类,一个文件夹为一类,label会自动标注好
        self.dataset = {x: ImageFolder(os.path.join(dataroot, x), self.transform) for x in self.dirs}

    def get_mean_std(self):
        """
        计算数据集的均值和标准差
        :param type: 使用的是那个数据集的数据,有'train', 'test', 'testing'
        :param mean_std_path: 计算出来的均值和标准差存储的文件
        :return:
        """
        num_imgs = len(self.dataset['tomatodata'])
        for data in self.dataset['tomatodata']:
            img = data[0]
            for i in range(3):
                # 一个通道的均值和标准差
                self.means[i] += img[i, :, :].mean()
                self.stdevs[i] += img[i, :, :].std()

        self.means = np.asarray(self.means) / num_imgs
        self.stdevs = np.asarray(self.stdevs) / num_imgs

        print("{} : normMean = {}".format(type, self.means))
        print("{} : normstdevs = {}".format(type, self.stdevs))

        # # 将得到的均值和标准差写到文件中,之后就能够从中读取
        # with open(mean_std_path, 'wb') as f:
        #     pickle.dump(self.means, f)
        #     pickle.dump(self.stdevs, f)
        #     print('pickle done')


if __name__ == '__main__':

    dataroot = '/home/jfw/tomato/'
    dataloader = Dataloader(dataroot)
    # for x in dataloader.dirs:
    #     mean_std_path = 'mean_std_value_' + x + '.pkl'
    dataloader.get_mean_std()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章