網上看到一個使用opencv讀取圖片然後計算數據集的均值和標準差的,但是那個讀取圖片後把圖片的每個值append到一個列表,要是數據集大的話內存真的會爆掉的啊,所以藉助網上另一個使用pytorch的數據讀取方式來計算的,原文https://www.cnblogs.com/wanghui-garcia/p/11448460.html 這篇是分別計算了訓練集、測試集和驗證集數據的均值和標準差並將均值和標準差保存到了一個文件中,我不需要那樣子,我只需要計算我總數據集的均值標準差並輸出就好了,所以做了一點修改。
首先說一下我的文件夾格式,沒有分訓練集測試集啥的,就是一個文件夾下面分類別放
'/home/jfw/tomato/'這個路徑下有個tomatodata文件夾,這個文件夾下面有10個文件夾分別是10類,朋友們可以對應下面代碼適合下自己數據集。
我輸入網絡層時需要resize到224*224,所以transforms裏先resize到這個尺寸,然後轉成tensor,最後來計算所有數據的均值標準差,註釋掉的不用看,有興趣可以看下上面原博主寫的
# coding:utf-8
import os
import numpy as np
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
# from options import options
import pickle
"""
在訓練前先運行該函數獲得數據的均值和標準差
"""
class Dataloader():
def __init__(self, dataroot):
# 訓練,驗證,測試數據集文件夾名
# self.dataroot = dataroot
self.dirs = ['tomatodata']
self.means = [0, 0, 0]
self.stdevs = [0, 0, 0]
self.transform = transforms.Compose([transforms.Resize((224,224)),
transforms.ToTensor(), # 數據值從[0,255]範圍轉爲[0,1],相當於除以255操作
# transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
# 因爲這裏使用的是ImageFolder,按文件夾給數據分類,一個文件夾爲一類,label會自動標註好
self.dataset = {x: ImageFolder(os.path.join(dataroot, x), self.transform) for x in self.dirs}
def get_mean_std(self):
"""
計算數據集的均值和標準差
:param type: 使用的是那個數據集的數據,有'train', 'test', 'testing'
:param mean_std_path: 計算出來的均值和標準差存儲的文件
:return:
"""
num_imgs = len(self.dataset['tomatodata'])
for data in self.dataset['tomatodata']:
img = data[0]
for i in range(3):
# 一個通道的均值和標準差
self.means[i] += img[i, :, :].mean()
self.stdevs[i] += img[i, :, :].std()
self.means = np.asarray(self.means) / num_imgs
self.stdevs = np.asarray(self.stdevs) / num_imgs
print("{} : normMean = {}".format(type, self.means))
print("{} : normstdevs = {}".format(type, self.stdevs))
# # 將得到的均值和標準差寫到文件中,之後就能夠從中讀取
# with open(mean_std_path, 'wb') as f:
# pickle.dump(self.means, f)
# pickle.dump(self.stdevs, f)
# print('pickle done')
if __name__ == '__main__':
dataroot = '/home/jfw/tomato/'
dataloader = Dataloader(dataroot)
# for x in dataloader.dirs:
# mean_std_path = 'mean_std_value_' + x + '.pkl'
dataloader.get_mean_std()