實驗基於論文: Class-Balanced Loss Based on Effective Number of Samples
Class-balanced-loss代碼地址:https://github.com/vandit15/Class-balanced-loss-pytorch
resnet18代碼參考鏈接:https://blog.csdn.net/sunqiande88/article/details/80100891
製作數據集
論文中通過公式,爲類索引.製作長尾cifar10數據集.以下代碼以不均勻比例100爲例.也可以通過科學上網在谷歌雲鏈接下載.
loadcifar.py
import torch
import torch.utils.data as Data
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# 從源文件讀取數據
# 返回 train_data[12406,3072]和labels[12406]
# test_data[10000,3072]和labels[10000]
def get_data(train=False):
data = None
labels = None
new_data = None
new_labels = []
if train == True:
for i in range(1, 6):
batch = unpickle('data/cifar-10-batches-py/data_batch_' + str(i))
if i == 1:
data = batch[b'data']
labels = batch[b'labels']
else:
data = np.concatenate([data, batch[b'data']])
labels = np.concatenate([labels, batch[b'labels']])
count = np.zeros((10),dtype=np.int)
for i in range(len(labels)):
labels[i] = labels[i].reshape(1,1)
data[i] = data[i].reshape((1,3072))
if count[labels[i]] < int(np.floor(5000 * ((1 / 100) ** (1 / 9)) ** (labels[i]))):
count[labels[i]] += 1
if i == 0:
new_data = data[i]
else:
new_data = np.concatenate([new_data,data[i]])
new_labels.append(labels[i])
else:
continue
new_labels = np.array(new_labels)
new_data = new_data.reshape(-1,3072)
else:
batch = unpickle('data/cifar-10-batches-py/test_batch')
new_data = batch[b'data']
new_labels = batch[b'labels']
return new_data, new_labels
# 圖像預處理函數,Compose會將多個transform操作包在一起
# 對於彩色圖像,色彩通道不存在平穩特性
transform = transforms.Compose([
# ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C)
# 從0到255的值映射到0到1的範圍內,並轉化成Tensor格式。
transforms.ToTensor(),
# Normalize函數將圖像數據歸一化到[-1,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 將標籤轉換爲torch.LongTensor
def target_transform(label):
label = np.array(label)
target = torch.from_numpy(label).long()
return target
'''
自定義數據集讀取框架來載入cifar10數據集
需要繼承data.Dataset
'''
class Cifar10_Dataset(Data.Dataset):
def __init__(self, train=True, transform=None, target_transform=None):
# 初始化文件路徑
self.transform = transform
self.target_transform = target_transform
self.train = train
# 載入訓練數據集
if self.train:
self.train_data, self.train_labels = get_data(train)
num = self.train_data.shape[0]
self.train_data = self.train_data.reshape((num, 3, 32, 32))
# 將圖像數據格式轉換爲[height,width,channels]方便預處理
self.train_data = self.train_data.transpose((0, 2, 3, 1))
# 載入測試數據集
else:
self.test_data, self.test_labels = get_data()
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1))
pass
def __getitem__(self, index):
# 從數據集中讀取一個數據並對數據進行
# 預處理返回一個數據對,如(data,label)
if self.train:
img, label = self.train_data[index], self.train_labels[index]
else:
img, label = self.test_data[index], self.test_labels[index]
img = Image.fromarray(img)
# 圖像預處理
if self.transform is not None:
img = self.transform(img)
# 標籤預處理
if self.target_transform is not None:
target = self.target_transform(label)
return img, target
def __len__(self):
# 返回數據集的size
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
if __name__ == '__main__':
# 讀取訓練集和測試集
train_data = Cifar10_Dataset(True, transform, target_transform)
print('size of train_data:{}'.format(train_data.__len__()))
test_data = Cifar10_Dataset(False, transform, target_transform)
print('size of test_data:{}'.format(test_data.__len__()))