pytorch+resnet18實現長尾數據集分類(一)

實驗基於論文: Class-Balanced Loss Based on Effective Number of Samples

Class-balanced-loss代碼地址:https://github.com/vandit15/Class-balanced-loss-pytorch

resnet18代碼參考鏈接:https://blog.csdn.net/sunqiande88/article/details/80100891

製作數據集

論文中通過公式n=niuin = n_iu^iii爲類索引.製作長尾cifar10數據集.以下代碼以不均勻比例100爲例.也可以通過科學上網在谷歌雲鏈接下載.

loadcifar.py

import torch
import torch.utils.data as Data
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
# 從源文件讀取數據
# 返回 train_data[12406,3072]和labels[12406]
#    test_data[10000,3072]和labels[10000]
def get_data(train=False):
    data = None
    labels = None
    new_data = None
    new_labels = []

    if train == True:
        for i in range(1, 6):
            batch = unpickle('data/cifar-10-batches-py/data_batch_' + str(i))
            if i == 1:
                data = batch[b'data']
                labels = batch[b'labels']
            else:
                data = np.concatenate([data, batch[b'data']])
                labels = np.concatenate([labels, batch[b'labels']])

        count = np.zeros((10),dtype=np.int)
        for i in range(len(labels)):
            labels[i] = labels[i].reshape(1,1)
            data[i] = data[i].reshape((1,3072))
            if count[labels[i]] < int(np.floor(5000 * ((1 / 100) ** (1 / 9)) ** (labels[i]))):
                count[labels[i]] += 1
                if i == 0:
                    new_data = data[i]
                else:
                    new_data = np.concatenate([new_data,data[i]])
                new_labels.append(labels[i])
            else:
                continue
        new_labels = np.array(new_labels)
        new_data = new_data.reshape(-1,3072)

    else:
        batch = unpickle('data/cifar-10-batches-py/test_batch')
        new_data = batch[b'data']
        new_labels = batch[b'labels']

    return new_data, new_labels

# 圖像預處理函數,Compose會將多個transform操作包在一起
# 對於彩色圖像,色彩通道不存在平穩特性
transform = transforms.Compose([
    # ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C)
    # 從0到255的值映射到0到1的範圍內,並轉化成Tensor格式。
    transforms.ToTensor(),
    # Normalize函數將圖像數據歸一化到[-1,1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 將標籤轉換爲torch.LongTensor
def target_transform(label):
    label = np.array(label)
    target = torch.from_numpy(label).long()
    return target

'''
自定義數據集讀取框架來載入cifar10數據集
需要繼承data.Dataset
'''
class Cifar10_Dataset(Data.Dataset):
    def __init__(self, train=True, transform=None, target_transform=None):
        # 初始化文件路徑
        self.transform = transform
        self.target_transform = target_transform
        self.train = train

        # 載入訓練數據集
        if self.train:
            self.train_data, self.train_labels = get_data(train)
            num = self.train_data.shape[0]
            self.train_data = self.train_data.reshape((num, 3, 32, 32))
            # 將圖像數據格式轉換爲[height,width,channels]方便預處理
            self.train_data = self.train_data.transpose((0, 2, 3, 1))
            # 載入測試數據集
        else:
            self.test_data, self.test_labels = get_data()
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))
        pass

    def __getitem__(self, index):
        # 從數據集中讀取一個數據並對數據進行
        # 預處理返回一個數據對,如(data,label)
        if self.train:
            img, label = self.train_data[index], self.train_labels[index]
        else:
            img, label = self.test_data[index], self.test_labels[index]
        img = Image.fromarray(img)
        # 圖像預處理
        if self.transform is not None:
            img = self.transform(img)
        # 標籤預處理
        if self.target_transform is not None:
            target = self.target_transform(label)
        return img, target

    def __len__(self):
        # 返回數據集的size
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

if __name__ == '__main__':
    # 讀取訓練集和測試集
    train_data = Cifar10_Dataset(True, transform, target_transform)
    print('size of train_data:{}'.format(train_data.__len__()))
    test_data = Cifar10_Dataset(False, transform, target_transform)
    print('size of test_data:{}'.format(test_data.__len__()))

第二步:定義損失函數
第三步:訓練

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章